diff --git a/.devcontainer/post_start_command.sh b/.devcontainer/post_start_command.sh index e3d5a6d59d746e..56e87614babf90 100755 --- a/.devcontainer/post_start_command.sh +++ b/.devcontainer/post_start_command.sh @@ -1,3 +1,3 @@ #!/bin/bash -poetry install -C api \ No newline at end of file +cd api && poetry install \ No newline at end of file diff --git a/.github/actions/setup-poetry/action.yml b/.github/actions/setup-poetry/action.yml new file mode 100644 index 00000000000000..5feab33d1daec3 --- /dev/null +++ b/.github/actions/setup-poetry/action.yml @@ -0,0 +1,36 @@ +name: Setup Poetry and Python + +inputs: + python-version: + description: Python version to use and the Poetry installed with + required: true + default: '3.10' + poetry-version: + description: Poetry version to set up + required: true + default: '1.8.4' + poetry-lockfile: + description: Path to the Poetry lockfile to restore cache from + required: true + default: '' + +runs: + using: composite + steps: + - name: Set up Python ${{ inputs.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python-version }} + cache: pip + + - name: Install Poetry + shell: bash + run: pip install poetry==${{ inputs.poetry-version }} + + - name: Restore Poetry cache + if: ${{ inputs.poetry-lockfile != '' }} + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python-version }} + cache: poetry + cache-dependency-path: ${{ inputs.poetry-lockfile }} diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 7c632f8a34d56a..76e844aaad8832 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -7,6 +7,7 @@ on: paths: - api/** - docker/** + - .github/workflows/api-tests.yml concurrency: group: api-tests-${{ github.head_ref || github.run_id }} @@ -27,19 +28,13 @@ jobs: - name: Checkout code uses: actions/checkout@v4 - - name: Install Poetry - uses: abatilo/actions-poetry@v3 - - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + - name: Setup Poetry and Python ${{ matrix.python-version }} + uses: ./.github/actions/setup-poetry with: python-version: ${{ matrix.python-version }} - cache: 'poetry' - cache-dependency-path: | - api/pyproject.toml - api/poetry.lock + poetry-lockfile: api/poetry.lock - - name: Poetry check + - name: Check Poetry lockfile run: | poetry check -C api --lock poetry show -C api @@ -47,6 +42,9 @@ jobs: - name: Install dependencies run: poetry install -C api --with dev + - name: Check dependencies in pyproject.toml + run: poetry run -C api bash dev/pytest/pytest_artifacts.sh + - name: Run Unit tests run: poetry run -C api bash dev/pytest/pytest_unit_tests.sh @@ -65,7 +63,7 @@ jobs: run: sh .github/workflows/expose_service_ports.sh - name: Set up Sandbox - uses: hoverkraft-tech/compose-action@v2.0.0 + uses: hoverkraft-tech/compose-action@v2.0.2 with: compose-file: | docker/docker-compose.middleware.yaml @@ -75,21 +73,3 @@ jobs: - name: Run Workflow run: poetry run -C api bash dev/pytest/pytest_workflow.sh - - - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch) - uses: hoverkraft-tech/compose-action@v2.0.0 - with: - compose-file: | - docker/docker-compose.yaml - services: | - weaviate - qdrant - etcd - minio - milvus-standalone - pgvecto-rs - pgvector - chroma - elasticsearch - - name: Test Vector Stores - run: poetry run -C api bash dev/pytest/pytest_vdb.sh diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index 407bd47d9b0f8f..8e5279fb67659b 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -49,7 +49,7 @@ jobs: echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV - name: Login to Docker Hub - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: username: ${{ env.DOCKERHUB_USER }} password: ${{ env.DOCKERHUB_TOKEN }} @@ -114,7 +114,7 @@ jobs: merge-multiple: true - name: Login to Docker Hub - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: username: ${{ env.DOCKERHUB_USER }} password: ${{ env.DOCKERHUB_TOKEN }} @@ -125,7 +125,7 @@ jobs: with: images: ${{ env[matrix.image_name_env] }} tags: | - type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }} + type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') && !contains(github.ref, '-') }} type=ref,event=branch type=sha,enable=true,priority=100,prefix=,suffix=,format=long type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }} diff --git a/.github/workflows/db-migration-test.yml b/.github/workflows/db-migration-test.yml index 67d1558dbcaaff..f4eb0f8e33e515 100644 --- a/.github/workflows/db-migration-test.yml +++ b/.github/workflows/db-migration-test.yml @@ -6,6 +6,7 @@ on: - main paths: - api/migrations/** + - .github/workflows/db-migration-test.yml concurrency: group: db-migration-test-${{ github.ref }} @@ -14,26 +15,15 @@ concurrency: jobs: db-migration-test: runs-on: ubuntu-latest - strategy: - matrix: - python-version: - - "3.10" steps: - name: Checkout code uses: actions/checkout@v4 - - name: Install Poetry - uses: abatilo/actions-poetry@v3 - - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + - name: Setup Poetry and Python + uses: ./.github/actions/setup-poetry with: - python-version: ${{ matrix.python-version }} - cache: 'poetry' - cache-dependency-path: | - api/pyproject.toml - api/poetry.lock + poetry-lockfile: api/poetry.lock - name: Install dependencies run: poetry install -C api @@ -44,7 +34,7 @@ jobs: cp middleware.env.example middleware.env - name: Set up Middlewares - uses: hoverkraft-tech/compose-action@v2.0.0 + uses: hoverkraft-tech/compose-action@v2.0.2 with: compose-file: | docker/docker-compose.middleware.yaml diff --git a/.github/workflows/expose_service_ports.sh b/.github/workflows/expose_service_ports.sh index ae3e0ee69d8cfb..bc65c19a913fcf 100755 --- a/.github/workflows/expose_service_ports.sh +++ b/.github/workflows/expose_service_ports.sh @@ -7,5 +7,7 @@ yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/dock yq eval '.services.pgvector.ports += ["5433:5432"]' -i docker/docker-compose.yaml yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compose.yaml yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml +yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml +yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml -echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch" \ No newline at end of file +echo "Ports exposed for sandbox, weaviate, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase" diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index d681dc66276dd1..282afefe74243a 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -20,35 +20,30 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v44 + uses: tj-actions/changed-files@v45 with: - files: api/** - - - name: Install Poetry - uses: abatilo/actions-poetry@v3 + files: | + api/** + .github/workflows/style.yml - - name: Set up Python - uses: actions/setup-python@v5 + - name: Setup Poetry and Python if: steps.changed-files.outputs.any_changed == 'true' - with: - python-version: '3.10' + uses: ./.github/actions/setup-poetry - - name: Python dependencies + - name: Install dependencies if: steps.changed-files.outputs.any_changed == 'true' run: poetry install -C api --only lint - name: Ruff check if: steps.changed-files.outputs.any_changed == 'true' - run: poetry run -C api ruff check ./api + run: | + poetry run -C api ruff check ./api + poetry run -C api ruff format --check ./api - name: Dotenv check if: steps.changed-files.outputs.any_changed == 'true' run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example - - name: Ruff formatter check - if: steps.changed-files.outputs.any_changed == 'true' - run: poetry run -C api ruff format --check ./api - - name: Lint hints if: failure() run: echo "Please run 'dev/reformat' to fix the fixable linting errors." @@ -66,7 +61,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v44 + uses: tj-actions/changed-files@v45 with: files: web/** @@ -97,7 +92,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v44 + uses: tj-actions/changed-files@v45 with: files: | **.sh @@ -107,7 +102,7 @@ jobs: dev/** - name: Super-linter - uses: super-linter/super-linter/slim@v6 + uses: super-linter/super-linter/slim@v7 if: steps.changed-files.outputs.any_changed == 'true' env: BASH_SEVERITY: warning diff --git a/.github/workflows/translate-i18n-base-on-english.yml b/.github/workflows/translate-i18n-base-on-english.yml new file mode 100644 index 00000000000000..3f51b3b2c79946 --- /dev/null +++ b/.github/workflows/translate-i18n-base-on-english.yml @@ -0,0 +1,54 @@ +name: Check i18n Files and Create PR + +on: + pull_request: + types: [closed] + branches: [main] + +jobs: + check-and-update: + if: github.event.pull_request.merged == true + runs-on: ubuntu-latest + defaults: + run: + working-directory: web + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 2 # last 2 commits + + - name: Check for file changes in i18n/en-US + id: check_files + run: | + recent_commit_sha=$(git rev-parse HEAD) + second_recent_commit_sha=$(git rev-parse HEAD~1) + changed_files=$(git diff --name-only $recent_commit_sha $second_recent_commit_sha -- 'i18n/en-US/*.ts') + echo "Changed files: $changed_files" + if [ -n "$changed_files" ]; then + echo "FILES_CHANGED=true" >> $GITHUB_ENV + else + echo "FILES_CHANGED=false" >> $GITHUB_ENV + fi + + - name: Set up Node.js + if: env.FILES_CHANGED == 'true' + uses: actions/setup-node@v2 + with: + node-version: 'lts/*' + + - name: Install dependencies + if: env.FILES_CHANGED == 'true' + run: yarn install --frozen-lockfile + + - name: Run npm script + if: env.FILES_CHANGED == 'true' + run: npm run auto-gen-i18n + + - name: Create Pull Request + if: env.FILES_CHANGED == 'true' + uses: peter-evans/create-pull-request@v6 + with: + commit-message: Update i18n files based on en-US changes + title: 'chore: translate i18n files' + body: This PR was automatically created to update i18n files based on changes in en-US locale. + branch: chore/automated-i18n-updates diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml new file mode 100644 index 00000000000000..caddd23bab70d3 --- /dev/null +++ b/.github/workflows/vdb-tests.yml @@ -0,0 +1,71 @@ +name: Run VDB Tests + +on: + pull_request: + branches: + - main + paths: + - api/core/rag/datasource/** + - docker/** + - .github/workflows/vdb-tests.yml + +concurrency: + group: vdb-tests-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + test: + name: VDB Tests + runs-on: ubuntu-latest + strategy: + matrix: + python-version: + - "3.10" + - "3.11" + - "3.12" + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Poetry and Python ${{ matrix.python-version }} + uses: ./.github/actions/setup-poetry + with: + python-version: ${{ matrix.python-version }} + poetry-lockfile: api/poetry.lock + + - name: Check Poetry lockfile + run: | + poetry check -C api --lock + poetry show -C api + + - name: Install dependencies + run: poetry install -C api --with dev + + - name: Set up dotenvs + run: | + cp docker/.env.example docker/.env + cp docker/middleware.env.example docker/middleware.env + + - name: Expose Service Ports + run: sh .github/workflows/expose_service_ports.sh + + - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase) + uses: hoverkraft-tech/compose-action@v2.0.2 + with: + compose-file: | + docker/docker-compose.yaml + services: | + weaviate + qdrant + couchbase-server + etcd + minio + milvus-standalone + pgvecto-rs + pgvector + chroma + elasticsearch + + - name: Test Vector Stores + run: poetry run -C api bash dev/pytest/pytest_vdb.sh diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml new file mode 100644 index 00000000000000..5aee64b8e6da02 --- /dev/null +++ b/.github/workflows/web-tests.yml @@ -0,0 +1,46 @@ +name: Web Tests + +on: + pull_request: + branches: + - main + paths: + - web/** + +concurrency: + group: web-tests-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + test: + name: Web Tests + runs-on: ubuntu-latest + defaults: + run: + working-directory: ./web + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Check changed files + id: changed-files + uses: tj-actions/changed-files@v45 + with: + files: web/** + + - name: Setup Node.js + uses: actions/setup-node@v4 + if: steps.changed-files.outputs.any_changed == 'true' + with: + node-version: 20 + cache: yarn + cache-dependency-path: ./web/package.json + + - name: Install dependencies + if: steps.changed-files.outputs.any_changed == 'true' + run: yarn install --frozen-lockfile + + - name: Run tests + if: steps.changed-files.outputs.any_changed == 'true' + run: yarn test diff --git a/.gitignore b/.gitignore index c52b9d8bbf5bca..1423bfee56e922 100644 --- a/.gitignore +++ b/.gitignore @@ -153,6 +153,9 @@ docker-legacy/volumes/etcd/* docker-legacy/volumes/minio/* docker-legacy/volumes/milvus/* docker-legacy/volumes/chroma/* +docker-legacy/volumes/opensearch/data/* +docker-legacy/volumes/pgvectors/data/* +docker-legacy/volumes/pgvector/data/* docker/volumes/app/storage/* docker/volumes/certbot/* @@ -164,8 +167,19 @@ docker/volumes/etcd/* docker/volumes/minio/* docker/volumes/milvus/* docker/volumes/chroma/* +docker/volumes/opensearch/data/* +docker/volumes/myscale/data/* +docker/volumes/myscale/log/* +docker/volumes/unstructured/* +docker/volumes/pgvector/data/* +docker/volumes/pgvecto_rs/data/* +docker/volumes/couchbase/* +docker/volumes/oceanbase/* +!docker/volumes/oceanbase/init.d docker/nginx/conf.d/default.conf +docker/nginx/ssl/* +!docker/nginx/ssl/.gitkeep docker/middleware.env sdks/python-client/build @@ -178,3 +192,4 @@ pyrightconfig.json api/.vscode .idea/ +.vscode diff --git a/.vscode/launch.json b/.vscode/launch.json deleted file mode 100644 index e4eb6aef932faf..00000000000000 --- a/.vscode/launch.json +++ /dev/null @@ -1,54 +0,0 @@ -{ - "version": "0.2.0", - "configurations": [ - { - "name": "Python: Flask", - "type": "debugpy", - "request": "launch", - "python": "${workspaceFolder}/api/.venv/bin/python", - "cwd": "${workspaceFolder}/api", - "envFile": ".env", - "module": "flask", - "justMyCode": true, - "jinja": true, - "env": { - "FLASK_APP": "app.py", - "GEVENT_SUPPORT": "True" - }, - "args": [ - "run", - "--host=0.0.0.0", - "--port=5001", - ] - }, - { - "name": "Python: Celery", - "type": "debugpy", - "request": "launch", - "python": "${workspaceFolder}/api/.venv/bin/python", - "cwd": "${workspaceFolder}/api", - "module": "celery", - "justMyCode": true, - "envFile": ".env", - "console": "integratedTerminal", - "env": { - "FLASK_APP": "app.py", - "FLASK_DEBUG": "1", - "GEVENT_SUPPORT": "True" - }, - "args": [ - "-A", - "app.celery", - "worker", - "-P", - "gevent", - "-c", - "1", - "--loglevel", - "info", - "-Q", - "dataset,generation,mail,ops_trace,app_deletion" - ] - }, - ] -} \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f810584f24115c..da2928d18926b3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -8,7 +8,7 @@ In terms of licensing, please take a minute to read our short [License and Contr ## Before you jump in -[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types: +[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types: ### Feature requests: @@ -81,7 +81,7 @@ Dify requires the following dependencies to build, make sure they're installed o Dify is composed of a backend and a frontend. Navigate to the backend directory by `cd api/`, then follow the [Backend README](api/README.md) to install it. In a separate terminal, navigate to the frontend directory by `cd web/`, then follow the [Frontend README](web/README.md) to install. -Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/self-host-faq) for a list of common issues and steps to troubleshoot. +Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/install-faq) for a list of common issues and steps to troubleshoot. ### 5. Visit dify in your browser diff --git a/CONTRIBUTING_CN.md b/CONTRIBUTING_CN.md index 303c2513f53b9b..310c55090ae82a 100644 --- a/CONTRIBUTING_CN.md +++ b/CONTRIBUTING_CN.md @@ -8,7 +8,7 @@ ## 在开始之前 -[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类: +[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:open)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类: ### 功能请求: @@ -36,7 +36,7 @@ | 被团队成员标记为高优先级的功能 | 高优先级 | | 在 [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks) 内反馈的常见功能请求 | 中等优先级 | | 非核心功能和小幅改进 | 低优先级 | - | 有价值当不紧急 | 未来功能 | + | 有价值但不紧急 | 未来功能 | ### 其他任何事情(例如 bug 报告、性能优化、拼写错误更正): * 立即开始编码。 @@ -138,7 +138,7 @@ Dify 的后端使用 Python 编写,使用 [Flask](https://flask.palletsproject ├── models // 描述数据模型和 API 响应的形状 ├── public // 如 favicon 等元资源 ├── service // 定义 API 操作的形状 -├── test +├── test ├── types // 函数参数和返回值的描述 └── utils // 共享的实用函数 ``` diff --git a/CONTRIBUTING_JA.md b/CONTRIBUTING_JA.md index 6d5bfb205c31dc..a68bdeddbc830f 100644 --- a/CONTRIBUTING_JA.md +++ b/CONTRIBUTING_JA.md @@ -10,7 +10,7 @@ Dify にコントリビュートしたいとお考えなのですね。それは ## 飛び込む前に -[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。 +[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:open) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。 ### 機能リクエスト diff --git a/CONTRIBUTING_VI.md b/CONTRIBUTING_VI.md new file mode 100644 index 00000000000000..a77239ff38420f --- /dev/null +++ b/CONTRIBUTING_VI.md @@ -0,0 +1,156 @@ +Thật tuyệt vời khi bạn muốn đóng góp cho Dify! Chúng tôi rất mong chờ được thấy những gì bạn sẽ làm. Là một startup với nguồn nhân lực và tài chính hạn chế, chúng tôi có tham vọng lớn là thiết kế quy trình trực quan nhất để xây dựng và quản lý các ứng dụng LLM. Mọi sự giúp đỡ từ cộng đồng đều rất quý giá đối với chúng tôi. + +Chúng tôi cần linh hoạt và làm việc nhanh chóng, nhưng đồng thời cũng muốn đảm bảo các cộng tác viên như bạn có trải nghiệm đóng góp thuận lợi nhất có thể. Chúng tôi đã tạo ra hướng dẫn đóng góp này nhằm giúp bạn làm quen với codebase và cách chúng tôi làm việc với các cộng tác viên, để bạn có thể nhanh chóng bắt tay vào phần thú vị. + +Hướng dẫn này, cũng như bản thân Dify, đang trong quá trình cải tiến liên tục. Chúng tôi rất cảm kích sự thông cảm của bạn nếu đôi khi nó không theo kịp dự án thực tế, và chúng tôi luôn hoan nghênh mọi phản hồi để cải thiện. + +Về vấn đề cấp phép, xin vui lòng dành chút thời gian đọc qua [Thỏa thuận Cấp phép và Đóng góp](./LICENSE) ngắn gọn của chúng tôi. Cộng đồng cũng tuân thủ [quy tắc ứng xử](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md). + +## Trước khi bắt đầu + +[Tìm kiếm](https://github.com/langgenius/dify/issues?q=is:issue+is:open) một vấn đề hiện có, hoặc [tạo mới](https://github.com/langgenius/dify/issues/new/choose) một vấn đề. Chúng tôi phân loại các vấn đề thành 2 loại: + +### Yêu cầu tính năng: + +* Nếu bạn đang tạo một yêu cầu tính năng mới, chúng tôi muốn bạn giải thích tính năng đề xuất sẽ đạt được điều gì và cung cấp càng nhiều thông tin chi tiết càng tốt. [@perzeusss](https://github.com/perzeuss) đã tạo một [Trợ lý Yêu cầu Tính năng](https://udify.app/chat/MK2kVSnw1gakVwMX) rất hữu ích để giúp bạn soạn thảo nhu cầu của mình. Hãy thử dùng nó nhé. + +* Nếu bạn muốn chọn một vấn đề từ danh sách hiện có, chỉ cần để lại bình luận dưới vấn đề đó nói rằng bạn sẽ làm. + + Một thành viên trong nhóm làm việc trong lĩnh vực liên quan sẽ được thông báo. Nếu mọi thứ ổn, họ sẽ cho phép bạn bắt đầu code. Chúng tôi yêu cầu bạn chờ đợi cho đến lúc đó trước khi bắt tay vào làm tính năng, để không lãng phí công sức của bạn nếu chúng tôi đề xuất thay đổi. + + Tùy thuộc vào lĩnh vực mà tính năng đề xuất thuộc về, bạn có thể nói chuyện với các thành viên khác nhau trong nhóm. Dưới đây là danh sách các lĩnh vực mà các thành viên trong nhóm chúng tôi đang làm việc hiện tại: + + | Thành viên | Phạm vi | + | ------------------------------------------------------------ | ---------------------------------------------------- | + | [@yeuoly](https://github.com/Yeuoly) | Thiết kế kiến trúc Agents | + | [@jyong](https://github.com/JohnJyong) | Thiết kế quy trình RAG | + | [@GarfieldDai](https://github.com/GarfieldDai) | Xây dựng quy trình làm việc | + | [@iamjoel](https://github.com/iamjoel) & [@zxhlyh](https://github.com/zxhlyh) | Làm cho giao diện người dùng dễ sử dụng | + | [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | Trải nghiệm nhà phát triển, đầu mối liên hệ cho mọi vấn đề | + | [@takatost](https://github.com/takatost) | Định hướng và kiến trúc tổng thể sản phẩm | + + Cách chúng tôi ưu tiên: + + | Loại tính năng | Mức độ ưu tiên | + | ------------------------------------------------------------ | -------------- | + | Tính năng ưu tiên cao được gắn nhãn bởi thành viên trong nhóm | Ưu tiên cao | + | Yêu cầu tính năng phổ biến từ [bảng phản hồi cộng đồng](https://github.com/langgenius/dify/discussions/categories/feedbacks) của chúng tôi | Ưu tiên trung bình | + | Tính năng không quan trọng và cải tiến nhỏ | Ưu tiên thấp | + | Có giá trị nhưng không cấp bách | Tính năng tương lai | + +### Những vấn đề khác (ví dụ: báo cáo lỗi, tối ưu hiệu suất, sửa lỗi chính tả): + +* Bắt đầu code ngay lập tức. + + Cách chúng tôi ưu tiên: + + | Loại vấn đề | Mức độ ưu tiên | + | ------------------------------------------------------------ | -------------- | + | Lỗi trong các chức năng chính (không thể đăng nhập, ứng dụng không hoạt động, lỗ hổng bảo mật) | Nghiêm trọng | + | Lỗi không quan trọng, cải thiện hiệu suất | Ưu tiên trung bình | + | Sửa lỗi nhỏ (lỗi chính tả, giao diện người dùng gây nhầm lẫn nhưng vẫn hoạt động) | Ưu tiên thấp | + + +## Cài đặt + +Dưới đây là các bước để thiết lập Dify cho việc phát triển: + +### 1. Fork repository này + +### 2. Clone repository + + Clone repository đã fork từ terminal của bạn: + +``` +git clone git@github.com:/dify.git +``` + +### 3. Kiểm tra các phụ thuộc + +Dify yêu cầu các phụ thuộc sau để build, hãy đảm bảo chúng đã được cài đặt trên hệ thống của bạn: + +- [Docker](https://www.docker.com/) +- [Docker Compose](https://docs.docker.com/compose/install/) +- [Node.js v18.x (LTS)](http://nodejs.org) +- [npm](https://www.npmjs.com/) phiên bản 8.x.x hoặc [Yarn](https://yarnpkg.com/) +- [Python](https://www.python.org/) phiên bản 3.10.x + +### 4. Cài đặt + +Dify bao gồm một backend và một frontend. Đi đến thư mục backend bằng lệnh `cd api/`, sau đó làm theo hướng dẫn trong [README của Backend](api/README.md) để cài đặt. Trong một terminal khác, đi đến thư mục frontend bằng lệnh `cd web/`, sau đó làm theo hướng dẫn trong [README của Frontend](web/README.md) để cài đặt. + +Kiểm tra [FAQ về cài đặt](https://docs.dify.ai/learn-more/faq/install-faq) để xem danh sách các vấn đề thường gặp và các bước khắc phục. + +### 5. Truy cập Dify trong trình duyệt của bạn + +Để xác nhận cài đặt của bạn, hãy truy cập [http://localhost:3000](http://localhost:3000) (địa chỉ mặc định, hoặc URL và cổng bạn đã cấu hình) trong trình duyệt. Bạn sẽ thấy Dify đang chạy. + +## Phát triển + +Nếu bạn đang thêm một nhà cung cấp mô hình, [hướng dẫn này](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/README.md) dành cho bạn. + +Nếu bạn đang thêm một nhà cung cấp công cụ cho Agent hoặc Workflow, [hướng dẫn này](./api/core/tools/README.md) dành cho bạn. + +Để giúp bạn nhanh chóng định hướng phần đóng góp của mình, dưới đây là một bản phác thảo ngắn gọn về cấu trúc backend & frontend của Dify: + +### Backend + +Backend của Dify được viết bằng Python sử dụng [Flask](https://flask.palletsprojects.com/en/3.0.x/). Nó sử dụng [SQLAlchemy](https://www.sqlalchemy.org/) cho ORM và [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) cho hàng đợi tác vụ. Logic xác thực được thực hiện thông qua Flask-login. + +``` +[api/] +├── constants // Các cài đặt hằng số được sử dụng trong toàn bộ codebase. +├── controllers // Định nghĩa các route API và logic xử lý yêu cầu. +├── core // Điều phối ứng dụng cốt lõi, tích hợp mô hình và công cụ. +├── docker // Cấu hình liên quan đến Docker & containerization. +├── events // Xử lý và xử lý sự kiện +├── extensions // Mở rộng với các framework/nền tảng bên thứ 3. +├── fields // Định nghĩa trường cho serialization/marshalling. +├── libs // Thư viện và tiện ích có thể tái sử dụng. +├── migrations // Script cho việc di chuyển cơ sở dữ liệu. +├── models // Mô hình cơ sở dữ liệu & định nghĩa schema. +├── services // Xác định logic nghiệp vụ. +├── storage // Lưu trữ khóa riêng tư. +├── tasks // Xử lý các tác vụ bất đồng bộ và công việc nền. +└── tests +``` + +### Frontend + +Website được khởi tạo trên boilerplate [Next.js](https://nextjs.org/) bằng Typescript và sử dụng [Tailwind CSS](https://tailwindcss.com/) cho styling. [React-i18next](https://react.i18next.com/) được sử dụng cho việc quốc tế hóa. + +``` +[web/] +├── app // layouts, pages và components +│ ├── (commonLayout) // layout chung được sử dụng trong toàn bộ ứng dụng +│ ├── (shareLayout) // layouts được chia sẻ cụ thể cho các phiên dựa trên token +│ ├── activate // trang kích hoạt +│ ├── components // được chia sẻ bởi các trang và layouts +│ ├── install // trang cài đặt +│ ├── signin // trang đăng nhập +│ └── styles // styles được chia sẻ toàn cục +├── assets // Tài nguyên tĩnh +├── bin // scripts chạy ở bước build +├── config // cài đặt và tùy chọn có thể điều chỉnh +├── context // contexts được chia sẻ bởi các phần khác nhau của ứng dụng +├── dictionaries // File dịch cho từng ngôn ngữ +├── docker // cấu hình container +├── hooks // Hooks có thể tái sử dụng +├── i18n // Cấu hình quốc tế hóa +├── models // mô tả các mô hình dữ liệu & hình dạng của phản hồi API +├── public // tài nguyên meta như favicon +├── service // xác định hình dạng của các hành động API +├── test +├── types // mô tả các tham số hàm và giá trị trả về +└── utils // Các hàm tiện ích được chia sẻ +``` + +## Gửi PR của bạn + +Cuối cùng, đã đến lúc mở một pull request (PR) đến repository của chúng tôi. Đối với các tính năng lớn, chúng tôi sẽ merge chúng vào nhánh `deploy/dev` để kiểm tra trước khi đưa vào nhánh `main`. Nếu bạn gặp vấn đề như xung đột merge hoặc không biết cách mở pull request, hãy xem [hướng dẫn về pull request của GitHub](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests). + +Và thế là xong! Khi PR của bạn được merge, bạn sẽ được giới thiệu là một người đóng góp trong [README](https://github.com/langgenius/dify/blob/main/README.md) của chúng tôi. + +## Nhận trợ giúp + +Nếu bạn gặp khó khăn hoặc có câu hỏi cấp bách trong quá trình đóng góp, hãy đặt câu hỏi của bạn trong vấn đề GitHub liên quan, hoặc tham gia [Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi để trò chuyện nhanh chóng. \ No newline at end of file diff --git a/LICENSE b/LICENSE index 071ef42bdada67..d7b8373839e1e6 100644 --- a/LICENSE +++ b/LICENSE @@ -4,10 +4,11 @@ Dify is licensed under the Apache License 2.0, with the following additional con 1. Dify may be utilized commercially, including as a backend service for other applications or as an application development platform for enterprises. Should the conditions below be met, a commercial license must be obtained from the producer: -a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment. +a. Multi-tenant service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment. - Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations. - -b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components. + +b. LOGO and copyright information: In the process of using Dify's frontend, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend. + - Frontend Definition: For the purposes of this license, the "frontend" of Dify includes all components located in the `web/` directory when running Dify from the raw source code, or the "web" image when running Dify with Docker. Please contact business@dify.ai by email to inquire about licensing matters. diff --git a/README.md b/README.md index 1c49c415fe09a9..d42b3b13a68204 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ ![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +

+ 📌 Introducing Dify Workflow File Upload: Recreate Google NotebookLM Podcast +

+

Dify Cloud · Self-hosting · @@ -17,7 +21,7 @@ alt="chat on Discord"> follow on Twitter + alt="follow on X(Twitter)"> Docker Pulls @@ -42,9 +46,33 @@

-Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. Here's a list of the core features: -

+Dify is an open-source LLM app development platform. Its intuitive interface combines agentic AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. + +## Quick start +> Before installing Dify, make sure your machine meets the following minimum system requirements: +> +>- CPU >= 2 Core +>- RAM >= 4 GiB +
+ +The easiest way to start the Dify server is through [docker compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine: + +```bash +cd dify +cd docker +cp .env.example .env +docker compose up -d +``` + +After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process. + +#### Seeking help +Please refer to our [FAQ](https://docs.dify.ai/getting-started/install-self-hosted/faqs) if you encounter problems setting up Dify. Reach out to [the community and us](#community--contact) if you are still having issues. + +> If you'd like to contribute to Dify or do additional development, refer to our [guide to deploying from source code](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code) + +## Key features **1. Workflow**: Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond. @@ -75,73 +103,6 @@ Dify is an open-source LLM app development platform. Its intuitive interface com All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic. -## Feature comparison - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FeatureDify.AILangChainFlowiseOpenAI Assistants API
Programming ApproachAPI + App-orientedPython CodeApp-orientedAPI-oriented
Supported LLMsRich VarietyRich VarietyRich VarietyOpenAI-only
RAG Engine
Agent
Workflow
Observability
Enterprise Features (SSO/Access control)
Local Deployment
- ## Using Dify - **Cloud
** @@ -163,28 +124,7 @@ Star Dify on GitHub and be instantly notified of new releases. ![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) - -## Quick start -> Before installing Dify, make sure your machine meets the following minimum system requirements: -> ->- CPU >= 2 Core ->- RAM >= 4GB - -
- -The easiest way to start the Dify server is to run our [docker-compose.yml](docker/docker-compose.yaml) file. Before running the installation command, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine: - -```bash -cd docker -cp .env.example .env -docker compose up -d -``` - -After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process. - -> If you'd like to contribute to Dify or do additional development, refer to our [guide to deploying from source code](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code) - -## Next steps +## Advanced Setup If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). @@ -196,10 +136,14 @@ If you'd like to configure a highly-available setup, there are community-contrib #### Using Terraform for Deployment +Deploy Dify to Cloud Platform with a single click using [terraform](https://www.terraform.io/) + ##### Azure Global -Deploy Dify to Azure with a single click using [terraform](https://www.terraform.io/). - [Azure Terraform by @nikawang](https://github.com/nikawang/dify-azure-terraform) +##### Google Cloud +- [Google Cloud Terraform by @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) + ## Contributing For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). @@ -208,18 +152,18 @@ At the same time, please consider supporting Dify by sharing it on social media > We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c). -**Contributors** - -
- - - ## Community & contact * [Github Discussion](https://github.com/langgenius/dify/discussions). Best for: sharing feedback and asking questions. * [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). * [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community. -* [Twitter](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community. +* [X(Twitter)](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community. + +**Contributors** + + + + ## Star history @@ -233,3 +177,4 @@ To protect your privacy, please avoid posting security issues on GitHub. Instead ## License This repository is available under the [Dify Open Source License](LICENSE), which is essentially Apache 2.0 with a few additional restrictions. + diff --git a/README_AR.md b/README_AR.md index 10d572cc49a83b..e46ba7373840c9 100644 --- a/README_AR.md +++ b/README_AR.md @@ -17,7 +17,7 @@ alt="chat on Discord"> follow on Twitter + alt="follow on X(Twitter)"> Docker Pulls @@ -179,10 +179,13 @@ docker compose up -d #### استخدام Terraform للتوزيع +انشر Dify إلى منصة السحابة بنقرة واحدة باستخدام [terraform](https://www.terraform.io/) + ##### Azure Global -استخدم [terraform](https://www.terraform.io/) لنشر Dify على Azure بنقرة واحدة. - [Azure Terraform بواسطة @nikawang](https://github.com/nikawang/dify-azure-terraform) +##### Google Cloud +- [Google Cloud Terraform بواسطة @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) ## المساهمة diff --git a/README_CN.md b/README_CN.md index 32551fcc313932..070951699a85ba 100644 --- a/README_CN.md +++ b/README_CN.md @@ -17,7 +17,7 @@ alt="chat on Discord"> follow on Twitter + alt="follow on X(Twitter)"> Docker Pulls @@ -154,7 +154,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI 我们提供[ Dify 云服务](https://dify.ai),任何人都可以零设置尝试。它提供了自部署版本的所有功能,并在沙盒计划中包含 200 次免费的 GPT-4 调用。 - **自托管 Dify 社区版
** -使用这个[入门指南](#quick-start)快速在您的环境中运行 Dify。 +使用这个[入门指南](#快速启动)快速在您的环境中运行 Dify。 使用我们的[文档](https://docs.dify.ai)进行进一步的参考和更深入的说明。 - **面向企业/组织的 Dify
** @@ -174,7 +174,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI 在安装 Dify 之前,请确保您的机器满足以下最低系统要求: - CPU >= 2 Core -- RAM >= 4GB +- RAM >= 4 GiB ### 快速启动 @@ -202,10 +202,14 @@ docker compose up -d #### 使用 Terraform 部署 +使用 [terraform](https://www.terraform.io/) 一键将 Dify 部署到云平台 + ##### Azure Global -使用 [terraform](https://www.terraform.io/) 一键部署 Dify 到 Azure。 - [Azure Terraform by @nikawang](https://github.com/nikawang/dify-azure-terraform) +##### Google Cloud +- [Google Cloud Terraform by @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) + ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) @@ -232,7 +236,7 @@ docker compose up -d - [GitHub Issues](https://github.com/langgenius/dify/issues)。👉:使用 Dify.AI 时遇到的错误和问题,请参阅[贡献指南](CONTRIBUTING.md)。 - [电子邮件支持](mailto:hello@dify.ai?subject=[GitHub]Questions%20About%20Dify)。👉:关于使用 Dify.AI 的问题。 - [Discord](https://discord.gg/FngNHpbcY7)。👉:分享您的应用程序并与社区交流。 -- [Twitter](https://twitter.com/dify_ai)。👉:分享您的应用程序并与社区交流。 +- [X(Twitter)](https://twitter.com/dify_ai)。👉:分享您的应用程序并与社区交流。 - [商业许可](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)。👉:有关商业用途许可 Dify.AI 的商业咨询。 - [微信]() 👉:扫描下方二维码,添加微信好友,备注 Dify,我们将邀请您加入 Dify 社区。 wechat diff --git a/README_ES.md b/README_ES.md index 2ae044b32883b9..7da5ac7b61fb19 100644 --- a/README_ES.md +++ b/README_ES.md @@ -17,7 +17,7 @@ alt="chat en Discord">
seguir en Twitter + alt="seguir en X(Twitter)"> Descargas de Docker @@ -204,10 +204,13 @@ Si desea configurar una configuración de alta disponibilidad, la comunidad prop #### Uso de Terraform para el despliegue +Despliega Dify en una plataforma en la nube con un solo clic utilizando [terraform](https://www.terraform.io/) + ##### Azure Global -Utiliza [terraform](https://www.terraform.io/) para desplegar Dify en Azure con un solo clic. - [Azure Terraform por @nikawang](https://github.com/nikawang/dify-azure-terraform) +##### Google Cloud +- [Google Cloud Terraform por @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) ## Contribuir @@ -228,7 +231,7 @@ Al mismo tiempo, considera apoyar a Dify compartiéndolo en redes sociales y en * [Discusión en GitHub](https://github.com/langgenius/dify/discussions). Lo mejor para: compartir comentarios y hacer preguntas. * [Reporte de problemas en GitHub](https://github.com/langgenius/dify/issues). Lo mejor para: errores que encuentres usando Dify.AI y propuestas de características. Consulta nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). * [Discord](https://discord.gg/FngNHpbcY7). Lo mejor para: compartir tus aplicaciones y pasar el rato con la comunidad. -* [Twitter](https://twitter.com/dify_ai). Lo mejor para: compartir tus aplicaciones y pasar el rato con la comunidad. +* [X(Twitter)](https://twitter.com/dify_ai). Lo mejor para: compartir tus aplicaciones y pasar el rato con la comunidad. ## Historial de Estrellas diff --git a/README_FR.md b/README_FR.md index 681d596749c9e7..15f6f2650f8f2b 100644 --- a/README_FR.md +++ b/README_FR.md @@ -17,7 +17,7 @@ alt="chat sur Discord"> suivre sur Twitter + alt="suivre sur X(Twitter)"> Tirages Docker @@ -202,10 +202,13 @@ Si vous souhaitez configurer une configuration haute disponibilité, la communau #### Utilisation de Terraform pour le déploiement +Déployez Dify sur une plateforme cloud en un clic en utilisant [terraform](https://www.terraform.io/) + ##### Azure Global -Utilisez [terraform](https://www.terraform.io/) pour déployer Dify sur Azure en un clic. - [Azure Terraform par @nikawang](https://github.com/nikawang/dify-azure-terraform) +##### Google Cloud +- [Google Cloud Terraform par @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) ## Contribuer @@ -226,7 +229,7 @@ Dans le même temps, veuillez envisager de soutenir Dify en le partageant sur le * [Discussion GitHub](https://github.com/langgenius/dify/discussions). Meilleur pour: partager des commentaires et poser des questions. * [Problèmes GitHub](https://github.com/langgenius/dify/issues). Meilleur pour: les bogues que vous rencontrez en utilisant Dify.AI et les propositions de fonctionnalités. Consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). * [Discord](https://discord.gg/FngNHpbcY7). Meilleur pour: partager vos applications et passer du temps avec la communauté. -* [Twitter](https://twitter.com/dify_ai). Meilleur pour: partager vos applications et passer du temps avec la communauté. +* [X(Twitter)](https://twitter.com/dify_ai). Meilleur pour: partager vos applications et passer du temps avec la communauté. ## Historique des étoiles diff --git a/README_JA.md b/README_JA.md index e6a8621e7baae5..a2e6b173f5bfcd 100644 --- a/README_JA.md +++ b/README_JA.md @@ -17,7 +17,7 @@ alt="Discordでチャット"> Twitterでフォロー + alt="X(Twitter)でフォロー"> Docker Pulls @@ -68,7 +68,7 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ プロンプトの作成、モデルパフォーマンスの比較が行え、チャットベースのアプリに音声合成などの機能も追加できます。 **4. RAGパイプライン**: - ドキュメントの取り込みから検索までをカバーする広範なRAG機能ができます。ほかにもPDF、PPT、その他の一般的なドキュメントフォーマットからのテキスト抽出のサーポイントも提供します。 + ドキュメントの取り込みから検索までをカバーする広範なRAG機能ができます。ほかにもPDF、PPT、その他の一般的なドキュメントフォーマットからのテキスト抽出のサポートも提供します。 **5. エージェント機能**: LLM Function CallingやReActに基づくエージェントの定義が可能で、AIエージェント用のプリビルトまたはカスタムツールを追加できます。Difyには、Google検索、DALL·E、Stable Diffusion、WolframAlphaなどのAIエージェント用の50以上の組み込みツールが提供します。 @@ -201,10 +201,13 @@ docker compose up -d #### Terraformを使用したデプロイ +[terraform](https://www.terraform.io/) を使用して、ワンクリックでDifyをクラウドプラットフォームにデプロイします + ##### Azure Global -[terraform](https://www.terraform.io/) を使用して、AzureにDifyをワンクリックでデプロイします。 -- [nikawangのAzure Terraform](https://github.com/nikawang/dify-azure-terraform) +- [@nikawangによるAzure Terraform](https://github.com/nikawang/dify-azure-terraform) +##### Google Cloud +- [@sotazumによるGoogle Cloud Terraform](https://github.com/DeNA/dify-google-cloud-terraform) ## 貢献 @@ -225,7 +228,7 @@ docker compose up -d * [Github Discussion](https://github.com/langgenius/dify/discussions). 主に: フィードバックの共有や質問。 * [GitHub Issues](https://github.com/langgenius/dify/issues). 主に: Dify.AIを使用する際に発生するエラーや問題については、[貢献ガイド](CONTRIBUTING_JA.md)を参照してください * [Discord](https://discord.gg/FngNHpbcY7). 主に: アプリケーションの共有やコミュニティとの交流。 -* [Twitter](https://twitter.com/dify_ai). 主に: アプリケーションの共有やコミュニティとの交流。 +* [X(Twitter)](https://twitter.com/dify_ai). 主に: アプリケーションの共有やコミュニティとの交流。 diff --git a/README_KL.md b/README_KL.md index 04620d42bbec8a..8f2affdce5ae59 100644 --- a/README_KL.md +++ b/README_KL.md @@ -17,7 +17,7 @@ alt="chat on Discord"> follow on Twitter + alt="follow on X(Twitter)"> Docker Pulls @@ -202,10 +202,13 @@ If you'd like to configure a highly-available setup, there are community-contrib #### Terraform atorlugu pilersitsineq +wa'logh nIqHom neH ghun deployment toy'wI' [terraform](https://www.terraform.io/) lo'laH. + ##### Azure Global -Atoruk [terraform](https://www.terraform.io/) Dify-mik Azure-mut ataatsikkut ikkussuilluarlugu. -- [Azure Terraform atorlugu @nikawang](https://github.com/nikawang/dify-azure-terraform) +- [Azure Terraform mung @nikawang](https://github.com/nikawang/dify-azure-terraform) +##### Google Cloud +- [Google Cloud Terraform qachlot @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) ## Contributing @@ -228,7 +231,7 @@ At the same time, please consider supporting Dify by sharing it on social media ). Best for: sharing feedback and asking questions. * [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). * [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community. -* [Twitter](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community. +* [X(Twitter)](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community. ## Star History diff --git a/README_KR.md b/README_KR.md index a5f3bc68d04d74..6c3a9ed7f6ea06 100644 --- a/README_KR.md +++ b/README_KR.md @@ -17,7 +17,7 @@ alt="chat on Discord"> follow on Twitter + alt="follow on X(Twitter)"> Docker Pulls @@ -39,7 +39,6 @@ README بالعربية Türkçe README README Tiếng Việt -

@@ -195,10 +194,14 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 #### Terraform을 사용한 배포 +[terraform](https://www.terraform.io/)을 사용하여 단 한 번의 클릭으로 Dify를 클라우드 플랫폼에 배포하십시오 + ##### Azure Global -[terraform](https://www.terraform.io/)을 사용하여 Azure에 Dify를 원클릭으로 배포하세요. - [nikawang의 Azure Terraform](https://github.com/nikawang/dify-azure-terraform) +##### Google Cloud +- [sotazum의 Google Cloud Terraform](https://github.com/DeNA/dify-google-cloud-terraform) + ## 기여 코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요. diff --git a/README_PT.md b/README_PT.md new file mode 100644 index 00000000000000..3d66b768023f17 --- /dev/null +++ b/README_PT.md @@ -0,0 +1,241 @@ +![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) + +

+ 📌 Introduzindo o Dify Workflow com Upload de Arquivo: Recrie o Podcast Google NotebookLM +

+ +

+ Dify Cloud · + Auto-hospedagem · + Documentação · + Consultas empresariais +

+ +

+ + Static Badge + + Static Badge + + chat on Discord + + follow on X(Twitter) + + Docker Pulls + + Commits last month + + Issues closed + + Discussion posts +

+ +

+ README em Inglês + 简体中文版自述文件 + 日本語のREADME + README em Espanhol + README em Francês + README tlhIngan Hol + README em Coreano + README em Árabe + README em Turco + README em Vietnamita + README em Português - BR +

+ +Dify é uma plataforma de desenvolvimento de aplicativos LLM de código aberto. Sua interface intuitiva combina workflow de IA, pipeline RAG, capacidades de agente, gerenciamento de modelos, recursos de observabilidade e muito mais, permitindo que você vá rapidamente do protótipo à produção. Aqui está uma lista das principais funcionalidades: +

+ +**1. Workflow**: + Construa e teste workflows poderosos de IA em uma interface visual, aproveitando todos os recursos a seguir e muito mais. + + + https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa + + + +**2. Suporte abrangente a modelos**: + Integração perfeita com centenas de LLMs proprietários e de código aberto de diversas provedoras e soluções auto-hospedadas, abrangendo GPT, Mistral, Llama3 e qualquer modelo compatível com a API da OpenAI. A lista completa de provedores suportados pode ser encontrada [aqui](https://docs.dify.ai/getting-started/readme/model-providers). + +![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) + + +**3. IDE de Prompt**: + Interface intuitiva para criação de prompts, comparação de desempenho de modelos e adição de recursos como conversão de texto para fala em um aplicativo baseado em chat. + +**4. Pipeline RAG**: + Extensas capacidades de RAG que cobrem desde a ingestão de documentos até a recuperação, com suporte nativo para extração de texto de PDFs, PPTs e outros formatos de documentos comuns. + +**5. Capacidades de agente**: + Você pode definir agentes com base em LLM Function Calling ou ReAct e adicionar ferramentas pré-construídas ou personalizadas para o agente. O Dify oferece mais de 50 ferramentas integradas para agentes de IA, como Google Search, DALL·E, Stable Diffusion e WolframAlpha. + +**6. LLMOps**: + Monitore e analise os registros e o desempenho do aplicativo ao longo do tempo. É possível melhorar continuamente prompts, conjuntos de dados e modelos com base nos dados de produção e anotações. + +**7. Backend como Serviço**: + Todas os recursos do Dify vêm com APIs correspondentes, permitindo que você integre o Dify sem esforço na lógica de negócios da sua empresa. + + +## Comparação de recursos + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
RecursoDify.AILangChainFlowiseOpenAI Assistants API
Abordagem de ProgramaçãoOrientada a API + AplicativoCódigo PythonOrientada a AplicativoOrientada a API
LLMs SuportadosVariedade RicaVariedade RicaVariedade RicaApenas OpenAI
RAG Engine
Agente
Workflow
Observabilidade
Recursos Empresariais (SSO/Controle de Acesso)
Implantação Local
+ +## Usando o Dify + +- **Nuvem
** +Oferecemos o serviço [Dify Cloud](https://dify.ai) para qualquer pessoa experimentar sem nenhuma configuração. Ele fornece todas as funcionalidades da versão auto-hospedada, incluindo 200 chamadas GPT-4 gratuitas no plano sandbox. + +- **Auto-hospedagem do Dify Community Edition
** +Configure rapidamente o Dify no seu ambiente com este [guia inicial](#quick-start). +Use nossa [documentação](https://docs.dify.ai) para referências adicionais e instruções mais detalhadas. + +- **Dify para empresas/organizações
** +Oferecemos recursos adicionais voltados para empresas. [Envie suas perguntas através deste chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) ou [envie-nos um e-mail](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) para discutir necessidades empresariais.
+ > Para startups e pequenas empresas que utilizam AWS, confira o [Dify Premium no AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) e implemente no seu próprio AWS VPC com um clique. É uma oferta AMI acessível com a opção de criar aplicativos com logotipo e marca personalizados. + + +## Mantendo-se atualizado + +Dê uma estrela no Dify no GitHub e seja notificado imediatamente sobre novos lançamentos. + +![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) + + + +## Início rápido +> Antes de instalar o Dify, certifique-se de que sua máquina atenda aos seguintes requisitos mínimos de sistema: +> +>- CPU >= 2 Núcleos +>- RAM >= 4 GiB + +
+ +A maneira mais fácil de iniciar o servidor Dify é executar nosso arquivo [docker-compose.yml](docker/docker-compose.yaml). Antes de rodar o comando de instalação, certifique-se de que o [Docker](https://docs.docker.com/get-docker/) e o [Docker Compose](https://docs.docker.com/compose/install/) estão instalados na sua máquina: + +```bash +cd docker +cp .env.example .env +docker compose up -d +``` + +Após a execução, você pode acessar o painel do Dify no navegador em [http://localhost/install](http://localhost/install) e iniciar o processo de inicialização. + +> Se você deseja contribuir com o Dify ou fazer desenvolvimento adicional, consulte nosso [guia para implantar a partir do código fonte](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code). + +## Próximos passos + +Se precisar personalizar a configuração, consulte os comentários no nosso arquivo [.env.example](docker/.env.example) e atualize os valores correspondentes no seu arquivo `.env`. Além disso, talvez seja necessário fazer ajustes no próprio arquivo `docker-compose.yaml`, como alterar versões de imagem, mapeamentos de portas ou montagens de volumes, com base no seu ambiente de implantação específico e nas suas necessidades. Após fazer quaisquer alterações, execute novamente `docker-compose up -d`. Você pode encontrar a lista completa de variáveis de ambiente disponíveis [aqui](https://docs.dify.ai/getting-started/install-self-hosted/environments). + +Se deseja configurar uma instalação de alta disponibilidade, há [Helm Charts](https://helm.sh/) e arquivos YAML contribuídos pela comunidade que permitem a implantação do Dify no Kubernetes. + +- [Helm Chart de @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) +- [Helm Chart de @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) +- [Arquivo YAML de @Winson-030](https://github.com/Winson-030/dify-kubernetes) + +#### Usando o Terraform para Implantação + +Implante o Dify na Plataforma Cloud com um único clique usando [terraform](https://www.terraform.io/) + +##### Azure Global +- [Azure Terraform por @nikawang](https://github.com/nikawang/dify-azure-terraform) + +##### Google Cloud +- [Google Cloud Terraform por @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) + +## Contribuindo + +Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +Ao mesmo tempo, considere apoiar o Dify compartilhando-o nas redes sociais e em eventos e conferências. + +> Estamos buscando contribuidores para ajudar na tradução do Dify para idiomas além de Mandarim e Inglês. Se você tiver interesse em ajudar, consulte o [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) para mais informações e deixe-nos um comentário no canal `global-users` em nosso [Servidor da Comunidade no Discord](https://discord.gg/8Tpq4AcN9c). + +**Contribuidores** + + + + + +## Comunidade e contato + +* [Discussões no GitHub](https://github.com/langgenius/dify/discussions). Melhor para: compartilhar feedback e fazer perguntas. +* [Problemas no GitHub](https://github.com/langgenius/dify/issues). Melhor para: relatar bugs encontrados no Dify.AI e propor novos recursos. Veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). +* [Discord](https://discord.gg/FngNHpbcY7). Melhor para: compartilhar suas aplicações e interagir com a comunidade. +* [X(Twitter)](https://twitter.com/dify_ai). Melhor para: compartilhar suas aplicações e interagir com a comunidade. + +## Histórico de estrelas + +[![Gráfico de Histórico de Estrelas](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) + +## Divulgação de segurança + +Para proteger sua privacidade, evite postar problemas de segurança no GitHub. Em vez disso, envie suas perguntas para security@dify.ai e forneceremos uma resposta mais detalhada. + +## Licença + +Este repositório está disponível sob a [Licença de Código Aberto Dify](LICENSE), que é essencialmente Apache 2.0 com algumas restrições adicionais. \ No newline at end of file diff --git a/README_TR.md b/README_TR.md index 54b6db3f823717..a75889e5760634 100644 --- a/README_TR.md +++ b/README_TR.md @@ -17,7 +17,7 @@ alt="Discord'da sohbet et"> Twitter'da takip et + alt="X(Twitter)'da takip et"> Docker Çekmeleri @@ -200,9 +200,13 @@ Yüksek kullanılabilirliğe sahip bir kurulum yapılandırmak isterseniz, Dify' #### Dağıtım için Terraform Kullanımı +Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.terraform.io/) kullanarak + ##### Azure Global -[Terraform](https://www.terraform.io/) kullanarak Dify'ı Azure'a tek tıklamayla dağıtın. -- [@nikawang tarafından Azure Terraform](https://github.com/nikawang/dify-azure-terraform) +- [Azure Terraform tarafından @nikawang](https://github.com/nikawang/dify-azure-terraform) + +##### Google Cloud +- [Google Cloud Terraform tarafından @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) ## Katkıda Bulunma @@ -222,7 +226,7 @@ Aynı zamanda, lütfen Dify'ı sosyal medyada, etkinliklerde ve konferanslarda p * [Github Tartışmaları](https://github.com/langgenius/dify/discussions). En uygun: geri bildirim paylaşmak ve soru sormak için. * [GitHub Sorunları](https://github.com/langgenius/dify/issues). En uygun: Dify.AI kullanırken karşılaştığınız hatalar ve özellik önerileri için. [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakın. * [Discord](https://discord.gg/FngNHpbcY7). En uygun: uygulamalarınızı paylaşmak ve toplulukla vakit geçirmek için. -* [Twitter](https://twitter.com/dify_ai). En uygun: uygulamalarınızı paylaşmak ve toplulukla vakit geçirmek için. +* [X(Twitter)](https://twitter.com/dify_ai). En uygun: uygulamalarınızı paylaşmak ve toplulukla vakit geçirmek için. ## Star history diff --git a/README_VI.md b/README_VI.md index 6d4035eceb06de..8d49e4976626b6 100644 --- a/README_VI.md +++ b/README_VI.md @@ -17,7 +17,7 @@ alt="chat trên Discord"> theo dõi trên Twitter + alt="theo dõi trên X(Twitter)"> Docker Pulls @@ -196,10 +196,14 @@ Nếu bạn muốn cấu hình một cài đặt có độ sẵn sàng cao, có #### Sử dụng Terraform để Triển khai +Triển khai Dify lên nền tảng đám mây với một cú nhấp chuột bằng cách sử dụng [terraform](https://www.terraform.io/) + ##### Azure Global -Triển khai Dify lên Azure chỉ với một cú nhấp chuột bằng cách sử dụng [terraform](https://www.terraform.io/). - [Azure Terraform bởi @nikawang](https://github.com/nikawang/dify-azure-terraform) +##### Google Cloud +- [Google Cloud Terraform bởi @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) + ## Đóng góp Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi. @@ -219,7 +223,7 @@ Triển khai Dify lên Azure chỉ với một cú nhấp chuột bằng cách s * [Thảo luận GitHub](https://github.com/langgenius/dify/discussions). Tốt nhất cho: chia sẻ phản hồi và đặt câu hỏi. * [Vấn đề GitHub](https://github.com/langgenius/dify/issues). Tốt nhất cho: lỗi bạn gặp phải khi sử dụng Dify.AI và đề xuất tính năng. Xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi. * [Discord](https://discord.gg/FngNHpbcY7). Tốt nhất cho: chia sẻ ứng dụng của bạn và giao lưu với cộng đồng. -* [Twitter](https://twitter.com/dify_ai). Tốt nhất cho: chia sẻ ứng dụng của bạn và giao lưu với cộng đồng. +* [X(Twitter)](https://twitter.com/dify_ai). Tốt nhất cho: chia sẻ ứng dụng của bạn và giao lưu với cộng đồng. ## Lịch sử Yêu thích diff --git a/api/.env.example b/api/.env.example index 775149f8fd3888..5751605b48119c 100644 --- a/api/.env.example +++ b/api/.env.example @@ -20,6 +20,9 @@ FILES_URL=http://127.0.0.1:5001 # The time in seconds after the signature is rejected FILES_ACCESS_TIMEOUT=300 +# Access token expiration time in minutes +ACCESS_TOKEN_EXPIRE_MINUTES=60 + # celery configuration CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1 @@ -28,8 +31,17 @@ REDIS_HOST=localhost REDIS_PORT=6379 REDIS_USERNAME= REDIS_PASSWORD=difyai123456 +REDIS_USE_SSL=false REDIS_DB=0 +# redis Sentinel configuration. +REDIS_USE_SENTINEL=false +REDIS_SENTINELS= +REDIS_SENTINEL_SERVICE_NAME= +REDIS_SENTINEL_USERNAME= +REDIS_SENTINEL_PASSWORD= +REDIS_SENTINEL_SOCKET_TIMEOUT=0.1 + # PostgreSQL database configuration DB_USERNAME=postgres DB_PASSWORD=difyai123456 @@ -39,7 +51,7 @@ DB_DATABASE=dify # Storage configuration # use for store upload files, private keys... -# storage type: local, s3, azure-blob, google-storage +# storage type: local, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase STORAGE_TYPE=local STORAGE_LOCAL_PATH=storage S3_USE_AWS_MANAGED_IAM=false @@ -60,7 +72,8 @@ ALIYUN_OSS_SECRET_KEY=your-secret-key ALIYUN_OSS_ENDPOINT=your-endpoint ALIYUN_OSS_AUTH_VERSION=v1 ALIYUN_OSS_REGION=your-region - +# Don't start with '/'. OSS doesn't support leading slash in object names. +ALIYUN_OSS_PATH=your-path # Google Storage configuration GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string @@ -72,6 +85,18 @@ TENCENT_COS_SECRET_ID=your-secret-id TENCENT_COS_REGION=your-region TENCENT_COS_SCHEME=your-scheme +# Huawei OBS Storage Configuration +HUAWEI_OBS_BUCKET_NAME=your-bucket-name +HUAWEI_OBS_SECRET_KEY=your-secret-key +HUAWEI_OBS_ACCESS_KEY=your-access-key +HUAWEI_OBS_SERVER=your-server-url + +# Baidu OBS Storage Configuration +BAIDU_OBS_BUCKET_NAME=your-bucket-name +BAIDU_OBS_SECRET_KEY=your-secret-key +BAIDU_OBS_ACCESS_KEY=your-access-key +BAIDU_OBS_ENDPOINT=your-server-url + # OCI Storage configuration OCI_ENDPOINT=your-endpoint OCI_BUCKET_NAME=your-bucket-name @@ -79,11 +104,24 @@ OCI_ACCESS_KEY=your-access-key OCI_SECRET_KEY=your-secret-key OCI_REGION=your-region +# Volcengine tos Storage configuration +VOLCENGINE_TOS_ENDPOINT=your-endpoint +VOLCENGINE_TOS_BUCKET_NAME=your-bucket-name +VOLCENGINE_TOS_ACCESS_KEY=your-access-key +VOLCENGINE_TOS_SECRET_KEY=your-secret-key +VOLCENGINE_TOS_REGION=your-region + +# Supabase Storage Configuration +SUPABASE_BUCKET_NAME=your-bucket-name +SUPABASE_API_KEY=your-access-key +SUPABASE_URL=your-server-url + # CORS configuration WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* -# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector + +# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase VECTOR_STORE=weaviate # Weaviate configuration @@ -99,12 +137,18 @@ QDRANT_CLIENT_TIMEOUT=20 QDRANT_GRPC_ENABLED=false QDRANT_GRPC_PORT=6334 +#Couchbase configuration +COUCHBASE_CONNECTION_STRING=127.0.0.1 +COUCHBASE_USER=Administrator +COUCHBASE_PASSWORD=password +COUCHBASE_BUCKET_NAME=Embeddings +COUCHBASE_SCOPE_NAME=_default + # Milvus configuration -MILVUS_HOST=127.0.0.1 -MILVUS_PORT=19530 +MILVUS_URI=http://127.0.0.1:19530 +MILVUS_TOKEN= MILVUS_USER=root MILVUS_PASSWORD=Milvus -MILVUS_SECURE=false # MyScale configuration MYSCALE_HOST=127.0.0.1 @@ -149,6 +193,8 @@ PGVECTOR_PORT=5433 PGVECTOR_USER=postgres PGVECTOR_PASSWORD=postgres PGVECTOR_DATABASE=postgres +PGVECTOR_MIN_CONNECTION=1 +PGVECTOR_MAX_CONNECTION=5 # Tidb Vector configuration TIDB_VECTOR_HOST=xxx.eu-central-1.xxx.aws.tidbcloud.com @@ -157,6 +203,20 @@ TIDB_VECTOR_USER=xxx.root TIDB_VECTOR_PASSWORD=xxxxxx TIDB_VECTOR_DATABASE=dify +# Tidb on qdrant configuration +TIDB_ON_QDRANT_URL=http://127.0.0.1 +TIDB_ON_QDRANT_API_KEY=dify +TIDB_ON_QDRANT_CLIENT_TIMEOUT=20 +TIDB_ON_QDRANT_GRPC_ENABLED=false +TIDB_ON_QDRANT_GRPC_PORT=6334 +TIDB_PUBLIC_KEY=dify +TIDB_PRIVATE_KEY=dify +TIDB_API_URL=http://127.0.0.1 +TIDB_IAM_API_URL=http://127.0.0.1 +TIDB_REGION=regions/aws-us-east-1 +TIDB_PROJECT_ID=dify +TIDB_SPEND_LIMIT=100 + # Chroma configuration CHROMA_HOST=127.0.0.1 CHROMA_PORT=8000 @@ -182,14 +242,54 @@ OPENSEARCH_USER=admin OPENSEARCH_PASSWORD=admin OPENSEARCH_SECURE=true +# Baidu configuration +BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287 +BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000 +BAIDU_VECTOR_DB_ACCOUNT=root +BAIDU_VECTOR_DB_API_KEY=dify +BAIDU_VECTOR_DB_DATABASE=dify +BAIDU_VECTOR_DB_SHARD=1 +BAIDU_VECTOR_DB_REPLICAS=3 + +# Upstash configuration +UPSTASH_VECTOR_URL=your-server-url +UPSTASH_VECTOR_TOKEN=your-access-token + +# ViKingDB configuration +VIKINGDB_ACCESS_KEY=your-ak +VIKINGDB_SECRET_KEY=your-sk +VIKINGDB_REGION=cn-shanghai +VIKINGDB_HOST=api-vikingdb.xxx.volces.com +VIKINGDB_SCHEMA=http +VIKINGDB_CONNECTION_TIMEOUT=30 +VIKINGDB_SOCKET_TIMEOUT=30 + +# Lindorm configuration +LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070 +LINDORM_USERNAME=admin +LINDORM_PASSWORD=admin + +# OceanBase Vector configuration +OCEANBASE_VECTOR_HOST=127.0.0.1 +OCEANBASE_VECTOR_PORT=2881 +OCEANBASE_VECTOR_USER=root@test +OCEANBASE_VECTOR_PASSWORD=difyai123456 +OCEANBASE_VECTOR_DATABASE=test +OCEANBASE_MEMORY_LIMIT=6G + + # Upload configuration UPLOAD_FILE_SIZE_LIMIT=15 UPLOAD_FILE_BATCH_LIMIT=5 UPLOAD_IMAGE_FILE_SIZE_LIMIT=10 +UPLOAD_VIDEO_FILE_SIZE_LIMIT=100 +UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 -# Model Configuration +# Model configuration MULTIMODAL_SEND_IMAGE_FORMAT=base64 +MULTIMODAL_SEND_VIDEO_FORMAT=base64 PROMPT_GENERATION_MAX_TOKENS=512 +CODE_GENERATION_MAX_TOKENS=1024 # Mail configuration, support: resend, smtp MAIL_TYPE= @@ -221,13 +321,21 @@ ETL_TYPE=dify UNSTRUCTURED_API_URL= UNSTRUCTURED_API_KEY= +#ssrf SSRF_PROXY_HTTP_URL= SSRF_PROXY_HTTPS_URL= SSRF_DEFAULT_MAX_RETRIES=3 +SSRF_DEFAULT_TIME_OUT=5 +SSRF_DEFAULT_CONNECT_TIME_OUT=5 +SSRF_DEFAULT_READ_TIME_OUT=5 +SSRF_DEFAULT_WRITE_TIME_OUT=5 BATCH_UPLOAD_LIMIT=10 KEYWORD_DATA_SOURCE_TYPE=database +# Workflow file upload limit +WORKFLOW_FILE_UPLOAD_LIMIT=10 + # CODE EXECUTION CONFIGURATION CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194 CODE_EXECUTION_API_KEY=dify-sandbox @@ -247,11 +355,22 @@ API_TOOL_DEFAULT_READ_TIMEOUT=60 HTTP_REQUEST_MAX_CONNECT_TIMEOUT=300 HTTP_REQUEST_MAX_READ_TIMEOUT=600 HTTP_REQUEST_MAX_WRITE_TIMEOUT=600 -HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 # 10MB -HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 # 1MB +HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 +HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 + +# Respect X-* headers to redirect clients +RESPECT_XFORWARD_HEADERS_ENABLED=false # Log file path LOG_FILE= +# Log file max size, the unit is MB +LOG_FILE_MAX_SIZE=20 +# Log file max backup count +LOG_FILE_BACKUP_COUNT=5 +# Log dateformat +LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S +# Log Timezone +LOG_TZ=UTC # Indexing configuration INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000 @@ -260,6 +379,7 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000 WORKFLOW_MAX_EXECUTION_STEPS=500 WORKFLOW_MAX_EXECUTION_TIME=1200 WORKFLOW_CALL_MAX_DEPTH=5 +MAX_VARIABLE_SIZE=204800 # App configuration APP_MAX_EXECUTION_TIME=1200 @@ -267,4 +387,18 @@ APP_MAX_ACTIVE_REQUESTS=0 # Celery beat configuration -CELERY_BEAT_SCHEDULER_TIME=1 \ No newline at end of file +CELERY_BEAT_SCHEDULER_TIME=1 + +# Position configuration +POSITION_TOOL_PINS= +POSITION_TOOL_INCLUDES= +POSITION_TOOL_EXCLUDES= + +POSITION_PROVIDER_PINS= +POSITION_PROVIDER_INCLUDES= +POSITION_PROVIDER_EXCLUDES= + +# Reset password token expiry minutes +RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 + +CREATE_TIDB_SERVICE_JOB_ENABLED=false \ No newline at end of file diff --git a/.idea/icon.png b/api/.idea/icon.png similarity index 100% rename from .idea/icon.png rename to api/.idea/icon.png diff --git a/.idea/vcs.xml b/api/.idea/vcs.xml similarity index 88% rename from .idea/vcs.xml rename to api/.idea/vcs.xml index ae8b1755c52a17..b7af618884ac3b 100644 --- a/.idea/vcs.xml +++ b/api/.idea/vcs.xml @@ -12,5 +12,6 @@ + - \ No newline at end of file + diff --git a/api/.vscode/launch.json.example b/api/.vscode/launch.json.example new file mode 100644 index 00000000000000..b9e32e2511a0ca --- /dev/null +++ b/api/.vscode/launch.json.example @@ -0,0 +1,61 @@ +{ + "version": "0.2.0", + "compounds": [ + { + "name": "Launch Flask and Celery", + "configurations": ["Python: Flask", "Python: Celery"] + } + ], + "configurations": [ + { + "name": "Python: Flask", + "consoleName": "Flask", + "type": "debugpy", + "request": "launch", + "python": "${workspaceFolder}/.venv/bin/python", + "cwd": "${workspaceFolder}", + "envFile": ".env", + "module": "flask", + "justMyCode": true, + "jinja": true, + "env": { + "FLASK_APP": "app.py", + "GEVENT_SUPPORT": "True" + }, + "args": [ + "run", + "--port=5001" + ] + }, + { + "name": "Python: Celery", + "consoleName": "Celery", + "type": "debugpy", + "request": "launch", + "python": "${workspaceFolder}/.venv/bin/python", + "cwd": "${workspaceFolder}", + "module": "celery", + "justMyCode": true, + "envFile": ".env", + "console": "integratedTerminal", + "env": { + "FLASK_APP": "app.py", + "FLASK_DEBUG": "1", + "GEVENT_SUPPORT": "True" + }, + "args": [ + "-A", + "app.celery", + "worker", + "-P", + "gevent", + "-c", + "1", + "--loglevel", + "DEBUG", + "-Q", + "dataset,generation,mail,ops_trace,app_deletion" + ] + } + ] +} diff --git a/api/Dockerfile b/api/Dockerfile index 06a6f43631e3ab..51e2a10506474e 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -4,7 +4,11 @@ FROM python:3.10-slim-bookworm AS base WORKDIR /app/api # Install Poetry -ENV POETRY_VERSION=1.8.3 +ENV POETRY_VERSION=1.8.4 + +# if you located in China, you can use aliyun mirror to speed up +# RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/ + RUN pip install --no-cache-dir poetry==${POETRY_VERSION} # Configure Poetry @@ -16,6 +20,9 @@ ENV POETRY_REQUESTS_TIMEOUT=15 FROM base AS packages +# if you located in China, you can use aliyun mirror to speed up +# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources + RUN apt-get update \ && apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev @@ -43,10 +50,14 @@ WORKDIR /app/api RUN apt-get update \ && apt-get install -y --no-install-recommends curl nodejs libgmp-dev libmpfr-dev libmpc-dev \ + # if you located in China, you can use aliyun mirror to speed up + # && echo "deb http://mirrors.aliyun.com/debian testing main" > /etc/apt/sources.list \ && echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \ && apt-get update \ # For Security - && apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-1 libldap-2.5-0=2.5.18+dfsg-2 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \ + && apt-get install -y --no-install-recommends expat=2.6.3-2 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \ + # install a chinese font to support the use of tools like matplotlib + && apt-get install -y fonts-noto-cjk \ && apt-get autoremove -y \ && rm -rf /var/lib/apt/lists/* @@ -56,7 +67,7 @@ COPY --from=packages ${VIRTUAL_ENV} ${VIRTUAL_ENV} ENV PATH="${VIRTUAL_ENV}/bin:${PATH}" # Download nltk data -RUN python -c "import nltk; nltk.download('punkt')" +RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')" # Copy source code COPY . /app/api/ diff --git a/api/README.md b/api/README.md index 70ca2e86a85a7d..de2baee4c5b8a7 100644 --- a/api/README.md +++ b/api/README.md @@ -65,25 +65,24 @@ 8. Start Dify [web](../web) service. 9. Setup your application by visiting `http://localhost:3000`... -10. If you need to debug local async processing, please start the worker service. +10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. ```bash poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion ``` - The started celery app handles the async tasks, e.g. dataset importing and documents indexing. - ## Testing 1. Install dependencies for both the backend and the test environment ```bash - poetry install --with dev + poetry install -C api --with dev ``` 2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml` ```bash - cd ../ poetry run -C api bash dev/pytest/pytest_all_tests.sh ``` + + diff --git a/api/app.py b/api/app.py index ad219ca0d67459..60cd622ef4d0a8 100644 --- a/api/app.py +++ b/api/app.py @@ -1,6 +1,8 @@ import os -if os.environ.get("DEBUG", "false").lower() != "true": +from configs import dify_config + +if not dify_config.DEBUG: from gevent import monkey monkey.patch_all() @@ -10,243 +12,37 @@ grpc.experimental.gevent.init_gevent() import json -import logging -import sys import threading import time import warnings -from logging.handlers import RotatingFileHandler -from flask import Flask, Response, request -from flask_cors import CORS -from werkzeug.exceptions import Unauthorized +from flask import Response -import contexts -from commands import register_commands -from configs import dify_config +from app_factory import create_app # DO NOT REMOVE BELOW -from events import event_handlers -from extensions import ( - ext_celery, - ext_code_based_extension, - ext_compress, - ext_database, - ext_hosting_provider, - ext_login, - ext_mail, - ext_migrate, - ext_redis, - ext_sentry, - ext_storage, -) +from events import event_handlers # noqa: F401 from extensions.ext_database import db -from extensions.ext_login import login_manager -from libs.passport import PassportService # TODO: Find a way to avoid importing models here -from models import account, dataset, model, source, task, tool, tools, web -from services.account_service import AccountService +from models import account, dataset, model, source, task, tool, tools, web # noqa: F401 # DO NOT REMOVE ABOVE warnings.simplefilter("ignore", ResourceWarning) -# fix windows platform -if os.name == "nt": - os.system('tzutil /s "UTC"') -else: - os.environ["TZ"] = "UTC" +os.environ["TZ"] = "UTC" +# windows platform not support tzset +if hasattr(time, "tzset"): time.tzset() -class DifyApp(Flask): - pass - - -# ------------- -# Configuration -# ------------- - - -config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first - - -# ---------------------------- -# Application Factory Function -# ---------------------------- - - -def create_flask_app_with_configs() -> Flask: - """ - create a raw flask app - with configs loaded from .env file - """ - dify_app = DifyApp(__name__) - dify_app.config.from_mapping(dify_config.model_dump()) - - # populate configs into system environment variables - for key, value in dify_app.config.items(): - if isinstance(value, str): - os.environ[key] = value - elif isinstance(value, int | float | bool): - os.environ[key] = str(value) - elif value is None: - os.environ[key] = "" - - return dify_app - - -def create_app() -> Flask: - app = create_flask_app_with_configs() - - app.secret_key = app.config["SECRET_KEY"] - - log_handlers = None - log_file = app.config.get("LOG_FILE") - if log_file: - log_dir = os.path.dirname(log_file) - os.makedirs(log_dir, exist_ok=True) - log_handlers = [ - RotatingFileHandler( - filename=log_file, - maxBytes=1024 * 1024 * 1024, - backupCount=5, - ), - logging.StreamHandler(sys.stdout), - ] - - logging.basicConfig( - level=app.config.get("LOG_LEVEL"), - format=app.config.get("LOG_FORMAT"), - datefmt=app.config.get("LOG_DATEFORMAT"), - handlers=log_handlers, - force=True, - ) - log_tz = app.config.get("LOG_TZ") - if log_tz: - from datetime import datetime - - import pytz - - timezone = pytz.timezone(log_tz) - - def time_converter(seconds): - return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() - - for handler in logging.root.handlers: - handler.formatter.converter = time_converter - initialize_extensions(app) - register_blueprints(app) - register_commands(app) - - return app - - -def initialize_extensions(app): - # Since the application instance is now created, pass it to each Flask - # extension instance to bind it to the Flask application instance (app) - ext_compress.init_app(app) - ext_code_based_extension.init() - ext_database.init_app(app) - ext_migrate.init(app, db) - ext_redis.init_app(app) - ext_storage.init_app(app) - ext_celery.init_app(app) - ext_login.init_app(app) - ext_mail.init_app(app) - ext_hosting_provider.init_app(app) - ext_sentry.init_app(app) - - -# Flask-Login configuration -@login_manager.request_loader -def load_user_from_request(request_from_flask_login): - """Load user based on the request.""" - if request.blueprint not in ["console", "inner_api"]: - return None - # Check if the user_id contains a dot, indicating the old format - auth_header = request.headers.get("Authorization", "") - if not auth_header: - auth_token = request.args.get("_token") - if not auth_token: - raise Unauthorized("Invalid Authorization token.") - else: - if " " not in auth_header: - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - auth_scheme, auth_token = auth_header.split(None, 1) - auth_scheme = auth_scheme.lower() - if auth_scheme != "bearer": - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - - decoded = PassportService().verify(auth_token) - user_id = decoded.get("user_id") - - account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token) - if account: - contexts.tenant_id.set(account.current_tenant_id) - return account - - -@login_manager.unauthorized_handler -def unauthorized_handler(): - """Handle unauthorized requests.""" - return Response( - json.dumps({"code": "unauthorized", "message": "Unauthorized."}), - status=401, - content_type="application/json", - ) - - -# register blueprint routers -def register_blueprints(app): - from controllers.console import bp as console_app_bp - from controllers.files import bp as files_bp - from controllers.inner_api import bp as inner_api_bp - from controllers.service_api import bp as service_api_bp - from controllers.web import bp as web_bp - - CORS( - service_api_bp, - allow_headers=["Content-Type", "Authorization", "X-App-Code"], - methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], - ) - app.register_blueprint(service_api_bp) - - CORS( - web_bp, - resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}}, - supports_credentials=True, - allow_headers=["Content-Type", "Authorization", "X-App-Code"], - methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], - expose_headers=["X-Version", "X-Env"], - ) - - app.register_blueprint(web_bp) - - CORS( - console_app_bp, - resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}}, - supports_credentials=True, - allow_headers=["Content-Type", "Authorization"], - methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], - expose_headers=["X-Version", "X-Env"], - ) - - app.register_blueprint(console_app_bp) - - CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"]) - app.register_blueprint(files_bp) - - app.register_blueprint(inner_api_bp) - - # create app app = create_app() celery = app.extensions["celery"] -if app.config.get("TESTING"): +if dify_config.TESTING: print("App is running in TESTING mode") @@ -254,15 +50,15 @@ def register_blueprints(app): def after_request(response): """Add Version headers to the response.""" response.set_cookie("remember_token", "", expires=0) - response.headers.add("X-Version", app.config["CURRENT_VERSION"]) - response.headers.add("X-Env", app.config["DEPLOY_ENV"]) + response.headers.add("X-Version", dify_config.CURRENT_VERSION) + response.headers.add("X-Env", dify_config.DEPLOY_ENV) return response @app.route("/health") def health(): return Response( - json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}), + json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.CURRENT_VERSION}), status=200, content_type="application/json", ) diff --git a/api/app_factory.py b/api/app_factory.py new file mode 100644 index 00000000000000..60a584798b608a --- /dev/null +++ b/api/app_factory.py @@ -0,0 +1,178 @@ +import os + +from configs import dify_config + +if not dify_config.DEBUG: + from gevent import monkey + + monkey.patch_all() + + import grpc.experimental.gevent + + grpc.experimental.gevent.init_gevent() + +import json + +from flask import Flask, Response, request +from flask_cors import CORS +from werkzeug.exceptions import Unauthorized + +import contexts +from commands import register_commands +from configs import dify_config +from extensions import ( + ext_celery, + ext_code_based_extension, + ext_compress, + ext_database, + ext_hosting_provider, + ext_logging, + ext_login, + ext_mail, + ext_migrate, + ext_proxy_fix, + ext_redis, + ext_sentry, + ext_storage, +) +from extensions.ext_database import db +from extensions.ext_login import login_manager +from libs.passport import PassportService +from services.account_service import AccountService + + +class DifyApp(Flask): + pass + + +# ---------------------------- +# Application Factory Function +# ---------------------------- +def create_flask_app_with_configs() -> Flask: + """ + create a raw flask app + with configs loaded from .env file + """ + dify_app = DifyApp(__name__) + dify_app.config.from_mapping(dify_config.model_dump()) + + # populate configs into system environment variables + for key, value in dify_app.config.items(): + if isinstance(value, str): + os.environ[key] = value + elif isinstance(value, int | float | bool): + os.environ[key] = str(value) + elif value is None: + os.environ[key] = "" + + return dify_app + + +def create_app() -> Flask: + app = create_flask_app_with_configs() + app.secret_key = dify_config.SECRET_KEY + initialize_extensions(app) + register_blueprints(app) + register_commands(app) + + return app + + +def initialize_extensions(app): + # Since the application instance is now created, pass it to each Flask + # extension instance to bind it to the Flask application instance (app) + ext_logging.init_app(app) + ext_compress.init_app(app) + ext_code_based_extension.init() + ext_database.init_app(app) + ext_migrate.init(app, db) + ext_redis.init_app(app) + ext_storage.init_app(app) + ext_celery.init_app(app) + ext_login.init_app(app) + ext_mail.init_app(app) + ext_hosting_provider.init_app(app) + ext_sentry.init_app(app) + ext_proxy_fix.init_app(app) + + +# Flask-Login configuration +@login_manager.request_loader +def load_user_from_request(request_from_flask_login): + """Load user based on the request.""" + if request.blueprint not in {"console", "inner_api"}: + return None + # Check if the user_id contains a dot, indicating the old format + auth_header = request.headers.get("Authorization", "") + if not auth_header: + auth_token = request.args.get("_token") + if not auth_token: + raise Unauthorized("Invalid Authorization token.") + else: + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") + auth_scheme, auth_token = auth_header.split(None, 1) + auth_scheme = auth_scheme.lower() + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") + + decoded = PassportService().verify(auth_token) + user_id = decoded.get("user_id") + + logged_in_account = AccountService.load_logged_in_account(account_id=user_id) + if logged_in_account: + contexts.tenant_id.set(logged_in_account.current_tenant_id) + return logged_in_account + + +@login_manager.unauthorized_handler +def unauthorized_handler(): + """Handle unauthorized requests.""" + return Response( + json.dumps({"code": "unauthorized", "message": "Unauthorized."}), + status=401, + content_type="application/json", + ) + + +# register blueprint routers +def register_blueprints(app): + from controllers.console import bp as console_app_bp + from controllers.files import bp as files_bp + from controllers.inner_api import bp as inner_api_bp + from controllers.service_api import bp as service_api_bp + from controllers.web import bp as web_bp + + CORS( + service_api_bp, + allow_headers=["Content-Type", "Authorization", "X-App-Code"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + ) + app.register_blueprint(service_api_bp) + + CORS( + web_bp, + resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}}, + supports_credentials=True, + allow_headers=["Content-Type", "Authorization", "X-App-Code"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=["X-Version", "X-Env"], + ) + + app.register_blueprint(web_bp) + + CORS( + console_app_bp, + resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}}, + supports_credentials=True, + allow_headers=["Content-Type", "Authorization"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=["X-Version", "X-Env"], + ) + + app.register_blueprint(console_app_bp) + + CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"]) + app.register_blueprint(files_bp) + + app.register_blueprint(inner_api_bp) diff --git a/api/commands.py b/api/commands.py index 41f1a6444c4581..10122ceb3dea2b 100644 --- a/api/commands.py +++ b/api/commands.py @@ -19,7 +19,7 @@ from libs.helper import email as email_validate from libs.password import hash_password, password_pattern, valid_password from libs.rsa import generate_key_pair -from models.account import Tenant +from models import Tenant from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation @@ -28,28 +28,28 @@ @click.command("reset-password", help="Reset the account password.") -@click.option("--email", prompt=True, help="The email address of the account whose password you need to reset") -@click.option("--new-password", prompt=True, help="the new password.") -@click.option("--password-confirm", prompt=True, help="the new password confirm.") +@click.option("--email", prompt=True, help="Account email to reset password for") +@click.option("--new-password", prompt=True, help="New password") +@click.option("--password-confirm", prompt=True, help="Confirm new password") def reset_password(email, new_password, password_confirm): """ Reset password of owner account Only available in SELF_HOSTED mode """ if str(new_password).strip() != str(password_confirm).strip(): - click.echo(click.style("sorry. The two passwords do not match.", fg="red")) + click.echo(click.style("Passwords do not match.", fg="red")) return account = db.session.query(Account).filter(Account.email == email).one_or_none() if not account: - click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red")) + click.echo(click.style("Account not found for email: {}".format(email), fg="red")) return try: valid_password(new_password) except: - click.echo(click.style("sorry. The passwords must match {} ".format(password_pattern), fg="red")) + click.echo(click.style("Invalid password. Must match {}".format(password_pattern), fg="red")) return # generate password salt @@ -62,37 +62,37 @@ def reset_password(email, new_password, password_confirm): account.password = base64_password_hashed account.password_salt = base64_salt db.session.commit() - click.echo(click.style("Congratulations! Password has been reset.", fg="green")) + click.echo(click.style("Password reset successfully.", fg="green")) @click.command("reset-email", help="Reset the account email.") -@click.option("--email", prompt=True, help="The old email address of the account whose email you need to reset") -@click.option("--new-email", prompt=True, help="the new email.") -@click.option("--email-confirm", prompt=True, help="the new email confirm.") +@click.option("--email", prompt=True, help="Current account email") +@click.option("--new-email", prompt=True, help="New email") +@click.option("--email-confirm", prompt=True, help="Confirm new email") def reset_email(email, new_email, email_confirm): """ Replace account email :return: """ if str(new_email).strip() != str(email_confirm).strip(): - click.echo(click.style("Sorry, new email and confirm email do not match.", fg="red")) + click.echo(click.style("New emails do not match.", fg="red")) return account = db.session.query(Account).filter(Account.email == email).one_or_none() if not account: - click.echo(click.style("sorry. the account: [{}] not exist .".format(email), fg="red")) + click.echo(click.style("Account not found for email: {}".format(email), fg="red")) return try: email_validate(new_email) except: - click.echo(click.style("sorry. {} is not a valid email. ".format(email), fg="red")) + click.echo(click.style("Invalid email: {}".format(new_email), fg="red")) return account.email = new_email db.session.commit() - click.echo(click.style("Congratulations!, email has been reset.", fg="green")) + click.echo(click.style("Email updated successfully.", fg="green")) @click.command( @@ -104,7 +104,7 @@ def reset_email(email, new_email, email_confirm): ) @click.confirmation_option( prompt=click.style( - "Are you sure you want to reset encrypt key pair?" " this operation cannot be rolled back!", fg="red" + "Are you sure you want to reset encrypt key pair? This operation cannot be rolled back!", fg="red" ) ) def reset_encrypt_key_pair(): @@ -114,13 +114,13 @@ def reset_encrypt_key_pair(): Only support SELF_HOSTED mode. """ if dify_config.EDITION != "SELF_HOSTED": - click.echo(click.style("Sorry, only support SELF_HOSTED mode.", fg="red")) + click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red")) return tenants = db.session.query(Tenant).all() for tenant in tenants: if not tenant: - click.echo(click.style("Sorry, no workspace found. Please enter /install to initialize.", fg="red")) + click.echo(click.style("No workspaces found. Run /install first.", fg="red")) return tenant.encrypt_public_key = generate_key_pair(tenant.id) @@ -131,18 +131,18 @@ def reset_encrypt_key_pair(): click.echo( click.style( - "Congratulations! " "the asymmetric key pair of workspace {} has been reset.".format(tenant.id), + "Congratulations! The asymmetric key pair of workspace {} has been reset.".format(tenant.id), fg="green", ) ) -@click.command("vdb-migrate", help="migrate vector db.") +@click.command("vdb-migrate", help="Migrate vector db.") @click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.") def vdb_migrate(scope: str): - if scope in ["knowledge", "all"]: + if scope in {"knowledge", "all"}: migrate_knowledge_vector_database() - if scope in ["annotation", "all"]: + if scope in {"annotation", "all"}: migrate_annotation_vector_database() @@ -150,7 +150,7 @@ def migrate_annotation_vector_database(): """ Migrate annotation datas to target vector database . """ - click.echo(click.style("Start migrate annotation data.", fg="green")) + click.echo(click.style("Starting annotation data migration.", fg="green")) create_count = 0 skipped_count = 0 total_count = 0 @@ -174,14 +174,14 @@ def migrate_annotation_vector_database(): f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped." ) try: - click.echo("Create app annotation index: {}".format(app.id)) + click.echo("Creating app annotation index: {}".format(app.id)) app_annotation_setting = ( db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first() ) if not app_annotation_setting: skipped_count = skipped_count + 1 - click.echo("App annotation setting is disabled: {}".format(app.id)) + click.echo("App annotation setting disabled: {}".format(app.id)) continue # get dataset_collection_binding info dataset_collection_binding = ( @@ -190,7 +190,7 @@ def migrate_annotation_vector_database(): .first() ) if not dataset_collection_binding: - click.echo("App annotation collection binding is not exist: {}".format(app.id)) + click.echo("App annotation collection binding not found: {}".format(app.id)) continue annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all() dataset = Dataset( @@ -211,11 +211,11 @@ def migrate_annotation_vector_database(): documents.append(document) vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) - click.echo(f"Start to migrate annotation, app_id: {app.id}.") + click.echo(f"Migrating annotations for app: {app.id}.") try: vector.delete() - click.echo(click.style(f"Successfully delete vector index for app: {app.id}.", fg="green")) + click.echo(click.style(f"Deleted vector index for app {app.id}.", fg="green")) except Exception as e: click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red")) raise e @@ -223,12 +223,12 @@ def migrate_annotation_vector_database(): try: click.echo( click.style( - f"Start to created vector index with {len(documents)} annotations for app {app.id}.", + f"Creating vector index with {len(documents)} annotations for app {app.id}.", fg="green", ) ) vector.create(documents) - click.echo(click.style(f"Successfully created vector index for app {app.id}.", fg="green")) + click.echo(click.style(f"Created vector index for app {app.id}.", fg="green")) except Exception as e: click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red")) raise e @@ -237,14 +237,14 @@ def migrate_annotation_vector_database(): except Exception as e: click.echo( click.style( - "Create app annotation index error: {} {}".format(e.__class__.__name__, str(e)), fg="red" + "Error creating app annotation index: {} {}".format(e.__class__.__name__, str(e)), fg="red" ) ) continue click.echo( click.style( - f"Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.", + f"Migration complete. Created {create_count} app annotation indexes. Skipped {skipped_count} apps.", fg="green", ) ) @@ -254,11 +254,33 @@ def migrate_knowledge_vector_database(): """ Migrate vector database datas to target vector database . """ - click.echo(click.style("Start migrate vector db.", fg="green")) + click.echo(click.style("Starting vector database migration.", fg="green")) create_count = 0 skipped_count = 0 total_count = 0 vector_type = dify_config.VECTOR_STORE + upper_colletion_vector_types = { + VectorType.MILVUS, + VectorType.PGVECTOR, + VectorType.RELYT, + VectorType.WEAVIATE, + VectorType.ORACLE, + VectorType.ELASTICSEARCH, + } + lower_colletion_vector_types = { + VectorType.ANALYTICDB, + VectorType.CHROMA, + VectorType.MYSCALE, + VectorType.PGVECTO_RS, + VectorType.TIDB_VECTOR, + VectorType.OPENSEARCH, + VectorType.TENCENT, + VectorType.BAIDU, + VectorType.VIKINGDB, + VectorType.UPSTASH, + VectorType.COUCHBASE, + VectorType.OCEANBASE, + } page = 1 while True: try: @@ -275,21 +297,18 @@ def migrate_knowledge_vector_database(): for dataset in datasets: total_count = total_count + 1 click.echo( - f"Processing the {total_count} dataset {dataset.id}. " - + f"{create_count} created, {skipped_count} skipped." + f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped." ) try: - click.echo("Create dataset vdb index: {}".format(dataset.id)) + click.echo("Creating dataset vector database index: {}".format(dataset.id)) if dataset.index_struct_dict: if dataset.index_struct_dict["type"] == vector_type: skipped_count = skipped_count + 1 continue collection_name = "" - if vector_type == VectorType.WEAVIATE: - dataset_id = dataset.id + dataset_id = dataset.id + if vector_type in upper_colletion_vector_types: collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": VectorType.WEAVIATE, "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.QDRANT: if dataset.collection_binding_id: dataset_collection_binding = ( @@ -300,66 +319,24 @@ def migrate_knowledge_vector_database(): if dataset_collection_binding: collection_name = dataset_collection_binding.collection_name else: - raise ValueError("Dataset Collection Bindings is not exist!") + raise ValueError("Dataset Collection Binding not found") else: - dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": VectorType.QDRANT, "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.MILVUS: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": VectorType.MILVUS, "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.RELYT: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": "relyt", "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.TENCENT: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": VectorType.TENCENT, "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.PGVECTOR: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": VectorType.PGVECTOR, "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.OPENSEARCH: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.OPENSEARCH, - "vector_store": {"class_prefix": collection_name}, - } - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.ANALYTICDB: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.ANALYTICDB, - "vector_store": {"class_prefix": collection_name}, - } - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.ELASTICSEARCH: - dataset_id = dataset.id - index_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}} - dataset.index_struct = json.dumps(index_struct_dict) + elif vector_type in lower_colletion_vector_types: + collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() else: raise ValueError(f"Vector store {vector_type} is not supported.") + index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}} + dataset.index_struct = json.dumps(index_struct_dict) vector = Vector(dataset) - click.echo(f"Start to migrate dataset {dataset.id}.") + click.echo(f"Migrating dataset {dataset.id}.") try: vector.delete() click.echo( - click.style( - f"Successfully delete vector index {collection_name} for dataset {dataset.id}.", fg="green" - ) + click.style(f"Deleted vector index {collection_name} for dataset {dataset.id}.", fg="green") ) except Exception as e: click.echo( @@ -411,14 +388,13 @@ def migrate_knowledge_vector_database(): try: click.echo( click.style( - f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.", + f"Creating vector index with {len(documents)} documents of {segments_count}" + f" segments for dataset {dataset.id}.", fg="green", ) ) vector.create(documents) - click.echo( - click.style(f"Successfully created vector index for dataset {dataset.id}.", fg="green") - ) + click.echo(click.style(f"Created vector index for dataset {dataset.id}.", fg="green")) except Exception as e: click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red")) raise e @@ -429,13 +405,13 @@ def migrate_knowledge_vector_database(): except Exception as e: db.session.rollback() click.echo( - click.style("Create dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red") + click.style("Error creating dataset index: {} {}".format(e.__class__.__name__, str(e)), fg="red") ) continue click.echo( click.style( - f"Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.", fg="green" + f"Migration complete. Created {create_count} dataset indexes. Skipped {skipped_count} datasets.", fg="green" ) ) @@ -445,7 +421,7 @@ def convert_to_agent_apps(): """ Convert Agent Assistant to Agent App. """ - click.echo(click.style("Start convert to agent apps.", fg="green")) + click.echo(click.style("Starting convert to agent apps.", fg="green")) proceeded_app_ids = [] @@ -453,14 +429,14 @@ def convert_to_agent_apps(): # fetch first 1000 apps sql_query = """SELECT a.id AS id FROM apps a INNER JOIN app_model_configs am ON a.app_model_config_id=am.id - WHERE a.mode = 'chat' - AND am.agent_mode is not null + WHERE a.mode = 'chat' + AND am.agent_mode is not null AND ( - am.agent_mode like '%"strategy": "function_call"%' + am.agent_mode like '%"strategy": "function_call"%' OR am.agent_mode like '%"strategy": "react"%' - ) + ) AND ( - am.agent_mode like '{"enabled": true%' + am.agent_mode like '{"enabled": true%' OR am.agent_mode like '{"max_iteration": %' ) ORDER BY a.created_at DESC LIMIT 1000 """ @@ -496,23 +472,23 @@ def convert_to_agent_apps(): except Exception as e: click.echo(click.style("Convert app error: {} {}".format(e.__class__.__name__, str(e)), fg="red")) - click.echo(click.style("Congratulations! Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green")) + click.echo(click.style("Conversion complete. Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green")) -@click.command("add-qdrant-doc-id-index", help="add qdrant doc_id index.") -@click.option("--field", default="metadata.doc_id", prompt=False, help="index field , default is metadata.doc_id.") +@click.command("add-qdrant-doc-id-index", help="Add Qdrant doc_id index.") +@click.option("--field", default="metadata.doc_id", prompt=False, help="Index field , default is metadata.doc_id.") def add_qdrant_doc_id_index(field: str): - click.echo(click.style("Start add qdrant doc_id index.", fg="green")) + click.echo(click.style("Starting Qdrant doc_id index creation.", fg="green")) vector_type = dify_config.VECTOR_STORE if vector_type != "qdrant": - click.echo(click.style("Sorry, only support qdrant vector store.", fg="red")) + click.echo(click.style("This command only supports Qdrant vector store.", fg="red")) return create_count = 0 try: bindings = db.session.query(DatasetCollectionBinding).all() if not bindings: - click.echo(click.style("Sorry, no dataset collection bindings found.", fg="red")) + click.echo(click.style("No dataset collection bindings found.", fg="red")) return import qdrant_client from qdrant_client.http.exceptions import UnexpectedResponse @@ -522,7 +498,7 @@ def add_qdrant_doc_id_index(field: str): for binding in bindings: if dify_config.QDRANT_URL is None: - raise ValueError("Qdrant url is required.") + raise ValueError("Qdrant URL is required.") qdrant_config = QdrantConfig( endpoint=dify_config.QDRANT_URL, api_key=dify_config.QDRANT_API_KEY, @@ -539,40 +515,39 @@ def add_qdrant_doc_id_index(field: str): except UnexpectedResponse as e: # Collection does not exist, so return if e.status_code == 404: - click.echo( - click.style(f"Collection not found, collection_name:{binding.collection_name}.", fg="red") - ) + click.echo(click.style(f"Collection not found: {binding.collection_name}.", fg="red")) continue # Some other error occurred, so re-raise the exception else: click.echo( click.style( - f"Failed to create qdrant index, collection_name:{binding.collection_name}.", fg="red" + f"Failed to create Qdrant index for collection: {binding.collection_name}.", fg="red" ) ) except Exception as e: - click.echo(click.style("Failed to create qdrant client.", fg="red")) + click.echo(click.style("Failed to create Qdrant client.", fg="red")) - click.echo(click.style(f"Congratulations! Create {create_count} collection indexes.", fg="green")) + click.echo(click.style(f"Index creation complete. Created {create_count} collection indexes.", fg="green")) @click.command("create-tenant", help="Create account and tenant.") -@click.option("--email", prompt=True, help="The email address of the tenant account.") +@click.option("--email", prompt=True, help="Tenant account email.") +@click.option("--name", prompt=True, help="Workspace name.") @click.option("--language", prompt=True, help="Account language, default: en-US.") -def create_tenant(email: str, language: Optional[str] = None): +def create_tenant(email: str, language: Optional[str] = None, name: Optional[str] = None): """ Create tenant account """ if not email: - click.echo(click.style("Sorry, email is required.", fg="red")) + click.echo(click.style("Email is required.", fg="red")) return # Create account email = email.strip() if "@" not in email: - click.echo(click.style("Sorry, invalid email address.", fg="red")) + click.echo(click.style("Invalid email address.", fg="red")) return account_name = email.split("@")[0] @@ -580,29 +555,31 @@ def create_tenant(email: str, language: Optional[str] = None): if language not in languages: language = "en-US" + name = name.strip() + # generate random password new_password = secrets.token_urlsafe(16) # register account account = RegisterService.register(email=email, name=account_name, password=new_password, language=language) - TenantService.create_owner_tenant_if_not_exist(account) + TenantService.create_owner_tenant_if_not_exist(account, name) click.echo( click.style( - "Congratulations! Account and tenant created.\n" "Account: {}\nPassword: {}".format(email, new_password), + "Account and tenant created.\nAccount: {}\nPassword: {}".format(email, new_password), fg="green", ) ) -@click.command("upgrade-db", help="upgrade the database") +@click.command("upgrade-db", help="Upgrade the database") def upgrade_db(): click.echo("Preparing database migration...") lock = redis_client.lock(name="db_upgrade_lock", timeout=60) if lock.acquire(blocking=False): try: - click.echo(click.style("Start database migration.", fg="green")) + click.echo(click.style("Starting database migration.", fg="green")) # run db migration import flask_migrate @@ -612,7 +589,7 @@ def upgrade_db(): click.echo(click.style("Database migration successful!", fg="green")) except Exception as e: - logging.exception(f"Database migration failed, error: {e}") + logging.exception(f"Database migration failed: {e}") finally: lock.release() else: @@ -624,7 +601,7 @@ def fix_app_site_missing(): """ Fix app related site missing issue. """ - click.echo(click.style("Start fix app related site missing issue.", fg="green")) + click.echo(click.style("Starting fix for missing app-related sites.", fg="green")) failed_app_ids = [] while True: @@ -647,22 +624,22 @@ def fix_app_site_missing(): if tenant: accounts = tenant.get_accounts() if not accounts: - print("Fix app {} failed.".format(app.id)) + print("Fix failed for app {}".format(app.id)) continue account = accounts[0] - print("Fix app {} related site missing issue.".format(app.id)) + print("Fixing missing site for app {}".format(app.id)) app_was_created.send(app, account=account) except Exception as e: failed_app_ids.append(app_id) - click.echo(click.style("Fix app {} related site missing issue failed!".format(app_id), fg="red")) + click.echo(click.style("Failed to fix missing site for app {}".format(app_id), fg="red")) logging.exception(f"Fix app related site missing issue failed, error: {e}") continue if not processed_count: break - click.echo(click.style("Congratulations! Fix app related site missing issue successful!", fg="green")) + click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green")) def register_commands(app): diff --git a/api/configs/__init__.py b/api/configs/__init__.py index c0e28c34e1e1ea..3a172601c96382 100644 --- a/api/configs/__init__.py +++ b/api/configs/__init__.py @@ -1,3 +1,3 @@ from .app_config import DifyConfig -dify_config = DifyConfig() \ No newline at end of file +dify_config = DifyConfig() diff --git a/api/configs/app_config.py b/api/configs/app_config.py index b277760edd7b2c..61de73c8689f8b 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -1,4 +1,3 @@ -from pydantic import Field, computed_field from pydantic_settings import SettingsConfigDict from configs.deploy import DeploymentConfig @@ -24,42 +23,16 @@ class DifyConfig( # **Before using, please contact business@dify.ai by email to inquire about licensing matters.** EnterpriseFeatureConfig, ): - DEBUG: bool = Field(default=False, description='whether to enable debug mode.') - model_config = SettingsConfigDict( # read from dotenv format config file - env_file='.env', - env_file_encoding='utf-8', + env_file=".env", + env_file_encoding="utf-8", frozen=True, # ignore extra attributes - extra='ignore', + extra="ignore", ) - CODE_MAX_NUMBER: int = 9223372036854775807 - CODE_MIN_NUMBER: int = -9223372036854775808 - CODE_MAX_STRING_LENGTH: int = 80000 - CODE_MAX_STRING_ARRAY_LENGTH: int = 30 - CODE_MAX_OBJECT_ARRAY_LENGTH: int = 30 - CODE_MAX_NUMBER_ARRAY_LENGTH: int = 1000 - - HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = 300 - HTTP_REQUEST_MAX_READ_TIMEOUT: int = 600 - HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = 600 - HTTP_REQUEST_NODE_MAX_BINARY_SIZE: int = 1024 * 1024 * 10 - - @computed_field - def HTTP_REQUEST_NODE_READABLE_MAX_BINARY_SIZE(self) -> str: - return f'{self.HTTP_REQUEST_NODE_MAX_BINARY_SIZE / 1024 / 1024:.2f}MB' - - HTTP_REQUEST_NODE_MAX_TEXT_SIZE: int = 1024 * 1024 - - @computed_field - def HTTP_REQUEST_NODE_READABLE_MAX_TEXT_SIZE(self) -> str: - return f'{self.HTTP_REQUEST_NODE_MAX_TEXT_SIZE / 1024 / 1024:.2f}MB' - - SSRF_PROXY_HTTP_URL: str | None = None - SSRF_PROXY_HTTPS_URL: str | None = None - - MODERATION_BUFFER_SIZE: int = Field(default=300, description='The buffer size for moderation.') - - MAX_VARIABLE_SIZE: int = Field(default=5 * 1024, description='The maximum size of a variable. default is 5KB.') + # Before adding any config, + # please consider to arrange it in the proper config group of existed or added + # for better readability and maintainability. + # Thanks for your concentration and consideration. diff --git a/api/configs/deploy/__init__.py b/api/configs/deploy/__init__.py index 219b315784323e..66d6a55b4c7eaf 100644 --- a/api/configs/deploy/__init__.py +++ b/api/configs/deploy/__init__.py @@ -4,24 +4,30 @@ class DeploymentConfig(BaseSettings): """ - Deployment configs + Configuration settings for application deployment """ + APPLICATION_NAME: str = Field( - description='application name', - default='langgenius/dify', + description="Name of the application, used for identification and logging purposes", + default="langgenius/dify", + ) + + DEBUG: bool = Field( + description="Enable debug mode for additional logging and development features", + default=False, ) TESTING: bool = Field( - description='', + description="Enable testing mode for running automated tests", default=False, ) EDITION: str = Field( - description='deployment edition', - default='SELF_HOSTED', + description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')", + default="SELF_HOSTED", ) DEPLOY_ENV: str = Field( - description='deployment environment, default to PRODUCTION.', - default='PRODUCTION', + description="Deployment environment (e.g., 'PRODUCTION', 'DEVELOPMENT'), default to PRODUCTION", + default="PRODUCTION", ) diff --git a/api/configs/enterprise/__init__.py b/api/configs/enterprise/__init__.py index b5d884e10e7b4d..eda6345e145a95 100644 --- a/api/configs/enterprise/__init__.py +++ b/api/configs/enterprise/__init__.py @@ -4,16 +4,17 @@ class EnterpriseFeatureConfig(BaseSettings): """ - Enterprise feature configs. + Configuration for enterprise-level features. **Before using, please contact business@dify.ai by email to inquire about licensing matters.** """ + ENTERPRISE_ENABLED: bool = Field( - description='whether to enable enterprise features.' - 'Before using, please contact business@dify.ai by email to inquire about licensing matters.', + description="Enable or disable enterprise-level features." + "Before using, please contact business@dify.ai by email to inquire about licensing matters.", default=False, ) CAN_REPLACE_LOGO: bool = Field( - description='whether to allow replacing enterprise logo.', + description="Allow customization of the enterprise logo.", default=False, ) diff --git a/api/configs/extra/notion_config.py b/api/configs/extra/notion_config.py index b77e8adaaeba52..f9c4d7346399ad 100644 --- a/api/configs/extra/notion_config.py +++ b/api/configs/extra/notion_config.py @@ -6,29 +6,31 @@ class NotionConfig(BaseSettings): """ - Notion integration configs + Configuration settings for Notion integration """ + NOTION_CLIENT_ID: Optional[str] = Field( - description='Notion client ID', + description="Client ID for Notion API authentication. Required for OAuth 2.0 flow.", default=None, ) NOTION_CLIENT_SECRET: Optional[str] = Field( - description='Notion client secret key', + description="Client secret for Notion API authentication. Required for OAuth 2.0 flow.", default=None, ) NOTION_INTEGRATION_TYPE: Optional[str] = Field( - description='Notion integration type, default to None, available values: internal.', + description="Type of Notion integration." + " Set to 'internal' for internal integrations, or None for public integrations.", default=None, ) NOTION_INTERNAL_SECRET: Optional[str] = Field( - description='Notion internal secret key', + description="Secret key for internal Notion integrations. Required when NOTION_INTEGRATION_TYPE is 'internal'.", default=None, ) NOTION_INTEGRATION_TOKEN: Optional[str] = Field( - description='Notion integration token', + description="Integration token for Notion API access. Used for direct API calls without OAuth flow.", default=None, ) diff --git a/api/configs/extra/sentry_config.py b/api/configs/extra/sentry_config.py index e6517f730a577a..f76a6bdb95ca5b 100644 --- a/api/configs/extra/sentry_config.py +++ b/api/configs/extra/sentry_config.py @@ -6,19 +6,23 @@ class SentryConfig(BaseSettings): """ - Sentry configs + Configuration settings for Sentry error tracking and performance monitoring """ + SENTRY_DSN: Optional[str] = Field( - description='Sentry DSN', + description="Sentry Data Source Name (DSN)." + " This is the unique identifier of your Sentry project, used to send events to the correct project.", default=None, ) SENTRY_TRACES_SAMPLE_RATE: NonNegativeFloat = Field( - description='Sentry trace sample rate', + description="Sample rate for Sentry performance monitoring traces." + " Value between 0.0 and 1.0, where 1.0 means 100% of traces are sent to Sentry.", default=1.0, ) SENTRY_PROFILES_SAMPLE_RATE: NonNegativeFloat = Field( - description='Sentry profiles sample rate', + description="Sample rate for Sentry profiling." + " Value between 0.0 and 1.0, where 1.0 means 100% of profiles are sent to Sentry.", default=1.0, ) diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 369b25d788a440..f368a194693f73 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -1,6 +1,15 @@ -from typing import Optional - -from pydantic import AliasChoices, Field, NonNegativeInt, PositiveInt, computed_field +from typing import Annotated, Literal, Optional + +from pydantic import ( + AliasChoices, + Field, + HttpUrl, + NegativeInt, + NonNegativeInt, + PositiveFloat, + PositiveInt, + computed_field, +) from pydantic_settings import BaseSettings from configs.feature.hosted_service import HostedServiceConfig @@ -8,443 +17,734 @@ class SecurityConfig(BaseSettings): """ - Secret Key configs + Security-related configurations for the application """ - SECRET_KEY: Optional[str] = Field( - description='Your App secret key will be used for securely signing the session cookie' - 'Make sure you are changing this key for your deployment with a strong key.' - 'You can generate a strong key using `openssl rand -base64 42`.' - 'Alternatively you can set it with `SECRET_KEY` environment variable.', - default=None, + + SECRET_KEY: str = Field( + description="Secret key for secure session cookie signing." + "Make sure you are changing this key for your deployment with a strong key." + "Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.", + default="", + ) + + RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: PositiveInt = Field( + description="Duration in minutes for which a password reset token remains valid", + default=5, + ) + + LOGIN_DISABLED: bool = Field( + description="Whether to disable login checks", + default=False, + ) + + ADMIN_API_KEY_ENABLE: bool = Field( + description="Whether to enable admin api key for authentication", + default=False, ) - RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field( - description='Expiry time in hours for reset token', - default=24, + ADMIN_API_KEY: Optional[str] = Field( + description="admin api key for authentication", + default=None, ) class AppExecutionConfig(BaseSettings): """ - App Execution configs + Configuration parameters for application execution """ + APP_MAX_EXECUTION_TIME: PositiveInt = Field( - description='execution timeout in seconds for app execution', + description="Maximum allowed execution time for the application in seconds", default=1200, ) APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field( - description='max active request per app, 0 means unlimited', + description="Maximum number of concurrent active requests per app (0 for unlimited)", default=0, ) class CodeExecutionSandboxConfig(BaseSettings): """ - Code Execution Sandbox configs + Configuration for the code execution sandbox environment """ - CODE_EXECUTION_ENDPOINT: str = Field( - description='endpoint URL of code execution servcie', - default='http://sandbox:8194', + + CODE_EXECUTION_ENDPOINT: HttpUrl = Field( + description="URL endpoint for the code execution service", + default="http://sandbox:8194", ) CODE_EXECUTION_API_KEY: str = Field( - description='API key for code execution service', - default='dify-sandbox', + description="API key for accessing the code execution service", + default="dify-sandbox", + ) + + CODE_EXECUTION_CONNECT_TIMEOUT: Optional[float] = Field( + description="Connection timeout in seconds for code execution requests", + default=10.0, + ) + + CODE_EXECUTION_READ_TIMEOUT: Optional[float] = Field( + description="Read timeout in seconds for code execution requests", + default=60.0, + ) + + CODE_EXECUTION_WRITE_TIMEOUT: Optional[float] = Field( + description="Write timeout in seconds for code execution request", + default=10.0, + ) + + CODE_MAX_NUMBER: PositiveInt = Field( + description="Maximum allowed numeric value in code execution", + default=9223372036854775807, + ) + + CODE_MIN_NUMBER: NegativeInt = Field( + description="Minimum allowed numeric value in code execution", + default=-9223372036854775807, + ) + + CODE_MAX_DEPTH: PositiveInt = Field( + description="Maximum allowed depth for nested structures in code execution", + default=5, + ) + + CODE_MAX_PRECISION: PositiveInt = Field( + description="Maximum number of decimal places for floating-point numbers in code execution", + default=20, + ) + + CODE_MAX_STRING_LENGTH: PositiveInt = Field( + description="Maximum allowed length for strings in code execution", + default=80000, + ) + + CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field( + description="Maximum allowed length for string arrays in code execution", + default=30, + ) + + CODE_MAX_OBJECT_ARRAY_LENGTH: PositiveInt = Field( + description="Maximum allowed length for object arrays in code execution", + default=30, + ) + + CODE_MAX_NUMBER_ARRAY_LENGTH: PositiveInt = Field( + description="Maximum allowed length for numeric arrays in code execution", + default=1000, ) class EndpointConfig(BaseSettings): """ - Module URL configs + Configuration for various application endpoints and URLs """ + CONSOLE_API_URL: str = Field( - description='The backend URL prefix of the console API.' - 'used to concatenate the login authorization callback or notion integration callback.', - default='', + description="Base URL for the console API," + "used for login authentication callback or notion integration callbacks", + default="", ) CONSOLE_WEB_URL: str = Field( - description='The front-end URL prefix of the console web.' - 'used to concatenate some front-end addresses and for CORS configuration use.', - default='', + description="Base URL for the console web interface," "used for frontend references and CORS configuration", + default="", ) SERVICE_API_URL: str = Field( - description='Service API Url prefix.' - 'used to display Service API Base Url to the front-end.', - default='', + description="Base URL for the service API, displayed to users for API access", + default="", ) APP_WEB_URL: str = Field( - description='WebApp Url prefix.' - 'used to display WebAPP API Base Url to the front-end.', - default='', + description="Base URL for the web application, used for frontend references", + default="", ) class FileAccessConfig(BaseSettings): """ - File Access configs + Configuration for file access and handling """ + FILES_URL: str = Field( - description='File preview or download Url prefix.' - ' used to display File preview or download Url to the front-end or as Multi-model inputs;' - 'Url is signed and has expiration time.', - validation_alias=AliasChoices('FILES_URL', 'CONSOLE_API_URL'), + description="Base URL for file preview or download," + " used for frontend display and multi-model inputs" + "Url is signed and has expiration time.", + validation_alias=AliasChoices("FILES_URL", "CONSOLE_API_URL"), alias_priority=1, - default='', + default="", ) FILES_ACCESS_TIMEOUT: int = Field( - description='timeout in seconds for file accessing', + description="Expiration time in seconds for file access URLs", default=300, ) class FileUploadConfig(BaseSettings): """ - File Uploading configs + Configuration for file upload limitations """ + UPLOAD_FILE_SIZE_LIMIT: NonNegativeInt = Field( - description='size limit in Megabytes for uploading files', + description="Maximum allowed file size for uploads in megabytes", default=15, ) UPLOAD_FILE_BATCH_LIMIT: NonNegativeInt = Field( - description='batch size limit for uploading files', + description="Maximum number of files allowed in a single upload batch", default=5, ) UPLOAD_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field( - description='image file size limit in Megabytes for uploading files', + description="Maximum allowed image file size for uploads in megabytes", default=10, ) + UPLOAD_VIDEO_FILE_SIZE_LIMIT: NonNegativeInt = Field( + description="video file size limit in Megabytes for uploading files", + default=100, + ) + + UPLOAD_AUDIO_FILE_SIZE_LIMIT: NonNegativeInt = Field( + description="audio file size limit in Megabytes for uploading files", + default=50, + ) + BATCH_UPLOAD_LIMIT: NonNegativeInt = Field( - description='', # todo: to be clarified + description="Maximum number of files allowed in a batch upload operation", default=20, ) + WORKFLOW_FILE_UPLOAD_LIMIT: PositiveInt = Field( + description="Maximum number of files allowed in a workflow upload operation", + default=10, + ) + class HttpConfig(BaseSettings): """ - HTTP configs + HTTP-related configurations for the application """ + API_COMPRESSION_ENABLED: bool = Field( - description='whether to enable HTTP response compression of gzip', + description="Enable or disable gzip compression for HTTP responses", default=False, ) inner_CONSOLE_CORS_ALLOW_ORIGINS: str = Field( - description='', - validation_alias=AliasChoices('CONSOLE_CORS_ALLOW_ORIGINS', 'CONSOLE_WEB_URL'), - default='', + description="Comma-separated list of allowed origins for CORS in the console", + validation_alias=AliasChoices("CONSOLE_CORS_ALLOW_ORIGINS", "CONSOLE_WEB_URL"), + default="", ) @computed_field @property def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]: - return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(',') + return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(",") inner_WEB_API_CORS_ALLOW_ORIGINS: str = Field( - description='', - validation_alias=AliasChoices('WEB_API_CORS_ALLOW_ORIGINS'), - default='*', + description="", + validation_alias=AliasChoices("WEB_API_CORS_ALLOW_ORIGINS"), + default="*", ) @computed_field @property def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]: - return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(',') + return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",") + + HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[ + PositiveInt, Field(ge=10, description="Maximum connection timeout in seconds for HTTP requests") + ] = 10 + + HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[ + PositiveInt, Field(ge=60, description="Maximum read timeout in seconds for HTTP requests") + ] = 60 + + HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[ + PositiveInt, Field(ge=10, description="Maximum write timeout in seconds for HTTP requests") + ] = 20 + + HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field( + description="Maximum allowed size in bytes for binary data in HTTP requests", + default=10 * 1024 * 1024, + ) + + HTTP_REQUEST_NODE_MAX_TEXT_SIZE: PositiveInt = Field( + description="Maximum allowed size in bytes for text data in HTTP requests", + default=1 * 1024 * 1024, + ) + + SSRF_DEFAULT_MAX_RETRIES: PositiveInt = Field( + description="Maximum number of retries for network requests (SSRF)", + default=3, + ) + + SSRF_PROXY_ALL_URL: Optional[str] = Field( + description="Proxy URL for HTTP or HTTPS requests to prevent Server-Side Request Forgery (SSRF)", + default=None, + ) + + SSRF_PROXY_HTTP_URL: Optional[str] = Field( + description="Proxy URL for HTTP requests to prevent Server-Side Request Forgery (SSRF)", + default=None, + ) + + SSRF_PROXY_HTTPS_URL: Optional[str] = Field( + description="Proxy URL for HTTPS requests to prevent Server-Side Request Forgery (SSRF)", + default=None, + ) + + SSRF_DEFAULT_TIME_OUT: PositiveFloat = Field( + description="The default timeout period used for network requests (SSRF)", + default=5, + ) + + SSRF_DEFAULT_CONNECT_TIME_OUT: PositiveFloat = Field( + description="The default connect timeout period used for network requests (SSRF)", + default=5, + ) + + SSRF_DEFAULT_READ_TIME_OUT: PositiveFloat = Field( + description="The default read timeout period used for network requests (SSRF)", + default=5, + ) + + SSRF_DEFAULT_WRITE_TIME_OUT: PositiveFloat = Field( + description="The default write timeout period used for network requests (SSRF)", + default=5, + ) + + RESPECT_XFORWARD_HEADERS_ENABLED: bool = Field( + description="Enable or disable the X-Forwarded-For Proxy Fix middleware from Werkzeug" + " to respect X-* headers to redirect clients", + default=False, + ) class InnerAPIConfig(BaseSettings): """ - Inner API configs + Configuration for internal API functionality """ + INNER_API: bool = Field( - description='whether to enable the inner API', + description="Enable or disable the internal API", default=False, ) INNER_API_KEY: Optional[str] = Field( - description='The inner API key is used to authenticate the inner API', + description="API key for accessing the internal API", default=None, ) class LoggingConfig(BaseSettings): """ - Logging configs + Configuration for application logging """ LOG_LEVEL: str = Field( - description='Log output level, default to INFO.' - 'It is recommended to set it to ERROR for production.', - default='INFO', + description="Logging level, default to INFO. Set to ERROR for production environments.", + default="INFO", ) LOG_FILE: Optional[str] = Field( - description='logging output file path', + description="File path for log output.", default=None, ) + LOG_FILE_MAX_SIZE: PositiveInt = Field( + description="Maximum file size for file rotation retention, the unit is megabytes (MB)", + default=20, + ) + + LOG_FILE_BACKUP_COUNT: PositiveInt = Field( + description="Maximum file backup count file rotation retention", + default=5, + ) + LOG_FORMAT: str = Field( - description='log format', - default='%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s', + description="Format string for log messages", + default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s", ) LOG_DATEFORMAT: Optional[str] = Field( - description='log date format', + description="Date format string for log timestamps", default=None, ) LOG_TZ: Optional[str] = Field( - description='specify log timezone, eg: America/New_York', - default=None, + description="Timezone for log timestamps (e.g., 'America/New_York')", + default="UTC", ) class ModelLoadBalanceConfig(BaseSettings): """ - Model load balance configs + Configuration for model load balancing """ + MODEL_LB_ENABLED: bool = Field( - description='whether to enable model load balancing', + description="Enable or disable load balancing for models", default=False, ) class BillingConfig(BaseSettings): """ - Platform Billing Configurations + Configuration for platform billing features """ + BILLING_ENABLED: bool = Field( - description='whether to enable billing', + description="Enable or disable billing functionality", default=False, ) class UpdateConfig(BaseSettings): """ - Update configs + Configuration for application update checks """ + CHECK_UPDATE_URL: str = Field( - description='url for checking updates', - default='https://updates.dify.ai', + description="URL to check for application updates", + default="https://updates.dify.ai", ) class WorkflowConfig(BaseSettings): """ - Workflow feature configs + Configuration for workflow execution """ WORKFLOW_MAX_EXECUTION_STEPS: PositiveInt = Field( - description='max execution steps in single workflow execution', + description="Maximum number of steps allowed in a single workflow execution", default=500, ) WORKFLOW_MAX_EXECUTION_TIME: PositiveInt = Field( - description='max execution time in seconds in single workflow execution', + description="Maximum execution time in seconds for a single workflow", default=1200, ) WORKFLOW_CALL_MAX_DEPTH: PositiveInt = Field( - description='max depth of calling in single workflow execution', + description="Maximum allowed depth for nested workflow calls", default=5, ) + MAX_VARIABLE_SIZE: PositiveInt = Field( + description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.", + default=200 * 1024, + ) + -class OAuthConfig(BaseSettings): +class AuthConfig(BaseSettings): """ - oauth configs + Configuration for authentication and OAuth """ + OAUTH_REDIRECT_PATH: str = Field( - description='redirect path for OAuth', - default='/console/api/oauth/authorize', + description="Redirect path for OAuth authentication callbacks", + default="/console/api/oauth/authorize", ) GITHUB_CLIENT_ID: Optional[str] = Field( - description='GitHub client id for OAuth', + description="GitHub OAuth client ID", default=None, ) GITHUB_CLIENT_SECRET: Optional[str] = Field( - description='GitHub client secret key for OAuth', + description="GitHub OAuth client secret", default=None, ) GOOGLE_CLIENT_ID: Optional[str] = Field( - description='Google client id for OAuth', + description="Google OAuth client ID", default=None, ) GOOGLE_CLIENT_SECRET: Optional[str] = Field( - description='Google client secret key for OAuth', + description="Google OAuth client secret", default=None, ) + ACCESS_TOKEN_EXPIRE_MINUTES: PositiveInt = Field( + description="Expiration time for access tokens in minutes", + default=60, + ) + class ModerationConfig(BaseSettings): """ - Moderation in app configs. + Configuration for content moderation """ - # todo: to be clarified in usage and unit - OUTPUT_MODERATION_BUFFER_SIZE: PositiveInt = Field( - description='buffer size for moderation', + MODERATION_BUFFER_SIZE: PositiveInt = Field( + description="Size of the buffer for content moderation processing", default=300, ) class ToolConfig(BaseSettings): """ - Tool configs + Configuration for tool management """ TOOL_ICON_CACHE_MAX_AGE: PositiveInt = Field( - description='max age in seconds for tool icon caching', + description="Maximum age in seconds for caching tool icons", default=3600, ) class MailConfig(BaseSettings): """ - Mail Configurations + Configuration for email services """ MAIL_TYPE: Optional[str] = Field( - description='Mail provider type name, default to None, availabile values are `smtp` and `resend`.', + description="Email service provider type ('smtp' or 'resend'), default to None.", default=None, ) MAIL_DEFAULT_SEND_FROM: Optional[str] = Field( - description='default email address for sending from ', + description="Default email address to use as the sender", default=None, ) RESEND_API_KEY: Optional[str] = Field( - description='API key for Resend', + description="API key for Resend email service", default=None, ) RESEND_API_URL: Optional[str] = Field( - description='API URL for Resend', + description="API URL for Resend email service", default=None, ) SMTP_SERVER: Optional[str] = Field( - description='smtp server host', + description="SMTP server hostname", default=None, ) SMTP_PORT: Optional[int] = Field( - description='smtp server port', + description="SMTP server port number", default=465, ) SMTP_USERNAME: Optional[str] = Field( - description='smtp server username', + description="Username for SMTP authentication", default=None, ) SMTP_PASSWORD: Optional[str] = Field( - description='smtp server password', + description="Password for SMTP authentication", default=None, ) SMTP_USE_TLS: bool = Field( - description='whether to use TLS connection to smtp server', + description="Enable TLS encryption for SMTP connections", default=False, ) SMTP_OPPORTUNISTIC_TLS: bool = Field( - description='whether to use opportunistic TLS connection to smtp server', + description="Enable opportunistic TLS for SMTP connections", default=False, ) + EMAIL_SEND_IP_LIMIT_PER_MINUTE: PositiveInt = Field( + description="Maximum number of emails allowed to be sent from the same IP address in a minute", + default=50, + ) + class RagEtlConfig(BaseSettings): """ - RAG ETL Configurations. + Configuration for RAG ETL processes """ + # TODO: This config is not only for rag etl, it is also for file upload, we should move it to file upload config ETL_TYPE: str = Field( - description='RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ', - default='dify', + description="RAG ETL type ('dify' or 'Unstructured'), default to 'dify'", + default="dify", ) KEYWORD_DATA_SOURCE_TYPE: str = Field( - description='source type for keyword data, default to `database`, available values are `database` .', - default='database', + description="Data source type for keyword extraction" + " ('database' or other supported types), default to 'database'", + default="database", ) UNSTRUCTURED_API_URL: Optional[str] = Field( - description='API URL for Unstructured', + description="API URL for Unstructured.io service", default=None, ) UNSTRUCTURED_API_KEY: Optional[str] = Field( - description='API key for Unstructured', + description="API key for Unstructured.io service", default=None, ) class DataSetConfig(BaseSettings): """ - Dataset configs + Configuration for dataset management """ - CLEAN_DAY_SETTING: PositiveInt = Field( - description='interval in days for cleaning up dataset', + PLAN_SANDBOX_CLEAN_DAY_SETTING: PositiveInt = Field( + description="Interval in days for dataset cleanup operations - plan: sandbox", default=30, ) + PLAN_PRO_CLEAN_DAY_SETTING: PositiveInt = Field( + description="Interval in days for dataset cleanup operations - plan: pro and team", + default=7, + ) + DATASET_OPERATOR_ENABLED: bool = Field( - description='whether to enable dataset operator', + description="Enable or disable dataset operator functionality", default=False, ) + TIDB_SERVERLESS_NUMBER: PositiveInt = Field( + description="number of tidb serverless cluster", + default=500, + ) + + CREATE_TIDB_SERVICE_JOB_ENABLED: bool = Field( + description="Enable or disable create tidb service job", + default=False, + ) + + class WorkspaceConfig(BaseSettings): """ - Workspace configs + Configuration for workspace management """ INVITE_EXPIRY_HOURS: PositiveInt = Field( - description='workspaces invitation expiration in hours', + description="Expiration time in hours for workspace invitation links", default=72, ) class IndexingConfig(BaseSettings): """ - Indexing configs. + Configuration for indexing operations """ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: PositiveInt = Field( - description='max segmentation token length for indexing', + description="Maximum token length for text segmentation during indexing", default=1000, ) -class ImageFormatConfig(BaseSettings): - MULTIMODAL_SEND_IMAGE_FORMAT: str = Field( - description='multi model send image format, support base64, url, default is base64', - default='base64', +class VisionFormatConfig(BaseSettings): + MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field( + description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64", + default="base64", + ) + + MULTIMODAL_SEND_VIDEO_FORMAT: Literal["base64", "url"] = Field( + description="Format for sending videos in multimodal contexts ('base64' or 'url'), default is base64", + default="base64", ) class CeleryBeatConfig(BaseSettings): CELERY_BEAT_SCHEDULER_TIME: int = Field( - description='the time of the celery scheduler, default to 1 day', + description="Interval in days for Celery Beat scheduler execution, default to 1 day", default=1, ) +class PositionConfig(BaseSettings): + POSITION_PROVIDER_PINS: str = Field( + description="Comma-separated list of pinned model providers", + default="", + ) + + POSITION_PROVIDER_INCLUDES: str = Field( + description="Comma-separated list of included model providers", + default="", + ) + + POSITION_PROVIDER_EXCLUDES: str = Field( + description="Comma-separated list of excluded model providers", + default="", + ) + + POSITION_TOOL_PINS: str = Field( + description="Comma-separated list of pinned tools", + default="", + ) + + POSITION_TOOL_INCLUDES: str = Field( + description="Comma-separated list of included tools", + default="", + ) + + POSITION_TOOL_EXCLUDES: str = Field( + description="Comma-separated list of excluded tools", + default="", + ) + + @computed_field + def POSITION_PROVIDER_PINS_LIST(self) -> list[str]: + return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(",") if item.strip() != ""] + + @computed_field + def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]: + return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(",") if item.strip() != ""} + + @computed_field + def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]: + return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(",") if item.strip() != ""} + + @computed_field + def POSITION_TOOL_PINS_LIST(self) -> list[str]: + return [item.strip() for item in self.POSITION_TOOL_PINS.split(",") if item.strip() != ""] + + @computed_field + def POSITION_TOOL_INCLUDES_SET(self) -> set[str]: + return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(",") if item.strip() != ""} + + @computed_field + def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]: + return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""} + + +class LoginConfig(BaseSettings): + ENABLE_EMAIL_CODE_LOGIN: bool = Field( + description="whether to enable email code login", + default=False, + ) + ENABLE_EMAIL_PASSWORD_LOGIN: bool = Field( + description="whether to enable email password login", + default=True, + ) + ENABLE_SOCIAL_OAUTH_LOGIN: bool = Field( + description="whether to enable github/google oauth login", + default=False, + ) + EMAIL_CODE_LOGIN_TOKEN_EXPIRY_MINUTES: PositiveInt = Field( + description="expiry time in minutes for email code login token", + default=5, + ) + ALLOW_REGISTER: bool = Field( + description="whether to enable register", + default=False, + ) + ALLOW_CREATE_WORKSPACE: bool = Field( + description="whether to enable create workspace", + default=False, + ) + + class FeatureConfig( # place the configs in alphabet order AppExecutionConfig, + AuthConfig, # Changed from OAuthConfig to AuthConfig BillingConfig, CodeExecutionSandboxConfig, DataSetConfig, @@ -452,21 +752,21 @@ class FeatureConfig( FileAccessConfig, FileUploadConfig, HttpConfig, - ImageFormatConfig, + VisionFormatConfig, InnerAPIConfig, IndexingConfig, LoggingConfig, MailConfig, ModelLoadBalanceConfig, ModerationConfig, - OAuthConfig, + PositionConfig, RagEtlConfig, SecurityConfig, ToolConfig, UpdateConfig, WorkflowConfig, WorkspaceConfig, - + LoginConfig, # hosted services config HostedServiceConfig, CeleryBeatConfig, diff --git a/api/configs/feature/hosted_service/__init__.py b/api/configs/feature/hosted_service/__init__.py index 88fe188587e7a1..7f103be8f4f909 100644 --- a/api/configs/feature/hosted_service/__init__.py +++ b/api/configs/feature/hosted_service/__init__.py @@ -6,190 +6,188 @@ class HostedOpenAiConfig(BaseSettings): """ - Hosted OpenAI service config + Configuration for hosted OpenAI service """ HOSTED_OPENAI_API_KEY: Optional[str] = Field( - description='', + description="API key for hosted OpenAI service", default=None, ) HOSTED_OPENAI_API_BASE: Optional[str] = Field( - description='', + description="Base URL for hosted OpenAI API", default=None, ) HOSTED_OPENAI_API_ORGANIZATION: Optional[str] = Field( - description='', + description="Organization ID for hosted OpenAI service", default=None, ) HOSTED_OPENAI_TRIAL_ENABLED: bool = Field( - description='', + description="Enable trial access to hosted OpenAI service", default=False, ) HOSTED_OPENAI_TRIAL_MODELS: str = Field( - description='', - default='gpt-3.5-turbo,' - 'gpt-3.5-turbo-1106,' - 'gpt-3.5-turbo-instruct,' - 'gpt-3.5-turbo-16k,' - 'gpt-3.5-turbo-16k-0613,' - 'gpt-3.5-turbo-0613,' - 'gpt-3.5-turbo-0125,' - 'text-davinci-003', + description="Comma-separated list of available models for trial access", + default="gpt-3.5-turbo," + "gpt-3.5-turbo-1106," + "gpt-3.5-turbo-instruct," + "gpt-3.5-turbo-16k," + "gpt-3.5-turbo-16k-0613," + "gpt-3.5-turbo-0613," + "gpt-3.5-turbo-0125," + "text-davinci-003", ) HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field( - description='', + description="Quota limit for hosted OpenAI service usage", default=200, ) HOSTED_OPENAI_PAID_ENABLED: bool = Field( - description='', + description="Enable paid access to hosted OpenAI service", default=False, ) HOSTED_OPENAI_PAID_MODELS: str = Field( - description='', - default='gpt-4,' - 'gpt-4-turbo-preview,' - 'gpt-4-turbo-2024-04-09,' - 'gpt-4-1106-preview,' - 'gpt-4-0125-preview,' - 'gpt-3.5-turbo,' - 'gpt-3.5-turbo-16k,' - 'gpt-3.5-turbo-16k-0613,' - 'gpt-3.5-turbo-1106,' - 'gpt-3.5-turbo-0613,' - 'gpt-3.5-turbo-0125,' - 'gpt-3.5-turbo-instruct,' - 'text-davinci-003', + description="Comma-separated list of available models for paid access", + default="gpt-4," + "gpt-4-turbo-preview," + "gpt-4-turbo-2024-04-09," + "gpt-4-1106-preview," + "gpt-4-0125-preview," + "gpt-3.5-turbo," + "gpt-3.5-turbo-16k," + "gpt-3.5-turbo-16k-0613," + "gpt-3.5-turbo-1106," + "gpt-3.5-turbo-0613," + "gpt-3.5-turbo-0125," + "gpt-3.5-turbo-instruct," + "text-davinci-003", ) class HostedAzureOpenAiConfig(BaseSettings): """ - Hosted OpenAI service config + Configuration for hosted Azure OpenAI service """ HOSTED_AZURE_OPENAI_ENABLED: bool = Field( - description='', + description="Enable hosted Azure OpenAI service", default=False, ) HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field( - description='', + description="API key for hosted Azure OpenAI service", default=None, ) HOSTED_AZURE_OPENAI_API_BASE: Optional[str] = Field( - description='', + description="Base URL for hosted Azure OpenAI API", default=None, ) HOSTED_AZURE_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field( - description='', + description="Quota limit for hosted Azure OpenAI service usage", default=200, ) class HostedAnthropicConfig(BaseSettings): """ - Hosted Azure OpenAI service config + Configuration for hosted Anthropic service """ HOSTED_ANTHROPIC_API_BASE: Optional[str] = Field( - description='', + description="Base URL for hosted Anthropic API", default=None, ) HOSTED_ANTHROPIC_API_KEY: Optional[str] = Field( - description='', + description="API key for hosted Anthropic service", default=None, ) HOSTED_ANTHROPIC_TRIAL_ENABLED: bool = Field( - description='', + description="Enable trial access to hosted Anthropic service", default=False, ) HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field( - description='', + description="Quota limit for hosted Anthropic service usage", default=600000, ) HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field( - description='', + description="Enable paid access to hosted Anthropic service", default=False, ) class HostedMinmaxConfig(BaseSettings): """ - Hosted Minmax service config + Configuration for hosted Minmax service """ HOSTED_MINIMAX_ENABLED: bool = Field( - description='', + description="Enable hosted Minmax service", default=False, ) class HostedSparkConfig(BaseSettings): """ - Hosted Spark service config + Configuration for hosted Spark service """ HOSTED_SPARK_ENABLED: bool = Field( - description='', + description="Enable hosted Spark service", default=False, ) class HostedZhipuAIConfig(BaseSettings): """ - Hosted Minmax service config + Configuration for hosted ZhipuAI service """ HOSTED_ZHIPUAI_ENABLED: bool = Field( - description='', + description="Enable hosted ZhipuAI service", default=False, ) class HostedModerationConfig(BaseSettings): """ - Hosted Moderation service config + Configuration for hosted Moderation service """ HOSTED_MODERATION_ENABLED: bool = Field( - description='', + description="Enable hosted Moderation service", default=False, ) HOSTED_MODERATION_PROVIDERS: str = Field( - description='', - default='', + description="Comma-separated list of moderation providers", + default="", ) class HostedFetchAppTemplateConfig(BaseSettings): """ - Hosted Moderation service config + Configuration for fetching app templates """ HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field( - description='the mode for fetching app templates,' - ' default to remote,' - ' available values: remote, db, builtin', - default='remote', + description="Mode for fetching app templates: remote, db, or builtin" " default to remote,", + default="remote", ) HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN: str = Field( - description='the domain for fetching remote app templates', - default='https://tmpl.dify.ai', + description="Domain for fetching remote app templates", + default="https://tmpl.dify.ai", ) @@ -202,7 +200,6 @@ class HostedServiceConfig( HostedOpenAiConfig, HostedSparkConfig, HostedZhipuAIConfig, - # moderation HostedModerationConfig, ): diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 07688e9aebfdc4..57cc805ebf5a59 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -1,20 +1,29 @@ from typing import Any, Optional from urllib.parse import quote_plus -from pydantic import Field, NonNegativeInt, PositiveInt, computed_field +from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field from pydantic_settings import BaseSettings from configs.middleware.cache.redis_config import RedisConfig from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorageConfig from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig +from configs.middleware.storage.baidu_obs_storage_config import BaiduOBSStorageConfig from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig +from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig from configs.middleware.storage.oci_storage_config import OCIStorageConfig +from configs.middleware.storage.supabase_storage_config import SupabaseStorageConfig from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig +from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig +from configs.middleware.vdb.baidu_vector_config import BaiduVectorDBConfig from configs.middleware.vdb.chroma_config import ChromaConfig +from configs.middleware.vdb.couchbase_config import CouchbaseConfig +from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig +from configs.middleware.vdb.lindorm_config import LindormConfig from configs.middleware.vdb.milvus_config import MilvusConfig from configs.middleware.vdb.myscale_config import MyScaleConfig +from configs.middleware.vdb.oceanbase_config import OceanBaseVectorConfig from configs.middleware.vdb.opensearch_config import OpenSearchConfig from configs.middleware.vdb.oracle_config import OracleConfig from configs.middleware.vdb.pgvector_config import PGVectorConfig @@ -22,114 +31,124 @@ from configs.middleware.vdb.qdrant_config import QdrantConfig from configs.middleware.vdb.relyt_config import RelytConfig from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig +from configs.middleware.vdb.tidb_on_qdrant_config import TidbOnQdrantConfig from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig +from configs.middleware.vdb.upstash_config import UpstashConfig +from configs.middleware.vdb.vikingdb_config import VikingDBConfig from configs.middleware.vdb.weaviate_config import WeaviateConfig class StorageConfig(BaseSettings): STORAGE_TYPE: str = Field( - description='storage type,' - ' default to `local`,' - ' available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.', - default='local', + description="Type of storage to use." + " Options: 'local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', 'huawei-obs', " + "'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'local'.", + default="local", ) STORAGE_LOCAL_PATH: str = Field( - description='local storage path', - default='storage', + description="Path for local storage when STORAGE_TYPE is set to 'local'.", + default="storage", ) class VectorStoreConfig(BaseSettings): VECTOR_STORE: Optional[str] = Field( - description='vector store type', + description="Type of vector store to use for efficient similarity search." + " Set to None if not using a vector store.", default=None, ) + VECTOR_STORE_WHITELIST_ENABLE: Optional[bool] = Field( + description="Enable whitelist for vector store.", + default=False, + ) + class KeywordStoreConfig(BaseSettings): KEYWORD_STORE: str = Field( - description='keyword store type', - default='jieba', + description="Method for keyword extraction and storage." + " Default is 'jieba', a Chinese text segmentation library.", + default="jieba", ) class DatabaseConfig: DB_HOST: str = Field( - description='db host', - default='localhost', + description="Hostname or IP address of the database server.", + default="localhost", ) DB_PORT: PositiveInt = Field( - description='db port', + description="Port number for database connection.", default=5432, ) DB_USERNAME: str = Field( - description='db username', - default='postgres', + description="Username for database authentication.", + default="postgres", ) DB_PASSWORD: str = Field( - description='db password', - default='', + description="Password for database authentication.", + default="", ) DB_DATABASE: str = Field( - description='db database', - default='dify', + description="Name of the database to connect to.", + default="dify", ) DB_CHARSET: str = Field( - description='db charset', - default='', + description="Character set for database connection.", + default="", ) DB_EXTRAS: str = Field( - description='db extras options. Example: keepalives_idle=60&keepalives=1', - default='', + description="Additional database connection parameters. Example: 'keepalives_idle=60&keepalives=1'", + default="", ) SQLALCHEMY_DATABASE_URI_SCHEME: str = Field( - description='db uri scheme', - default='postgresql', + description="Database URI scheme for SQLAlchemy connection.", + default="postgresql", ) @computed_field @property def SQLALCHEMY_DATABASE_URI(self) -> str: db_extras = ( - f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" - if self.DB_CHARSET - else self.DB_EXTRAS + f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS ).strip("&") db_extras = f"?{db_extras}" if db_extras else "" - return (f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://" - f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}" - f"{db_extras}") + return ( + f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://" + f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}" + f"{db_extras}" + ) SQLALCHEMY_POOL_SIZE: NonNegativeInt = Field( - description='pool size of SqlAlchemy', + description="Maximum number of database connections in the pool.", default=30, ) SQLALCHEMY_MAX_OVERFLOW: NonNegativeInt = Field( - description='max overflows for SqlAlchemy', + description="Maximum number of connections that can be created beyond the pool_size.", default=10, ) SQLALCHEMY_POOL_RECYCLE: NonNegativeInt = Field( - description='SqlAlchemy pool recycle', + description="Number of seconds after which a connection is automatically recycled.", default=3600, ) SQLALCHEMY_POOL_PRE_PING: bool = Field( - description='whether to enable pool pre-ping in SqlAlchemy', + description="If True, enables connection pool pre-ping feature to check connections.", default=False, ) SQLALCHEMY_ECHO: bool | str = Field( - description='whether to enable SqlAlchemy echo', + description="If True, SQLAlchemy will log all SQL statements.", default=False, ) @@ -137,35 +156,69 @@ def SQLALCHEMY_DATABASE_URI(self) -> str: @property def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: return { - 'pool_size': self.SQLALCHEMY_POOL_SIZE, - 'max_overflow': self.SQLALCHEMY_MAX_OVERFLOW, - 'pool_recycle': self.SQLALCHEMY_POOL_RECYCLE, - 'pool_pre_ping': self.SQLALCHEMY_POOL_PRE_PING, - 'connect_args': {'options': '-c timezone=UTC'}, + "pool_size": self.SQLALCHEMY_POOL_SIZE, + "max_overflow": self.SQLALCHEMY_MAX_OVERFLOW, + "pool_recycle": self.SQLALCHEMY_POOL_RECYCLE, + "pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING, + "connect_args": {"options": "-c timezone=UTC"}, } class CeleryConfig(DatabaseConfig): CELERY_BACKEND: str = Field( - description='Celery backend, available values are `database`, `redis`', - default='database', + description="Backend for Celery task results. Options: 'database', 'redis'.", + default="database", ) CELERY_BROKER_URL: Optional[str] = Field( - description='CELERY_BROKER_URL', + description="URL of the message broker for Celery tasks.", default=None, ) + CELERY_USE_SENTINEL: Optional[bool] = Field( + description="Whether to use Redis Sentinel for high availability.", + default=False, + ) + + CELERY_SENTINEL_MASTER_NAME: Optional[str] = Field( + description="Name of the Redis Sentinel master.", + default=None, + ) + + CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field( + description="Timeout for Redis Sentinel socket operations in seconds.", + default=0.1, + ) + @computed_field @property def CELERY_RESULT_BACKEND(self) -> str | None: - return 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \ - if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL + return ( + "db+{}".format(self.SQLALCHEMY_DATABASE_URI) + if self.CELERY_BACKEND == "database" + else self.CELERY_BROKER_URL + ) @computed_field @property def BROKER_USE_SSL(self) -> bool: - return self.CELERY_BROKER_URL.startswith('rediss://') if self.CELERY_BROKER_URL else False + return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False + + +class InternalTestConfig(BaseSettings): + """ + Configuration settings for Internal Test + """ + + AWS_SECRET_ACCESS_KEY: Optional[str] = Field( + description="Internal test AWS secret access key", + default=None, + ) + + AWS_ACCESS_KEY_ID: Optional[str] = Field( + description="Internal test AWS access key ID", + default=None, + ) class MiddlewareConfig( @@ -174,16 +227,18 @@ class MiddlewareConfig( DatabaseConfig, KeywordStoreConfig, RedisConfig, - # configs of storage and storage providers StorageConfig, AliyunOSSStorageConfig, AzureBlobStorageConfig, + BaiduOBSStorageConfig, GoogleCloudStorageConfig, - TencentCloudCOSStorageConfig, - S3StorageConfig, + HuaweiCloudOBSStorageConfig, OCIStorageConfig, - + S3StorageConfig, + SupabaseStorageConfig, + TencentCloudCOSStorageConfig, + VolcengineTOSStorageConfig, # configs of vdb and vdb providers VectorStoreConfig, AnalyticdbConfig, @@ -199,5 +254,14 @@ class MiddlewareConfig( TencentVectorDBConfig, TiDBVectorConfig, WeaviateConfig, + ElasticsearchConfig, + CouchbaseConfig, + InternalTestConfig, + VikingDBConfig, + UpstashConfig, + TidbOnQdrantConfig, + LindormConfig, + OceanBaseVectorConfig, + BaiduVectorDBConfig, ): pass diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index 436ba5d4c01f5c..26b9b1347c61be 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -1,39 +1,70 @@ from typing import Optional -from pydantic import Field, NonNegativeInt, PositiveInt +from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt from pydantic_settings import BaseSettings class RedisConfig(BaseSettings): """ - Redis configs + Configuration settings for Redis connection """ + REDIS_HOST: str = Field( - description='Redis host', - default='localhost', + description="Hostname or IP address of the Redis server", + default="localhost", ) REDIS_PORT: PositiveInt = Field( - description='Redis port', + description="Port number on which the Redis server is listening", default=6379, ) REDIS_USERNAME: Optional[str] = Field( - description='Redis username', + description="Username for Redis authentication (if required)", default=None, ) REDIS_PASSWORD: Optional[str] = Field( - description='Redis password', + description="Password for Redis authentication (if required)", default=None, ) REDIS_DB: NonNegativeInt = Field( - description='Redis database id, default to 0', + description="Redis database number to use (0-15)", default=0, ) REDIS_USE_SSL: bool = Field( - description='whether to use SSL for Redis connection', + description="Enable SSL/TLS for the Redis connection", default=False, ) + + REDIS_USE_SENTINEL: Optional[bool] = Field( + description="Enable Redis Sentinel mode for high availability", + default=False, + ) + + REDIS_SENTINELS: Optional[str] = Field( + description="Comma-separated list of Redis Sentinel nodes (host:port)", + default=None, + ) + + REDIS_SENTINEL_SERVICE_NAME: Optional[str] = Field( + description="Name of the Redis Sentinel service to monitor", + default=None, + ) + + REDIS_SENTINEL_USERNAME: Optional[str] = Field( + description="Username for Redis Sentinel authentication (if required)", + default=None, + ) + + REDIS_SENTINEL_PASSWORD: Optional[str] = Field( + description="Password for Redis Sentinel authentication (if required)", + default=None, + ) + + REDIS_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field( + description="Socket timeout in seconds for Redis Sentinel connections", + default=0.1, + ) diff --git a/api/configs/middleware/storage/aliyun_oss_storage_config.py b/api/configs/middleware/storage/aliyun_oss_storage_config.py index 19e6cafb1282b1..07eb527170b2ea 100644 --- a/api/configs/middleware/storage/aliyun_oss_storage_config.py +++ b/api/configs/middleware/storage/aliyun_oss_storage_config.py @@ -6,35 +6,40 @@ class AliyunOSSStorageConfig(BaseSettings): """ - Aliyun storage configs + Configuration settings for Aliyun Object Storage Service (OSS) """ ALIYUN_OSS_BUCKET_NAME: Optional[str] = Field( - description='Aliyun OSS bucket name', + description="Name of the Aliyun OSS bucket to store and retrieve objects", default=None, ) ALIYUN_OSS_ACCESS_KEY: Optional[str] = Field( - description='Aliyun OSS access key', + description="Access key ID for authenticating with Aliyun OSS", default=None, ) ALIYUN_OSS_SECRET_KEY: Optional[str] = Field( - description='Aliyun OSS secret key', + description="Secret access key for authenticating with Aliyun OSS", default=None, ) ALIYUN_OSS_ENDPOINT: Optional[str] = Field( - description='Aliyun OSS endpoint URL', + description="URL of the Aliyun OSS endpoint for your chosen region", default=None, ) ALIYUN_OSS_REGION: Optional[str] = Field( - description='Aliyun OSS region', + description="Aliyun OSS region where your bucket is located (e.g., 'oss-cn-hangzhou')", default=None, ) ALIYUN_OSS_AUTH_VERSION: Optional[str] = Field( - description='Aliyun OSS authentication version', + description="Version of the authentication protocol to use with Aliyun OSS (e.g., 'v4')", + default=None, + ) + + ALIYUN_OSS_PATH: Optional[str] = Field( + description="Base path within the bucket to store objects (e.g., 'my-app-data/')", default=None, ) diff --git a/api/configs/middleware/storage/amazon_s3_storage_config.py b/api/configs/middleware/storage/amazon_s3_storage_config.py index 2566fbd5da6b96..f2d94b12ffa979 100644 --- a/api/configs/middleware/storage/amazon_s3_storage_config.py +++ b/api/configs/middleware/storage/amazon_s3_storage_config.py @@ -6,40 +6,40 @@ class S3StorageConfig(BaseSettings): """ - S3 storage configs + Configuration settings for S3-compatible object storage """ S3_ENDPOINT: Optional[str] = Field( - description='S3 storage endpoint', + description="URL of the S3-compatible storage endpoint (e.g., 'https://s3.amazonaws.com')", default=None, ) S3_REGION: Optional[str] = Field( - description='S3 storage region', + description="Region where the S3 bucket is located (e.g., 'us-east-1')", default=None, ) S3_BUCKET_NAME: Optional[str] = Field( - description='S3 storage bucket name', + description="Name of the S3 bucket to store and retrieve objects", default=None, ) S3_ACCESS_KEY: Optional[str] = Field( - description='S3 storage access key', + description="Access key ID for authenticating with the S3 service", default=None, ) S3_SECRET_KEY: Optional[str] = Field( - description='S3 storage secret key', + description="Secret access key for authenticating with the S3 service", default=None, ) S3_ADDRESS_STYLE: str = Field( - description='S3 storage address style', - default='auto', + description="S3 addressing style: 'auto', 'path', or 'virtual'", + default="auto", ) S3_USE_AWS_MANAGED_IAM: bool = Field( - description='whether to use aws managed IAM for S3', + description="Use AWS managed IAM roles for authentication instead of access/secret keys", default=False, ) diff --git a/api/configs/middleware/storage/azure_blob_storage_config.py b/api/configs/middleware/storage/azure_blob_storage_config.py index 26e441c89bd4e4..b7ab5247a9d4dd 100644 --- a/api/configs/middleware/storage/azure_blob_storage_config.py +++ b/api/configs/middleware/storage/azure_blob_storage_config.py @@ -6,25 +6,25 @@ class AzureBlobStorageConfig(BaseSettings): """ - Azure Blob storage configs + Configuration settings for Azure Blob Storage """ AZURE_BLOB_ACCOUNT_NAME: Optional[str] = Field( - description='Azure Blob account name', + description="Name of the Azure Storage account (e.g., 'mystorageaccount')", default=None, ) AZURE_BLOB_ACCOUNT_KEY: Optional[str] = Field( - description='Azure Blob account key', + description="Access key for authenticating with the Azure Storage account", default=None, ) AZURE_BLOB_CONTAINER_NAME: Optional[str] = Field( - description='Azure Blob container name', + description="Name of the Azure Blob container to store and retrieve objects", default=None, ) AZURE_BLOB_ACCOUNT_URL: Optional[str] = Field( - description='Azure Blob account URL', + description="URL of the Azure Blob storage endpoint (e.g., 'https://mystorageaccount.blob.core.windows.net')", default=None, ) diff --git a/api/configs/middleware/storage/baidu_obs_storage_config.py b/api/configs/middleware/storage/baidu_obs_storage_config.py new file mode 100644 index 00000000000000..c511628a1514a7 --- /dev/null +++ b/api/configs/middleware/storage/baidu_obs_storage_config.py @@ -0,0 +1,29 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class BaiduOBSStorageConfig(BaseModel): + """ + Configuration settings for Baidu Object Storage Service (OBS) + """ + + BAIDU_OBS_BUCKET_NAME: Optional[str] = Field( + description="Name of the Baidu OBS bucket to store and retrieve objects (e.g., 'my-obs-bucket')", + default=None, + ) + + BAIDU_OBS_ACCESS_KEY: Optional[str] = Field( + description="Access Key ID for authenticating with Baidu OBS", + default=None, + ) + + BAIDU_OBS_SECRET_KEY: Optional[str] = Field( + description="Secret Access Key for authenticating with Baidu OBS", + default=None, + ) + + BAIDU_OBS_ENDPOINT: Optional[str] = Field( + description="URL of the Baidu OSS endpoint for your chosen region (e.g., 'https://.bj.bcebos.com')", + default=None, + ) diff --git a/api/configs/middleware/storage/google_cloud_storage_config.py b/api/configs/middleware/storage/google_cloud_storage_config.py index e1b0e34e0c32fd..e5d763d7f5c615 100644 --- a/api/configs/middleware/storage/google_cloud_storage_config.py +++ b/api/configs/middleware/storage/google_cloud_storage_config.py @@ -6,15 +6,15 @@ class GoogleCloudStorageConfig(BaseSettings): """ - Google Cloud storage configs + Configuration settings for Google Cloud Storage """ GOOGLE_STORAGE_BUCKET_NAME: Optional[str] = Field( - description='Google Cloud storage bucket name', + description="Name of the Google Cloud Storage bucket to store and retrieve objects (e.g., 'my-gcs-bucket')", default=None, ) GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: Optional[str] = Field( - description='Google Cloud storage service account json base64', + description="Base64-encoded JSON key file for Google Cloud service account authentication", default=None, ) diff --git a/api/configs/middleware/storage/huawei_obs_storage_config.py b/api/configs/middleware/storage/huawei_obs_storage_config.py new file mode 100644 index 00000000000000..3e9e7543ab2bab --- /dev/null +++ b/api/configs/middleware/storage/huawei_obs_storage_config.py @@ -0,0 +1,29 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class HuaweiCloudOBSStorageConfig(BaseModel): + """ + Configuration settings for Huawei Cloud Object Storage Service (OBS) + """ + + HUAWEI_OBS_BUCKET_NAME: Optional[str] = Field( + description="Name of the Huawei Cloud OBS bucket to store and retrieve objects (e.g., 'my-obs-bucket')", + default=None, + ) + + HUAWEI_OBS_ACCESS_KEY: Optional[str] = Field( + description="Access Key ID for authenticating with Huawei Cloud OBS", + default=None, + ) + + HUAWEI_OBS_SECRET_KEY: Optional[str] = Field( + description="Secret Access Key for authenticating with Huawei Cloud OBS", + default=None, + ) + + HUAWEI_OBS_SERVER: Optional[str] = Field( + description="Endpoint URL for Huawei Cloud OBS (e.g., 'https://obs.cn-north-4.myhuaweicloud.com')", + default=None, + ) diff --git a/api/configs/middleware/storage/oci_storage_config.py b/api/configs/middleware/storage/oci_storage_config.py index 6c0c06746954f0..edc245bcac59bb 100644 --- a/api/configs/middleware/storage/oci_storage_config.py +++ b/api/configs/middleware/storage/oci_storage_config.py @@ -6,31 +6,30 @@ class OCIStorageConfig(BaseSettings): """ - OCI storage configs + Configuration settings for Oracle Cloud Infrastructure (OCI) Object Storage """ OCI_ENDPOINT: Optional[str] = Field( - description='OCI storage endpoint', + description="URL of the OCI Object Storage endpoint (e.g., 'https://objectstorage.us-phoenix-1.oraclecloud.com')", default=None, ) OCI_REGION: Optional[str] = Field( - description='OCI storage region', + description="OCI region where the bucket is located (e.g., 'us-phoenix-1')", default=None, ) OCI_BUCKET_NAME: Optional[str] = Field( - description='OCI storage bucket name', + description="Name of the OCI Object Storage bucket to store and retrieve objects (e.g., 'my-oci-bucket')", default=None, ) OCI_ACCESS_KEY: Optional[str] = Field( - description='OCI storage access key', + description="Access key (also known as API key) for authenticating with OCI Object Storage", default=None, ) OCI_SECRET_KEY: Optional[str] = Field( - description='OCI storage secret key', + description="Secret key associated with the access key for authenticating with OCI Object Storage", default=None, ) - diff --git a/api/configs/middleware/storage/supabase_storage_config.py b/api/configs/middleware/storage/supabase_storage_config.py new file mode 100644 index 00000000000000..a3e905b21c63e9 --- /dev/null +++ b/api/configs/middleware/storage/supabase_storage_config.py @@ -0,0 +1,24 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class SupabaseStorageConfig(BaseModel): + """ + Configuration settings for Supabase Object Storage Service + """ + + SUPABASE_BUCKET_NAME: Optional[str] = Field( + description="Name of the Supabase bucket to store and retrieve objects (e.g., 'dify-bucket')", + default=None, + ) + + SUPABASE_API_KEY: Optional[str] = Field( + description="API KEY for authenticating with Supabase", + default=None, + ) + + SUPABASE_URL: Optional[str] = Field( + description="URL of the Supabase", + default=None, + ) diff --git a/api/configs/middleware/storage/tencent_cos_storage_config.py b/api/configs/middleware/storage/tencent_cos_storage_config.py index 1060c7b93e0bf1..255c4e8938e0fb 100644 --- a/api/configs/middleware/storage/tencent_cos_storage_config.py +++ b/api/configs/middleware/storage/tencent_cos_storage_config.py @@ -6,30 +6,30 @@ class TencentCloudCOSStorageConfig(BaseSettings): """ - Tencent Cloud COS storage configs + Configuration settings for Tencent Cloud Object Storage (COS) """ TENCENT_COS_BUCKET_NAME: Optional[str] = Field( - description='Tencent Cloud COS bucket name', + description="Name of the Tencent Cloud COS bucket to store and retrieve objects", default=None, ) TENCENT_COS_REGION: Optional[str] = Field( - description='Tencent Cloud COS region', + description="Tencent Cloud region where the COS bucket is located (e.g., 'ap-guangzhou')", default=None, ) TENCENT_COS_SECRET_ID: Optional[str] = Field( - description='Tencent Cloud COS secret id', + description="SecretId for authenticating with Tencent Cloud COS (part of API credentials)", default=None, ) TENCENT_COS_SECRET_KEY: Optional[str] = Field( - description='Tencent Cloud COS secret key', + description="SecretKey for authenticating with Tencent Cloud COS (part of API credentials)", default=None, ) TENCENT_COS_SCHEME: Optional[str] = Field( - description='Tencent Cloud COS scheme', + description="Protocol scheme for COS requests: 'https' (recommended) or 'http'", default=None, ) diff --git a/api/configs/middleware/storage/volcengine_tos_storage_config.py b/api/configs/middleware/storage/volcengine_tos_storage_config.py new file mode 100644 index 00000000000000..89ea8850023009 --- /dev/null +++ b/api/configs/middleware/storage/volcengine_tos_storage_config.py @@ -0,0 +1,34 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class VolcengineTOSStorageConfig(BaseModel): + """ + Configuration settings for Volcengine Tinder Object Storage (TOS) + """ + + VOLCENGINE_TOS_BUCKET_NAME: Optional[str] = Field( + description="Name of the Volcengine TOS bucket to store and retrieve objects (e.g., 'my-tos-bucket')", + default=None, + ) + + VOLCENGINE_TOS_ACCESS_KEY: Optional[str] = Field( + description="Access Key ID for authenticating with Volcengine TOS", + default=None, + ) + + VOLCENGINE_TOS_SECRET_KEY: Optional[str] = Field( + description="Secret Access Key for authenticating with Volcengine TOS", + default=None, + ) + + VOLCENGINE_TOS_ENDPOINT: Optional[str] = Field( + description="URL of the Volcengine TOS endpoint (e.g., 'https://tos-cn-beijing.volces.com')", + default=None, + ) + + VOLCENGINE_TOS_REGION: Optional[str] = Field( + description="Volcengine region where the TOS bucket is located (e.g., 'cn-beijing')", + default=None, + ) diff --git a/api/configs/middleware/vdb/analyticdb_config.py b/api/configs/middleware/vdb/analyticdb_config.py index db2899265e204f..247a8ea555948c 100644 --- a/api/configs/middleware/vdb/analyticdb_config.py +++ b/api/configs/middleware/vdb/analyticdb_config.py @@ -5,40 +5,38 @@ class AnalyticdbConfig(BaseModel): """ - Configuration for connecting to AnalyticDB. + Configuration for connecting to Alibaba Cloud AnalyticDB for PostgreSQL. Refer to the following documentation for details on obtaining credentials: https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled """ - ANALYTICDB_KEY_ID : Optional[str] = Field( - default=None, - description="The Access Key ID provided by Alibaba Cloud for authentication." + ANALYTICDB_KEY_ID: Optional[str] = Field( + default=None, description="The Access Key ID provided by Alibaba Cloud for API authentication." ) - ANALYTICDB_KEY_SECRET : Optional[str] = Field( - default=None, - description="The Secret Access Key corresponding to the Access Key ID for secure access." + ANALYTICDB_KEY_SECRET: Optional[str] = Field( + default=None, description="The Secret Access Key corresponding to the Access Key ID for secure API access." ) - ANALYTICDB_REGION_ID : Optional[str] = Field( + ANALYTICDB_REGION_ID: Optional[str] = Field( default=None, - description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')." + description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou', 'ap-southeast-1').", ) - ANALYTICDB_INSTANCE_ID : Optional[str] = Field( + ANALYTICDB_INSTANCE_ID: Optional[str] = Field( default=None, - description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456').." + description="The unique identifier of the AnalyticDB instance you want to connect to.", ) - ANALYTICDB_ACCOUNT : Optional[str] = Field( + ANALYTICDB_ACCOUNT: Optional[str] = Field( default=None, - description="The account name used to log in to the AnalyticDB instance." + description="The account name used to log in to the AnalyticDB instance" + " (usually the initial account created with the instance).", ) - ANALYTICDB_PASSWORD : Optional[str] = Field( - default=None, - description="The password associated with the AnalyticDB account for authentication." + ANALYTICDB_PASSWORD: Optional[str] = Field( + default=None, description="The password associated with the AnalyticDB account for database authentication." ) - ANALYTICDB_NAMESPACE : Optional[str] = Field( - default=None, - description="The namespace within AnalyticDB for schema isolation." + ANALYTICDB_NAMESPACE: Optional[str] = Field( + default=None, description="The namespace within AnalyticDB for schema isolation (if using namespace feature)." ) - ANALYTICDB_NAMESPACE_PASSWORD : Optional[str] = Field( + ANALYTICDB_NAMESPACE_PASSWORD: Optional[str] = Field( default=None, - description="The password for accessing the specified namespace within the AnalyticDB instance." + description="The password for accessing the specified namespace within the AnalyticDB instance" + " (if namespace feature is enabled).", ) diff --git a/api/configs/middleware/vdb/baidu_vector_config.py b/api/configs/middleware/vdb/baidu_vector_config.py new file mode 100644 index 00000000000000..44742c2e2f4349 --- /dev/null +++ b/api/configs/middleware/vdb/baidu_vector_config.py @@ -0,0 +1,45 @@ +from typing import Optional + +from pydantic import Field, NonNegativeInt, PositiveInt +from pydantic_settings import BaseSettings + + +class BaiduVectorDBConfig(BaseSettings): + """ + Configuration settings for Baidu Vector Database + """ + + BAIDU_VECTOR_DB_ENDPOINT: Optional[str] = Field( + description="URL of the Baidu Vector Database service (e.g., 'http://vdb.bj.baidubce.com')", + default=None, + ) + + BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: PositiveInt = Field( + description="Timeout in milliseconds for Baidu Vector Database operations (default is 30000 milliseconds)", + default=30000, + ) + + BAIDU_VECTOR_DB_ACCOUNT: Optional[str] = Field( + description="Account for authenticating with the Baidu Vector Database", + default=None, + ) + + BAIDU_VECTOR_DB_API_KEY: Optional[str] = Field( + description="API key for authenticating with the Baidu Vector Database service", + default=None, + ) + + BAIDU_VECTOR_DB_DATABASE: Optional[str] = Field( + description="Name of the specific Baidu Vector Database to connect to", + default=None, + ) + + BAIDU_VECTOR_DB_SHARD: PositiveInt = Field( + description="Number of shards for the Baidu Vector Database (default is 1)", + default=1, + ) + + BAIDU_VECTOR_DB_REPLICAS: NonNegativeInt = Field( + description="Number of replicas for the Baidu Vector Database (default is 3)", + default=3, + ) diff --git a/api/configs/middleware/vdb/chroma_config.py b/api/configs/middleware/vdb/chroma_config.py index f365879efb1a19..e83a9902dee903 100644 --- a/api/configs/middleware/vdb/chroma_config.py +++ b/api/configs/middleware/vdb/chroma_config.py @@ -6,35 +6,35 @@ class ChromaConfig(BaseSettings): """ - Chroma configs + Configuration settings for Chroma vector database """ CHROMA_HOST: Optional[str] = Field( - description='Chroma host', + description="Hostname or IP address of the Chroma server (e.g., 'localhost' or '192.168.1.100')", default=None, ) CHROMA_PORT: PositiveInt = Field( - description='Chroma port', + description="Port number on which the Chroma server is listening (default is 8000)", default=8000, ) CHROMA_TENANT: Optional[str] = Field( - description='Chroma database', + description="Tenant identifier for multi-tenancy support in Chroma", default=None, ) CHROMA_DATABASE: Optional[str] = Field( - description='Chroma database', + description="Name of the Chroma database to connect to", default=None, ) CHROMA_AUTH_PROVIDER: Optional[str] = Field( - description='Chroma authentication provider', + description="Authentication provider for Chroma (e.g., 'basic', 'token', or a custom provider)", default=None, ) CHROMA_AUTH_CREDENTIALS: Optional[str] = Field( - description='Chroma authentication credentials', + description="Authentication credentials for Chroma (format depends on the auth provider)", default=None, ) diff --git a/api/configs/middleware/vdb/couchbase_config.py b/api/configs/middleware/vdb/couchbase_config.py new file mode 100644 index 00000000000000..391089ec6e8d00 --- /dev/null +++ b/api/configs/middleware/vdb/couchbase_config.py @@ -0,0 +1,34 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class CouchbaseConfig(BaseModel): + """ + Couchbase configs + """ + + COUCHBASE_CONNECTION_STRING: Optional[str] = Field( + description="COUCHBASE connection string", + default=None, + ) + + COUCHBASE_USER: Optional[str] = Field( + description="COUCHBASE user", + default=None, + ) + + COUCHBASE_PASSWORD: Optional[str] = Field( + description="COUCHBASE password", + default=None, + ) + + COUCHBASE_BUCKET_NAME: Optional[str] = Field( + description="COUCHBASE bucket name", + default=None, + ) + + COUCHBASE_SCOPE_NAME: Optional[str] = Field( + description="COUCHBASE scope name", + default=None, + ) diff --git a/api/configs/middleware/vdb/elasticsearch_config.py b/api/configs/middleware/vdb/elasticsearch_config.py new file mode 100644 index 00000000000000..df8182985dc193 --- /dev/null +++ b/api/configs/middleware/vdb/elasticsearch_config.py @@ -0,0 +1,30 @@ +from typing import Optional + +from pydantic import Field, PositiveInt +from pydantic_settings import BaseSettings + + +class ElasticsearchConfig(BaseSettings): + """ + Configuration settings for Elasticsearch + """ + + ELASTICSEARCH_HOST: Optional[str] = Field( + description="Hostname or IP address of the Elasticsearch server (e.g., 'localhost' or '192.168.1.100')", + default="127.0.0.1", + ) + + ELASTICSEARCH_PORT: PositiveInt = Field( + description="Port number on which the Elasticsearch server is listening (default is 9200)", + default=9200, + ) + + ELASTICSEARCH_USERNAME: Optional[str] = Field( + description="Username for authenticating with Elasticsearch (default is 'elastic')", + default="elastic", + ) + + ELASTICSEARCH_PASSWORD: Optional[str] = Field( + description="Password for authenticating with Elasticsearch (default is 'elastic')", + default="elastic", + ) diff --git a/api/configs/middleware/vdb/lindorm_config.py b/api/configs/middleware/vdb/lindorm_config.py new file mode 100644 index 00000000000000..0f6c6528066747 --- /dev/null +++ b/api/configs/middleware/vdb/lindorm_config.py @@ -0,0 +1,23 @@ +from typing import Optional + +from pydantic import Field +from pydantic_settings import BaseSettings + + +class LindormConfig(BaseSettings): + """ + Lindorm configs + """ + + LINDORM_URL: Optional[str] = Field( + description="Lindorm url", + default=None, + ) + LINDORM_USERNAME: Optional[str] = Field( + description="Lindorm user", + default=None, + ) + LINDORM_PASSWORD: Optional[str] = Field( + description="Lindorm password", + default=None, + ) diff --git a/api/configs/middleware/vdb/milvus_config.py b/api/configs/middleware/vdb/milvus_config.py index 01502d45901764..231cbbbe8ffc9e 100644 --- a/api/configs/middleware/vdb/milvus_config.py +++ b/api/configs/middleware/vdb/milvus_config.py @@ -1,40 +1,35 @@ from typing import Optional -from pydantic import Field, PositiveInt +from pydantic import Field from pydantic_settings import BaseSettings class MilvusConfig(BaseSettings): """ - Milvus configs + Configuration settings for Milvus vector database """ - MILVUS_HOST: Optional[str] = Field( - description='Milvus host', - default=None, + MILVUS_URI: Optional[str] = Field( + description="URI for connecting to the Milvus server (e.g., 'http://localhost:19530' or 'https://milvus-instance.example.com:19530')", + default="http://127.0.0.1:19530", ) - MILVUS_PORT: PositiveInt = Field( - description='Milvus RestFul API port', - default=9091, + MILVUS_TOKEN: Optional[str] = Field( + description="Authentication token for Milvus, if token-based authentication is enabled", + default=None, ) MILVUS_USER: Optional[str] = Field( - description='Milvus user', + description="Username for authenticating with Milvus, if username/password authentication is enabled", default=None, ) MILVUS_PASSWORD: Optional[str] = Field( - description='Milvus password', + description="Password for authenticating with Milvus, if username/password authentication is enabled", default=None, ) - MILVUS_SECURE: bool = Field( - description='whether to use SSL connection for Milvus', - default=False, - ) - MILVUS_DATABASE: str = Field( - description='Milvus database, default to `default`', - default='default', + description="Name of the Milvus database to connect to (default is 'default')", + default="default", ) diff --git a/api/configs/middleware/vdb/myscale_config.py b/api/configs/middleware/vdb/myscale_config.py index 895cd6f1769d6e..5896c19d27d117 100644 --- a/api/configs/middleware/vdb/myscale_config.py +++ b/api/configs/middleware/vdb/myscale_config.py @@ -1,38 +1,37 @@ - from pydantic import BaseModel, Field, PositiveInt class MyScaleConfig(BaseModel): """ - MyScale configs + Configuration settings for MyScale vector database """ MYSCALE_HOST: str = Field( - description='MyScale host', - default='localhost', + description="Hostname or IP address of the MyScale server (e.g., 'localhost' or 'myscale.example.com')", + default="localhost", ) MYSCALE_PORT: PositiveInt = Field( - description='MyScale port', + description="Port number on which the MyScale server is listening (default is 8123)", default=8123, ) MYSCALE_USER: str = Field( - description='MyScale user', - default='default', + description="Username for authenticating with MyScale (default is 'default')", + default="default", ) MYSCALE_PASSWORD: str = Field( - description='MyScale password', - default='', + description="Password for authenticating with MyScale (default is an empty string)", + default="", ) MYSCALE_DATABASE: str = Field( - description='MyScale database name', - default='default', + description="Name of the MyScale database to connect to (default is 'default')", + default="default", ) MYSCALE_FTS_PARAMS: str = Field( - description='MyScale fts index parameters', - default='', + description="Additional parameters for MyScale Full Text Search index)", + default="", ) diff --git a/api/configs/middleware/vdb/oceanbase_config.py b/api/configs/middleware/vdb/oceanbase_config.py new file mode 100644 index 00000000000000..87427af960202d --- /dev/null +++ b/api/configs/middleware/vdb/oceanbase_config.py @@ -0,0 +1,35 @@ +from typing import Optional + +from pydantic import Field, PositiveInt +from pydantic_settings import BaseSettings + + +class OceanBaseVectorConfig(BaseSettings): + """ + Configuration settings for OceanBase Vector database + """ + + OCEANBASE_VECTOR_HOST: Optional[str] = Field( + description="Hostname or IP address of the OceanBase Vector server (e.g. 'localhost')", + default=None, + ) + + OCEANBASE_VECTOR_PORT: Optional[PositiveInt] = Field( + description="Port number on which the OceanBase Vector server is listening (default is 2881)", + default=2881, + ) + + OCEANBASE_VECTOR_USER: Optional[str] = Field( + description="Username for authenticating with the OceanBase Vector database", + default=None, + ) + + OCEANBASE_VECTOR_PASSWORD: Optional[str] = Field( + description="Password for authenticating with the OceanBase Vector database", + default=None, + ) + + OCEANBASE_VECTOR_DATABASE: Optional[str] = Field( + description="Name of the OceanBase Vector database to connect to", + default=None, + ) diff --git a/api/configs/middleware/vdb/opensearch_config.py b/api/configs/middleware/vdb/opensearch_config.py index 15d6f5b6a97126..81dde4c04d472e 100644 --- a/api/configs/middleware/vdb/opensearch_config.py +++ b/api/configs/middleware/vdb/opensearch_config.py @@ -6,30 +6,30 @@ class OpenSearchConfig(BaseSettings): """ - OpenSearch configs + Configuration settings for OpenSearch """ OPENSEARCH_HOST: Optional[str] = Field( - description='OpenSearch host', + description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')", default=None, ) OPENSEARCH_PORT: PositiveInt = Field( - description='OpenSearch port', + description="Port number on which the OpenSearch server is listening (default is 9200)", default=9200, ) OPENSEARCH_USER: Optional[str] = Field( - description='OpenSearch user', + description="Username for authenticating with OpenSearch", default=None, ) OPENSEARCH_PASSWORD: Optional[str] = Field( - description='OpenSearch password', + description="Password for authenticating with OpenSearch", default=None, ) OPENSEARCH_SECURE: bool = Field( - description='whether to use SSL connection for OpenSearch', + description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)", default=False, ) diff --git a/api/configs/middleware/vdb/oracle_config.py b/api/configs/middleware/vdb/oracle_config.py index 888fc19492d67e..5d2cf67ba37b34 100644 --- a/api/configs/middleware/vdb/oracle_config.py +++ b/api/configs/middleware/vdb/oracle_config.py @@ -6,30 +6,30 @@ class OracleConfig(BaseSettings): """ - ORACLE configs + Configuration settings for Oracle database """ ORACLE_HOST: Optional[str] = Field( - description='ORACLE host', + description="Hostname or IP address of the Oracle database server (e.g., 'localhost' or 'oracle.example.com')", default=None, ) - ORACLE_PORT: Optional[PositiveInt] = Field( - description='ORACLE port', + ORACLE_PORT: PositiveInt = Field( + description="Port number on which the Oracle database server is listening (default is 1521)", default=1521, ) ORACLE_USER: Optional[str] = Field( - description='ORACLE user', + description="Username for authenticating with the Oracle database", default=None, ) ORACLE_PASSWORD: Optional[str] = Field( - description='ORACLE password', + description="Password for authenticating with the Oracle database", default=None, ) ORACLE_DATABASE: Optional[str] = Field( - description='ORACLE database', + description="Name of the Oracle database or service to connect to (e.g., 'ORCL' or 'pdborcl')", default=None, ) diff --git a/api/configs/middleware/vdb/pgvector_config.py b/api/configs/middleware/vdb/pgvector_config.py index 8a677f60a3a851..4561a9a7ca9626 100644 --- a/api/configs/middleware/vdb/pgvector_config.py +++ b/api/configs/middleware/vdb/pgvector_config.py @@ -6,30 +6,40 @@ class PGVectorConfig(BaseSettings): """ - PGVector configs + Configuration settings for PGVector (PostgreSQL with vector extension) """ PGVECTOR_HOST: Optional[str] = Field( - description='PGVector host', + description="Hostname or IP address of the PostgreSQL server with PGVector extension (e.g., 'localhost')", default=None, ) - PGVECTOR_PORT: Optional[PositiveInt] = Field( - description='PGVector port', + PGVECTOR_PORT: PositiveInt = Field( + description="Port number on which the PostgreSQL server is listening (default is 5433)", default=5433, ) PGVECTOR_USER: Optional[str] = Field( - description='PGVector user', + description="Username for authenticating with the PostgreSQL database", default=None, ) PGVECTOR_PASSWORD: Optional[str] = Field( - description='PGVector password', + description="Password for authenticating with the PostgreSQL database", default=None, ) PGVECTOR_DATABASE: Optional[str] = Field( - description='PGVector database', + description="Name of the PostgreSQL database to connect to", default=None, ) + + PGVECTOR_MIN_CONNECTION: PositiveInt = Field( + description="Min connection of the PostgreSQL database", + default=1, + ) + + PGVECTOR_MAX_CONNECTION: PositiveInt = Field( + description="Max connection of the PostgreSQL database", + default=5, + ) diff --git a/api/configs/middleware/vdb/pgvectors_config.py b/api/configs/middleware/vdb/pgvectors_config.py index 39f52f22ff6c95..fa3bca5bb75bc5 100644 --- a/api/configs/middleware/vdb/pgvectors_config.py +++ b/api/configs/middleware/vdb/pgvectors_config.py @@ -6,30 +6,30 @@ class PGVectoRSConfig(BaseSettings): """ - PGVectoRS configs + Configuration settings for PGVecto.RS (Rust-based vector extension for PostgreSQL) """ PGVECTO_RS_HOST: Optional[str] = Field( - description='PGVectoRS host', + description="Hostname or IP address of the PostgreSQL server with PGVecto.RS extension (e.g., 'localhost')", default=None, ) - PGVECTO_RS_PORT: Optional[PositiveInt] = Field( - description='PGVectoRS port', + PGVECTO_RS_PORT: PositiveInt = Field( + description="Port number on which the PostgreSQL server with PGVecto.RS is listening (default is 5431)", default=5431, ) PGVECTO_RS_USER: Optional[str] = Field( - description='PGVectoRS user', + description="Username for authenticating with the PostgreSQL database using PGVecto.RS", default=None, ) PGVECTO_RS_PASSWORD: Optional[str] = Field( - description='PGVectoRS password', + description="Password for authenticating with the PostgreSQL database using PGVecto.RS", default=None, ) PGVECTO_RS_DATABASE: Optional[str] = Field( - description='PGVectoRS database', + description="Name of the PostgreSQL database with PGVecto.RS extension to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/qdrant_config.py b/api/configs/middleware/vdb/qdrant_config.py index c85bf9c7dc6047..b70f6246523c57 100644 --- a/api/configs/middleware/vdb/qdrant_config.py +++ b/api/configs/middleware/vdb/qdrant_config.py @@ -6,30 +6,30 @@ class QdrantConfig(BaseSettings): """ - Qdrant configs + Configuration settings for Qdrant vector database """ QDRANT_URL: Optional[str] = Field( - description='Qdrant url', + description="URL of the Qdrant server (e.g., 'http://localhost:6333' or 'https://qdrant.example.com')", default=None, ) QDRANT_API_KEY: Optional[str] = Field( - description='Qdrant api key', + description="API key for authenticating with the Qdrant server", default=None, ) QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field( - description='Qdrant client timeout in seconds', + description="Timeout in seconds for Qdrant client operations (default is 20 seconds)", default=20, ) QDRANT_GRPC_ENABLED: bool = Field( - description='whether enable grpc support for Qdrant connection', + description="Whether to enable gRPC support for Qdrant connection (True for gRPC, False for HTTP)", default=False, ) QDRANT_GRPC_PORT: PositiveInt = Field( - description='Qdrant grpc port', + description="Port number for gRPC connection to Qdrant server (default is 6334)", default=6334, ) diff --git a/api/configs/middleware/vdb/relyt_config.py b/api/configs/middleware/vdb/relyt_config.py index be93185f3ccab1..5ffbea7b19bb8f 100644 --- a/api/configs/middleware/vdb/relyt_config.py +++ b/api/configs/middleware/vdb/relyt_config.py @@ -6,30 +6,30 @@ class RelytConfig(BaseSettings): """ - Relyt configs + Configuration settings for Relyt database """ RELYT_HOST: Optional[str] = Field( - description='Relyt host', + description="Hostname or IP address of the Relyt server (e.g., 'localhost' or 'relyt.example.com')", default=None, ) RELYT_PORT: PositiveInt = Field( - description='Relyt port', + description="Port number on which the Relyt server is listening (default is 9200)", default=9200, ) RELYT_USER: Optional[str] = Field( - description='Relyt user', + description="Username for authenticating with the Relyt database", default=None, ) RELYT_PASSWORD: Optional[str] = Field( - description='Relyt password', + description="Password for authenticating with the Relyt database", default=None, ) RELYT_DATABASE: Optional[str] = Field( - description='Relyt database', - default='default', + description="Name of the Relyt database to connect to (default is 'default')", + default="default", ) diff --git a/api/configs/middleware/vdb/tencent_vector_config.py b/api/configs/middleware/vdb/tencent_vector_config.py index 531ec840686eea..9cf4d07f6fe660 100644 --- a/api/configs/middleware/vdb/tencent_vector_config.py +++ b/api/configs/middleware/vdb/tencent_vector_config.py @@ -6,45 +6,45 @@ class TencentVectorDBConfig(BaseSettings): """ - Tencent Vector configs + Configuration settings for Tencent Vector Database """ TENCENT_VECTOR_DB_URL: Optional[str] = Field( - description='Tencent Vector URL', + description="URL of the Tencent Vector Database service (e.g., 'https://vectordb.tencentcloudapi.com')", default=None, ) TENCENT_VECTOR_DB_API_KEY: Optional[str] = Field( - description='Tencent Vector API key', + description="API key for authenticating with the Tencent Vector Database service", default=None, ) TENCENT_VECTOR_DB_TIMEOUT: PositiveInt = Field( - description='Tencent Vector timeout in seconds', + description="Timeout in seconds for Tencent Vector Database operations (default is 30 seconds)", default=30, ) TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field( - description='Tencent Vector username', + description="Username for authenticating with the Tencent Vector Database (if required)", default=None, ) TENCENT_VECTOR_DB_PASSWORD: Optional[str] = Field( - description='Tencent Vector password', + description="Password for authenticating with the Tencent Vector Database (if required)", default=None, ) TENCENT_VECTOR_DB_SHARD: PositiveInt = Field( - description='Tencent Vector sharding number', + description="Number of shards for the Tencent Vector Database (default is 1)", default=1, ) TENCENT_VECTOR_DB_REPLICAS: NonNegativeInt = Field( - description='Tencent Vector replicas', + description="Number of replicas for the Tencent Vector Database (default is 2)", default=2, ) TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field( - description='Tencent Vector Database', + description="Name of the specific Tencent Vector Database to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/tidb_on_qdrant_config.py b/api/configs/middleware/vdb/tidb_on_qdrant_config.py new file mode 100644 index 00000000000000..d2625af2644785 --- /dev/null +++ b/api/configs/middleware/vdb/tidb_on_qdrant_config.py @@ -0,0 +1,70 @@ +from typing import Optional + +from pydantic import Field, NonNegativeInt, PositiveInt +from pydantic_settings import BaseSettings + + +class TidbOnQdrantConfig(BaseSettings): + """ + Tidb on Qdrant configs + """ + + TIDB_ON_QDRANT_URL: Optional[str] = Field( + description="Tidb on Qdrant url", + default=None, + ) + + TIDB_ON_QDRANT_API_KEY: Optional[str] = Field( + description="Tidb on Qdrant api key", + default=None, + ) + + TIDB_ON_QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field( + description="Tidb on Qdrant client timeout in seconds", + default=20, + ) + + TIDB_ON_QDRANT_GRPC_ENABLED: bool = Field( + description="whether enable grpc support for Tidb on Qdrant connection", + default=False, + ) + + TIDB_ON_QDRANT_GRPC_PORT: PositiveInt = Field( + description="Tidb on Qdrant grpc port", + default=6334, + ) + + TIDB_PUBLIC_KEY: Optional[str] = Field( + description="Tidb account public key", + default=None, + ) + + TIDB_PRIVATE_KEY: Optional[str] = Field( + description="Tidb account private key", + default=None, + ) + + TIDB_API_URL: Optional[str] = Field( + description="Tidb API url", + default=None, + ) + + TIDB_IAM_API_URL: Optional[str] = Field( + description="Tidb IAM API url", + default=None, + ) + + TIDB_REGION: Optional[str] = Field( + description="Tidb serverless region", + default="regions/aws-us-east-1", + ) + + TIDB_PROJECT_ID: Optional[str] = Field( + description="Tidb project id", + default=None, + ) + + TIDB_SPEND_LIMIT: Optional[int] = Field( + description="Tidb spend limit", + default=100, + ) diff --git a/api/configs/middleware/vdb/tidb_vector_config.py b/api/configs/middleware/vdb/tidb_vector_config.py index 8d459691a895bd..bc68be69d86ad7 100644 --- a/api/configs/middleware/vdb/tidb_vector_config.py +++ b/api/configs/middleware/vdb/tidb_vector_config.py @@ -6,30 +6,30 @@ class TiDBVectorConfig(BaseSettings): """ - TiDB Vector configs + Configuration settings for TiDB Vector database """ TIDB_VECTOR_HOST: Optional[str] = Field( - description='TiDB Vector host', + description="Hostname or IP address of the TiDB Vector server (e.g., 'localhost' or 'tidb.example.com')", default=None, ) TIDB_VECTOR_PORT: Optional[PositiveInt] = Field( - description='TiDB Vector port', + description="Port number on which the TiDB Vector server is listening (default is 4000)", default=4000, ) TIDB_VECTOR_USER: Optional[str] = Field( - description='TiDB Vector user', + description="Username for authenticating with the TiDB Vector database", default=None, ) TIDB_VECTOR_PASSWORD: Optional[str] = Field( - description='TiDB Vector password', + description="Password for authenticating with the TiDB Vector database", default=None, ) TIDB_VECTOR_DATABASE: Optional[str] = Field( - description='TiDB Vector database', + description="Name of the TiDB Vector database to connect to", default=None, ) diff --git a/api/configs/middleware/vdb/upstash_config.py b/api/configs/middleware/vdb/upstash_config.py new file mode 100644 index 00000000000000..412c56374ad41d --- /dev/null +++ b/api/configs/middleware/vdb/upstash_config.py @@ -0,0 +1,20 @@ +from typing import Optional + +from pydantic import Field +from pydantic_settings import BaseSettings + + +class UpstashConfig(BaseSettings): + """ + Configuration settings for Upstash vector database + """ + + UPSTASH_VECTOR_URL: Optional[str] = Field( + description="URL of the upstash server (e.g., 'https://vector.upstash.io')", + default=None, + ) + + UPSTASH_VECTOR_TOKEN: Optional[str] = Field( + description="Token for authenticating with the upstash server", + default=None, + ) diff --git a/api/configs/middleware/vdb/vikingdb_config.py b/api/configs/middleware/vdb/vikingdb_config.py new file mode 100644 index 00000000000000..3e718481dc7e05 --- /dev/null +++ b/api/configs/middleware/vdb/vikingdb_config.py @@ -0,0 +1,49 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class VikingDBConfig(BaseModel): + """ + Configuration for connecting to Volcengine VikingDB. + Refer to the following documentation for details on obtaining credentials: + https://www.volcengine.com/docs/6291/65568 + """ + + VIKINGDB_ACCESS_KEY: Optional[str] = Field( + description="The Access Key provided by Volcengine VikingDB for API authentication." + "Refer to the following documentation for details on obtaining credentials:" + "https://www.volcengine.com/docs/6291/65568", + default=None, + ) + + VIKINGDB_SECRET_KEY: Optional[str] = Field( + description="The Secret Key provided by Volcengine VikingDB for API authentication.", + default=None, + ) + + VIKINGDB_REGION: str = Field( + description="The region of the Volcengine VikingDB service.(e.g., 'cn-shanghai', 'cn-beijing').", + default="cn-shanghai", + ) + + VIKINGDB_HOST: str = Field( + description="The host of the Volcengine VikingDB service.(e.g., 'api-vikingdb.volces.com', \ + 'api-vikingdb.mlp.cn-shanghai.volces.com')", + default="api-vikingdb.mlp.cn-shanghai.volces.com", + ) + + VIKINGDB_SCHEME: str = Field( + description="The scheme of the Volcengine VikingDB service.(e.g., 'http', 'https').", + default="http", + ) + + VIKINGDB_CONNECTION_TIMEOUT: int = Field( + description="The connection timeout of the Volcengine VikingDB service.", + default=30, + ) + + VIKINGDB_SOCKET_TIMEOUT: int = Field( + description="The socket timeout of the Volcengine VikingDB service.", + default=30, + ) diff --git a/api/configs/middleware/vdb/weaviate_config.py b/api/configs/middleware/vdb/weaviate_config.py index b985ecea121f9e..25000e8bde2907 100644 --- a/api/configs/middleware/vdb/weaviate_config.py +++ b/api/configs/middleware/vdb/weaviate_config.py @@ -6,25 +6,25 @@ class WeaviateConfig(BaseSettings): """ - Weaviate configs + Configuration settings for Weaviate vector database """ WEAVIATE_ENDPOINT: Optional[str] = Field( - description='Weaviate endpoint URL', + description="URL of the Weaviate server (e.g., 'http://localhost:8080' or 'https://weaviate.example.com')", default=None, ) WEAVIATE_API_KEY: Optional[str] = Field( - description='Weaviate API key', + description="API key for authenticating with the Weaviate server", default=None, ) WEAVIATE_GRPC_ENABLED: bool = Field( - description='whether to enable gRPC for Weaviate connection', + description="Whether to enable gRPC for Weaviate connection (True for gRPC, False for HTTP)", default=True, ) WEAVIATE_BATCH_SIZE: PositiveInt = Field( - description='Weaviate batch size', + description="Number of objects to be processed in a single batch operation (default is 100)", default=100, ) diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index 247fcde655a180..b5cb1f06d951f0 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -8,11 +8,11 @@ class PackagingInfo(BaseSettings): """ CURRENT_VERSION: str = Field( - description='Dify version', - default='0.7.0', + description="Dify version", + default="0.11.0", ) COMMIT_SHA: str = Field( description="SHA-1 checksum of the git commit used to build the app", - default='', + default="", ) diff --git a/api/constants/__init__.py b/api/constants/__init__.py index e22c3268ef428b..05795e11d7dcc5 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -1 +1,24 @@ +from configs import dify_config + HIDDEN_VALUE = "[__HIDDEN__]" +UUID_NIL = "00000000-0000-0000-0000-000000000000" + +IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] +IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) + +VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"] +VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS]) + +AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"] +AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS]) + + +if dify_config.ETL_TYPE == "Unstructured": + DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"] + DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "xml", "epub")) + if dify_config.UNSTRUCTURED_API_URL: + DOCUMENT_EXTENSIONS.append("ppt") + DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS]) +else: + DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"] + DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS]) diff --git a/api/constants/recommended_apps.json b/api/constants/recommended_apps.json index df4adc4a1f81d4..3779fb0180ede4 100644 --- a/api/constants/recommended_apps.json +++ b/api/constants/recommended_apps.json @@ -320,7 +320,7 @@ "icon_background": "#FFEAD5", "id": "e9870913-dd01-4710-9f06-15d4180ca1ce", "mode": "advanced-chat", - "name": "Knowledge Retreival + Chatbot " + "name": "Knowledge Retrieval + Chatbot " }, "app_id": "e9870913-dd01-4710-9f06-15d4180ca1ce", "category": "Workflow", @@ -423,7 +423,7 @@ "name": "Website Generator" }, "a23b57fa-85da-49c0-a571-3aff375976c1": { - "export_data": "app:\n icon: \"\\U0001F911\"\n icon_background: '#E4FBCC'\n mode: agent-chat\n name: Investment Analysis Report Copilot\nmodel_config:\n agent_mode:\n enabled: true\n max_iteration: 5\n strategy: function_call\n tools:\n - enabled: true\n isDeleted: false\n notAuthor: false\n provider_id: yahoo\n provider_name: yahoo\n provider_type: builtin\n tool_label: Analytics\n tool_name: yahoo_finance_analytics\n tool_parameters:\n end_date: ''\n start_date: ''\n symbol: ''\n - enabled: true\n isDeleted: false\n notAuthor: false\n provider_id: yahoo\n provider_name: yahoo\n provider_type: builtin\n tool_label: News\n tool_name: yahoo_finance_news\n tool_parameters:\n symbol: ''\n - enabled: true\n isDeleted: false\n notAuthor: false\n provider_id: yahoo\n provider_name: yahoo\n provider_type: builtin\n tool_label: Ticker\n tool_name: yahoo_finance_ticker\n tool_parameters:\n symbol: ''\n annotation_reply:\n enabled: false\n chat_prompt_config: {}\n completion_prompt_config: {}\n dataset_configs:\n datasets:\n datasets: []\n retrieval_model: single\n dataset_query_variable: ''\n external_data_tools: []\n file_upload:\n image:\n detail: high\n enabled: false\n number_limits: 3\n transfer_methods:\n - remote_url\n - local_file\n model:\n completion_params:\n frequency_penalty: 0.5\n max_tokens: 4096\n presence_penalty: 0.5\n stop: []\n temperature: 0.2\n top_p: 0.75\n mode: chat\n name: gpt-4-1106-preview\n provider: openai\n more_like_this:\n enabled: false\n opening_statement: 'Welcome to your personalized Investment Analysis Copilot service,\n where we delve into the depths of stock analysis to provide you with comprehensive\n insights. To begin our journey into the financial world, try to ask:\n\n '\n pre_prompt: \"# Job Description: Data Analysis Copilot\\n## Character\\nMy primary\\\n \\ goal is to provide user with expert data analysis advice. Using extensive and\\\n \\ detailed data. Tell me the stock (with ticket symbol) you want to analyze. I\\\n \\ will do all fundemental, technical, market sentiment, and Marcoeconomical analysis\\\n \\ for the stock as an expert. \\n\\n## Skills \\n### Skill 1: Search for stock information\\\n \\ using 'Ticker' from Yahoo Finance \\n### Skill 2: Search for recent news using\\\n \\ 'News' for the target company. \\n### Skill 3: Search for financial figures and\\\n \\ analytics using 'Analytics' for the target company\\n\\n## Workflow\\nAsks the\\\n \\ user which stocks with ticker name need to be analyzed and then performs the\\\n \\ following analysis in sequence. \\n**Part I: Fundamental analysis: financial\\\n \\ reporting analysis\\n*Objective 1: In-depth analysis of the financial situation\\\n \\ of the target company.\\n*Steps:\\n1. Identify the object of analysis:\\n\\n\\n\\n2. Access to financial\\\n \\ reports \\n\\n- Obtain the key data\\\n \\ of the latest financial report of the target company {{company}} organized by\\\n \\ Yahoo Finance. \\n\\n\\n\\n3. Vertical Analysis:\\n- Get the insight of the company's\\\n \\ balance sheet Income Statement and cash flow. \\n- Analyze Income Statement:\\\n \\ Analyze the proportion of each type of income and expense to total income. /Analyze\\\n \\ Balance Sheet: Analyze the proportion of each asset and liability to total assets\\\n \\ or total liabilities./ Analyze Cash Flow \\n-\\n4. Ratio Analysis:\\n\\\n - analyze the Profitability Ratios Solvency Ratios Operational Efficiency Ratios\\\n \\ and Market Performance Ratios of the company. \\n(Profitability Ratios: Such\\\n \\ as net profit margin gross profit margin operating profit margin to assess the\\\n \\ company's profitability.)\\n(Solvency Ratios: Such as debt-to-asset ratio interest\\\n \\ coverage ratio to assess the company's ability to pay its debts.)\\n(Operational\\\n \\ Efficiency Ratios: Such as inventory turnover accounts receivable turnover to\\\n \\ assess the company's operational efficiency.)\\n(Market Performance Ratios: Such\\\n \\ as price-to-earnings ratio price-to-book ratio to assess the company's market\\\n \\ performance.)>\\n-\\n5. Comprehensive Analysis and Conclusion:\\n- Combine the above analyses to\\\n \\ evaluate the company's financial health profitability solvency and operational\\\n \\ efficiency comprehensively. Identify the main financial risks and potential\\\n \\ opportunities facing the company.\\n-\\nOrganize and output [Record 1.1] [Record 1.2] [Record\\\n \\ 1.3] [Record 1.4] [Record 1.5] \\nPart II: Foundamental Analysis: Industry\\n\\\n *Objective 2: To analyze the position and competitiveness of the target company\\\n \\ {{company}} in the industry. \\n\\n\\n* Steps:\\n1. Determine the industry classification:\\n\\\n - Define the industry to which the target company belongs.\\n- Search for company\\\n \\ information to determine its main business and industry.\\n-\\n2. Market Positioning and Segmentation\\\n \\ analysis:\\n- To assess the company's market positioning and segmentation. \\n\\\n - Understand the company's market share growth rate and competitors in the industry\\\n \\ to analyze them. \\n-\\n3. Analysis \\n- Analyze the development\\\n \\ trend of the industry. \\n- \\n4. Competitors\\n- Analyze the competition around the target company \\n-\\\n \\ \\nOrganize\\\n \\ and output [Record 2.1] [Record 2.2] [Record 2.3] [Record 2.4]\\nCombine the\\\n \\ above Record and output all the analysis in the form of a investment analysis\\\n \\ report. Use markdown syntax for a structured output. \\n\\n## Constraints\\n- Your\\\n \\ responses should be strictly on analysis tasks. Use a structured language and\\\n \\ think step by step. \\n- The language you use should be identical to the user's\\\n \\ language.\\n- Avoid addressing questions regarding work tools and regulations.\\n\\\n - Give a structured response using bullet points and markdown syntax. Give an\\\n \\ introduction to the situation first then analyse the main trend in the graph.\\\n \\ \\n\"\n prompt_type: simple\n retriever_resource:\n enabled: true\n sensitive_word_avoidance:\n configs: []\n enabled: false\n type: ''\n speech_to_text:\n enabled: false\n suggested_questions:\n - 'Analyze the stock of Tesla. '\n - What are some recent development on Nvidia?\n - 'Do a fundamental analysis for Amazon. '\n suggested_questions_after_answer:\n enabled: true\n text_to_speech:\n enabled: false\n user_input_form:\n - text-input:\n default: ''\n label: company\n required: false\n variable: company\n", + "export_data": "app:\n icon: \"\\U0001F911\"\n icon_background: '#E4FBCC'\n mode: agent-chat\n name: Investment Analysis Report Copilot\nmodel_config:\n agent_mode:\n enabled: true\n max_iteration: 5\n strategy: function_call\n tools:\n - enabled: true\n isDeleted: false\n notAuthor: false\n provider_id: yahoo\n provider_name: yahoo\n provider_type: builtin\n tool_label: Analytics\n tool_name: yahoo_finance_analytics\n tool_parameters:\n end_date: ''\n start_date: ''\n symbol: ''\n - enabled: true\n isDeleted: false\n notAuthor: false\n provider_id: yahoo\n provider_name: yahoo\n provider_type: builtin\n tool_label: News\n tool_name: yahoo_finance_news\n tool_parameters:\n symbol: ''\n - enabled: true\n isDeleted: false\n notAuthor: false\n provider_id: yahoo\n provider_name: yahoo\n provider_type: builtin\n tool_label: Ticker\n tool_name: yahoo_finance_ticker\n tool_parameters:\n symbol: ''\n annotation_reply:\n enabled: false\n chat_prompt_config: {}\n completion_prompt_config: {}\n dataset_configs:\n datasets:\n datasets: []\n retrieval_model: single\n dataset_query_variable: ''\n external_data_tools: []\n file_upload:\n image:\n detail: high\n enabled: false\n number_limits: 3\n transfer_methods:\n - remote_url\n - local_file\n model:\n completion_params:\n frequency_penalty: 0.5\n max_tokens: 4096\n presence_penalty: 0.5\n stop: []\n temperature: 0.2\n top_p: 0.75\n mode: chat\n name: gpt-4-1106-preview\n provider: openai\n more_like_this:\n enabled: false\n opening_statement: 'Welcome to your personalized Investment Analysis Copilot service,\n where we delve into the depths of stock analysis to provide you with comprehensive\n insights. To begin our journey into the financial world, try to ask:\n\n '\n pre_prompt: \"# Job Description: Data Analysis Copilot\\n## Character\\nMy primary\\\n \\ goal is to provide user with expert data analysis advice. Using extensive and\\\n \\ detailed data. Tell me the stock (with ticket symbol) you want to analyze. I\\\n \\ will do all fundamental, technical, market sentiment, and Marco economical analysis\\\n \\ for the stock as an expert. \\n\\n## Skills \\n### Skill 1: Search for stock information\\\n \\ using 'Ticker' from Yahoo Finance \\n### Skill 2: Search for recent news using\\\n \\ 'News' for the target company. \\n### Skill 3: Search for financial figures and\\\n \\ analytics using 'Analytics' for the target company\\n\\n## Workflow\\nAsks the\\\n \\ user which stocks with ticker name need to be analyzed and then performs the\\\n \\ following analysis in sequence. \\n**Part I: Fundamental analysis: financial\\\n \\ reporting analysis\\n*Objective 1: In-depth analysis of the financial situation\\\n \\ of the target company.\\n*Steps:\\n1. Identify the object of analysis:\\n\\n\\n\\n2. Access to financial\\\n \\ reports \\n\\n- Obtain the key data\\\n \\ of the latest financial report of the target company {{company}} organized by\\\n \\ Yahoo Finance. \\n\\n\\n\\n3. Vertical Analysis:\\n- Get the insight of the company's\\\n \\ balance sheet Income Statement and cash flow. \\n- Analyze Income Statement:\\\n \\ Analyze the proportion of each type of income and expense to total income. /Analyze\\\n \\ Balance Sheet: Analyze the proportion of each asset and liability to total assets\\\n \\ or total liabilities./ Analyze Cash Flow \\n-\\n4. Ratio Analysis:\\n\\\n - analyze the Profitability Ratios Solvency Ratios Operational Efficiency Ratios\\\n \\ and Market Performance Ratios of the company. \\n(Profitability Ratios: Such\\\n \\ as net profit margin gross profit margin operating profit margin to assess the\\\n \\ company's profitability.)\\n(Solvency Ratios: Such as debt-to-asset ratio interest\\\n \\ coverage ratio to assess the company's ability to pay its debts.)\\n(Operational\\\n \\ Efficiency Ratios: Such as inventory turnover accounts receivable turnover to\\\n \\ assess the company's operational efficiency.)\\n(Market Performance Ratios: Such\\\n \\ as price-to-earnings ratio price-to-book ratio to assess the company's market\\\n \\ performance.)>\\n-\\n5. Comprehensive Analysis and Conclusion:\\n- Combine the above analyses to\\\n \\ evaluate the company's financial health profitability solvency and operational\\\n \\ efficiency comprehensively. Identify the main financial risks and potential\\\n \\ opportunities facing the company.\\n-\\nOrganize and output [Record 1.1] [Record 1.2] [Record\\\n \\ 1.3] [Record 1.4] [Record 1.5] \\nPart II: Fundamental Analysis: Industry\\n\\\n *Objective 2: To analyze the position and competitiveness of the target company\\\n \\ {{company}} in the industry. \\n\\n\\n* Steps:\\n1. Determine the industry classification:\\n\\\n - Define the industry to which the target company belongs.\\n- Search for company\\\n \\ information to determine its main business and industry.\\n-\\n2. Market Positioning and Segmentation\\\n \\ analysis:\\n- To assess the company's market positioning and segmentation. \\n\\\n - Understand the company's market share growth rate and competitors in the industry\\\n \\ to analyze them. \\n-\\n3. Analysis \\n- Analyze the development\\\n \\ trend of the industry. \\n- \\n4. Competitors\\n- Analyze the competition around the target company \\n-\\\n \\ \\nOrganize\\\n \\ and output [Record 2.1] [Record 2.2] [Record 2.3] [Record 2.4]\\nCombine the\\\n \\ above Record and output all the analysis in the form of a investment analysis\\\n \\ report. Use markdown syntax for a structured output. \\n\\n## Constraints\\n- Your\\\n \\ responses should be strictly on analysis tasks. Use a structured language and\\\n \\ think step by step. \\n- The language you use should be identical to the user's\\\n \\ language.\\n- Avoid addressing questions regarding work tools and regulations.\\n\\\n - Give a structured response using bullet points and markdown syntax. Give an\\\n \\ introduction to the situation first then analyse the main trend in the graph.\\\n \\ \\n\"\n prompt_type: simple\n retriever_resource:\n enabled: true\n sensitive_word_avoidance:\n configs: []\n enabled: false\n type: ''\n speech_to_text:\n enabled: false\n suggested_questions:\n - 'Analyze the stock of Tesla. '\n - What are some recent development on Nvidia?\n - 'Do a fundamental analysis for Amazon. '\n suggested_questions_after_answer:\n enabled: true\n text_to_speech:\n enabled: false\n user_input_form:\n - text-input:\n default: ''\n label: company\n required: false\n variable: company\n", "icon": "🤑", "icon_background": "#E4FBCC", "id": "a23b57fa-85da-49c0-a571-3aff375976c1", @@ -438,8 +438,8 @@ "mode": "advanced-chat", "name": "Workflow Planning Assistant " }, - "e9d92058-7d20-4904-892f-75d90bef7587":{"export_data":"app:\n icon: \"\\U0001F916\"\n icon_background: '#FFEAD5'\n mode: advanced-chat\n name: 'Automated Email Reply '\nworkflow:\n features:\n file_upload:\n image:\n enabled: false\n opening_statement: ''\n retriever_resource:\n enabled: false\n sensitive_word_avoidance:\n enabled: false\n speech_to_text:\n enabled: false\n suggested_questions: []\n suggested_questions_after_answer:\n enabled: false\n text_to_speech:\n enabled: false\n language: ''\n voice: ''\n graph:\n edges:\n - data:\n isInIteration: false\n sourceType: code\n targetType: iteration\n id: 1716909112104-source-1716909114582-target\n source: '1716909112104'\n sourceHandle: source\n target: '1716909114582'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: iteration\n targetType: template-transform\n id: 1716909114582-source-1716913435742-target\n source: '1716909114582'\n sourceHandle: source\n target: '1716913435742'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: template-transform\n targetType: answer\n id: 1716913435742-source-1716806267180-target\n source: '1716913435742'\n sourceHandle: source\n target: '1716806267180'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: start\n targetType: tool\n id: 1716800588219-source-1716946869294-target\n source: '1716800588219'\n sourceHandle: source\n target: '1716946869294'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: tool\n targetType: code\n id: 1716946869294-source-1716909112104-target\n source: '1716946869294'\n sourceHandle: source\n target: '1716909112104'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: tool\n targetType: code\n id: 1716946889408-source-1716909122343-target\n source: '1716946889408'\n sourceHandle: source\n target: '1716909122343'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: code\n targetType: code\n id: 1716909122343-source-1716951357236-target\n source: '1716909122343'\n sourceHandle: source\n target: '1716951357236'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: code\n targetType: llm\n id: 1716951357236-source-1716913272656-target\n source: '1716951357236'\n sourceHandle: source\n target: '1716913272656'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: template-transform\n targetType: llm\n id: 1716951236700-source-1716951159073-target\n source: '1716951236700'\n sourceHandle: source\n target: '1716951159073'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: llm\n targetType: template-transform\n id: 1716951159073-source-1716952228079-target\n source: '1716951159073'\n sourceHandle: source\n target: '1716952228079'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: template-transform\n targetType: tool\n id: 1716952228079-source-1716952912103-target\n source: '1716952228079'\n sourceHandle: source\n target: '1716952912103'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: llm\n targetType: question-classifier\n id: 1716913272656-source-1716960721611-target\n source: '1716913272656'\n sourceHandle: source\n target: '1716960721611'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: question-classifier\n targetType: llm\n id: 1716960721611-1-1716909125498-target\n source: '1716960721611'\n sourceHandle: '1'\n target: '1716909125498'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: question-classifier\n targetType: llm\n id: 1716960721611-2-1716960728136-target\n source: '1716960721611'\n sourceHandle: '2'\n target: '1716960728136'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: llm\n targetType: variable-aggregator\n id: 1716909125498-source-1716960791399-target\n source: '1716909125498'\n sourceHandle: source\n target: '1716960791399'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: variable-aggregator\n targetType: template-transform\n id: 1716960791399-source-1716951236700-target\n source: '1716960791399'\n sourceHandle: source\n target: '1716951236700'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: question-classifier\n targetType: template-transform\n id: 1716960721611-1716960736883-1716960834468-target\n source: '1716960721611'\n sourceHandle: '1716960736883'\n target: '1716960834468'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: llm\n targetType: variable-aggregator\n id: 1716960728136-source-1716960791399-target\n source: '1716960728136'\n sourceHandle: source\n target: '1716960791399'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: template-transform\n targetType: variable-aggregator\n id: 1716960834468-source-1716960791399-target\n source: '1716960834468'\n sourceHandle: source\n target: '1716960791399'\n targetHandle: target\n type: custom\n zIndex: 1002\n nodes:\n - data:\n desc: ''\n selected: false\n title: Start\n type: start\n variables:\n - label: Your Email\n max_length: 256\n options: []\n required: true\n type: text-input\n variable: email\n - label: Maximum Number of Email you want to retrieve\n max_length: 256\n options: []\n required: true\n type: number\n variable: maxResults\n height: 115\n id: '1716800588219'\n position:\n x: 30\n y: 445\n positionAbsolute:\n x: 30\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n answer: '{{#1716913435742.output#}}'\n desc: ''\n selected: false\n title: Direct Reply\n type: answer\n variables: []\n height: 106\n id: '1716806267180'\n position:\n x: 4700\n y: 445\n positionAbsolute:\n x: 4700\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n code: \"def main(message: str) -> dict:\\n import json\\n \\n # Parse\\\n \\ the JSON string\\n parsed_data = json.loads(message)\\n \\n # Extract\\\n \\ all the \\\"id\\\" values\\n ids = [msg['id'] for msg in parsed_data['messages']]\\n\\\n \\ \\n return {\\n \\\"result\\\": ids\\n }\"\n code_language: python3\n desc: ''\n outputs:\n result:\n children: null\n type: array[string]\n selected: false\n title: 'Code: Extract Email ID'\n type: code\n variables:\n - value_selector:\n - '1716946869294'\n - text\n variable: message\n height: 53\n id: '1716909112104'\n position:\n x: 638\n y: 445\n positionAbsolute:\n x: 638\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n desc: ''\n height: 490\n iterator_selector:\n - '1716909112104'\n - result\n output_selector:\n - '1716909125498'\n - text\n output_type: array[string]\n selected: false\n startNodeType: tool\n start_node_id: '1716946889408'\n title: 'Iteraction '\n type: iteration\n width: 3393.7520359289056\n height: 490\n id: '1716909114582'\n position:\n x: 942\n y: 445\n positionAbsolute:\n x: 942\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 3394\n zIndex: 1\n - data:\n desc: ''\n isInIteration: true\n isIterationStart: true\n iteration_id: '1716909114582'\n provider_id: e64b4c7f-2795-499c-8d11-a971a7d57fc9\n provider_name: List and Get Gmail\n provider_type: api\n selected: false\n title: getMessage\n tool_configurations: {}\n tool_label: getMessage\n tool_name: getMessage\n tool_parameters:\n format:\n type: mixed\n value: full\n id:\n type: mixed\n value: '{{#1716909114582.item#}}'\n userId:\n type: mixed\n value: '{{#1716800588219.email#}}'\n type: tool\n extent: parent\n height: 53\n id: '1716946889408'\n parentId: '1716909114582'\n position:\n x: 117\n y: 85\n positionAbsolute:\n x: 1059\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1001\n - data:\n code: \"\\ndef main(email_json: dict) -> dict:\\n import json \\n email_dict\\\n \\ = json.loads(email_json)\\n base64_data = email_dict['payload']['parts'][0]['body']['data']\\n\\\n \\n return {\\n \\\"result\\\": base64_data, \\n }\\n\"\n code_language: python3\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n outputs:\n result:\n children: null\n type: string\n selected: false\n title: 'Code: Extract Email Body'\n type: code\n variables:\n - value_selector:\n - '1716946889408'\n - text\n variable: email_json\n extent: parent\n height: 53\n id: '1716909122343'\n parentId: '1716909114582'\n position:\n x: 421\n y: 85\n positionAbsolute:\n x: 1363\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: 'Generate reply. '\n isInIteration: true\n iteration_id: '1716909114582'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: 982014aa-702b-4d7c-ae1f-08dbceb6e930\n role: system\n text: \" \\nRespond to the emails. \\n\\n{{#1716913272656.text#}}\\n\\\n \"\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 127\n id: '1716909125498'\n parentId: '1716909114582'\n position:\n x: 1625\n y: 85\n positionAbsolute:\n x: 2567\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: fd8de569-c099-4320-955b-61aa4b054789\n role: system\n text: \"\\nYou need to transform the input data (in base64 encoding)\\\n \\ to text. Input base64. Output text. \\n\\n{{#1716909122343.result#}}\\n\\\n \"\n selected: false\n title: 'Base64 Decoder '\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: false\n extent: parent\n height: 97\n id: '1716913272656'\n parentId: '1716909114582'\n position:\n x: 1025\n y: 85\n positionAbsolute:\n x: 1967\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 | join(\"\\n\\n -------------------------\\n\\n\") }}'\n title: 'Template '\n type: template-transform\n variables:\n - value_selector:\n - '1716909114582'\n - output\n variable: arg1\n height: 53\n id: '1716913435742'\n position:\n x: 4396\n y: 445\n positionAbsolute:\n x: 4396\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n desc: ''\n provider_id: e64b4c7f-2795-499c-8d11-a971a7d57fc9\n provider_name: List and Get Gmail\n provider_type: api\n selected: false\n title: listMessages\n tool_configurations: {}\n tool_label: listMessages\n tool_name: listMessages\n tool_parameters:\n maxResults:\n type: variable\n value:\n - '1716800588219'\n - maxResults\n userId:\n type: mixed\n value: '{{#1716800588219.email#}}'\n type: tool\n height: 53\n id: '1716946869294'\n position:\n x: 334\n y: 445\n positionAbsolute:\n x: 334\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: b7fd0ec5-864a-42c6-9d04-a1958bd4fc0d\n role: system\n text: \"\\nYou need to encode the input data from text to base64. Input\\\n \\ text. Output base64 encoding. Output nothing other than base64 encoding.\\\n \\ \\n\\n{{#1716951236700.output#}}\\n \"\n selected: false\n title: Base64 Encoder\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 97\n id: '1716951159073'\n parentId: '1716909114582'\n position:\n x: 2525.7520359289056\n y: 85\n positionAbsolute:\n x: 3467.7520359289056\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: Generaate MIME email template\n isInIteration: true\n iteration_id: '1716909114582'\n selected: false\n template: \"Content-Type: text/plain; charset=\\\"utf-8\\\"\\r\\nContent-Transfer-Encoding:\\\n \\ 7bit\\r\\nMIME-Version: 1.0\\r\\nTo: {{ emailMetadata.recipientEmail }} #\\\n \\ xiaoyi@dify.ai\\r\\nFrom: {{ emailMetadata.senderEmail }} # sxy.hj156@gmail.com\\r\\\n \\nSubject: Re: {{ emailMetadata.subject }} \\r\\n\\r\\n{{ text }}\\r\\n\"\n title: 'Template: Reply Email'\n type: template-transform\n variables:\n - value_selector:\n - '1716951357236'\n - result\n variable: emailMetadata\n - value_selector:\n - '1716960791399'\n - output\n variable: text\n extent: parent\n height: 83\n id: '1716951236700'\n parentId: '1716909114582'\n position:\n x: 2231.269960149744\n y: 85\n positionAbsolute:\n x: 3173.269960149744\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n code: \"def main(email_json: dict) -> dict:\\n import json\\n if isinstance(email_json,\\\n \\ str): \\n email_json = json.loads(email_json)\\n\\n subject = None\\n\\\n \\ recipient_email = None \\n sender_email = None\\n \\n headers\\\n \\ = email_json['payload']['headers']\\n for header in headers:\\n \\\n \\ if header['name'] == 'Subject':\\n subject = header['value']\\n\\\n \\ elif header['name'] == 'To':\\n recipient_email = header['value']\\n\\\n \\ elif header['name'] == 'From':\\n sender_email = header['value']\\n\\\n \\n return {\\n \\\"result\\\": [subject, recipient_email, sender_email]\\n\\\n \\ }\\n\"\n code_language: python3\n desc: \"Recipient, Sender, Subject\\uFF0COutput Array[String]\"\n isInIteration: true\n iteration_id: '1716909114582'\n outputs:\n result:\n children: null\n type: array[string]\n selected: false\n title: Extract Email Metadata\n type: code\n variables:\n - value_selector:\n - '1716946889408'\n - text\n variable: email_json\n extent: parent\n height: 101\n id: '1716951357236'\n parentId: '1716909114582'\n position:\n x: 725\n y: 85\n positionAbsolute:\n x: 1667\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n selected: false\n template: '{\"raw\": \"{{ encoded_message }}\"}'\n title: \"Template\\uFF1AEmail Request Body\"\n type: template-transform\n variables:\n - value_selector:\n - '1716951159073'\n - text\n variable: encoded_message\n extent: parent\n height: 53\n id: '1716952228079'\n parentId: '1716909114582'\n position:\n x: 2828.4325280181324\n y: 86.31950791077293\n positionAbsolute:\n x: 3770.4325280181324\n y: 531.3195079107729\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n provider_id: 038963aa-43c8-47fc-be4b-0255c19959c1\n provider_name: Draft Gmail\n provider_type: api\n selected: false\n title: createDraft\n tool_configurations: {}\n tool_label: createDraft\n tool_name: createDraft\n tool_parameters:\n message:\n type: mixed\n value: '{{#1716952228079.output#}}'\n userId:\n type: mixed\n value: '{{#1716800588219.email#}}'\n type: tool\n extent: parent\n height: 53\n id: '1716952912103'\n parentId: '1716909114582'\n position:\n x: 3133.7520359289056\n y: 85\n positionAbsolute:\n x: 4075.7520359289056\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n classes:\n - id: '1'\n name: 'Technical questions, related to product '\n - id: '2'\n name: Unrelated to technicals, non technical\n - id: '1716960736883'\n name: Other questions\n desc: ''\n instructions: ''\n isInIteration: true\n iteration_id: '1716909114582'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n query_variable_selector:\n - '1716800588219'\n - sys.query\n selected: false\n title: Question Classifier\n topics: []\n type: question-classifier\n extent: parent\n height: 255\n id: '1716960721611'\n parentId: '1716909114582'\n position:\n x: 1325\n y: 85\n positionAbsolute:\n x: 2267\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - id: a639bbf8-bc58-42a2-b477-6748e80ecda2\n role: system\n text: \" \\nRespond to the emails. \\n\\n{{#1716913272656.text#}}\\n\\\n \"\n selected: false\n title: 'LLM - Non technical '\n type: llm\n variables: []\n vision:\n enabled: false\n extent: parent\n height: 97\n id: '1716960728136'\n parentId: '1716909114582'\n position:\n x: 1625\n y: 251\n positionAbsolute:\n x: 2567\n y: 696\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n output_type: string\n selected: false\n title: Variable Aggregator\n type: variable-aggregator\n variables:\n - - '1716909125498'\n - text\n - - '1716960728136'\n - text\n - - '1716960834468'\n - output\n extent: parent\n height: 164\n id: '1716960791399'\n parentId: '1716909114582'\n position:\n x: 1931.2699601497438\n y: 85\n positionAbsolute:\n x: 2873.269960149744\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: Other questions\n isInIteration: true\n iteration_id: '1716909114582'\n selected: false\n template: 'Sorry, I cannot answer that. This is outside my capabilities. '\n title: 'Direct Reply '\n type: template-transform\n variables: []\n extent: parent\n height: 83\n id: '1716960834468'\n parentId: '1716909114582'\n position:\n x: 1625\n y: 385.57142857142856\n positionAbsolute:\n x: 2567\n y: 830.5714285714286\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n author: Dify\n desc: ''\n height: 153\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":3,\"mode\":\"normal\",\"style\":\"font-size:\n 14px;\",\"text\":\"OpenAPI-Swagger for all custom tools: \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":3},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"openapi:\n 3.0.0\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"info:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" title:\n Gmail API\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n OpenAPI schema for Gmail API methods `users.messages.get`, `users.messages.list`,\n and `users.drafts.create`.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" version:\n 1.0.0\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"servers:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n url: https://gmail.googleapis.com\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Gmail API Server\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"paths:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" /gmail/v1/users/{userId}/messages/{id}:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" get:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" summary:\n Get a message.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Retrieves a specific message by ID.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" operationId:\n getMessage\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" parameters:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: userId\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n path\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n true\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The user''s email address. The special value `me` can be used to indicate\n the authenticated user.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: id\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n path\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n true\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The ID of the message to retrieve.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: format\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n query\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n false\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" enum:\n [full, metadata, minimal, raw]\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" default:\n full\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The format to return the message in.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" responses:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''200'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Successful response\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" content:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" application/json:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" id:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" threadId:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" labelIds:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n array\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" items:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" snippet:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" historyId:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" internalDate:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" payload:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" sizeEstimate:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n integer\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" raw:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''401'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Unauthorized\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''403'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Forbidden\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''404'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Not Found\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" /gmail/v1/users/{userId}/messages:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" get:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" summary:\n List messages.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Lists the messages in the user''s mailbox.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" operationId:\n listMessages\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" parameters:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: userId\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n path\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n true\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The user''s email address. The special value `me` can be used to indicate\n the authenticated user.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: maxResults\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n query\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n integer\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" format:\n int32\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" default:\n 100\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Maximum number of messages to return.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" responses:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''200'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Successful response\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" content:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" application/json:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" messages:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n array\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" items:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" id:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" threadId:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" nextPageToken:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" resultSizeEstimate:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n integer\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''401'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Unauthorized\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" /gmail/v1/users/{userId}/drafts:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" post:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" summary:\n Creates a new draft.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" operationId:\n createDraft\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" tags:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n Drafts\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" parameters:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: userId\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n path\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n true\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The user''s email address. The special value \\\"me\\\" can be used to indicate\n the authenticated user.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" requestBody:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n true\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" content:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" application/json:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" message:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" raw:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The entire email message in an RFC 2822 formatted and base64url encoded\n string.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" responses:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''200'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Successful response with the created draft.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" content:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" application/json:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" id:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The immutable ID of the draft.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" message:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" id:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The immutable ID of the message.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" threadId:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The ID of the thread the message belongs to.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" labelIds:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n array\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" items:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" snippet:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n A short part of the message text.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" historyId:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The ID of the last history record that modified this message.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''400'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Bad Request - The request is invalid.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''401'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Unauthorized - Authentication is required.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''403'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Forbidden - The user does not have permission to create drafts.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''404'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Not Found - The specified user does not exist.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''500'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Internal Server Error - An error occurred on the server.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"components:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" securitySchemes:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" OAuth2:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n oauth2\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" flows:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" authorizationCode:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" authorizationUrl:\n https://accounts.google.com/o/oauth2/auth\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" tokenUrl:\n https://oauth2.googleapis.com/token\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" scopes:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" https://mail.google.com/:\n All access to Gmail.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" https://www.googleapis.com/auth/gmail.compose:\n Send email on your behalf.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" https://www.googleapis.com/auth/gmail.modify:\n Modify your email.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"security:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n OAuth2:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n https://mail.google.com/\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n https://www.googleapis.com/auth/gmail.compose\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n https://www.googleapis.com/auth/gmail.modify\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: yellow\n title: ''\n type: ''\n width: 367\n height: 153\n id: '1718992681576'\n position:\n x: 321.9646831030669\n y: 538.1642616264143\n positionAbsolute:\n x: 321.9646831030669\n y: 538.1642616264143\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 367\n - data:\n author: Dify\n desc: ''\n height: 158\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Replace\n custom tools after added this template to your own workspace. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Fill\n in \",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"your\n email \",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"and\n the \",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"maximum\n number of results you want to retrieve from your inbox \",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"to\n get started. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 287\n height: 158\n id: '1718992805687'\n position:\n x: 18.571428571428356\n y: 237.80887395992687\n positionAbsolute:\n x: 18.571428571428356\n y: 237.80887395992687\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 287\n - data:\n author: Dify\n desc: ''\n height: 375\n selected: true\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"font-size:\n 16px;\",\"text\":\"Steps within Iteraction node: \",\"type\":\"text\",\"version\":1},{\"type\":\"linebreak\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"1.\n getMessage: This step retrieves the incoming email message.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":1},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"2.\n Code: Extract Email Body: Custom code is executed to extract the body of\n the email from the retrieved message.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"3.\n Extract Email Metadata: Extracts metadata from the email, such as the recipient,\n sender, subject, and other relevant information.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"4.\n Base64 Decoder: Decodes the email content from Base64 encoding.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"5.\n Question Classifier (gpt-3.5-turbo): Uses a GPT-3.5-turbo model to classify\n the email content into different categories. For each classified question,\n the workflow uses a GPT-4.0 model to generate an appropriate reply:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"6.\n Template: Reply Email: Uses a template to generate a MIME email format for\n the reply.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"6.\n Base64 Encoder: Encodes the generated reply email content back to Base64.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"7.\n Template: Email Request: Prepares the email request using a template.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"8.\n createDraft: Creates a draft of the email reply.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"This\n workflow automates the process of reading, classifying, responding to, and\n drafting replies to incoming emails, leveraging advanced language models\n to generate contextually appropriate responses.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 640\n height: 375\n id: '1718993366836'\n position:\n x: 966.7525290975368\n y: 971.80362905854\n positionAbsolute:\n x: 966.7525290975368\n y: 971.80362905854\n selected: true\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 640\n - data:\n author: Dify\n desc: ''\n height: 400\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":3,\"mode\":\"normal\",\"style\":\"font-size:\n 16px;\",\"text\":\"Preparation\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":3},{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Enable\n Gmail API in Google Cloud Console\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"listitem\",\"version\":1,\"value\":1},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Configure\n OAuth Client ID, OAuth Client Secrets, and OAuth Consent Screen for the\n Web Application in Google Cloud Console\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"listitem\",\"version\":1,\"value\":2},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Use\n Postman to authorize and obtain the OAuth Access Token (Google''s Access\n Token will expire after 1 hour and cannot be used for a long time)\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"listitem\",\"version\":1,\"value\":3}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"list\",\"version\":1,\"listType\":\"bullet\",\"start\":1,\"tag\":\"ul\"},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Users\n who want to try building an AI auto-reply email can refer to this document\n to use Postman (Postman.com) to obtain all the above keys: https://blog.postman.com/how-to-access-google-apis-using-oauth-in-postman/.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Developers\n who want to use Google OAuth to call the Gmail API to develop corresponding\n plugins can refer to this official document: \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"https://developers.google.com/identity/protocols/oauth2/web-server.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"At\n this stage, it is still a bit difficult to reproduce this example within\n the Dify platform. If you have development capabilities, developing the\n corresponding plugin externally and using an external database to automatically\n read and write the user''s Access Token and write the Refresh Token would\n be a better choice.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 608\n height: 400\n id: '1718993557447'\n position:\n x: 354.0157230378119\n y: -1.2732157979666\n positionAbsolute:\n x: 354.0157230378119\n y: -1.2732157979666\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 608\n viewport:\n x: 147.09446825757777\n y: 101.03530130020579\n zoom: 0.9548416039104178\n","icon":"\ud83e\udd16","icon_background":"#FFEAD5","id":"e9d92058-7d20-4904-892f-75d90bef7587","mode":"advanced-chat","name":"Automated Email Reply "}, - "98b87f88-bd22-4d86-8b74-86beba5e0ed4":{"export_data":"app:\n icon: \"\\U0001F916\"\n icon_background: '#FFEAD5'\n mode: workflow\n name: 'Book Translation '\nworkflow:\n features:\n file_upload:\n image:\n enabled: false\n number_limits: 3\n transfer_methods:\n - local_file\n - remote_url\n opening_statement: ''\n retriever_resource:\n enabled: false\n sensitive_word_avoidance:\n enabled: false\n speech_to_text:\n enabled: false\n suggested_questions: []\n suggested_questions_after_answer:\n enabled: false\n text_to_speech:\n enabled: false\n language: ''\n voice: ''\n graph:\n edges:\n - data:\n isInIteration: false\n sourceType: start\n targetType: code\n id: 1711067409646-source-1717916867969-target\n source: '1711067409646'\n sourceHandle: source\n target: '1717916867969'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: code\n targetType: iteration\n id: 1717916867969-source-1717916955547-target\n source: '1717916867969'\n sourceHandle: source\n target: '1717916955547'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: true\n iteration_id: '1717916955547'\n sourceType: llm\n targetType: llm\n id: 1717916961837-source-1717916977413-target\n source: '1717916961837'\n sourceHandle: source\n target: '1717916977413'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1717916955547'\n sourceType: llm\n targetType: llm\n id: 1717916977413-source-1717916984996-target\n source: '1717916977413'\n sourceHandle: source\n target: '1717916984996'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1717916955547'\n sourceType: llm\n targetType: llm\n id: 1717916984996-source-1717916991709-target\n source: '1717916984996'\n sourceHandle: source\n target: '1717916991709'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: false\n sourceType: iteration\n targetType: template-transform\n id: 1717916955547-source-1717917057450-target\n source: '1717916955547'\n sourceHandle: source\n target: '1717917057450'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: template-transform\n targetType: end\n id: 1717917057450-source-1711068257370-target\n source: '1717917057450'\n sourceHandle: source\n target: '1711068257370'\n targetHandle: target\n type: custom\n zIndex: 0\n nodes:\n - data:\n desc: ''\n selected: false\n title: Start\n type: start\n variables:\n - label: Input Text\n max_length: null\n options: []\n required: true\n type: paragraph\n variable: input_text\n dragging: false\n height: 89\n id: '1711067409646'\n position:\n x: 30\n y: 301.5\n positionAbsolute:\n x: 30\n y: 301.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1717917057450'\n - output\n variable: final\n selected: false\n title: End\n type: end\n height: 89\n id: '1711068257370'\n position:\n x: 2291\n y: 301.5\n positionAbsolute:\n x: 2291\n y: 301.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n code: \"\\ndef main(input_text: str) -> str:\\n token_limit = 1000\\n overlap\\\n \\ = 100\\n chunk_size = int(token_limit * 6 * (4/3))\\n\\n # Initialize\\\n \\ variables\\n chunks = []\\n start_index = 0\\n text_length = len(input_text)\\n\\\n \\n # Loop until the end of the text is reached\\n while start_index\\\n \\ < text_length:\\n # If we are not at the beginning, adjust the start_index\\\n \\ to ensure overlap\\n if start_index > 0:\\n start_index\\\n \\ -= overlap\\n\\n # Calculate end index for the current chunk\\n \\\n \\ end_index = start_index + chunk_size\\n if end_index > text_length:\\n\\\n \\ end_index = text_length\\n\\n # Add the current chunk\\\n \\ to the list\\n chunks.append(input_text[start_index:end_index])\\n\\\n \\n # Update the start_index for the next chunk\\n start_index\\\n \\ += chunk_size\\n\\n return {\\n \\\"chunks\\\": chunks,\\n }\\n\"\n code_language: python3\n dependencies: []\n desc: 'token_limit = 1000\n\n overlap = 100'\n outputs:\n chunks:\n children: null\n type: array[string]\n selected: false\n title: Code\n type: code\n variables:\n - value_selector:\n - '1711067409646'\n - input_text\n variable: input_text\n height: 101\n id: '1717916867969'\n position:\n x: 336\n y: 301.5\n positionAbsolute:\n x: 336\n y: 301.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n desc: 'Take good care on maximum number of iterations. '\n height: 203\n iterator_selector:\n - '1717916867969'\n - chunks\n output_selector:\n - '1717916991709'\n - text\n output_type: array[string]\n selected: false\n startNodeType: llm\n start_node_id: '1717916961837'\n title: Iteration\n type: iteration\n width: 1289\n height: 203\n id: '1717916955547'\n position:\n x: 638\n y: 301.5\n positionAbsolute:\n x: 638\n y: 301.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 1289\n zIndex: 1\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n isIterationStart: true\n iteration_id: '1717916955547'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: 7261280b-cb27-4f84-8363-b93e09246d16\n role: system\n text: \" Identify the technical terms in the users input. Use the following\\\n \\ format {XXX} -> {XXX} to show the corresponding technical terms before\\\n \\ and after translation. \\n\\n \\n{{#1717916955547.item#}}\\n\\\n \\n\\n| \\u82F1\\u6587 | \\u4E2D\\u6587 |\\n| --- | --- |\\n| Prompt\\\n \\ Engineering | \\u63D0\\u793A\\u8BCD\\u5DE5\\u7A0B |\\n| Text Generation \\_\\\n | \\u6587\\u672C\\u751F\\u6210 |\\n| Token \\_| Token |\\n| Prompt \\_| \\u63D0\\\n \\u793A\\u8BCD |\\n| Meta Prompting \\_| \\u5143\\u63D0\\u793A |\\n| diffusion\\\n \\ models \\_| \\u6269\\u6563\\u6A21\\u578B |\\n| Agent \\_| \\u667A\\u80FD\\u4F53\\\n \\ |\\n| Transformer \\_| Transformer |\\n| Zero Shot \\_| \\u96F6\\u6837\\u672C\\\n \\ |\\n| Few Shot \\_| \\u5C11\\u6837\\u672C |\\n| chat window \\_| \\u804A\\u5929\\\n \\ |\\n| context | \\u4E0A\\u4E0B\\u6587 |\\n| stock photo \\_| \\u56FE\\u5E93\\u7167\\\n \\u7247 |\\n\\n\\n \"\n selected: false\n title: 'Identify Terms '\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 97\n id: '1717916961837'\n parentId: '1717916955547'\n position:\n x: 117\n y: 85\n positionAbsolute:\n x: 755\n y: 386.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1001\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1717916955547'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: 05e03f0d-c1a9-43ab-b4c0-44b55049434d\n role: system\n text: \" You are a professional translator proficient in Simplified\\\n \\ Chinese especially skilled in translating professional academic papers\\\n \\ into easy-to-understand popular science articles. Please help me translate\\\n \\ the following english paragraph into Chinese, in a style similar to\\\n \\ Chinese popular science articles .\\n \\nTranslate directly\\\n \\ based on the English content, maintain the original format and do not\\\n \\ omit any information. \\n \\n{{#1717916955547.item#}}\\n\\\n \\n{{#1717916961837.text#}}\\n \"\n selected: false\n title: 1st Translation\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 97\n id: '1717916977413'\n parentId: '1717916955547'\n position:\n x: 421\n y: 85\n positionAbsolute:\n x: 1059\n y: 386.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1717916955547'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: 9e6cc050-465e-4632-abc9-411acb255a95\n role: system\n text: \"\\nBased on the results of the direct translation, point out\\\n \\ specific issues it have. Accurate descriptions are required, avoiding\\\n \\ vague statements, and there's no need to add content or formats that\\\n \\ were not present in the original text, including but not liimited to:\\\n \\ \\n- inconsistent with chinese expression habits, clearly indicate where\\\n \\ it does not conform\\n- Clumsy sentences, specify the location, no need\\\n \\ to offer suggestions for modification, which will be fixed during free\\\n \\ translation\\n- Obscure and difficult to understand, attempts to explain\\\n \\ may be made\\n- \\u65E0\\u6F0F\\u8BD1\\uFF08\\u539F\\u2F42\\u4E2D\\u7684\\u5173\\\n \\u952E\\u8BCD\\u3001\\u53E5\\u2F26\\u3001\\u6BB5\\u843D\\u90FD\\u5E94\\u4F53\\u73B0\\\n \\u5728\\u8BD1\\u2F42\\u4E2D\\uFF09\\u3002\\n- \\u2F46\\u9519\\u8BD1\\uFF08\\u770B\\\n \\u9519\\u539F\\u2F42\\u3001\\u8BEF\\u89E3\\u539F\\u2F42\\u610F\\u601D\\u5747\\u7B97\\\n \\u9519\\u8BD1\\uFF09\\u3002\\n- \\u2F46\\u6709\\u610F\\u589E\\u52A0\\u6216\\u8005\\\n \\u5220\\u51CF\\u7684\\u539F\\u2F42\\u5185\\u5BB9\\uFF08\\u7FFB\\u8BD1\\u5E76\\u2FAE\\\n \\u521B\\u4F5C\\uFF0C\\u9700\\u5C0A\\u91CD\\u4F5C\\u8005\\u89C2 \\u70B9\\uFF1B\\u53EF\\\n \\u4EE5\\u9002\\u5F53\\u52A0\\u8BD1\\u8005\\u6CE8\\u8BF4\\u660E\\uFF09\\u3002\\n-\\\n \\ \\u8BD1\\u2F42\\u6D41\\u7545\\uFF0C\\u7B26\\u5408\\u4E2D\\u2F42\\u8868\\u8FBE\\u4E60\\\n \\u60EF\\u3002\\n- \\u5173\\u4E8E\\u2F08\\u540D\\u7684\\u7FFB\\u8BD1\\u3002\\u6280\\\n \\u672F\\u56FE\\u4E66\\u4E2D\\u7684\\u2F08\\u540D\\u901A\\u5E38\\u4E0D\\u7FFB\\u8BD1\\\n \\uFF0C\\u4F46\\u662F\\u2F00\\u4E9B\\u4F17\\u6240 \\u5468\\u77E5\\u7684\\u2F08\\u540D\\\n \\u9700\\u2F64\\u4E2D\\u2F42\\uFF08\\u5982\\u4E54\\u5E03\\u65AF\\uFF09\\u3002\\n-\\\n \\ \\u5173\\u4E8E\\u4E66\\u540D\\u7684\\u7FFB\\u8BD1\\u3002\\u6709\\u4E2D\\u2F42\\u7248\\\n \\u7684\\u56FE\\u4E66\\uFF0C\\u8BF7\\u2F64\\u4E2D\\u2F42\\u7248\\u4E66\\u540D\\uFF1B\\\n \\u2F46\\u4E2D\\u2F42\\u7248 \\u7684\\u56FE\\u4E66\\uFF0C\\u76F4\\u63A5\\u2F64\\u82F1\\\n \\u2F42\\u4E66\\u540D\\u3002\\n- \\u5173\\u4E8E\\u56FE\\u8868\\u7684\\u7FFB\\u8BD1\\\n \\u3002\\u8868\\u683C\\u4E2D\\u7684\\u8868\\u9898\\u3001\\u8868\\u5B57\\u548C\\u6CE8\\\n \\u89E3\\u7B49\\u5747\\u9700\\u7FFB\\u8BD1\\u3002\\u56FE\\u9898 \\u9700\\u8981\\u7FFB\\\n \\u8BD1\\u3002\\u754C\\u2FAF\\u622A\\u56FE\\u4E0D\\u9700\\u8981\\u7FFB\\u8BD1\\u56FE\\\n \\u5B57\\u3002\\u89E3\\u91CA\\u6027\\u56FE\\u9700\\u8981\\u6309\\u7167\\u4E2D\\u82F1\\\n \\u2F42 \\u5BF9\\u7167\\u683C\\u5F0F\\u7ED9\\u51FA\\u56FE\\u5B57\\u7FFB\\u8BD1\\u3002\\\n \\n- \\u5173\\u4E8E\\u82F1\\u2F42\\u672F\\u8BED\\u7684\\u8868\\u8FF0\\u3002\\u82F1\\\n \\u2F42\\u672F\\u8BED\\u2FB8\\u6B21\\u51FA\\u73B0\\u65F6\\uFF0C\\u5E94\\u8BE5\\u6839\\\n \\u636E\\u8BE5\\u672F\\u8BED\\u7684 \\u6D41\\u2F8F\\u60C5\\u51B5\\uFF0C\\u4F18\\u5148\\\n \\u4F7F\\u2F64\\u7B80\\u5199\\u5F62\\u5F0F\\uFF0C\\u5E76\\u5728\\u5176\\u540E\\u4F7F\\\n \\u2F64\\u62EC\\u53F7\\u52A0\\u82F1\\u2F42\\u3001\\u4E2D\\u2F42 \\u5168\\u79F0\\u6CE8\\\n \\u89E3\\uFF0C\\u683C\\u5F0F\\u4E3A\\uFF08\\u4E3E\\u4F8B\\uFF09\\uFF1AHTML\\uFF08\\\n Hypertext Markup Language\\uFF0C\\u8D85\\u2F42\\u672C\\u6807\\u8BC6\\u8BED\\u2F94\\\n \\uFF09\\u3002\\u7136\\u540E\\u5728\\u4E0B\\u2F42\\u4E2D\\u76F4\\u63A5\\u4F7F\\u2F64\\\n \\u7B80\\u5199\\u5F62 \\u5F0F\\u3002\\u5F53\\u7136\\uFF0C\\u5FC5\\u8981\\u65F6\\u4E5F\\\n \\u53EF\\u4EE5\\u6839\\u636E\\u8BED\\u5883\\u4F7F\\u2F64\\u4E2D\\u3001\\u82F1\\u2F42\\\n \\u5168\\u79F0\\u3002\\n- \\u5173\\u4E8E\\u4EE3\\u7801\\u6E05\\u5355\\u548C\\u4EE3\\\n \\u7801\\u2F5A\\u6BB5\\u3002\\u539F\\u4E66\\u4E2D\\u5305\\u542B\\u7684\\u7A0B\\u5E8F\\\n \\u4EE3\\u7801\\u4E0D\\u8981\\u6C42\\u8BD1\\u8005\\u5F55 \\u2F0A\\uFF0C\\u4F46\\u5E94\\\n \\u8BE5\\u4F7F\\u2F64\\u201C\\u539F\\u4E66P99\\u2EDA\\u4EE3\\u78011\\u201D\\uFF08\\\n \\u5373\\u539F\\u4E66\\u7B2C99\\u2EDA\\u4E2D\\u7684\\u7B2C\\u2F00\\u6BB5\\u4EE3 \\u7801\\\n \\uFF09\\u7684\\u683C\\u5F0F\\u4F5C\\u51FA\\u6807\\u6CE8\\u3002\\u540C\\u65F6\\uFF0C\\\n \\u8BD1\\u8005\\u5E94\\u8BE5\\u5728\\u6709\\u6761\\u4EF6\\u7684\\u60C5\\u51B5\\u4E0B\\\n \\u68C0\\u6838\\u4EE3 \\u7801\\u7684\\u6B63\\u786E\\u6027\\uFF0C\\u5BF9\\u53D1\\u73B0\\\n \\u7684\\u9519\\u8BEF\\u4EE5\\u8BD1\\u8005\\u6CE8\\u5F62\\u5F0F\\u8BF4\\u660E\\u3002\\\n \\u7A0B\\u5E8F\\u4EE3\\u7801\\u4E2D\\u7684\\u6CE8 \\u91CA\\u8981\\u6C42\\u7FFB\\u8BD1\\\n \\uFF0C\\u5982\\u679C\\u8BD1\\u7A3F\\u4E2D\\u6CA1\\u6709\\u4EE3\\u7801\\uFF0C\\u5219\\\n \\u5E94\\u8BE5\\u4EE5\\u2F00\\u53E5\\u82F1\\u2F42\\uFF08\\u6CE8\\u91CA\\uFF09 \\u2F00\\\n \\u53E5\\u4E2D\\u2F42\\uFF08\\u6CE8\\u91CA\\uFF09\\u7684\\u5F62\\u5F0F\\u7ED9\\u51FA\\\n \\u6CE8\\u91CA\\u3002\\n- \\u5173\\u4E8E\\u6807\\u70B9\\u7B26\\u53F7\\u3002\\u8BD1\\\n \\u7A3F\\u4E2D\\u7684\\u6807\\u70B9\\u7B26\\u53F7\\u8981\\u9075\\u5FAA\\u4E2D\\u2F42\\\n \\u8868\\u8FBE\\u4E60\\u60EF\\u548C\\u4E2D\\u2F42\\u6807 \\u70B9\\u7B26\\u53F7\\u7684\\\n \\u4F7F\\u2F64\\u4E60\\u60EF\\uFF0C\\u4E0D\\u80FD\\u7167\\u642C\\u539F\\u2F42\\u7684\\\n \\u6807\\u70B9\\u7B26\\u53F7\\u3002\\n\\n\\n{{#1717916977413.text#}}\\n\\\n \\n{{#1717916955547.item#}}\\n\\n{{#1717916961837.text#}}\\n\\\n \"\n selected: false\n title: 'Problems '\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 97\n id: '1717916984996'\n parentId: '1717916955547'\n position:\n x: 725\n y: 85\n positionAbsolute:\n x: 1363\n y: 386.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1717916955547'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: 4d7ae758-2d7b-4404-ad9f-d6748ee64439\n role: system\n text: \"\\nBased on the results of the direct translation in the first\\\n \\ step and the problems identified in the second step, re-translate to\\\n \\ achieve a meaning-based interpretation. Ensure the original intent of\\\n \\ the content is preserved while making it easier to understand and more\\\n \\ in line with Chinese expression habits. All the while maintaining the\\\n \\ original format unchanged. \\n\\n\\n- inconsistent with chinese\\\n \\ expression habits, clearly indicate where it does not conform\\n- Clumsy\\\n \\ sentences, specify the location, no need to offer suggestions for modification,\\\n \\ which will be fixed during free translation\\n- Obscure and difficult\\\n \\ to understand, attempts to explain may be made\\n- \\u65E0\\u6F0F\\u8BD1\\\n \\uFF08\\u539F\\u2F42\\u4E2D\\u7684\\u5173\\u952E\\u8BCD\\u3001\\u53E5\\u2F26\\u3001\\\n \\u6BB5\\u843D\\u90FD\\u5E94\\u4F53\\u73B0\\u5728\\u8BD1\\u2F42\\u4E2D\\uFF09\\u3002\\\n \\n- \\u2F46\\u9519\\u8BD1\\uFF08\\u770B\\u9519\\u539F\\u2F42\\u3001\\u8BEF\\u89E3\\\n \\u539F\\u2F42\\u610F\\u601D\\u5747\\u7B97\\u9519\\u8BD1\\uFF09\\u3002\\n- \\u2F46\\\n \\u6709\\u610F\\u589E\\u52A0\\u6216\\u8005\\u5220\\u51CF\\u7684\\u539F\\u2F42\\u5185\\\n \\u5BB9\\uFF08\\u7FFB\\u8BD1\\u5E76\\u2FAE\\u521B\\u4F5C\\uFF0C\\u9700\\u5C0A\\u91CD\\\n \\u4F5C\\u8005\\u89C2 \\u70B9\\uFF1B\\u53EF\\u4EE5\\u9002\\u5F53\\u52A0\\u8BD1\\u8005\\\n \\u6CE8\\u8BF4\\u660E\\uFF09\\u3002\\n- \\u8BD1\\u2F42\\u6D41\\u7545\\uFF0C\\u7B26\\\n \\u5408\\u4E2D\\u2F42\\u8868\\u8FBE\\u4E60\\u60EF\\u3002\\n- \\u5173\\u4E8E\\u2F08\\\n \\u540D\\u7684\\u7FFB\\u8BD1\\u3002\\u6280\\u672F\\u56FE\\u4E66\\u4E2D\\u7684\\u2F08\\\n \\u540D\\u901A\\u5E38\\u4E0D\\u7FFB\\u8BD1\\uFF0C\\u4F46\\u662F\\u2F00\\u4E9B\\u4F17\\\n \\u6240 \\u5468\\u77E5\\u7684\\u2F08\\u540D\\u9700\\u2F64\\u4E2D\\u2F42\\uFF08\\u5982\\\n \\u4E54\\u5E03\\u65AF\\uFF09\\u3002\\n- \\u5173\\u4E8E\\u4E66\\u540D\\u7684\\u7FFB\\\n \\u8BD1\\u3002\\u6709\\u4E2D\\u2F42\\u7248\\u7684\\u56FE\\u4E66\\uFF0C\\u8BF7\\u2F64\\\n \\u4E2D\\u2F42\\u7248\\u4E66\\u540D\\uFF1B\\u2F46\\u4E2D\\u2F42\\u7248 \\u7684\\u56FE\\\n \\u4E66\\uFF0C\\u76F4\\u63A5\\u2F64\\u82F1\\u2F42\\u4E66\\u540D\\u3002\\n- \\u5173\\\n \\u4E8E\\u56FE\\u8868\\u7684\\u7FFB\\u8BD1\\u3002\\u8868\\u683C\\u4E2D\\u7684\\u8868\\\n \\u9898\\u3001\\u8868\\u5B57\\u548C\\u6CE8\\u89E3\\u7B49\\u5747\\u9700\\u7FFB\\u8BD1\\\n \\u3002\\u56FE\\u9898 \\u9700\\u8981\\u7FFB\\u8BD1\\u3002\\u754C\\u2FAF\\u622A\\u56FE\\\n \\u4E0D\\u9700\\u8981\\u7FFB\\u8BD1\\u56FE\\u5B57\\u3002\\u89E3\\u91CA\\u6027\\u56FE\\\n \\u9700\\u8981\\u6309\\u7167\\u4E2D\\u82F1\\u2F42 \\u5BF9\\u7167\\u683C\\u5F0F\\u7ED9\\\n \\u51FA\\u56FE\\u5B57\\u7FFB\\u8BD1\\u3002\\n- \\u5173\\u4E8E\\u82F1\\u2F42\\u672F\\\n \\u8BED\\u7684\\u8868\\u8FF0\\u3002\\u82F1\\u2F42\\u672F\\u8BED\\u2FB8\\u6B21\\u51FA\\\n \\u73B0\\u65F6\\uFF0C\\u5E94\\u8BE5\\u6839\\u636E\\u8BE5\\u672F\\u8BED\\u7684 \\u6D41\\\n \\u2F8F\\u60C5\\u51B5\\uFF0C\\u4F18\\u5148\\u4F7F\\u2F64\\u7B80\\u5199\\u5F62\\u5F0F\\\n \\uFF0C\\u5E76\\u5728\\u5176\\u540E\\u4F7F\\u2F64\\u62EC\\u53F7\\u52A0\\u82F1\\u2F42\\\n \\u3001\\u4E2D\\u2F42 \\u5168\\u79F0\\u6CE8\\u89E3\\uFF0C\\u683C\\u5F0F\\u4E3A\\uFF08\\\n \\u4E3E\\u4F8B\\uFF09\\uFF1AHTML\\uFF08Hypertext Markup Language\\uFF0C\\u8D85\\\n \\u2F42\\u672C\\u6807\\u8BC6\\u8BED\\u2F94\\uFF09\\u3002\\u7136\\u540E\\u5728\\u4E0B\\\n \\u2F42\\u4E2D\\u76F4\\u63A5\\u4F7F\\u2F64\\u7B80\\u5199\\u5F62 \\u5F0F\\u3002\\u5F53\\\n \\u7136\\uFF0C\\u5FC5\\u8981\\u65F6\\u4E5F\\u53EF\\u4EE5\\u6839\\u636E\\u8BED\\u5883\\\n \\u4F7F\\u2F64\\u4E2D\\u3001\\u82F1\\u2F42\\u5168\\u79F0\\u3002\\n- \\u5173\\u4E8E\\\n \\u4EE3\\u7801\\u6E05\\u5355\\u548C\\u4EE3\\u7801\\u2F5A\\u6BB5\\u3002\\u539F\\u4E66\\\n \\u4E2D\\u5305\\u542B\\u7684\\u7A0B\\u5E8F\\u4EE3\\u7801\\u4E0D\\u8981\\u6C42\\u8BD1\\\n \\u8005\\u5F55 \\u2F0A\\uFF0C\\u4F46\\u5E94\\u8BE5\\u4F7F\\u2F64\\u201C\\u539F\\u4E66\\\n P99\\u2EDA\\u4EE3\\u78011\\u201D\\uFF08\\u5373\\u539F\\u4E66\\u7B2C99\\u2EDA\\u4E2D\\\n \\u7684\\u7B2C\\u2F00\\u6BB5\\u4EE3 \\u7801\\uFF09\\u7684\\u683C\\u5F0F\\u4F5C\\u51FA\\\n \\u6807\\u6CE8\\u3002\\u540C\\u65F6\\uFF0C\\u8BD1\\u8005\\u5E94\\u8BE5\\u5728\\u6709\\\n \\u6761\\u4EF6\\u7684\\u60C5\\u51B5\\u4E0B\\u68C0\\u6838\\u4EE3 \\u7801\\u7684\\u6B63\\\n \\u786E\\u6027\\uFF0C\\u5BF9\\u53D1\\u73B0\\u7684\\u9519\\u8BEF\\u4EE5\\u8BD1\\u8005\\\n \\u6CE8\\u5F62\\u5F0F\\u8BF4\\u660E\\u3002\\u7A0B\\u5E8F\\u4EE3\\u7801\\u4E2D\\u7684\\\n \\u6CE8 \\u91CA\\u8981\\u6C42\\u7FFB\\u8BD1\\uFF0C\\u5982\\u679C\\u8BD1\\u7A3F\\u4E2D\\\n \\u6CA1\\u6709\\u4EE3\\u7801\\uFF0C\\u5219\\u5E94\\u8BE5\\u4EE5\\u2F00\\u53E5\\u82F1\\\n \\u2F42\\uFF08\\u6CE8\\u91CA\\uFF09 \\u2F00\\u53E5\\u4E2D\\u2F42\\uFF08\\u6CE8\\u91CA\\\n \\uFF09\\u7684\\u5F62\\u5F0F\\u7ED9\\u51FA\\u6CE8\\u91CA\\u3002\\n- \\u5173\\u4E8E\\\n \\u6807\\u70B9\\u7B26\\u53F7\\u3002\\u8BD1\\u7A3F\\u4E2D\\u7684\\u6807\\u70B9\\u7B26\\\n \\u53F7\\u8981\\u9075\\u5FAA\\u4E2D\\u2F42\\u8868\\u8FBE\\u4E60\\u60EF\\u548C\\u4E2D\\\n \\u2F42\\u6807 \\u70B9\\u7B26\\u53F7\\u7684\\u4F7F\\u2F64\\u4E60\\u60EF\\uFF0C\\u4E0D\\\n \\u80FD\\u7167\\u642C\\u539F\\u2F42\\u7684\\u6807\\u70B9\\u7B26\\u53F7\\u3002\\n\\n\\\n \\n{{#1717916977413.text#}}\\n\\n{{#1717916984996.text#}}\\n\\n{{#1711067409646.input_text#}}\\n\\\n \\n{{#1717916961837.text#}}\\n \"\n selected: false\n title: '2nd Translation '\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 97\n id: '1717916991709'\n parentId: '1717916955547'\n position:\n x: 1029\n y: 85\n positionAbsolute:\n x: 1667\n y: 386.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: 'Combine all chunks of translation. '\n selected: false\n template: '{{ translated_text | join('' '') }}'\n title: Template\n type: template-transform\n variables:\n - value_selector:\n - '1717916955547'\n - output\n variable: translated_text\n height: 83\n id: '1717917057450'\n position:\n x: 1987\n y: 301.5\n positionAbsolute:\n x: 1987\n y: 301.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n author: Dify\n desc: ''\n height: 186\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Code\n node separates the input_text into chunks with length of token_limit. Each\n chunk overlap with each other to make sure the texts are consistent. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"The\n code node outputs an array of segmented texts of input_texts. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 340\n height: 186\n id: '1718990593686'\n position:\n x: 259.3026056936437\n y: 451.6924912936374\n positionAbsolute:\n x: 259.3026056936437\n y: 451.6924912936374\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 340\n - data:\n author: Dify\n desc: ''\n height: 128\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Iterate\n through all the elements in output of the code node and translate each chunk\n using a three steps translation workflow. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 355\n height: 128\n id: '1718991836605'\n position:\n x: 764.3891977435923\n y: 530.8917807505335\n positionAbsolute:\n x: 764.3891977435923\n y: 530.8917807505335\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 355\n - data:\n author: Dify\n desc: ''\n height: 126\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Avoid\n using a high token_limit, LLM''s performance decreases with longer context\n length for gpt-4o. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Recommend\n to use less than or equal to 1000 tokens. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: yellow\n title: ''\n type: ''\n width: 351\n height: 126\n id: '1718991882984'\n position:\n x: 304.49115824454367\n y: 148.4042994607805\n positionAbsolute:\n x: 304.49115824454367\n y: 148.4042994607805\n selected: true\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 351\n viewport:\n x: 335.92505067152274\n y: 18.806553508850584\n zoom: 0.8705505632961259\n","icon":"\ud83e\udd16","icon_background":"#FFEAD5","id":"98b87f88-bd22-4d86-8b74-86beba5e0ed4","mode":"workflow","name":"Book Translation "}, + "e9d92058-7d20-4904-892f-75d90bef7587":{"export_data":"app:\n icon: \"\\U0001F916\"\n icon_background: '#FFEAD5'\n mode: advanced-chat\n name: 'Automated Email Reply '\nworkflow:\n features:\n file_upload:\n image:\n enabled: false\n opening_statement: ''\n retriever_resource:\n enabled: false\n sensitive_word_avoidance:\n enabled: false\n speech_to_text:\n enabled: false\n suggested_questions: []\n suggested_questions_after_answer:\n enabled: false\n text_to_speech:\n enabled: false\n language: ''\n voice: ''\n graph:\n edges:\n - data:\n isInIteration: false\n sourceType: code\n targetType: iteration\n id: 1716909112104-source-1716909114582-target\n source: '1716909112104'\n sourceHandle: source\n target: '1716909114582'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: iteration\n targetType: template-transform\n id: 1716909114582-source-1716913435742-target\n source: '1716909114582'\n sourceHandle: source\n target: '1716913435742'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: template-transform\n targetType: answer\n id: 1716913435742-source-1716806267180-target\n source: '1716913435742'\n sourceHandle: source\n target: '1716806267180'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: start\n targetType: tool\n id: 1716800588219-source-1716946869294-target\n source: '1716800588219'\n sourceHandle: source\n target: '1716946869294'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: tool\n targetType: code\n id: 1716946869294-source-1716909112104-target\n source: '1716946869294'\n sourceHandle: source\n target: '1716909112104'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: tool\n targetType: code\n id: 1716946889408-source-1716909122343-target\n source: '1716946889408'\n sourceHandle: source\n target: '1716909122343'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: code\n targetType: code\n id: 1716909122343-source-1716951357236-target\n source: '1716909122343'\n sourceHandle: source\n target: '1716951357236'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: code\n targetType: llm\n id: 1716951357236-source-1716913272656-target\n source: '1716951357236'\n sourceHandle: source\n target: '1716913272656'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: template-transform\n targetType: llm\n id: 1716951236700-source-1716951159073-target\n source: '1716951236700'\n sourceHandle: source\n target: '1716951159073'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: llm\n targetType: template-transform\n id: 1716951159073-source-1716952228079-target\n source: '1716951159073'\n sourceHandle: source\n target: '1716952228079'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: template-transform\n targetType: tool\n id: 1716952228079-source-1716952912103-target\n source: '1716952228079'\n sourceHandle: source\n target: '1716952912103'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: llm\n targetType: question-classifier\n id: 1716913272656-source-1716960721611-target\n source: '1716913272656'\n sourceHandle: source\n target: '1716960721611'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: question-classifier\n targetType: llm\n id: 1716960721611-1-1716909125498-target\n source: '1716960721611'\n sourceHandle: '1'\n target: '1716909125498'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: question-classifier\n targetType: llm\n id: 1716960721611-2-1716960728136-target\n source: '1716960721611'\n sourceHandle: '2'\n target: '1716960728136'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: llm\n targetType: variable-aggregator\n id: 1716909125498-source-1716960791399-target\n source: '1716909125498'\n sourceHandle: source\n target: '1716960791399'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: variable-aggregator\n targetType: template-transform\n id: 1716960791399-source-1716951236700-target\n source: '1716960791399'\n sourceHandle: source\n target: '1716951236700'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: question-classifier\n targetType: template-transform\n id: 1716960721611-1716960736883-1716960834468-target\n source: '1716960721611'\n sourceHandle: '1716960736883'\n target: '1716960834468'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: llm\n targetType: variable-aggregator\n id: 1716960728136-source-1716960791399-target\n source: '1716960728136'\n sourceHandle: source\n target: '1716960791399'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1716909114582'\n sourceType: template-transform\n targetType: variable-aggregator\n id: 1716960834468-source-1716960791399-target\n source: '1716960834468'\n sourceHandle: source\n target: '1716960791399'\n targetHandle: target\n type: custom\n zIndex: 1002\n nodes:\n - data:\n desc: ''\n selected: false\n title: Start\n type: start\n variables:\n - label: Your Email\n max_length: 256\n options: []\n required: true\n type: text-input\n variable: email\n - label: Maximum Number of Email you want to retrieve\n max_length: 256\n options: []\n required: true\n type: number\n variable: maxResults\n height: 115\n id: '1716800588219'\n position:\n x: 30\n y: 445\n positionAbsolute:\n x: 30\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n answer: '{{#1716913435742.output#}}'\n desc: ''\n selected: false\n title: Direct Reply\n type: answer\n variables: []\n height: 106\n id: '1716806267180'\n position:\n x: 4700\n y: 445\n positionAbsolute:\n x: 4700\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n code: \"def main(message: str) -> dict:\\n import json\\n \\n # Parse\\\n \\ the JSON string\\n parsed_data = json.loads(message)\\n \\n # Extract\\\n \\ all the \\\"id\\\" values\\n ids = [msg['id'] for msg in parsed_data['messages']]\\n\\\n \\ \\n return {\\n \\\"result\\\": ids\\n }\"\n code_language: python3\n desc: ''\n outputs:\n result:\n children: null\n type: array[string]\n selected: false\n title: 'Code: Extract Email ID'\n type: code\n variables:\n - value_selector:\n - '1716946869294'\n - text\n variable: message\n height: 53\n id: '1716909112104'\n position:\n x: 638\n y: 445\n positionAbsolute:\n x: 638\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n desc: ''\n height: 490\n iterator_selector:\n - '1716909112104'\n - result\n output_selector:\n - '1716909125498'\n - text\n output_type: array[string]\n selected: false\n startNodeType: tool\n start_node_id: '1716946889408'\n title: 'Iteraction '\n type: iteration\n width: 3393.7520359289056\n height: 490\n id: '1716909114582'\n position:\n x: 942\n y: 445\n positionAbsolute:\n x: 942\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 3394\n zIndex: 1\n - data:\n desc: ''\n isInIteration: true\n isIterationStart: true\n iteration_id: '1716909114582'\n provider_id: e64b4c7f-2795-499c-8d11-a971a7d57fc9\n provider_name: List and Get Gmail\n provider_type: api\n selected: false\n title: getMessage\n tool_configurations: {}\n tool_label: getMessage\n tool_name: getMessage\n tool_parameters:\n format:\n type: mixed\n value: full\n id:\n type: mixed\n value: '{{#1716909114582.item#}}'\n userId:\n type: mixed\n value: '{{#1716800588219.email#}}'\n type: tool\n extent: parent\n height: 53\n id: '1716946889408'\n parentId: '1716909114582'\n position:\n x: 117\n y: 85\n positionAbsolute:\n x: 1059\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1001\n - data:\n code: \"\\ndef main(email_json: dict) -> dict:\\n import json \\n email_dict\\\n \\ = json.loads(email_json)\\n base64_data = email_dict['payload']['parts'][0]['body']['data']\\n\\\n \\n return {\\n \\\"result\\\": base64_data, \\n }\\n\"\n code_language: python3\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n outputs:\n result:\n children: null\n type: string\n selected: false\n title: 'Code: Extract Email Body'\n type: code\n variables:\n - value_selector:\n - '1716946889408'\n - text\n variable: email_json\n extent: parent\n height: 53\n id: '1716909122343'\n parentId: '1716909114582'\n position:\n x: 421\n y: 85\n positionAbsolute:\n x: 1363\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: 'Generate reply. '\n isInIteration: true\n iteration_id: '1716909114582'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: 982014aa-702b-4d7c-ae1f-08dbceb6e930\n role: system\n text: \" \\nRespond to the emails. \\n\\n{{#1716913272656.text#}}\\n\\\n \"\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 127\n id: '1716909125498'\n parentId: '1716909114582'\n position:\n x: 1625\n y: 85\n positionAbsolute:\n x: 2567\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: fd8de569-c099-4320-955b-61aa4b054789\n role: system\n text: \"\\nYou need to transform the input data (in base64 encoding)\\\n \\ to text. Input base64. Output text. \\n\\n{{#1716909122343.result#}}\\n\\\n \"\n selected: false\n title: 'Base64 Decoder '\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: false\n extent: parent\n height: 97\n id: '1716913272656'\n parentId: '1716909114582'\n position:\n x: 1025\n y: 85\n positionAbsolute:\n x: 1967\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 | join(\"\\n\\n -------------------------\\n\\n\") }}'\n title: 'Template '\n type: template-transform\n variables:\n - value_selector:\n - '1716909114582'\n - output\n variable: arg1\n height: 53\n id: '1716913435742'\n position:\n x: 4396\n y: 445\n positionAbsolute:\n x: 4396\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n desc: ''\n provider_id: e64b4c7f-2795-499c-8d11-a971a7d57fc9\n provider_name: List and Get Gmail\n provider_type: api\n selected: false\n title: listMessages\n tool_configurations: {}\n tool_label: listMessages\n tool_name: listMessages\n tool_parameters:\n maxResults:\n type: variable\n value:\n - '1716800588219'\n - maxResults\n userId:\n type: mixed\n value: '{{#1716800588219.email#}}'\n type: tool\n height: 53\n id: '1716946869294'\n position:\n x: 334\n y: 445\n positionAbsolute:\n x: 334\n y: 445\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: b7fd0ec5-864a-42c6-9d04-a1958bd4fc0d\n role: system\n text: \"\\nYou need to encode the input data from text to base64. Input\\\n \\ text. Output base64 encoding. Output nothing other than base64 encoding.\\\n \\ \\n\\n{{#1716951236700.output#}}\\n \"\n selected: false\n title: Base64 Encoder\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 97\n id: '1716951159073'\n parentId: '1716909114582'\n position:\n x: 2525.7520359289056\n y: 85\n positionAbsolute:\n x: 3467.7520359289056\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: Generate MIME email template\n isInIteration: true\n iteration_id: '1716909114582'\n selected: false\n template: \"Content-Type: text/plain; charset=\\\"utf-8\\\"\\r\\nContent-Transfer-Encoding:\\\n \\ 7bit\\r\\nMIME-Version: 1.0\\r\\nTo: {{ emailMetadata.recipientEmail }} #\\\n \\ xiaoyi@dify.ai\\r\\nFrom: {{ emailMetadata.senderEmail }} # sxy.hj156@gmail.com\\r\\\n \\nSubject: Re: {{ emailMetadata.subject }} \\r\\n\\r\\n{{ text }}\\r\\n\"\n title: 'Template: Reply Email'\n type: template-transform\n variables:\n - value_selector:\n - '1716951357236'\n - result\n variable: emailMetadata\n - value_selector:\n - '1716960791399'\n - output\n variable: text\n extent: parent\n height: 83\n id: '1716951236700'\n parentId: '1716909114582'\n position:\n x: 2231.269960149744\n y: 85\n positionAbsolute:\n x: 3173.269960149744\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n code: \"def main(email_json: dict) -> dict:\\n import json\\n if isinstance(email_json,\\\n \\ str): \\n email_json = json.loads(email_json)\\n\\n subject = None\\n\\\n \\ recipient_email = None \\n sender_email = None\\n \\n headers\\\n \\ = email_json['payload']['headers']\\n for header in headers:\\n \\\n \\ if header['name'] == 'Subject':\\n subject = header['value']\\n\\\n \\ elif header['name'] == 'To':\\n recipient_email = header['value']\\n\\\n \\ elif header['name'] == 'From':\\n sender_email = header['value']\\n\\\n \\n return {\\n \\\"result\\\": [subject, recipient_email, sender_email]\\n\\\n \\ }\\n\"\n code_language: python3\n desc: \"Recipient, Sender, Subject\\uFF0COutput Array[String]\"\n isInIteration: true\n iteration_id: '1716909114582'\n outputs:\n result:\n children: null\n type: array[string]\n selected: false\n title: Extract Email Metadata\n type: code\n variables:\n - value_selector:\n - '1716946889408'\n - text\n variable: email_json\n extent: parent\n height: 101\n id: '1716951357236'\n parentId: '1716909114582'\n position:\n x: 725\n y: 85\n positionAbsolute:\n x: 1667\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n selected: false\n template: '{\"raw\": \"{{ encoded_message }}\"}'\n title: \"Template\\uFF1AEmail Request Body\"\n type: template-transform\n variables:\n - value_selector:\n - '1716951159073'\n - text\n variable: encoded_message\n extent: parent\n height: 53\n id: '1716952228079'\n parentId: '1716909114582'\n position:\n x: 2828.4325280181324\n y: 86.31950791077293\n positionAbsolute:\n x: 3770.4325280181324\n y: 531.3195079107729\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n provider_id: 038963aa-43c8-47fc-be4b-0255c19959c1\n provider_name: Draft Gmail\n provider_type: api\n selected: false\n title: createDraft\n tool_configurations: {}\n tool_label: createDraft\n tool_name: createDraft\n tool_parameters:\n message:\n type: mixed\n value: '{{#1716952228079.output#}}'\n userId:\n type: mixed\n value: '{{#1716800588219.email#}}'\n type: tool\n extent: parent\n height: 53\n id: '1716952912103'\n parentId: '1716909114582'\n position:\n x: 3133.7520359289056\n y: 85\n positionAbsolute:\n x: 4075.7520359289056\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n classes:\n - id: '1'\n name: 'Technical questions, related to product '\n - id: '2'\n name: Unrelated to technicals, non technical\n - id: '1716960736883'\n name: Other questions\n desc: ''\n instructions: ''\n isInIteration: true\n iteration_id: '1716909114582'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n query_variable_selector:\n - '1716800588219'\n - sys.query\n selected: false\n title: Question Classifier\n topics: []\n type: question-classifier\n extent: parent\n height: 255\n id: '1716960721611'\n parentId: '1716909114582'\n position:\n x: 1325\n y: 85\n positionAbsolute:\n x: 2267\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - id: a639bbf8-bc58-42a2-b477-6748e80ecda2\n role: system\n text: \" \\nRespond to the emails. \\n\\n{{#1716913272656.text#}}\\n\\\n \"\n selected: false\n title: 'LLM - Non technical '\n type: llm\n variables: []\n vision:\n enabled: false\n extent: parent\n height: 97\n id: '1716960728136'\n parentId: '1716909114582'\n position:\n x: 1625\n y: 251\n positionAbsolute:\n x: 2567\n y: 696\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: ''\n isInIteration: true\n iteration_id: '1716909114582'\n output_type: string\n selected: false\n title: Variable Aggregator\n type: variable-aggregator\n variables:\n - - '1716909125498'\n - text\n - - '1716960728136'\n - text\n - - '1716960834468'\n - output\n extent: parent\n height: 164\n id: '1716960791399'\n parentId: '1716909114582'\n position:\n x: 1931.2699601497438\n y: 85\n positionAbsolute:\n x: 2873.269960149744\n y: 530\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: Other questions\n isInIteration: true\n iteration_id: '1716909114582'\n selected: false\n template: 'Sorry, I cannot answer that. This is outside my capabilities. '\n title: 'Direct Reply '\n type: template-transform\n variables: []\n extent: parent\n height: 83\n id: '1716960834468'\n parentId: '1716909114582'\n position:\n x: 1625\n y: 385.57142857142856\n positionAbsolute:\n x: 2567\n y: 830.5714285714286\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n author: Dify\n desc: ''\n height: 153\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":3,\"mode\":\"normal\",\"style\":\"font-size:\n 14px;\",\"text\":\"OpenAPI-Swagger for all custom tools: \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":3},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"openapi:\n 3.0.0\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"info:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" title:\n Gmail API\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n OpenAPI schema for Gmail API methods `users.messages.get`, `users.messages.list`,\n and `users.drafts.create`.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" version:\n 1.0.0\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"servers:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n url: https://gmail.googleapis.com\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Gmail API Server\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"paths:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" /gmail/v1/users/{userId}/messages/{id}:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" get:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" summary:\n Get a message.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Retrieves a specific message by ID.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" operationId:\n getMessage\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" parameters:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: userId\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n path\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n true\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The user''s email address. The special value `me` can be used to indicate\n the authenticated user.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: id\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n path\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n true\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The ID of the message to retrieve.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: format\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n query\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n false\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" enum:\n [full, metadata, minimal, raw]\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" default:\n full\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The format to return the message in.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" responses:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''200'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Successful response\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" content:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" application/json:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" id:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" threadId:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" labelIds:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n array\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" items:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" snippet:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" historyId:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" internalDate:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" payload:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" sizeEstimate:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n integer\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" raw:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''401'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Unauthorized\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''403'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Forbidden\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''404'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Not Found\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" /gmail/v1/users/{userId}/messages:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" get:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" summary:\n List messages.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Lists the messages in the user''s mailbox.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" operationId:\n listMessages\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" parameters:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: userId\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n path\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n true\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The user''s email address. The special value `me` can be used to indicate\n the authenticated user.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: maxResults\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n query\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n integer\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" format:\n int32\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" default:\n 100\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Maximum number of messages to return.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" responses:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''200'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Successful response\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" content:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" application/json:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" messages:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n array\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" items:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" id:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" threadId:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" nextPageToken:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" resultSizeEstimate:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n integer\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''401'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Unauthorized\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" /gmail/v1/users/{userId}/drafts:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" post:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" summary:\n Creates a new draft.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" operationId:\n createDraft\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" tags:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n Drafts\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" parameters:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n name: userId\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" in:\n path\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n true\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The user''s email address. The special value \\\"me\\\" can be used to indicate\n the authenticated user.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" requestBody:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" required:\n true\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" content:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" application/json:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" message:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" raw:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The entire email message in an RFC 2822 formatted and base64url encoded\n string.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" responses:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''200'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Successful response with the created draft.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" content:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" application/json:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" schema:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" id:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The immutable ID of the draft.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" message:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n object\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" properties:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" id:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The immutable ID of the message.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" threadId:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The ID of the thread the message belongs to.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" labelIds:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n array\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" items:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" snippet:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n A short part of the message text.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" historyId:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n string\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n The ID of the last history record that modified this message.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''400'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Bad Request - The request is invalid.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''401'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Unauthorized - Authentication is required.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''403'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Forbidden - The user does not have permission to create drafts.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''404'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Not Found - The specified user does not exist.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" ''500'':\",\"type\":\"text\",\"version\":1}],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" description:\n Internal Server Error - An error occurred on the server.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"components:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" securitySchemes:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" OAuth2:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" type:\n oauth2\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" flows:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" authorizationCode:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" authorizationUrl:\n https://accounts.google.com/o/oauth2/auth\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" tokenUrl:\n https://oauth2.googleapis.com/token\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" scopes:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" https://mail.google.com/:\n All access to Gmail.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" https://www.googleapis.com/auth/gmail.compose:\n Send email on your behalf.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" https://www.googleapis.com/auth/gmail.modify:\n Modify your email.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"security:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n OAuth2:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n https://mail.google.com/\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n https://www.googleapis.com/auth/gmail.compose\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\" -\n https://www.googleapis.com/auth/gmail.modify\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: yellow\n title: ''\n type: ''\n width: 367\n height: 153\n id: '1718992681576'\n position:\n x: 321.9646831030669\n y: 538.1642616264143\n positionAbsolute:\n x: 321.9646831030669\n y: 538.1642616264143\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 367\n - data:\n author: Dify\n desc: ''\n height: 158\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Replace\n custom tools after added this template to your own workspace. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Fill\n in \",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"your\n email \",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"and\n the \",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"\",\"text\":\"maximum\n number of results you want to retrieve from your inbox \",\"type\":\"text\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"to\n get started. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 287\n height: 158\n id: '1718992805687'\n position:\n x: 18.571428571428356\n y: 237.80887395992687\n positionAbsolute:\n x: 18.571428571428356\n y: 237.80887395992687\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 287\n - data:\n author: Dify\n desc: ''\n height: 375\n selected: true\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":1,\"mode\":\"normal\",\"style\":\"font-size:\n 16px;\",\"text\":\"Steps within Iteraction node: \",\"type\":\"text\",\"version\":1},{\"type\":\"linebreak\",\"version\":1},{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"1.\n getMessage: This step retrieves the incoming email message.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":1},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"2.\n Code: Extract Email Body: Custom code is executed to extract the body of\n the email from the retrieved message.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"3.\n Extract Email Metadata: Extracts metadata from the email, such as the recipient,\n sender, subject, and other relevant information.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"4.\n Base64 Decoder: Decodes the email content from Base64 encoding.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"5.\n Question Classifier (gpt-3.5-turbo): Uses a GPT-3.5-turbo model to classify\n the email content into different categories. For each classified question,\n the workflow uses a GPT-4.0 model to generate an appropriate reply:\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"6.\n Template: Reply Email: Uses a template to generate a MIME email format for\n the reply.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"6.\n Base64 Encoder: Encodes the generated reply email content back to Base64.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"7.\n Template: Email Request: Prepares the email request using a template.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"8.\n createDraft: Creates a draft of the email reply.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"This\n workflow automates the process of reading, classifying, responding to, and\n drafting replies to incoming emails, leveraging advanced language models\n to generate contextually appropriate responses.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 640\n height: 375\n id: '1718993366836'\n position:\n x: 966.7525290975368\n y: 971.80362905854\n positionAbsolute:\n x: 966.7525290975368\n y: 971.80362905854\n selected: true\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 640\n - data:\n author: Dify\n desc: ''\n height: 400\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":3,\"mode\":\"normal\",\"style\":\"font-size:\n 16px;\",\"text\":\"Preparation\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":3},{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Enable\n Gmail API in Google Cloud Console\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"listitem\",\"version\":1,\"value\":1},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Configure\n OAuth Client ID, OAuth Client Secrets, and OAuth Consent Screen for the\n Web Application in Google Cloud Console\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"listitem\",\"version\":1,\"value\":2},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Use\n Postman to authorize and obtain the OAuth Access Token (Google''s Access\n Token will expire after 1 hour and cannot be used for a long time)\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"listitem\",\"version\":1,\"value\":3}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"list\",\"version\":1,\"listType\":\"bullet\",\"start\":1,\"tag\":\"ul\"},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Users\n who want to try building an AI auto-reply email can refer to this document\n to use Postman (Postman.com) to obtain all the above keys: https://blog.postman.com/how-to-access-google-apis-using-oauth-in-postman/.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Developers\n who want to use Google OAuth to call the Gmail API to develop corresponding\n plugins can refer to this official document: \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"https://developers.google.com/identity/protocols/oauth2/web-server.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"At\n this stage, it is still a bit difficult to reproduce this example within\n the Dify platform. If you have development capabilities, developing the\n corresponding plugin externally and using an external database to automatically\n read and write the user''s Access Token and write the Refresh Token would\n be a better choice.\",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 608\n height: 400\n id: '1718993557447'\n position:\n x: 354.0157230378119\n y: -1.2732157979666\n positionAbsolute:\n x: 354.0157230378119\n y: -1.2732157979666\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 608\n viewport:\n x: 147.09446825757777\n y: 101.03530130020579\n zoom: 0.9548416039104178\n","icon":"\ud83e\udd16","icon_background":"#FFEAD5","id":"e9d92058-7d20-4904-892f-75d90bef7587","mode":"advanced-chat","name":"Automated Email Reply "}, + "98b87f88-bd22-4d86-8b74-86beba5e0ed4":{"export_data":"app:\n icon: \"\\U0001F916\"\n icon_background: '#FFEAD5'\n mode: workflow\n name: 'Book Translation '\nworkflow:\n features:\n file_upload:\n image:\n enabled: false\n number_limits: 3\n transfer_methods:\n - local_file\n - remote_url\n opening_statement: ''\n retriever_resource:\n enabled: false\n sensitive_word_avoidance:\n enabled: false\n speech_to_text:\n enabled: false\n suggested_questions: []\n suggested_questions_after_answer:\n enabled: false\n text_to_speech:\n enabled: false\n language: ''\n voice: ''\n graph:\n edges:\n - data:\n isInIteration: false\n sourceType: start\n targetType: code\n id: 1711067409646-source-1717916867969-target\n source: '1711067409646'\n sourceHandle: source\n target: '1717916867969'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: code\n targetType: iteration\n id: 1717916867969-source-1717916955547-target\n source: '1717916867969'\n sourceHandle: source\n target: '1717916955547'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: true\n iteration_id: '1717916955547'\n sourceType: llm\n targetType: llm\n id: 1717916961837-source-1717916977413-target\n source: '1717916961837'\n sourceHandle: source\n target: '1717916977413'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1717916955547'\n sourceType: llm\n targetType: llm\n id: 1717916977413-source-1717916984996-target\n source: '1717916977413'\n sourceHandle: source\n target: '1717916984996'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: true\n iteration_id: '1717916955547'\n sourceType: llm\n targetType: llm\n id: 1717916984996-source-1717916991709-target\n source: '1717916984996'\n sourceHandle: source\n target: '1717916991709'\n targetHandle: target\n type: custom\n zIndex: 1002\n - data:\n isInIteration: false\n sourceType: iteration\n targetType: template-transform\n id: 1717916955547-source-1717917057450-target\n source: '1717916955547'\n sourceHandle: source\n target: '1717917057450'\n targetHandle: target\n type: custom\n zIndex: 0\n - data:\n isInIteration: false\n sourceType: template-transform\n targetType: end\n id: 1717917057450-source-1711068257370-target\n source: '1717917057450'\n sourceHandle: source\n target: '1711068257370'\n targetHandle: target\n type: custom\n zIndex: 0\n nodes:\n - data:\n desc: ''\n selected: false\n title: Start\n type: start\n variables:\n - label: Input Text\n max_length: null\n options: []\n required: true\n type: paragraph\n variable: input_text\n dragging: false\n height: 89\n id: '1711067409646'\n position:\n x: 30\n y: 301.5\n positionAbsolute:\n x: 30\n y: 301.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1717917057450'\n - output\n variable: final\n selected: false\n title: End\n type: end\n height: 89\n id: '1711068257370'\n position:\n x: 2291\n y: 301.5\n positionAbsolute:\n x: 2291\n y: 301.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n code: \"\\ndef main(input_text: str) -> str:\\n token_limit = 1000\\n overlap\\\n \\ = 100\\n chunk_size = int(token_limit * 6 * (4/3))\\n\\n # Initialize\\\n \\ variables\\n chunks = []\\n start_index = 0\\n text_length = len(input_text)\\n\\\n \\n # Loop until the end of the text is reached\\n while start_index\\\n \\ < text_length:\\n # If we are not at the beginning, adjust the start_index\\\n \\ to ensure overlap\\n if start_index > 0:\\n start_index\\\n \\ -= overlap\\n\\n # Calculate end index for the current chunk\\n \\\n \\ end_index = start_index + chunk_size\\n if end_index > text_length:\\n\\\n \\ end_index = text_length\\n\\n # Add the current chunk\\\n \\ to the list\\n chunks.append(input_text[start_index:end_index])\\n\\\n \\n # Update the start_index for the next chunk\\n start_index\\\n \\ += chunk_size\\n\\n return {\\n \\\"chunks\\\": chunks,\\n }\\n\"\n code_language: python3\n dependencies: []\n desc: 'token_limit = 1000\n\n overlap = 100'\n outputs:\n chunks:\n children: null\n type: array[string]\n selected: false\n title: Code\n type: code\n variables:\n - value_selector:\n - '1711067409646'\n - input_text\n variable: input_text\n height: 101\n id: '1717916867969'\n position:\n x: 336\n y: 301.5\n positionAbsolute:\n x: 336\n y: 301.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n desc: 'Take good care on maximum number of iterations. '\n height: 203\n iterator_selector:\n - '1717916867969'\n - chunks\n output_selector:\n - '1717916991709'\n - text\n output_type: array[string]\n selected: false\n startNodeType: llm\n start_node_id: '1717916961837'\n title: Iteration\n type: iteration\n width: 1289\n height: 203\n id: '1717916955547'\n position:\n x: 638\n y: 301.5\n positionAbsolute:\n x: 638\n y: 301.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 1289\n zIndex: 1\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n isIterationStart: true\n iteration_id: '1717916955547'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: 7261280b-cb27-4f84-8363-b93e09246d16\n role: system\n text: \" Identify the technical terms in the users input. Use the following\\\n \\ format {XXX} -> {XXX} to show the corresponding technical terms before\\\n \\ and after translation. \\n\\n \\n{{#1717916955547.item#}}\\n\\\n \\n\\n| \\u82F1\\u6587 | \\u4E2D\\u6587 |\\n| --- | --- |\\n| Prompt\\\n \\ Engineering | \\u63D0\\u793A\\u8BCD\\u5DE5\\u7A0B |\\n| Text Generation \\_\\\n | \\u6587\\u672C\\u751F\\u6210 |\\n| Token \\_| Token |\\n| Prompt \\_| \\u63D0\\\n \\u793A\\u8BCD |\\n| Meta Prompting \\_| \\u5143\\u63D0\\u793A |\\n| diffusion\\\n \\ models \\_| \\u6269\\u6563\\u6A21\\u578B |\\n| Agent \\_| \\u667A\\u80FD\\u4F53\\\n \\ |\\n| Transformer \\_| Transformer |\\n| Zero Shot \\_| \\u96F6\\u6837\\u672C\\\n \\ |\\n| Few Shot \\_| \\u5C11\\u6837\\u672C |\\n| chat window \\_| \\u804A\\u5929\\\n \\ |\\n| context | \\u4E0A\\u4E0B\\u6587 |\\n| stock photo \\_| \\u56FE\\u5E93\\u7167\\\n \\u7247 |\\n\\n\\n \"\n selected: false\n title: 'Identify Terms '\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 97\n id: '1717916961837'\n parentId: '1717916955547'\n position:\n x: 117\n y: 85\n positionAbsolute:\n x: 755\n y: 386.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1001\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1717916955547'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: 05e03f0d-c1a9-43ab-b4c0-44b55049434d\n role: system\n text: \" You are a professional translator proficient in Simplified\\\n \\ Chinese especially skilled in translating professional academic papers\\\n \\ into easy-to-understand popular science articles. Please help me translate\\\n \\ the following english paragraph into Chinese, in a style similar to\\\n \\ Chinese popular science articles .\\n \\nTranslate directly\\\n \\ based on the English content, maintain the original format and do not\\\n \\ omit any information. \\n \\n{{#1717916955547.item#}}\\n\\\n \\n{{#1717916961837.text#}}\\n \"\n selected: false\n title: 1st Translation\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 97\n id: '1717916977413'\n parentId: '1717916955547'\n position:\n x: 421\n y: 85\n positionAbsolute:\n x: 1059\n y: 386.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1717916955547'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: 9e6cc050-465e-4632-abc9-411acb255a95\n role: system\n text: \"\\nBased on the results of the direct translation, point out\\\n \\ specific issues it have. Accurate descriptions are required, avoiding\\\n \\ vague statements, and there's no need to add content or formats that\\\n \\ were not present in the original text, including but not limited to:\\\n \\ \\n- inconsistent with chinese expression habits, clearly indicate where\\\n \\ it does not conform\\n- Clumsy sentences, specify the location, no need\\\n \\ to offer suggestions for modification, which will be fixed during free\\\n \\ translation\\n- Obscure and difficult to understand, attempts to explain\\\n \\ may be made\\n- \\u65E0\\u6F0F\\u8BD1\\uFF08\\u539F\\u2F42\\u4E2D\\u7684\\u5173\\\n \\u952E\\u8BCD\\u3001\\u53E5\\u2F26\\u3001\\u6BB5\\u843D\\u90FD\\u5E94\\u4F53\\u73B0\\\n \\u5728\\u8BD1\\u2F42\\u4E2D\\uFF09\\u3002\\n- \\u2F46\\u9519\\u8BD1\\uFF08\\u770B\\\n \\u9519\\u539F\\u2F42\\u3001\\u8BEF\\u89E3\\u539F\\u2F42\\u610F\\u601D\\u5747\\u7B97\\\n \\u9519\\u8BD1\\uFF09\\u3002\\n- \\u2F46\\u6709\\u610F\\u589E\\u52A0\\u6216\\u8005\\\n \\u5220\\u51CF\\u7684\\u539F\\u2F42\\u5185\\u5BB9\\uFF08\\u7FFB\\u8BD1\\u5E76\\u2FAE\\\n \\u521B\\u4F5C\\uFF0C\\u9700\\u5C0A\\u91CD\\u4F5C\\u8005\\u89C2 \\u70B9\\uFF1B\\u53EF\\\n \\u4EE5\\u9002\\u5F53\\u52A0\\u8BD1\\u8005\\u6CE8\\u8BF4\\u660E\\uFF09\\u3002\\n-\\\n \\ \\u8BD1\\u2F42\\u6D41\\u7545\\uFF0C\\u7B26\\u5408\\u4E2D\\u2F42\\u8868\\u8FBE\\u4E60\\\n \\u60EF\\u3002\\n- \\u5173\\u4E8E\\u2F08\\u540D\\u7684\\u7FFB\\u8BD1\\u3002\\u6280\\\n \\u672F\\u56FE\\u4E66\\u4E2D\\u7684\\u2F08\\u540D\\u901A\\u5E38\\u4E0D\\u7FFB\\u8BD1\\\n \\uFF0C\\u4F46\\u662F\\u2F00\\u4E9B\\u4F17\\u6240 \\u5468\\u77E5\\u7684\\u2F08\\u540D\\\n \\u9700\\u2F64\\u4E2D\\u2F42\\uFF08\\u5982\\u4E54\\u5E03\\u65AF\\uFF09\\u3002\\n-\\\n \\ \\u5173\\u4E8E\\u4E66\\u540D\\u7684\\u7FFB\\u8BD1\\u3002\\u6709\\u4E2D\\u2F42\\u7248\\\n \\u7684\\u56FE\\u4E66\\uFF0C\\u8BF7\\u2F64\\u4E2D\\u2F42\\u7248\\u4E66\\u540D\\uFF1B\\\n \\u2F46\\u4E2D\\u2F42\\u7248 \\u7684\\u56FE\\u4E66\\uFF0C\\u76F4\\u63A5\\u2F64\\u82F1\\\n \\u2F42\\u4E66\\u540D\\u3002\\n- \\u5173\\u4E8E\\u56FE\\u8868\\u7684\\u7FFB\\u8BD1\\\n \\u3002\\u8868\\u683C\\u4E2D\\u7684\\u8868\\u9898\\u3001\\u8868\\u5B57\\u548C\\u6CE8\\\n \\u89E3\\u7B49\\u5747\\u9700\\u7FFB\\u8BD1\\u3002\\u56FE\\u9898 \\u9700\\u8981\\u7FFB\\\n \\u8BD1\\u3002\\u754C\\u2FAF\\u622A\\u56FE\\u4E0D\\u9700\\u8981\\u7FFB\\u8BD1\\u56FE\\\n \\u5B57\\u3002\\u89E3\\u91CA\\u6027\\u56FE\\u9700\\u8981\\u6309\\u7167\\u4E2D\\u82F1\\\n \\u2F42 \\u5BF9\\u7167\\u683C\\u5F0F\\u7ED9\\u51FA\\u56FE\\u5B57\\u7FFB\\u8BD1\\u3002\\\n \\n- \\u5173\\u4E8E\\u82F1\\u2F42\\u672F\\u8BED\\u7684\\u8868\\u8FF0\\u3002\\u82F1\\\n \\u2F42\\u672F\\u8BED\\u2FB8\\u6B21\\u51FA\\u73B0\\u65F6\\uFF0C\\u5E94\\u8BE5\\u6839\\\n \\u636E\\u8BE5\\u672F\\u8BED\\u7684 \\u6D41\\u2F8F\\u60C5\\u51B5\\uFF0C\\u4F18\\u5148\\\n \\u4F7F\\u2F64\\u7B80\\u5199\\u5F62\\u5F0F\\uFF0C\\u5E76\\u5728\\u5176\\u540E\\u4F7F\\\n \\u2F64\\u62EC\\u53F7\\u52A0\\u82F1\\u2F42\\u3001\\u4E2D\\u2F42 \\u5168\\u79F0\\u6CE8\\\n \\u89E3\\uFF0C\\u683C\\u5F0F\\u4E3A\\uFF08\\u4E3E\\u4F8B\\uFF09\\uFF1AHTML\\uFF08\\\n Hypertext Markup Language\\uFF0C\\u8D85\\u2F42\\u672C\\u6807\\u8BC6\\u8BED\\u2F94\\\n \\uFF09\\u3002\\u7136\\u540E\\u5728\\u4E0B\\u2F42\\u4E2D\\u76F4\\u63A5\\u4F7F\\u2F64\\\n \\u7B80\\u5199\\u5F62 \\u5F0F\\u3002\\u5F53\\u7136\\uFF0C\\u5FC5\\u8981\\u65F6\\u4E5F\\\n \\u53EF\\u4EE5\\u6839\\u636E\\u8BED\\u5883\\u4F7F\\u2F64\\u4E2D\\u3001\\u82F1\\u2F42\\\n \\u5168\\u79F0\\u3002\\n- \\u5173\\u4E8E\\u4EE3\\u7801\\u6E05\\u5355\\u548C\\u4EE3\\\n \\u7801\\u2F5A\\u6BB5\\u3002\\u539F\\u4E66\\u4E2D\\u5305\\u542B\\u7684\\u7A0B\\u5E8F\\\n \\u4EE3\\u7801\\u4E0D\\u8981\\u6C42\\u8BD1\\u8005\\u5F55 \\u2F0A\\uFF0C\\u4F46\\u5E94\\\n \\u8BE5\\u4F7F\\u2F64\\u201C\\u539F\\u4E66P99\\u2EDA\\u4EE3\\u78011\\u201D\\uFF08\\\n \\u5373\\u539F\\u4E66\\u7B2C99\\u2EDA\\u4E2D\\u7684\\u7B2C\\u2F00\\u6BB5\\u4EE3 \\u7801\\\n \\uFF09\\u7684\\u683C\\u5F0F\\u4F5C\\u51FA\\u6807\\u6CE8\\u3002\\u540C\\u65F6\\uFF0C\\\n \\u8BD1\\u8005\\u5E94\\u8BE5\\u5728\\u6709\\u6761\\u4EF6\\u7684\\u60C5\\u51B5\\u4E0B\\\n \\u68C0\\u6838\\u4EE3 \\u7801\\u7684\\u6B63\\u786E\\u6027\\uFF0C\\u5BF9\\u53D1\\u73B0\\\n \\u7684\\u9519\\u8BEF\\u4EE5\\u8BD1\\u8005\\u6CE8\\u5F62\\u5F0F\\u8BF4\\u660E\\u3002\\\n \\u7A0B\\u5E8F\\u4EE3\\u7801\\u4E2D\\u7684\\u6CE8 \\u91CA\\u8981\\u6C42\\u7FFB\\u8BD1\\\n \\uFF0C\\u5982\\u679C\\u8BD1\\u7A3F\\u4E2D\\u6CA1\\u6709\\u4EE3\\u7801\\uFF0C\\u5219\\\n \\u5E94\\u8BE5\\u4EE5\\u2F00\\u53E5\\u82F1\\u2F42\\uFF08\\u6CE8\\u91CA\\uFF09 \\u2F00\\\n \\u53E5\\u4E2D\\u2F42\\uFF08\\u6CE8\\u91CA\\uFF09\\u7684\\u5F62\\u5F0F\\u7ED9\\u51FA\\\n \\u6CE8\\u91CA\\u3002\\n- \\u5173\\u4E8E\\u6807\\u70B9\\u7B26\\u53F7\\u3002\\u8BD1\\\n \\u7A3F\\u4E2D\\u7684\\u6807\\u70B9\\u7B26\\u53F7\\u8981\\u9075\\u5FAA\\u4E2D\\u2F42\\\n \\u8868\\u8FBE\\u4E60\\u60EF\\u548C\\u4E2D\\u2F42\\u6807 \\u70B9\\u7B26\\u53F7\\u7684\\\n \\u4F7F\\u2F64\\u4E60\\u60EF\\uFF0C\\u4E0D\\u80FD\\u7167\\u642C\\u539F\\u2F42\\u7684\\\n \\u6807\\u70B9\\u7B26\\u53F7\\u3002\\n\\n\\n{{#1717916977413.text#}}\\n\\\n \\n{{#1717916955547.item#}}\\n\\n{{#1717916961837.text#}}\\n\\\n \"\n selected: false\n title: 'Problems '\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 97\n id: '1717916984996'\n parentId: '1717916955547'\n position:\n x: 725\n y: 85\n positionAbsolute:\n x: 1363\n y: 386.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: ''\n isInIteration: true\n iteration_id: '1717916955547'\n model:\n completion_params:\n temperature: 0.7\n mode: chat\n name: gpt-4o\n provider: openai\n prompt_template:\n - id: 4d7ae758-2d7b-4404-ad9f-d6748ee64439\n role: system\n text: \"\\nBased on the results of the direct translation in the first\\\n \\ step and the problems identified in the second step, re-translate to\\\n \\ achieve a meaning-based interpretation. Ensure the original intent of\\\n \\ the content is preserved while making it easier to understand and more\\\n \\ in line with Chinese expression habits. All the while maintaining the\\\n \\ original format unchanged. \\n\\n\\n- inconsistent with chinese\\\n \\ expression habits, clearly indicate where it does not conform\\n- Clumsy\\\n \\ sentences, specify the location, no need to offer suggestions for modification,\\\n \\ which will be fixed during free translation\\n- Obscure and difficult\\\n \\ to understand, attempts to explain may be made\\n- \\u65E0\\u6F0F\\u8BD1\\\n \\uFF08\\u539F\\u2F42\\u4E2D\\u7684\\u5173\\u952E\\u8BCD\\u3001\\u53E5\\u2F26\\u3001\\\n \\u6BB5\\u843D\\u90FD\\u5E94\\u4F53\\u73B0\\u5728\\u8BD1\\u2F42\\u4E2D\\uFF09\\u3002\\\n \\n- \\u2F46\\u9519\\u8BD1\\uFF08\\u770B\\u9519\\u539F\\u2F42\\u3001\\u8BEF\\u89E3\\\n \\u539F\\u2F42\\u610F\\u601D\\u5747\\u7B97\\u9519\\u8BD1\\uFF09\\u3002\\n- \\u2F46\\\n \\u6709\\u610F\\u589E\\u52A0\\u6216\\u8005\\u5220\\u51CF\\u7684\\u539F\\u2F42\\u5185\\\n \\u5BB9\\uFF08\\u7FFB\\u8BD1\\u5E76\\u2FAE\\u521B\\u4F5C\\uFF0C\\u9700\\u5C0A\\u91CD\\\n \\u4F5C\\u8005\\u89C2 \\u70B9\\uFF1B\\u53EF\\u4EE5\\u9002\\u5F53\\u52A0\\u8BD1\\u8005\\\n \\u6CE8\\u8BF4\\u660E\\uFF09\\u3002\\n- \\u8BD1\\u2F42\\u6D41\\u7545\\uFF0C\\u7B26\\\n \\u5408\\u4E2D\\u2F42\\u8868\\u8FBE\\u4E60\\u60EF\\u3002\\n- \\u5173\\u4E8E\\u2F08\\\n \\u540D\\u7684\\u7FFB\\u8BD1\\u3002\\u6280\\u672F\\u56FE\\u4E66\\u4E2D\\u7684\\u2F08\\\n \\u540D\\u901A\\u5E38\\u4E0D\\u7FFB\\u8BD1\\uFF0C\\u4F46\\u662F\\u2F00\\u4E9B\\u4F17\\\n \\u6240 \\u5468\\u77E5\\u7684\\u2F08\\u540D\\u9700\\u2F64\\u4E2D\\u2F42\\uFF08\\u5982\\\n \\u4E54\\u5E03\\u65AF\\uFF09\\u3002\\n- \\u5173\\u4E8E\\u4E66\\u540D\\u7684\\u7FFB\\\n \\u8BD1\\u3002\\u6709\\u4E2D\\u2F42\\u7248\\u7684\\u56FE\\u4E66\\uFF0C\\u8BF7\\u2F64\\\n \\u4E2D\\u2F42\\u7248\\u4E66\\u540D\\uFF1B\\u2F46\\u4E2D\\u2F42\\u7248 \\u7684\\u56FE\\\n \\u4E66\\uFF0C\\u76F4\\u63A5\\u2F64\\u82F1\\u2F42\\u4E66\\u540D\\u3002\\n- \\u5173\\\n \\u4E8E\\u56FE\\u8868\\u7684\\u7FFB\\u8BD1\\u3002\\u8868\\u683C\\u4E2D\\u7684\\u8868\\\n \\u9898\\u3001\\u8868\\u5B57\\u548C\\u6CE8\\u89E3\\u7B49\\u5747\\u9700\\u7FFB\\u8BD1\\\n \\u3002\\u56FE\\u9898 \\u9700\\u8981\\u7FFB\\u8BD1\\u3002\\u754C\\u2FAF\\u622A\\u56FE\\\n \\u4E0D\\u9700\\u8981\\u7FFB\\u8BD1\\u56FE\\u5B57\\u3002\\u89E3\\u91CA\\u6027\\u56FE\\\n \\u9700\\u8981\\u6309\\u7167\\u4E2D\\u82F1\\u2F42 \\u5BF9\\u7167\\u683C\\u5F0F\\u7ED9\\\n \\u51FA\\u56FE\\u5B57\\u7FFB\\u8BD1\\u3002\\n- \\u5173\\u4E8E\\u82F1\\u2F42\\u672F\\\n \\u8BED\\u7684\\u8868\\u8FF0\\u3002\\u82F1\\u2F42\\u672F\\u8BED\\u2FB8\\u6B21\\u51FA\\\n \\u73B0\\u65F6\\uFF0C\\u5E94\\u8BE5\\u6839\\u636E\\u8BE5\\u672F\\u8BED\\u7684 \\u6D41\\\n \\u2F8F\\u60C5\\u51B5\\uFF0C\\u4F18\\u5148\\u4F7F\\u2F64\\u7B80\\u5199\\u5F62\\u5F0F\\\n \\uFF0C\\u5E76\\u5728\\u5176\\u540E\\u4F7F\\u2F64\\u62EC\\u53F7\\u52A0\\u82F1\\u2F42\\\n \\u3001\\u4E2D\\u2F42 \\u5168\\u79F0\\u6CE8\\u89E3\\uFF0C\\u683C\\u5F0F\\u4E3A\\uFF08\\\n \\u4E3E\\u4F8B\\uFF09\\uFF1AHTML\\uFF08Hypertext Markup Language\\uFF0C\\u8D85\\\n \\u2F42\\u672C\\u6807\\u8BC6\\u8BED\\u2F94\\uFF09\\u3002\\u7136\\u540E\\u5728\\u4E0B\\\n \\u2F42\\u4E2D\\u76F4\\u63A5\\u4F7F\\u2F64\\u7B80\\u5199\\u5F62 \\u5F0F\\u3002\\u5F53\\\n \\u7136\\uFF0C\\u5FC5\\u8981\\u65F6\\u4E5F\\u53EF\\u4EE5\\u6839\\u636E\\u8BED\\u5883\\\n \\u4F7F\\u2F64\\u4E2D\\u3001\\u82F1\\u2F42\\u5168\\u79F0\\u3002\\n- \\u5173\\u4E8E\\\n \\u4EE3\\u7801\\u6E05\\u5355\\u548C\\u4EE3\\u7801\\u2F5A\\u6BB5\\u3002\\u539F\\u4E66\\\n \\u4E2D\\u5305\\u542B\\u7684\\u7A0B\\u5E8F\\u4EE3\\u7801\\u4E0D\\u8981\\u6C42\\u8BD1\\\n \\u8005\\u5F55 \\u2F0A\\uFF0C\\u4F46\\u5E94\\u8BE5\\u4F7F\\u2F64\\u201C\\u539F\\u4E66\\\n P99\\u2EDA\\u4EE3\\u78011\\u201D\\uFF08\\u5373\\u539F\\u4E66\\u7B2C99\\u2EDA\\u4E2D\\\n \\u7684\\u7B2C\\u2F00\\u6BB5\\u4EE3 \\u7801\\uFF09\\u7684\\u683C\\u5F0F\\u4F5C\\u51FA\\\n \\u6807\\u6CE8\\u3002\\u540C\\u65F6\\uFF0C\\u8BD1\\u8005\\u5E94\\u8BE5\\u5728\\u6709\\\n \\u6761\\u4EF6\\u7684\\u60C5\\u51B5\\u4E0B\\u68C0\\u6838\\u4EE3 \\u7801\\u7684\\u6B63\\\n \\u786E\\u6027\\uFF0C\\u5BF9\\u53D1\\u73B0\\u7684\\u9519\\u8BEF\\u4EE5\\u8BD1\\u8005\\\n \\u6CE8\\u5F62\\u5F0F\\u8BF4\\u660E\\u3002\\u7A0B\\u5E8F\\u4EE3\\u7801\\u4E2D\\u7684\\\n \\u6CE8 \\u91CA\\u8981\\u6C42\\u7FFB\\u8BD1\\uFF0C\\u5982\\u679C\\u8BD1\\u7A3F\\u4E2D\\\n \\u6CA1\\u6709\\u4EE3\\u7801\\uFF0C\\u5219\\u5E94\\u8BE5\\u4EE5\\u2F00\\u53E5\\u82F1\\\n \\u2F42\\uFF08\\u6CE8\\u91CA\\uFF09 \\u2F00\\u53E5\\u4E2D\\u2F42\\uFF08\\u6CE8\\u91CA\\\n \\uFF09\\u7684\\u5F62\\u5F0F\\u7ED9\\u51FA\\u6CE8\\u91CA\\u3002\\n- \\u5173\\u4E8E\\\n \\u6807\\u70B9\\u7B26\\u53F7\\u3002\\u8BD1\\u7A3F\\u4E2D\\u7684\\u6807\\u70B9\\u7B26\\\n \\u53F7\\u8981\\u9075\\u5FAA\\u4E2D\\u2F42\\u8868\\u8FBE\\u4E60\\u60EF\\u548C\\u4E2D\\\n \\u2F42\\u6807 \\u70B9\\u7B26\\u53F7\\u7684\\u4F7F\\u2F64\\u4E60\\u60EF\\uFF0C\\u4E0D\\\n \\u80FD\\u7167\\u642C\\u539F\\u2F42\\u7684\\u6807\\u70B9\\u7B26\\u53F7\\u3002\\n\\n\\\n \\n{{#1717916977413.text#}}\\n\\n{{#1717916984996.text#}}\\n\\n{{#1711067409646.input_text#}}\\n\\\n \\n{{#1717916961837.text#}}\\n \"\n selected: false\n title: '2nd Translation '\n type: llm\n variables: []\n vision:\n configs:\n detail: high\n enabled: true\n extent: parent\n height: 97\n id: '1717916991709'\n parentId: '1717916955547'\n position:\n x: 1029\n y: 85\n positionAbsolute:\n x: 1667\n y: 386.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n zIndex: 1002\n - data:\n desc: 'Combine all chunks of translation. '\n selected: false\n template: '{{ translated_text | join('' '') }}'\n title: Template\n type: template-transform\n variables:\n - value_selector:\n - '1717916955547'\n - output\n variable: translated_text\n height: 83\n id: '1717917057450'\n position:\n x: 1987\n y: 301.5\n positionAbsolute:\n x: 1987\n y: 301.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 244\n - data:\n author: Dify\n desc: ''\n height: 186\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Code\n node separates the input_text into chunks with length of token_limit. Each\n chunk overlap with each other to make sure the texts are consistent. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"The\n code node outputs an array of segmented texts of input_texts. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 340\n height: 186\n id: '1718990593686'\n position:\n x: 259.3026056936437\n y: 451.6924912936374\n positionAbsolute:\n x: 259.3026056936437\n y: 451.6924912936374\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 340\n - data:\n author: Dify\n desc: ''\n height: 128\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Iterate\n through all the elements in output of the code node and translate each chunk\n using a three steps translation workflow. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: blue\n title: ''\n type: ''\n width: 355\n height: 128\n id: '1718991836605'\n position:\n x: 764.3891977435923\n y: 530.8917807505335\n positionAbsolute:\n x: 764.3891977435923\n y: 530.8917807505335\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 355\n - data:\n author: Dify\n desc: ''\n height: 126\n selected: false\n showAuthor: true\n text: '{\"root\":{\"children\":[{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Avoid\n using a high token_limit, LLM''s performance decreases with longer context\n length for gpt-4o. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[],\"direction\":null,\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0},{\"children\":[{\"detail\":0,\"format\":0,\"mode\":\"normal\",\"style\":\"\",\"text\":\"Recommend\n to use less than or equal to 1000 tokens. \",\"type\":\"text\",\"version\":1}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"paragraph\",\"version\":1,\"textFormat\":0}],\"direction\":\"ltr\",\"format\":\"\",\"indent\":0,\"type\":\"root\",\"version\":1}}'\n theme: yellow\n title: ''\n type: ''\n width: 351\n height: 126\n id: '1718991882984'\n position:\n x: 304.49115824454367\n y: 148.4042994607805\n positionAbsolute:\n x: 304.49115824454367\n y: 148.4042994607805\n selected: true\n sourcePosition: right\n targetPosition: left\n type: custom-note\n width: 351\n viewport:\n x: 335.92505067152274\n y: 18.806553508850584\n zoom: 0.8705505632961259\n","icon":"\ud83e\udd16","icon_background":"#FFEAD5","id":"98b87f88-bd22-4d86-8b74-86beba5e0ed4","mode":"workflow","name":"Book Translation "}, "cae337e6-aec5-4c7b-beca-d6f1a808bd5e":{ "export_data": "app:\n icon: \"\\U0001F916\"\n icon_background: '#FFEAD5'\n mode: chat\n name: Python bug fixer\nmodel_config:\n agent_mode:\n enabled: false\n max_iteration: 5\n strategy: function_call\n tools: []\n annotation_reply:\n enabled: false\n chat_prompt_config: {}\n completion_prompt_config: {}\n dataset_configs:\n datasets:\n datasets: []\n retrieval_model: single\n dataset_query_variable: ''\n external_data_tools: []\n file_upload:\n image:\n detail: high\n enabled: false\n number_limits: 3\n transfer_methods:\n - remote_url\n - local_file\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n stop: []\n temperature: 0\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n more_like_this:\n enabled: false\n opening_statement: ''\n pre_prompt: Your task is to analyze the provided Python code snippet, identify any\n bugs or errors present, and provide a corrected version of the code that resolves\n these issues. Explain the problems you found in the original code and how your\n fixes address them. The corrected code should be functional, efficient, and adhere\n to best practices in Python programming.\n prompt_type: simple\n retriever_resource:\n enabled: false\n sensitive_word_avoidance:\n configs: []\n enabled: false\n type: ''\n speech_to_text:\n enabled: false\n suggested_questions: []\n suggested_questions_after_answer:\n enabled: false\n text_to_speech:\n enabled: false\n language: ''\n voice: ''\n user_input_form: []\n", "icon": "🤖", @@ -553,15 +553,15 @@ "name": "AI Front-end interviewer" }, "e9870913-dd01-4710-9f06-15d4180ca1ce": { - "export_data": "app:\n icon: \"\\U0001F916\"\n icon_background: '#FFEAD5'\n mode: advanced-chat\n name: 'Knowledge Retreival + Chatbot '\nworkflow:\n features:\n file_upload:\n image:\n enabled: false\n number_limits: 3\n transfer_methods:\n - local_file\n - remote_url\n opening_statement: ''\n retriever_resource:\n enabled: false\n sensitive_word_avoidance:\n enabled: false\n speech_to_text:\n enabled: false\n suggested_questions: []\n suggested_questions_after_answer:\n enabled: false\n text_to_speech:\n enabled: false\n language: ''\n voice: ''\n graph:\n edges:\n - data:\n sourceType: start\n targetType: knowledge-retrieval\n id: 1711528914102-1711528915811\n source: '1711528914102'\n sourceHandle: source\n target: '1711528915811'\n targetHandle: target\n type: custom\n - data:\n sourceType: knowledge-retrieval\n targetType: llm\n id: 1711528915811-1711528917469\n source: '1711528915811'\n sourceHandle: source\n target: '1711528917469'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: answer\n id: 1711528917469-1711528919501\n source: '1711528917469'\n sourceHandle: source\n target: '1711528919501'\n targetHandle: target\n type: custom\n nodes:\n - data:\n desc: ''\n selected: true\n title: Start\n type: start\n variables: []\n height: 53\n id: '1711528914102'\n position:\n x: 79.5\n y: 2634.5\n positionAbsolute:\n x: 79.5\n y: 2634.5\n selected: true\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n dataset_ids:\n - 6084ed3f-d100-4df2-a277-b40d639ea7c6\n desc: Allows you to query text content related to user questions from the\n Knowledge\n query_variable_selector:\n - '1711528914102'\n - sys.query\n retrieval_mode: single\n selected: false\n single_retrieval_config:\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n title: Knowledge Retrieval\n type: knowledge-retrieval\n dragging: false\n height: 101\n id: '1711528915811'\n position:\n x: 362.5\n y: 2634.5\n positionAbsolute:\n x: 362.5\n y: 2634.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Invoking large language models to answer questions or process natural\n language\n memory:\n role_prefix:\n assistant: ''\n user: ''\n window:\n enabled: false\n size: 50\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: \"You are a helpful assistant. \\nUse the following context as your\\\n \\ learned knowledge, inside XML tags.\\n\\n\\\n {{#context#}}\\n\\nWhen answer to user:\\n- If you don't know,\\\n \\ just say that you don't know.\\n- If you don't know when you are not\\\n \\ sure, ask for clarification.\\nAvoid mentioning that you obtained the\\\n \\ information from the context.\\nAnd answer according to the language\\\n \\ of the user's question.\"\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n height: 163\n id: '1711528917469'\n position:\n x: 645.5\n y: 2634.5\n positionAbsolute:\n x: 645.5\n y: 2634.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n answer: '{{#1711528917469.text#}}'\n desc: ''\n selected: false\n title: Answer\n type: answer\n variables: []\n height: 105\n id: '1711528919501'\n position:\n x: 928.5\n y: 2634.5\n positionAbsolute:\n x: 928.5\n y: 2634.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n viewport:\n x: 86.31278232100044\n y: -2276.452137533831\n zoom: 0.9753554615276419\n", + "export_data": "app:\n icon: \"\\U0001F916\"\n icon_background: '#FFEAD5'\n mode: advanced-chat\n name: 'Knowledge Retrieval + Chatbot '\nworkflow:\n features:\n file_upload:\n image:\n enabled: false\n number_limits: 3\n transfer_methods:\n - local_file\n - remote_url\n opening_statement: ''\n retriever_resource:\n enabled: false\n sensitive_word_avoidance:\n enabled: false\n speech_to_text:\n enabled: false\n suggested_questions: []\n suggested_questions_after_answer:\n enabled: false\n text_to_speech:\n enabled: false\n language: ''\n voice: ''\n graph:\n edges:\n - data:\n sourceType: start\n targetType: knowledge-retrieval\n id: 1711528914102-1711528915811\n source: '1711528914102'\n sourceHandle: source\n target: '1711528915811'\n targetHandle: target\n type: custom\n - data:\n sourceType: knowledge-retrieval\n targetType: llm\n id: 1711528915811-1711528917469\n source: '1711528915811'\n sourceHandle: source\n target: '1711528917469'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: answer\n id: 1711528917469-1711528919501\n source: '1711528917469'\n sourceHandle: source\n target: '1711528919501'\n targetHandle: target\n type: custom\n nodes:\n - data:\n desc: ''\n selected: true\n title: Start\n type: start\n variables: []\n height: 53\n id: '1711528914102'\n position:\n x: 79.5\n y: 2634.5\n positionAbsolute:\n x: 79.5\n y: 2634.5\n selected: true\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n dataset_ids:\n - 6084ed3f-d100-4df2-a277-b40d639ea7c6\n desc: Allows you to query text content related to user questions from the\n Knowledge\n query_variable_selector:\n - '1711528914102'\n - sys.query\n retrieval_mode: single\n selected: false\n single_retrieval_config:\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n title: Knowledge Retrieval\n type: knowledge-retrieval\n dragging: false\n height: 101\n id: '1711528915811'\n position:\n x: 362.5\n y: 2634.5\n positionAbsolute:\n x: 362.5\n y: 2634.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Invoking large language models to answer questions or process natural\n language\n memory:\n role_prefix:\n assistant: ''\n user: ''\n window:\n enabled: false\n size: 50\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: \"You are a helpful assistant. \\nUse the following context as your\\\n \\ learned knowledge, inside XML tags.\\n\\n\\\n {{#context#}}\\n\\nWhen answer to user:\\n- If you don't know,\\\n \\ just say that you don't know.\\n- If you don't know when you are not\\\n \\ sure, ask for clarification.\\nAvoid mentioning that you obtained the\\\n \\ information from the context.\\nAnd answer according to the language\\\n \\ of the user's question.\"\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n height: 163\n id: '1711528917469'\n position:\n x: 645.5\n y: 2634.5\n positionAbsolute:\n x: 645.5\n y: 2634.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n answer: '{{#1711528917469.text#}}'\n desc: ''\n selected: false\n title: Answer\n type: answer\n variables: []\n height: 105\n id: '1711528919501'\n position:\n x: 928.5\n y: 2634.5\n positionAbsolute:\n x: 928.5\n y: 2634.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n viewport:\n x: 86.31278232100044\n y: -2276.452137533831\n zoom: 0.9753554615276419\n", "icon": "🤖", "icon_background": "#FFEAD5", "id": "e9870913-dd01-4710-9f06-15d4180ca1ce", "mode": "advanced-chat", - "name": "Knowledge Retreival + Chatbot " + "name": "Knowledge Retrieval + Chatbot " }, "dd5b6353-ae9b-4bce-be6a-a681a12cf709":{ - "export_data": "app:\n icon: \"\\U0001F916\"\n icon_background: '#FFEAD5'\n mode: workflow\n name: 'Email Assistant Workflow '\nworkflow:\n features:\n file_upload:\n image:\n enabled: false\n number_limits: 3\n transfer_methods:\n - local_file\n - remote_url\n opening_statement: ''\n retriever_resource:\n enabled: false\n sensitive_word_avoidance:\n enabled: false\n speech_to_text:\n enabled: false\n suggested_questions: []\n suggested_questions_after_answer:\n enabled: false\n text_to_speech:\n enabled: false\n language: ''\n voice: ''\n graph:\n edges:\n - data:\n sourceType: start\n targetType: question-classifier\n id: 1711511281652-1711512802873\n source: '1711511281652'\n sourceHandle: source\n target: '1711512802873'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: question-classifier\n id: 1711512802873-1711512837494\n source: '1711512802873'\n sourceHandle: '1711512813038'\n target: '1711512837494'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512802873-1711512911454\n source: '1711512802873'\n sourceHandle: '1711512811520'\n target: '1711512911454'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512802873-1711512914870\n source: '1711512802873'\n sourceHandle: '1711512812031'\n target: '1711512914870'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512802873-1711512916516\n source: '1711512802873'\n sourceHandle: '1711512812510'\n target: '1711512916516'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512837494-1711512924231\n source: '1711512837494'\n sourceHandle: '1711512846439'\n target: '1711512924231'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512837494-1711512926020\n source: '1711512837494'\n sourceHandle: '1711512847112'\n target: '1711512926020'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512837494-1711512927569\n source: '1711512837494'\n sourceHandle: '1711512847641'\n target: '1711512927569'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512837494-1711512929190\n source: '1711512837494'\n sourceHandle: '1711512848120'\n target: '1711512929190'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512837494-1711512930700\n source: '1711512837494'\n sourceHandle: '1711512848616'\n target: '1711512930700'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512911454-1711513015189\n source: '1711512911454'\n sourceHandle: source\n target: '1711513015189'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512914870-1711513017096\n source: '1711512914870'\n sourceHandle: source\n target: '1711513017096'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512916516-1711513018759\n source: '1711512916516'\n sourceHandle: source\n target: '1711513018759'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512924231-1711513020857\n source: '1711512924231'\n sourceHandle: source\n target: '1711513020857'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512926020-1711513022516\n source: '1711512926020'\n sourceHandle: source\n target: '1711513022516'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512927569-1711513024315\n source: '1711512927569'\n sourceHandle: source\n target: '1711513024315'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512929190-1711513025732\n source: '1711512929190'\n sourceHandle: source\n target: '1711513025732'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512930700-1711513027347\n source: '1711512930700'\n sourceHandle: source\n target: '1711513027347'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513015189-1711513029058\n source: '1711513015189'\n sourceHandle: source\n target: '1711513029058'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513017096-1711513030924\n source: '1711513017096'\n sourceHandle: source\n target: '1711513030924'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513018759-1711513032459\n source: '1711513018759'\n sourceHandle: source\n target: '1711513032459'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513020857-1711513034850\n source: '1711513020857'\n sourceHandle: source\n target: '1711513034850'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513022516-1711513036356\n source: '1711513022516'\n sourceHandle: source\n target: '1711513036356'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513024315-1711513037973\n source: '1711513024315'\n sourceHandle: source\n target: '1711513037973'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513025732-1711513039350\n source: '1711513025732'\n sourceHandle: source\n target: '1711513039350'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513027347-1711513041219\n source: '1711513027347'\n sourceHandle: source\n target: '1711513041219'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512802873-1711513940609\n source: '1711512802873'\n sourceHandle: '1711513927279'\n target: '1711513940609'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711513940609-1711513967853\n source: '1711513940609'\n sourceHandle: source\n target: '1711513967853'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513967853-1711513974643\n source: '1711513967853'\n sourceHandle: source\n target: '1711513974643'\n targetHandle: target\n type: custom\n nodes:\n - data:\n desc: ''\n selected: true\n title: Start\n type: start\n variables:\n - label: Email\n max_length: null\n options: []\n required: true\n type: paragraph\n variable: Input_Text\n - label: What do you need to do? (Summarize / Reply / Write / Improve)\n max_length: 48\n options:\n - Summarize\n - 'Reply '\n - Write a email\n - 'Improve writings '\n required: true\n type: select\n variable: user_request\n - label: 'How do you want it to be polished? (Optional) '\n max_length: 48\n options:\n - 'Imporve writing and clarity '\n - Shorten\n - 'Lengthen '\n - 'Simplify '\n - Rewrite in my voice\n required: false\n type: select\n variable: how_polish\n dragging: false\n height: 141\n id: '1711511281652'\n position:\n x: 79.5\n y: 409.5\n positionAbsolute:\n x: 79.5\n y: 409.5\n selected: true\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n classes:\n - id: '1711512811520'\n name: Summarize\n - id: '1711512812031'\n name: Reply to emails\n - id: '1711512812510'\n name: Help me write the email\n - id: '1711512813038'\n name: Improve writings or polish\n - id: '1711513927279'\n name: Grammer check\n desc: 'Classify users'' demands. '\n instructions: ''\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n query_variable_selector:\n - '1711511281652'\n - user_request\n selected: false\n title: 'Question Classifier '\n topics: []\n type: question-classifier\n dragging: false\n height: 333\n id: '1711512802873'\n position:\n x: 362.5\n y: 409.5\n positionAbsolute:\n x: 362.5\n y: 409.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n classes:\n - id: '1711512846439'\n name: 'Improve writing and clarity '\n - id: '1711512847112'\n name: 'Shorten '\n - id: '1711512847641'\n name: 'Lengthen '\n - id: '1711512848120'\n name: 'Simplify '\n - id: '1711512848616'\n name: Rewrite in my voice\n desc: 'Improve writings. '\n instructions: ''\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n query_variable_selector:\n - '1711511281652'\n - how_polish\n selected: false\n title: 'Question Classifier '\n topics: []\n type: question-classifier\n dragging: false\n height: 333\n id: '1711512837494'\n position:\n x: 645.5\n y: 409.5\n positionAbsolute:\n x: 645.5\n y: 409.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Summary\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Summary the email for me. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512911454'\n position:\n x: 645.5\n y: 1327.5\n positionAbsolute:\n x: 645.5\n y: 1327.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Reply\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Rely the emails for me, in my own voice. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512914870'\n position:\n x: 645.5\n y: 1518.5\n positionAbsolute:\n x: 645.5\n y: 1518.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Turn idea into email\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Turn my idea into email. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512916516'\n position:\n x: 645.5\n y: 1709.5\n positionAbsolute:\n x: 645.5\n y: 1709.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: 'Improve the clarity. '\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: \" Imporve the clarity of the email for me. \\n{{#1711511281652.Input_Text#}}\\n\\\n \"\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512924231'\n position:\n x: 928.5\n y: 409.5\n positionAbsolute:\n x: 928.5\n y: 409.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: 'Shorten. '\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Shorten the email for me. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512926020'\n position:\n x: 928.5\n y: 600.5\n positionAbsolute:\n x: 928.5\n y: 600.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: 'Lengthen '\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Lengthen the email for me. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512927569'\n position:\n x: 928.5\n y: 791.5\n positionAbsolute:\n x: 928.5\n y: 791.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Simplify\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Simplify the email for me. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512929190'\n position:\n x: 928.5\n y: 982.5\n positionAbsolute:\n x: 928.5\n y: 982.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Rewrite in my voice\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Rewrite the email for me. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512930700'\n position:\n x: 928.5\n y: 1173.5\n positionAbsolute:\n x: 928.5\n y: 1173.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template\n type: template-transform\n variables:\n - value_selector:\n - '1711512911454'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513015189'\n position:\n x: 928.5\n y: 1327.5\n positionAbsolute:\n x: 928.5\n y: 1327.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 2\n type: template-transform\n variables:\n - value_selector:\n - '1711512914870'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513017096'\n position:\n x: 928.5\n y: 1518.5\n positionAbsolute:\n x: 928.5\n y: 1518.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 3\n type: template-transform\n variables:\n - value_selector:\n - '1711512916516'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513018759'\n position:\n x: 928.5\n y: 1709.5\n positionAbsolute:\n x: 928.5\n y: 1709.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 4\n type: template-transform\n variables:\n - value_selector:\n - '1711512924231'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513020857'\n position:\n x: 1211.5\n y: 409.5\n positionAbsolute:\n x: 1211.5\n y: 409.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 5\n type: template-transform\n variables:\n - value_selector:\n - '1711512926020'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513022516'\n position:\n x: 1211.5\n y: 600.5\n positionAbsolute:\n x: 1211.5\n y: 600.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 6\n type: template-transform\n variables:\n - value_selector:\n - '1711512927569'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513024315'\n position:\n x: 1211.5\n y: 791.5\n positionAbsolute:\n x: 1211.5\n y: 791.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 7\n type: template-transform\n variables:\n - value_selector:\n - '1711512929190'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513025732'\n position:\n x: 1211.5\n y: 982.5\n positionAbsolute:\n x: 1211.5\n y: 982.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 8\n type: template-transform\n variables:\n - value_selector:\n - '1711512930700'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513027347'\n position:\n x: 1211.5\n y: 1173.5\n positionAbsolute:\n x: 1211.5\n y: 1173.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512911454'\n - text\n variable: text\n selected: false\n title: End\n type: end\n dragging: false\n height: 89\n id: '1711513029058'\n position:\n x: 1211.5\n y: 1327.5\n positionAbsolute:\n x: 1211.5\n y: 1327.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512914870'\n - text\n variable: text\n selected: false\n title: End 2\n type: end\n dragging: false\n height: 89\n id: '1711513030924'\n position:\n x: 1211.5\n y: 1518.5\n positionAbsolute:\n x: 1211.5\n y: 1518.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512916516'\n - text\n variable: text\n selected: false\n title: End 3\n type: end\n dragging: false\n height: 89\n id: '1711513032459'\n position:\n x: 1211.5\n y: 1709.5\n positionAbsolute:\n x: 1211.5\n y: 1709.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512924231'\n - text\n variable: text\n selected: false\n title: End 4\n type: end\n dragging: false\n height: 89\n id: '1711513034850'\n position:\n x: 1494.5\n y: 409.5\n positionAbsolute:\n x: 1494.5\n y: 409.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512926020'\n - text\n variable: text\n selected: false\n title: End 5\n type: end\n dragging: false\n height: 89\n id: '1711513036356'\n position:\n x: 1494.5\n y: 600.5\n positionAbsolute:\n x: 1494.5\n y: 600.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512927569'\n - text\n variable: text\n selected: false\n title: End 6\n type: end\n dragging: false\n height: 89\n id: '1711513037973'\n position:\n x: 1494.5\n y: 791.5\n positionAbsolute:\n x: 1494.5\n y: 791.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512929190'\n - text\n variable: text\n selected: false\n title: End 7\n type: end\n dragging: false\n height: 89\n id: '1711513039350'\n position:\n x: 1494.5\n y: 982.5\n positionAbsolute:\n x: 1494.5\n y: 982.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512930700'\n - text\n variable: text\n selected: false\n title: End 8\n type: end\n dragging: false\n height: 89\n id: '1711513041219'\n position:\n x: 1494.5\n y: 1173.5\n positionAbsolute:\n x: 1494.5\n y: 1173.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Grammer Check\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: 'Please check grammer of my email and comment on the grammer. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711513940609'\n position:\n x: 645.5\n y: 1900.5\n positionAbsolute:\n x: 645.5\n y: 1900.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 9\n type: template-transform\n variables:\n - value_selector:\n - '1711513940609'\n - text\n variable: arg1\n height: 53\n id: '1711513967853'\n position:\n x: 928.5\n y: 1900.5\n positionAbsolute:\n x: 928.5\n y: 1900.5\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711513940609'\n - text\n variable: text\n selected: false\n title: End 9\n type: end\n height: 89\n id: '1711513974643'\n position:\n x: 1211.5\n y: 1900.5\n positionAbsolute:\n x: 1211.5\n y: 1900.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n viewport:\n x: 0\n y: 0\n zoom: 0.7\n", + "export_data": "app:\n icon: \"\\U0001F916\"\n icon_background: '#FFEAD5'\n mode: workflow\n name: 'Email Assistant Workflow '\nworkflow:\n features:\n file_upload:\n image:\n enabled: false\n number_limits: 3\n transfer_methods:\n - local_file\n - remote_url\n opening_statement: ''\n retriever_resource:\n enabled: false\n sensitive_word_avoidance:\n enabled: false\n speech_to_text:\n enabled: false\n suggested_questions: []\n suggested_questions_after_answer:\n enabled: false\n text_to_speech:\n enabled: false\n language: ''\n voice: ''\n graph:\n edges:\n - data:\n sourceType: start\n targetType: question-classifier\n id: 1711511281652-1711512802873\n source: '1711511281652'\n sourceHandle: source\n target: '1711512802873'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: question-classifier\n id: 1711512802873-1711512837494\n source: '1711512802873'\n sourceHandle: '1711512813038'\n target: '1711512837494'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512802873-1711512911454\n source: '1711512802873'\n sourceHandle: '1711512811520'\n target: '1711512911454'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512802873-1711512914870\n source: '1711512802873'\n sourceHandle: '1711512812031'\n target: '1711512914870'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512802873-1711512916516\n source: '1711512802873'\n sourceHandle: '1711512812510'\n target: '1711512916516'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512837494-1711512924231\n source: '1711512837494'\n sourceHandle: '1711512846439'\n target: '1711512924231'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512837494-1711512926020\n source: '1711512837494'\n sourceHandle: '1711512847112'\n target: '1711512926020'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512837494-1711512927569\n source: '1711512837494'\n sourceHandle: '1711512847641'\n target: '1711512927569'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512837494-1711512929190\n source: '1711512837494'\n sourceHandle: '1711512848120'\n target: '1711512929190'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512837494-1711512930700\n source: '1711512837494'\n sourceHandle: '1711512848616'\n target: '1711512930700'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512911454-1711513015189\n source: '1711512911454'\n sourceHandle: source\n target: '1711513015189'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512914870-1711513017096\n source: '1711512914870'\n sourceHandle: source\n target: '1711513017096'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512916516-1711513018759\n source: '1711512916516'\n sourceHandle: source\n target: '1711513018759'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512924231-1711513020857\n source: '1711512924231'\n sourceHandle: source\n target: '1711513020857'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512926020-1711513022516\n source: '1711512926020'\n sourceHandle: source\n target: '1711513022516'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512927569-1711513024315\n source: '1711512927569'\n sourceHandle: source\n target: '1711513024315'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512929190-1711513025732\n source: '1711512929190'\n sourceHandle: source\n target: '1711513025732'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711512930700-1711513027347\n source: '1711512930700'\n sourceHandle: source\n target: '1711513027347'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513015189-1711513029058\n source: '1711513015189'\n sourceHandle: source\n target: '1711513029058'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513017096-1711513030924\n source: '1711513017096'\n sourceHandle: source\n target: '1711513030924'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513018759-1711513032459\n source: '1711513018759'\n sourceHandle: source\n target: '1711513032459'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513020857-1711513034850\n source: '1711513020857'\n sourceHandle: source\n target: '1711513034850'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513022516-1711513036356\n source: '1711513022516'\n sourceHandle: source\n target: '1711513036356'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513024315-1711513037973\n source: '1711513024315'\n sourceHandle: source\n target: '1711513037973'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513025732-1711513039350\n source: '1711513025732'\n sourceHandle: source\n target: '1711513039350'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513027347-1711513041219\n source: '1711513027347'\n sourceHandle: source\n target: '1711513041219'\n targetHandle: target\n type: custom\n - data:\n sourceType: question-classifier\n targetType: llm\n id: 1711512802873-1711513940609\n source: '1711512802873'\n sourceHandle: '1711513927279'\n target: '1711513940609'\n targetHandle: target\n type: custom\n - data:\n sourceType: llm\n targetType: template-transform\n id: 1711513940609-1711513967853\n source: '1711513940609'\n sourceHandle: source\n target: '1711513967853'\n targetHandle: target\n type: custom\n - data:\n sourceType: template-transform\n targetType: end\n id: 1711513967853-1711513974643\n source: '1711513967853'\n sourceHandle: source\n target: '1711513974643'\n targetHandle: target\n type: custom\n nodes:\n - data:\n desc: ''\n selected: true\n title: Start\n type: start\n variables:\n - label: Email\n max_length: null\n options: []\n required: true\n type: paragraph\n variable: Input_Text\n - label: What do you need to do? (Summarize / Reply / Write / Improve)\n max_length: 48\n options:\n - Summarize\n - 'Reply '\n - Write a email\n - 'Improve writings '\n required: true\n type: select\n variable: user_request\n - label: 'How do you want it to be polished? (Optional) '\n max_length: 48\n options:\n - 'Imporve writing and clarity '\n - Shorten\n - 'Lengthen '\n - 'Simplify '\n - Rewrite in my voice\n required: false\n type: select\n variable: how_polish\n dragging: false\n height: 141\n id: '1711511281652'\n position:\n x: 79.5\n y: 409.5\n positionAbsolute:\n x: 79.5\n y: 409.5\n selected: true\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n classes:\n - id: '1711512811520'\n name: Summarize\n - id: '1711512812031'\n name: Reply to emails\n - id: '1711512812510'\n name: Help me write the email\n - id: '1711512813038'\n name: Improve writings or polish\n - id: '1711513927279'\n name: Grammar check\n desc: 'Classify users'' demands. '\n instructions: ''\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n query_variable_selector:\n - '1711511281652'\n - user_request\n selected: false\n title: 'Question Classifier '\n topics: []\n type: question-classifier\n dragging: false\n height: 333\n id: '1711512802873'\n position:\n x: 362.5\n y: 409.5\n positionAbsolute:\n x: 362.5\n y: 409.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n classes:\n - id: '1711512846439'\n name: 'Improve writing and clarity '\n - id: '1711512847112'\n name: 'Shorten '\n - id: '1711512847641'\n name: 'Lengthen '\n - id: '1711512848120'\n name: 'Simplify '\n - id: '1711512848616'\n name: Rewrite in my voice\n desc: 'Improve writings. '\n instructions: ''\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n query_variable_selector:\n - '1711511281652'\n - how_polish\n selected: false\n title: 'Question Classifier '\n topics: []\n type: question-classifier\n dragging: false\n height: 333\n id: '1711512837494'\n position:\n x: 645.5\n y: 409.5\n positionAbsolute:\n x: 645.5\n y: 409.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Summary\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Summary the email for me. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512911454'\n position:\n x: 645.5\n y: 1327.5\n positionAbsolute:\n x: 645.5\n y: 1327.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Reply\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Rely the emails for me, in my own voice. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512914870'\n position:\n x: 645.5\n y: 1518.5\n positionAbsolute:\n x: 645.5\n y: 1518.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Turn idea into email\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Turn my idea into email. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512916516'\n position:\n x: 645.5\n y: 1709.5\n positionAbsolute:\n x: 645.5\n y: 1709.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: 'Improve the clarity. '\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: \" Imporve the clarity of the email for me. \\n{{#1711511281652.Input_Text#}}\\n\\\n \"\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512924231'\n position:\n x: 928.5\n y: 409.5\n positionAbsolute:\n x: 928.5\n y: 409.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: 'Shorten. '\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Shorten the email for me. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512926020'\n position:\n x: 928.5\n y: 600.5\n positionAbsolute:\n x: 928.5\n y: 600.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: 'Lengthen '\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Lengthen the email for me. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512927569'\n position:\n x: 928.5\n y: 791.5\n positionAbsolute:\n x: 928.5\n y: 791.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Simplify\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Simplify the email for me. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512929190'\n position:\n x: 928.5\n y: 982.5\n positionAbsolute:\n x: 928.5\n y: 982.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Rewrite in my voice\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: ' Rewrite the email for me. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711512930700'\n position:\n x: 928.5\n y: 1173.5\n positionAbsolute:\n x: 928.5\n y: 1173.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template\n type: template-transform\n variables:\n - value_selector:\n - '1711512911454'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513015189'\n position:\n x: 928.5\n y: 1327.5\n positionAbsolute:\n x: 928.5\n y: 1327.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 2\n type: template-transform\n variables:\n - value_selector:\n - '1711512914870'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513017096'\n position:\n x: 928.5\n y: 1518.5\n positionAbsolute:\n x: 928.5\n y: 1518.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 3\n type: template-transform\n variables:\n - value_selector:\n - '1711512916516'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513018759'\n position:\n x: 928.5\n y: 1709.5\n positionAbsolute:\n x: 928.5\n y: 1709.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 4\n type: template-transform\n variables:\n - value_selector:\n - '1711512924231'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513020857'\n position:\n x: 1211.5\n y: 409.5\n positionAbsolute:\n x: 1211.5\n y: 409.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 5\n type: template-transform\n variables:\n - value_selector:\n - '1711512926020'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513022516'\n position:\n x: 1211.5\n y: 600.5\n positionAbsolute:\n x: 1211.5\n y: 600.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 6\n type: template-transform\n variables:\n - value_selector:\n - '1711512927569'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513024315'\n position:\n x: 1211.5\n y: 791.5\n positionAbsolute:\n x: 1211.5\n y: 791.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 7\n type: template-transform\n variables:\n - value_selector:\n - '1711512929190'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513025732'\n position:\n x: 1211.5\n y: 982.5\n positionAbsolute:\n x: 1211.5\n y: 982.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 8\n type: template-transform\n variables:\n - value_selector:\n - '1711512930700'\n - text\n variable: arg1\n dragging: false\n height: 53\n id: '1711513027347'\n position:\n x: 1211.5\n y: 1173.5\n positionAbsolute:\n x: 1211.5\n y: 1173.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512911454'\n - text\n variable: text\n selected: false\n title: End\n type: end\n dragging: false\n height: 89\n id: '1711513029058'\n position:\n x: 1211.5\n y: 1327.5\n positionAbsolute:\n x: 1211.5\n y: 1327.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512914870'\n - text\n variable: text\n selected: false\n title: End 2\n type: end\n dragging: false\n height: 89\n id: '1711513030924'\n position:\n x: 1211.5\n y: 1518.5\n positionAbsolute:\n x: 1211.5\n y: 1518.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512916516'\n - text\n variable: text\n selected: false\n title: End 3\n type: end\n dragging: false\n height: 89\n id: '1711513032459'\n position:\n x: 1211.5\n y: 1709.5\n positionAbsolute:\n x: 1211.5\n y: 1709.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512924231'\n - text\n variable: text\n selected: false\n title: End 4\n type: end\n dragging: false\n height: 89\n id: '1711513034850'\n position:\n x: 1494.5\n y: 409.5\n positionAbsolute:\n x: 1494.5\n y: 409.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512926020'\n - text\n variable: text\n selected: false\n title: End 5\n type: end\n dragging: false\n height: 89\n id: '1711513036356'\n position:\n x: 1494.5\n y: 600.5\n positionAbsolute:\n x: 1494.5\n y: 600.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512927569'\n - text\n variable: text\n selected: false\n title: End 6\n type: end\n dragging: false\n height: 89\n id: '1711513037973'\n position:\n x: 1494.5\n y: 791.5\n positionAbsolute:\n x: 1494.5\n y: 791.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512929190'\n - text\n variable: text\n selected: false\n title: End 7\n type: end\n dragging: false\n height: 89\n id: '1711513039350'\n position:\n x: 1494.5\n y: 982.5\n positionAbsolute:\n x: 1494.5\n y: 982.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711512930700'\n - text\n variable: text\n selected: false\n title: End 8\n type: end\n dragging: false\n height: 89\n id: '1711513041219'\n position:\n x: 1494.5\n y: 1173.5\n positionAbsolute:\n x: 1494.5\n y: 1173.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n context:\n enabled: false\n variable_selector: []\n desc: Grammar Check\n model:\n completion_params:\n frequency_penalty: 0\n max_tokens: 512\n presence_penalty: 0\n temperature: 0.7\n top_p: 1\n mode: chat\n name: gpt-3.5-turbo\n provider: openai\n prompt_template:\n - role: system\n text: 'Please check grammar of my email and comment on the grammar. {{#1711511281652.Input_Text#}}\n\n '\n selected: false\n title: LLM\n type: llm\n variables: []\n vision:\n enabled: false\n dragging: false\n height: 127\n id: '1711513940609'\n position:\n x: 645.5\n y: 1900.5\n positionAbsolute:\n x: 645.5\n y: 1900.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n selected: false\n template: '{{ arg1 }}'\n title: Template 9\n type: template-transform\n variables:\n - value_selector:\n - '1711513940609'\n - text\n variable: arg1\n height: 53\n id: '1711513967853'\n position:\n x: 928.5\n y: 1900.5\n positionAbsolute:\n x: 928.5\n y: 1900.5\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n - data:\n desc: ''\n outputs:\n - value_selector:\n - '1711513940609'\n - text\n variable: text\n selected: false\n title: End 9\n type: end\n height: 89\n id: '1711513974643'\n position:\n x: 1211.5\n y: 1900.5\n positionAbsolute:\n x: 1211.5\n y: 1900.5\n selected: false\n sourcePosition: right\n targetPosition: left\n type: custom\n width: 243\n viewport:\n x: 0\n y: 0\n zoom: 0.7\n", "icon": "🤖", "icon_background": "#FFEAD5", "id": "dd5b6353-ae9b-4bce-be6a-a681a12cf709", diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index 623a1a28eb731e..85380b73304043 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -1,7 +1,9 @@ from contextvars import ContextVar +from typing import TYPE_CHECKING -from core.workflow.entities.variable_pool import VariablePool +if TYPE_CHECKING: + from core.workflow.entities.variable_pool import VariablePool tenant_id: ContextVar[str] = ContextVar("tenant_id") -workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool") +workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool") diff --git a/api/controllers/__init__.py b/api/controllers/__init__.py index b28b04f643122b..8b137891791fe9 100644 --- a/api/controllers/__init__.py +++ b/api/controllers/__init__.py @@ -1,3 +1 @@ - - diff --git a/api/controllers/common/errors.py b/api/controllers/common/errors.py new file mode 100644 index 00000000000000..c71f1ce5a31027 --- /dev/null +++ b/api/controllers/common/errors.py @@ -0,0 +1,6 @@ +from werkzeug.exceptions import HTTPException + + +class FilenameNotExistsError(HTTPException): + code = 400 + description = "The specified filename does not exist." diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py new file mode 100644 index 00000000000000..79869916eda062 --- /dev/null +++ b/api/controllers/common/fields.py @@ -0,0 +1,24 @@ +from flask_restful import fields + +parameters__system_parameters = { + "image_file_size_limit": fields.Integer, + "video_file_size_limit": fields.Integer, + "audio_file_size_limit": fields.Integer, + "file_size_limit": fields.Integer, + "workflow_file_upload_limit": fields.Integer, +} + +parameters_fields = { + "opening_statement": fields.String, + "suggested_questions": fields.Raw, + "suggested_questions_after_answer": fields.Raw, + "speech_to_text": fields.Raw, + "text_to_speech": fields.Raw, + "retriever_resource": fields.Raw, + "annotation_reply": fields.Raw, + "more_like_this": fields.Raw, + "user_input_form": fields.Raw, + "sensitive_word_avoidance": fields.Raw, + "file_upload": fields.Raw, + "system_parameters": fields.Nested(parameters__system_parameters), +} diff --git a/api/controllers/common/helpers.py b/api/controllers/common/helpers.py new file mode 100644 index 00000000000000..2bae2037126835 --- /dev/null +++ b/api/controllers/common/helpers.py @@ -0,0 +1,97 @@ +import mimetypes +import os +import re +import urllib.parse +from collections.abc import Mapping +from typing import Any +from uuid import uuid4 + +import httpx +from pydantic import BaseModel + +from configs import dify_config + + +class FileInfo(BaseModel): + filename: str + extension: str + mimetype: str + size: int + + +def guess_file_info_from_response(response: httpx.Response): + url = str(response.url) + # Try to extract filename from URL + parsed_url = urllib.parse.urlparse(url) + url_path = parsed_url.path + filename = os.path.basename(url_path) + + # If filename couldn't be extracted, use Content-Disposition header + if not filename: + content_disposition = response.headers.get("Content-Disposition") + if content_disposition: + filename_match = re.search(r'filename="?(.+)"?', content_disposition) + if filename_match: + filename = filename_match.group(1) + + # If still no filename, generate a unique one + if not filename: + unique_name = str(uuid4()) + filename = f"{unique_name}" + + # Guess MIME type from filename first, then URL + mimetype, _ = mimetypes.guess_type(filename) + if mimetype is None: + mimetype, _ = mimetypes.guess_type(url) + if mimetype is None: + # If guessing fails, use Content-Type from response headers + mimetype = response.headers.get("Content-Type", "application/octet-stream") + + extension = os.path.splitext(filename)[1] + + # Ensure filename has an extension + if not extension: + extension = mimetypes.guess_extension(mimetype) or ".bin" + filename = f"{filename}{extension}" + + return FileInfo( + filename=filename, + extension=extension, + mimetype=mimetype, + size=int(response.headers.get("Content-Length", -1)), + ) + + +def get_parameters_from_feature_dict(*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]): + return { + "opening_statement": features_dict.get("opening_statement"), + "suggested_questions": features_dict.get("suggested_questions", []), + "suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}), + "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), + "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), + "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), + "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), + "more_like_this": features_dict.get("more_like_this", {"enabled": False}), + "user_input_form": user_input_form, + "sensitive_word_avoidance": features_dict.get( + "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} + ), + "file_upload": features_dict.get( + "file_upload", + { + "image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"], + } + }, + ), + "system_parameters": { + "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, + "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, + "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, + "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, + "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, + }, + } diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index b2b9d8d4967927..8a5c2e5b8fad13 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -2,9 +2,21 @@ from libs.external_api import ExternalApi -bp = Blueprint('console', __name__, url_prefix='/console/api') +from .files import FileApi, FilePreviewApi, FileSupportTypeApi +from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi + +bp = Blueprint("console", __name__, url_prefix="/console/api") api = ExternalApi(bp) +# File +api.add_resource(FileApi, "/files/upload") +api.add_resource(FilePreviewApi, "/files//preview") +api.add_resource(FileSupportTypeApi, "/files/support-type") + +# Remote files +api.add_resource(RemoteFileInfoApi, "/remote-files/") +api.add_resource(RemoteFileUploadApi, "/remote-files/upload") + # Import other controllers from . import admin, apikey, extension, feature, ping, setup, version @@ -37,7 +49,15 @@ from .billing import billing # Import datasets controllers -from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing, website +from .datasets import ( + data_source, + datasets, + datasets_document, + datasets_segments, + external, + hit_testing, + website, +) # Import explore controllers from .explore import ( diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 028be5de548b7d..a70c4a31c7db94 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -1,10 +1,10 @@ -import os from functools import wraps from flask import request from flask_restful import Resource, reqparse from werkzeug.exceptions import NotFound, Unauthorized +from configs import dify_config from constants.languages import supported_language from controllers.console import api from controllers.console.wraps import only_edition_cloud @@ -15,24 +15,24 @@ def admin_required(view): @wraps(view) def decorated(*args, **kwargs): - if not os.getenv('ADMIN_API_KEY'): - raise Unauthorized('API key is invalid.') + if not dify_config.ADMIN_API_KEY: + raise Unauthorized("API key is invalid.") - auth_header = request.headers.get('Authorization') + auth_header = request.headers.get("Authorization") if auth_header is None: - raise Unauthorized('Authorization header is missing.') + raise Unauthorized("Authorization header is missing.") - if ' ' not in auth_header: - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != 'bearer': - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - if os.getenv('ADMIN_API_KEY') != auth_token: - raise Unauthorized('API key is invalid.') + if dify_config.ADMIN_API_KEY != auth_token: + raise Unauthorized("API key is invalid.") return view(*args, **kwargs) @@ -44,37 +44,33 @@ class InsertExploreAppListApi(Resource): @admin_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('app_id', type=str, required=True, nullable=False, location='json') - parser.add_argument('desc', type=str, location='json') - parser.add_argument('copyright', type=str, location='json') - parser.add_argument('privacy_policy', type=str, location='json') - parser.add_argument('custom_disclaimer', type=str, location='json') - parser.add_argument('language', type=supported_language, required=True, nullable=False, location='json') - parser.add_argument('category', type=str, required=True, nullable=False, location='json') - parser.add_argument('position', type=int, required=True, nullable=False, location='json') + parser.add_argument("app_id", type=str, required=True, nullable=False, location="json") + parser.add_argument("desc", type=str, location="json") + parser.add_argument("copyright", type=str, location="json") + parser.add_argument("privacy_policy", type=str, location="json") + parser.add_argument("custom_disclaimer", type=str, location="json") + parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json") + parser.add_argument("category", type=str, required=True, nullable=False, location="json") + parser.add_argument("position", type=int, required=True, nullable=False, location="json") args = parser.parse_args() - app = App.query.filter(App.id == args['app_id']).first() + app = App.query.filter(App.id == args["app_id"]).first() if not app: raise NotFound(f'App \'{args["app_id"]}\' is not found') site = app.site if not site: - desc = args['desc'] if args['desc'] else '' - copy_right = args['copyright'] if args['copyright'] else '' - privacy_policy = args['privacy_policy'] if args['privacy_policy'] else '' - custom_disclaimer = args['custom_disclaimer'] if args['custom_disclaimer'] else '' + desc = args["desc"] or "" + copy_right = args["copyright"] or "" + privacy_policy = args["privacy_policy"] or "" + custom_disclaimer = args["custom_disclaimer"] or "" else: - desc = site.description if site.description else \ - args['desc'] if args['desc'] else '' - copy_right = site.copyright if site.copyright else \ - args['copyright'] if args['copyright'] else '' - privacy_policy = site.privacy_policy if site.privacy_policy else \ - args['privacy_policy'] if args['privacy_policy'] else '' - custom_disclaimer = site.custom_disclaimer if site.custom_disclaimer else \ - args['custom_disclaimer'] if args['custom_disclaimer'] else '' + desc = site.description or args["desc"] or "" + copy_right = site.copyright or args["copyright"] or "" + privacy_policy = site.privacy_policy or args["privacy_policy"] or "" + custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or "" - recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first() + recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() if not recommended_app: recommended_app = RecommendedApp( @@ -83,9 +79,9 @@ def post(self): copyright=copy_right, privacy_policy=privacy_policy, custom_disclaimer=custom_disclaimer, - language=args['language'], - category=args['category'], - position=args['position'] + language=args["language"], + category=args["category"], + position=args["position"], ) db.session.add(recommended_app) @@ -93,21 +89,21 @@ def post(self): app.is_public = True db.session.commit() - return {'result': 'success'}, 201 + return {"result": "success"}, 201 else: recommended_app.description = desc recommended_app.copyright = copy_right recommended_app.privacy_policy = privacy_policy recommended_app.custom_disclaimer = custom_disclaimer - recommended_app.language = args['language'] - recommended_app.category = args['category'] - recommended_app.position = args['position'] + recommended_app.language = args["language"] + recommended_app.category = args["category"] + recommended_app.position = args["position"] app.is_public = True db.session.commit() - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class InsertExploreAppApi(Resource): @@ -116,15 +112,14 @@ class InsertExploreAppApi(Resource): def delete(self, app_id): recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first() if not recommended_app: - return {'result': 'success'}, 204 + return {"result": "success"}, 204 app = App.query.filter(App.id == recommended_app.app_id).first() if app: app.is_public = False installed_apps = InstalledApp.query.filter( - InstalledApp.app_id == recommended_app.app_id, - InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id + InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id ).all() for installed_app in installed_apps: @@ -133,8 +128,8 @@ def delete(self, app_id): db.session.delete(recommended_app) db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 -api.add_resource(InsertExploreAppListApi, '/admin/insert-explore-apps') -api.add_resource(InsertExploreAppApi, '/admin/insert-explore-apps/') +api.add_resource(InsertExploreAppListApi, "/admin/insert-explore-apps") +api.add_resource(InsertExploreAppApi, "/admin/insert-explore-apps/") diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 324b8311752898..953770868904d3 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -10,30 +10,24 @@ from models.model import ApiToken, App from . import api -from .setup import setup_required -from .wraps import account_initialization_required +from .wraps import account_initialization_required, setup_required api_key_fields = { - 'id': fields.String, - 'type': fields.String, - 'token': fields.String, - 'last_used_at': TimestampField, - 'created_at': TimestampField + "id": fields.String, + "type": fields.String, + "token": fields.String, + "last_used_at": TimestampField, + "created_at": TimestampField, } -api_key_list = { - 'data': fields.List(fields.Nested(api_key_fields), attribute="items") -} +api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")} def _get_resource(resource_id, tenant_id, resource_model): - resource = resource_model.query.filter_by( - id=resource_id, tenant_id=tenant_id - ).first() + resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first() if resource is None: - flask_restful.abort( - 404, message=f"{resource_model.__name__} not found.") + flask_restful.abort(404, message=f"{resource_model.__name__} not found.") return resource @@ -50,30 +44,32 @@ class BaseApiKeyListResource(Resource): @marshal_with(api_key_list) def get(self, resource_id): resource_id = str(resource_id) - _get_resource(resource_id, current_user.current_tenant_id, - self.resource_model) - keys = db.session.query(ApiToken). \ - filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \ - all() + _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) + keys = ( + db.session.query(ApiToken) + .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) + .all() + ) return {"items": keys} @marshal_with(api_key_fields) def post(self, resource_id): resource_id = str(resource_id) - _get_resource(resource_id, current_user.current_tenant_id, - self.resource_model) - if not current_user.is_admin_or_owner: + _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) + if not current_user.is_editor: raise Forbidden() - current_key_count = db.session.query(ApiToken). \ - filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \ - count() + current_key_count = ( + db.session.query(ApiToken) + .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) + .count() + ) if current_key_count >= self.max_keys: flask_restful.abort( 400, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", - code='max_keys_exceeded' + code="max_keys_exceeded", ) key = ApiToken.generate_api_key(self.token_prefix, 24) @@ -97,79 +93,78 @@ class BaseApiKeyResource(Resource): def delete(self, resource_id, api_key_id): resource_id = str(resource_id) api_key_id = str(api_key_id) - _get_resource(resource_id, current_user.current_tenant_id, - self.resource_model) + _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() - key = db.session.query(ApiToken). \ - filter(getattr(ApiToken, self.resource_id_field) == resource_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id). \ - first() + key = ( + db.session.query(ApiToken) + .filter( + getattr(ApiToken, self.resource_id_field) == resource_id, + ApiToken.type == self.resource_type, + ApiToken.id == api_key_id, + ) + .first() + ) if key is None: - flask_restful.abort(404, message='API key not found') + flask_restful.abort(404, message="API key not found") db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class AppApiKeyListResource(BaseApiKeyListResource): - def after_request(self, resp): - resp.headers['Access-Control-Allow-Origin'] = '*' - resp.headers['Access-Control-Allow-Credentials'] = 'true' + resp.headers["Access-Control-Allow-Origin"] = "*" + resp.headers["Access-Control-Allow-Credentials"] = "true" return resp - resource_type = 'app' + resource_type = "app" resource_model = App - resource_id_field = 'app_id' - token_prefix = 'app-' + resource_id_field = "app_id" + token_prefix = "app-" class AppApiKeyResource(BaseApiKeyResource): - def after_request(self, resp): - resp.headers['Access-Control-Allow-Origin'] = '*' - resp.headers['Access-Control-Allow-Credentials'] = 'true' + resp.headers["Access-Control-Allow-Origin"] = "*" + resp.headers["Access-Control-Allow-Credentials"] = "true" return resp - resource_type = 'app' + resource_type = "app" resource_model = App - resource_id_field = 'app_id' + resource_id_field = "app_id" class DatasetApiKeyListResource(BaseApiKeyListResource): - def after_request(self, resp): - resp.headers['Access-Control-Allow-Origin'] = '*' - resp.headers['Access-Control-Allow-Credentials'] = 'true' + resp.headers["Access-Control-Allow-Origin"] = "*" + resp.headers["Access-Control-Allow-Credentials"] = "true" return resp - resource_type = 'dataset' + resource_type = "dataset" resource_model = Dataset - resource_id_field = 'dataset_id' - token_prefix = 'ds-' + resource_id_field = "dataset_id" + token_prefix = "ds-" class DatasetApiKeyResource(BaseApiKeyResource): - def after_request(self, resp): - resp.headers['Access-Control-Allow-Origin'] = '*' - resp.headers['Access-Control-Allow-Credentials'] = 'true' + resp.headers["Access-Control-Allow-Origin"] = "*" + resp.headers["Access-Control-Allow-Credentials"] = "true" return resp - resource_type = 'dataset' + + resource_type = "dataset" resource_model = Dataset - resource_id_field = 'dataset_id' + resource_id_field = "dataset_id" -api.add_resource(AppApiKeyListResource, '/apps//api-keys') -api.add_resource(AppApiKeyResource, - '/apps//api-keys/') -api.add_resource(DatasetApiKeyListResource, - '/datasets//api-keys') -api.add_resource(DatasetApiKeyResource, - '/datasets//api-keys/') +api.add_resource(AppApiKeyListResource, "/apps//api-keys") +api.add_resource(AppApiKeyResource, "/apps//api-keys/") +api.add_resource(DatasetApiKeyListResource, "/datasets//api-keys") +api.add_resource(DatasetApiKeyResource, "/datasets//api-keys/") diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index fa2b3807e82778..c228743fa53591 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -1,26 +1,24 @@ from flask_restful import Resource, reqparse from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required from services.advanced_prompt_template_service import AdvancedPromptTemplateService class AdvancedPromptTemplateList(Resource): - @setup_required @login_required @account_initialization_required def get(self): - parser = reqparse.RequestParser() - parser.add_argument('app_mode', type=str, required=True, location='args') - parser.add_argument('model_mode', type=str, required=True, location='args') - parser.add_argument('has_context', type=str, required=False, default='true', location='args') - parser.add_argument('model_name', type=str, required=True, location='args') + parser.add_argument("app_mode", type=str, required=True, location="args") + parser.add_argument("model_mode", type=str, required=True, location="args") + parser.add_argument("has_context", type=str, required=False, default="true", location="args") + parser.add_argument("model_name", type=str, required=True, location="args") args = parser.parse_args() return AdvancedPromptTemplateService.get_prompt(args) -api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates') \ No newline at end of file + +api.add_resource(AdvancedPromptTemplateList, "/app/prompt-templates") diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index aee367276c0777..d4334158945e16 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -2,8 +2,7 @@ from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from libs.helper import uuid_value from libs.login import login_required from models.model import AppMode @@ -18,15 +17,12 @@ class AgentLogApi(Resource): def get(self, app_model): """Get agent logs""" parser = reqparse.RequestParser() - parser.add_argument('message_id', type=uuid_value, required=True, location='args') - parser.add_argument('conversation_id', type=uuid_value, required=True, location='args') + parser.add_argument("message_id", type=uuid_value, required=True, location="args") + parser.add_argument("conversation_id", type=uuid_value, required=True, location="args") args = parser.parse_args() - return AgentService.get_agent_logs( - app_model, - args['conversation_id'], - args['message_id'] - ) - -api.add_resource(AgentLogApi, '/apps//agent/logs') \ No newline at end of file + return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"]) + + +api.add_resource(AgentLogApi, "/apps//agent/logs") diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index bc15919a992cfd..fd05cbc19bf04f 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -6,8 +6,11 @@ from controllers.console import api from controllers.console.app.error import NoFileUploadedError from controllers.console.datasets.error import TooManyFilesError -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_resource_check, + setup_required, +) from extensions.ext_redis import redis_client from fields.annotation_fields import ( annotation_fields, @@ -21,23 +24,23 @@ class AnnotationReplyActionApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") def post(self, app_id, action): if not current_user.is_editor: raise Forbidden() app_id = str(app_id) parser = reqparse.RequestParser() - parser.add_argument('score_threshold', required=True, type=float, location='json') - parser.add_argument('embedding_provider_name', required=True, type=str, location='json') - parser.add_argument('embedding_model_name', required=True, type=str, location='json') + parser.add_argument("score_threshold", required=True, type=float, location="json") + parser.add_argument("embedding_provider_name", required=True, type=str, location="json") + parser.add_argument("embedding_model_name", required=True, type=str, location="json") args = parser.parse_args() - if action == 'enable': + if action == "enable": result = AppAnnotationService.enable_app_annotation(args, app_id) - elif action == 'disable': + elif action == "disable": result = AppAnnotationService.disable_app_annotation(app_id) else: - raise ValueError('Unsupported annotation reply action') + raise ValueError("Unsupported annotation reply action") return result, 200 @@ -66,7 +69,7 @@ def post(self, app_id, annotation_setting_id): annotation_setting_id = str(annotation_setting_id) parser = reqparse.RequestParser() - parser.add_argument('score_threshold', required=True, type=float, location='json') + parser.add_argument("score_threshold", required=True, type=float, location="json") args = parser.parse_args() result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args) @@ -77,28 +80,24 @@ class AnnotationReplyActionStatusApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") def get(self, app_id, job_id, action): if not current_user.is_editor: raise Forbidden() job_id = str(job_id) - app_annotation_job_key = '{}_app_annotation_job_{}'.format(action, str(job_id)) + app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id)) cache_result = redis_client.get(app_annotation_job_key) if cache_result is None: raise ValueError("The job is not exist.") job_status = cache_result.decode() - error_msg = '' - if job_status == 'error': - app_annotation_error_key = '{}_app_annotation_error_{}'.format(action, str(job_id)) + error_msg = "" + if job_status == "error": + app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id)) error_msg = redis_client.get(app_annotation_error_key).decode() - return { - 'job_id': job_id, - 'job_status': job_status, - 'error_msg': error_msg - }, 200 + return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 class AnnotationListApi(Resource): @@ -109,18 +108,18 @@ def get(self, app_id): if not current_user.is_editor: raise Forbidden() - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - keyword = request.args.get('keyword', default=None, type=str) + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + keyword = request.args.get("keyword", default=None, type=str) app_id = str(app_id) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) response = { - 'data': marshal(annotation_list, annotation_fields), - 'has_more': len(annotation_list) == limit, - 'limit': limit, - 'total': total, - 'page': page + "data": marshal(annotation_list, annotation_fields), + "has_more": len(annotation_list) == limit, + "limit": limit, + "total": total, + "page": page, } return response, 200 @@ -135,9 +134,7 @@ def get(self, app_id): app_id = str(app_id) annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) - response = { - 'data': marshal(annotation_list, annotation_fields) - } + response = {"data": marshal(annotation_list, annotation_fields)} return response, 200 @@ -145,7 +142,7 @@ class AnnotationCreateApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") @marshal_with(annotation_fields) def post(self, app_id): if not current_user.is_editor: @@ -153,8 +150,8 @@ def post(self, app_id): app_id = str(app_id) parser = reqparse.RequestParser() - parser.add_argument('question', required=True, type=str, location='json') - parser.add_argument('answer', required=True, type=str, location='json') + parser.add_argument("question", required=True, type=str, location="json") + parser.add_argument("answer", required=True, type=str, location="json") args = parser.parse_args() annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id) return annotation @@ -164,7 +161,7 @@ class AnnotationUpdateDeleteApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") @marshal_with(annotation_fields) def post(self, app_id, annotation_id): if not current_user.is_editor: @@ -173,8 +170,8 @@ def post(self, app_id, annotation_id): app_id = str(app_id) annotation_id = str(annotation_id) parser = reqparse.RequestParser() - parser.add_argument('question', required=True, type=str, location='json') - parser.add_argument('answer', required=True, type=str, location='json') + parser.add_argument("question", required=True, type=str, location="json") + parser.add_argument("answer", required=True, type=str, location="json") args = parser.parse_args() annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id) return annotation @@ -189,29 +186,29 @@ def delete(self, app_id, annotation_id): app_id = str(app_id) annotation_id = str(annotation_id) AppAnnotationService.delete_app_annotation(app_id, annotation_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class AnnotationBatchImportApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") def post(self, app_id): if not current_user.is_editor: raise Forbidden() app_id = str(app_id) # get file from request - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() # check file type - if not file.filename.endswith('.csv'): + if not file.filename.endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") return AppAnnotationService.batch_import_app_annotations(app_id, file) @@ -220,27 +217,23 @@ class AnnotationBatchImportStatusApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") def get(self, app_id, job_id): if not current_user.is_editor: raise Forbidden() job_id = str(job_id) - indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id)) + indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) cache_result = redis_client.get(indexing_cache_key) if cache_result is None: raise ValueError("The job is not exist.") job_status = cache_result.decode() - error_msg = '' - if job_status == 'error': - indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id)) + error_msg = "" + if job_status == "error": + indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id)) error_msg = redis_client.get(indexing_error_msg_key).decode() - return { - 'job_id': job_id, - 'job_status': job_status, - 'error_msg': error_msg - }, 200 + return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 class AnnotationHitHistoryListApi(Resource): @@ -251,30 +244,32 @@ def get(self, app_id, annotation_id): if not current_user.is_editor: raise Forbidden() - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) app_id = str(app_id) annotation_id = str(annotation_id) - annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(app_id, annotation_id, - page, limit) + annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories( + app_id, annotation_id, page, limit + ) response = { - 'data': marshal(annotation_hit_history_list, annotation_hit_history_fields), - 'has_more': len(annotation_hit_history_list) == limit, - 'limit': limit, - 'total': total, - 'page': page + "data": marshal(annotation_hit_history_list, annotation_hit_history_fields), + "has_more": len(annotation_hit_history_list) == limit, + "limit": limit, + "total": total, + "page": page, } return response -api.add_resource(AnnotationReplyActionApi, '/apps//annotation-reply/') -api.add_resource(AnnotationReplyActionStatusApi, - '/apps//annotation-reply//status/') -api.add_resource(AnnotationListApi, '/apps//annotations') -api.add_resource(AnnotationExportApi, '/apps//annotations/export') -api.add_resource(AnnotationUpdateDeleteApi, '/apps//annotations/') -api.add_resource(AnnotationBatchImportApi, '/apps//annotations/batch-import') -api.add_resource(AnnotationBatchImportStatusApi, '/apps//annotations/batch-import-status/') -api.add_resource(AnnotationHitHistoryListApi, '/apps//annotations//hit-histories') -api.add_resource(AppAnnotationSettingDetailApi, '/apps//annotation-setting') -api.add_resource(AppAnnotationSettingUpdateApi, '/apps//annotation-settings/') +api.add_resource(AnnotationReplyActionApi, "/apps//annotation-reply/") +api.add_resource( + AnnotationReplyActionStatusApi, "/apps//annotation-reply//status/" +) +api.add_resource(AnnotationListApi, "/apps//annotations") +api.add_resource(AnnotationExportApi, "/apps//annotations/export") +api.add_resource(AnnotationUpdateDeleteApi, "/apps//annotations/") +api.add_resource(AnnotationBatchImportApi, "/apps//annotations/batch-import") +api.add_resource(AnnotationBatchImportStatusApi, "/apps//annotations/batch-import-status/") +api.add_resource(AnnotationHitHistoryListApi, "/apps//annotations//hit-histories") +api.add_resource(AppAnnotationSettingDetailApi, "/apps//annotation-setting") +api.add_resource(AppAnnotationSettingUpdateApi, "/apps//annotation-settings/") diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 2f304b970c6050..36338cbd8a4c59 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -6,8 +6,11 @@ from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_resource_check, + setup_required, +) from core.ops.ops_trace_manager import OpsTraceManager from fields.app_fields import ( app_detail_fields, @@ -18,27 +21,35 @@ from services.app_dsl_service import AppDslService from services.app_service import AppService -ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion'] +ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] class AppListApi(Resource): - @setup_required @login_required @account_initialization_required def get(self): """Get app list""" + def uuid_list(value): try: - return [str(uuid.UUID(v)) for v in value.split(',')] + return [str(uuid.UUID(v)) for v in value.split(",")] except ValueError: abort(400, message="Invalid UUID format in tag_ids.") + parser = reqparse.RequestParser() - parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args') - parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent-chat', 'channel', 'all'], default='all', location='args', required=False) - parser.add_argument('name', type=str, location='args', required=False) - parser.add_argument('tag_ids', type=uuid_list, location='args', required=False) + parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") + parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") + parser.add_argument( + "mode", + type=str, + choices=["chat", "workflow", "agent-chat", "channel", "all"], + default="all", + location="args", + required=False, + ) + parser.add_argument("name", type=str, location="args", required=False) + parser.add_argument("tag_ids", type=uuid_list, location="args", required=False) args = parser.parse_args() @@ -46,7 +57,7 @@ def uuid_list(value): app_service = AppService() app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args) if not app_pagination: - return {'data': [], 'total': 0, 'page': 1, 'limit': 20, 'has_more': False} + return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} return marshal(app_pagination, app_pagination_fields) @@ -54,22 +65,23 @@ def uuid_list(value): @login_required @account_initialization_required @marshal_with(app_detail_fields) - @cloud_edition_billing_resource_check('apps') + @cloud_edition_billing_resource_check("apps") def post(self): """Create app""" parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("name", type=str, required=True, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if 'mode' not in args or args['mode'] is None: + if "mode" not in args or args["mode"] is None: raise BadRequest("mode is required") app_service = AppService() @@ -83,7 +95,7 @@ class AppImportApi(Resource): @login_required @account_initialization_required @marshal_with(app_detail_fields_with_site) - @cloud_edition_billing_resource_check('apps') + @cloud_edition_billing_resource_check("apps") def post(self): """Import app""" # The role of the current user in the ta table must be admin, owner, or editor @@ -91,18 +103,16 @@ def post(self): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('data', type=str, required=True, nullable=False, location='json') - parser.add_argument('name', type=str, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("data", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() app = AppDslService.import_and_create_new_app( - tenant_id=current_user.current_tenant_id, - data=args['data'], - args=args, - account=current_user + tenant_id=current_user.current_tenant_id, data=args["data"], args=args, account=current_user ) return app, 201 @@ -113,7 +123,7 @@ class AppImportFromUrlApi(Resource): @login_required @account_initialization_required @marshal_with(app_detail_fields_with_site) - @cloud_edition_billing_resource_check('apps') + @cloud_edition_billing_resource_check("apps") def post(self): """Import app from url""" # The role of the current user in the ta table must be admin, owner, or editor @@ -121,25 +131,21 @@ def post(self): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('url', type=str, required=True, nullable=False, location='json') - parser.add_argument('name', type=str, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("url", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() app = AppDslService.import_and_create_new_app_from_url( - tenant_id=current_user.current_tenant_id, - url=args['url'], - args=args, - account=current_user + tenant_id=current_user.current_tenant_id, url=args["url"], args=args, account=current_user ) return app, 201 class AppApi(Resource): - @setup_required @login_required @account_initialization_required @@ -163,13 +169,15 @@ def put(self, app_model): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, nullable=False, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') - parser.add_argument('max_active_requests', type=int, location='json') + parser.add_argument("name", type=str, required=True, nullable=False, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") + parser.add_argument("max_active_requests", type=int, location="json") + parser.add_argument("use_icon_as_answer_icon", type=bool, location="json") args = parser.parse_args() app_service = AppService() @@ -190,7 +198,7 @@ def delete(self, app_model): app_service = AppService() app_service.delete_app(app_model) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class AppCopyApi(Resource): @@ -206,18 +214,16 @@ def post(self, app_model): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('name', type=str, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("name", type=str, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() data = AppDslService.export_dsl(app_model=app_model, include_secret=True) app = AppDslService.import_and_create_new_app( - tenant_id=current_user.current_tenant_id, - data=data, - args=args, - account=current_user + tenant_id=current_user.current_tenant_id, data=data, args=args, account=current_user ) return app, 201 @@ -236,12 +242,10 @@ def get(self, app_model): # Add include_secret params parser = reqparse.RequestParser() - parser.add_argument('include_secret', type=inputs.boolean, default=False, location='args') + parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args") args = parser.parse_args() - return { - "data": AppDslService.export_dsl(app_model=app_model, include_secret=args['include_secret']) - } + return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])} class AppNameApi(Resource): @@ -254,13 +258,13 @@ def post(self, app_model): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_name(app_model, args.get('name')) + app_model = app_service.update_app_name(app_model, args.get("name")) return app_model @@ -275,14 +279,14 @@ def post(self, app_model): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_icon(app_model, args.get('icon'), args.get('icon_background')) + app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background")) return app_model @@ -297,13 +301,13 @@ def post(self, app_model): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('enable_site', type=bool, required=True, location='json') + parser.add_argument("enable_site", type=bool, required=True, location="json") args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_site_status(app_model, args.get('enable_site')) + app_model = app_service.update_app_site_status(app_model, args.get("enable_site")) return app_model @@ -318,13 +322,13 @@ def post(self, app_model): # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('enable_api', type=bool, required=True, location='json') + parser.add_argument("enable_api", type=bool, required=True, location="json") args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_api_status(app_model, args.get('enable_api')) + app_model = app_service.update_app_api_status(app_model, args.get("enable_api")) return app_model @@ -335,9 +339,7 @@ class AppTraceApi(Resource): @account_initialization_required def get(self, app_id): """Get app trace""" - app_trace_config = OpsTraceManager.get_app_tracing_config( - app_id=app_id - ) + app_trace_config = OpsTraceManager.get_app_tracing_config(app_id=app_id) return app_trace_config @@ -349,27 +351,27 @@ def post(self, app_id): if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('enabled', type=bool, required=True, location='json') - parser.add_argument('tracing_provider', type=str, required=True, location='json') + parser.add_argument("enabled", type=bool, required=True, location="json") + parser.add_argument("tracing_provider", type=str, required=True, location="json") args = parser.parse_args() OpsTraceManager.update_app_tracing_config( app_id=app_id, - enabled=args['enabled'], - tracing_provider=args['tracing_provider'], + enabled=args["enabled"], + tracing_provider=args["tracing_provider"], ) return {"result": "success"} -api.add_resource(AppListApi, '/apps') -api.add_resource(AppImportApi, '/apps/import') -api.add_resource(AppImportFromUrlApi, '/apps/import/url') -api.add_resource(AppApi, '/apps/') -api.add_resource(AppCopyApi, '/apps//copy') -api.add_resource(AppExportApi, '/apps//export') -api.add_resource(AppNameApi, '/apps//name') -api.add_resource(AppIconApi, '/apps//icon') -api.add_resource(AppSiteStatus, '/apps//site-enable') -api.add_resource(AppApiStatus, '/apps//api-enable') -api.add_resource(AppTraceApi, '/apps//trace') +api.add_resource(AppListApi, "/apps") +api.add_resource(AppImportApi, "/apps/import") +api.add_resource(AppImportFromUrlApi, "/apps/import/url") +api.add_resource(AppApi, "/apps/") +api.add_resource(AppCopyApi, "/apps//copy") +api.add_resource(AppExportApi, "/apps//export") +api.add_resource(AppNameApi, "/apps//name") +api.add_resource(AppIconApi, "/apps//icon") +api.add_resource(AppSiteStatus, "/apps//site-enable") +api.add_resource(AppApiStatus, "/apps//api-enable") +api.add_resource(AppTraceApi, "/apps//trace") diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 1de08afa4e08b9..112446613feaac 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -18,8 +18,7 @@ UnsupportedAudioTypeError, ) from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.login import login_required @@ -39,7 +38,7 @@ class ChatMessageAudioApi(Resource): @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def post(self, app_model): - file = request.files['file'] + file = request.files["file"] try: response = AudioService.transcript_asr( @@ -85,31 +84,27 @@ def post(self, app_model): try: parser = reqparse.RequestParser() - parser.add_argument('message_id', type=str, location='json') - parser.add_argument('text', type=str, location='json') - parser.add_argument('voice', type=str, location='json') - parser.add_argument('streaming', type=bool, location='json') + parser.add_argument("message_id", type=str, location="json") + parser.add_argument("text", type=str, location="json") + parser.add_argument("voice", type=str, location="json") + parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get('message_id', None) - text = args.get('text', None) - if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] - and app_model.workflow - and app_model.workflow.features_dict): - text_to_speech = app_model.workflow.features_dict.get('text_to_speech') - voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + message_id = args.get("message_id", None) + text = args.get("text", None) + if ( + app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} + and app_model.workflow + and app_model.workflow.features_dict + ): + text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + voice = args.get("voice") or text_to_speech.get("voice") else: try: - voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get( - 'voice') + voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") except Exception: voice = None - response = AudioService.transcript_tts( - app_model=app_model, - text=text, - message_id=message_id, - voice=voice - ) + response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice) return response except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") @@ -145,12 +140,12 @@ class TextModesApi(Resource): def get(self, app_model): try: parser = reqparse.RequestParser() - parser.add_argument('language', type=str, required=True, location='args') + parser.add_argument("language", type=str, required=True, location="args") args = parser.parse_args() response = AudioService.transcript_tts_voices( tenant_id=app_model.tenant_id, - language=args['language'], + language=args["language"], ) return response @@ -179,6 +174,6 @@ def get(self, app_model): raise InternalServerError() -api.add_resource(ChatMessageAudioApi, '/apps//audio-to-text') -api.add_resource(ChatMessageTextApi, '/apps//text-to-audio') -api.add_resource(TextModesApi, '/apps//text-to-audio/voices') +api.add_resource(ChatMessageAudioApi, "/apps//audio-to-text") +api.add_resource(ChatMessageTextApi, "/apps//text-to-audio") +api.add_resource(TextModesApi, "/apps//text-to-audio/voices") diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 61582536fdbe1d..9896fcaab8ad36 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -15,8 +15,8 @@ ProviderQuotaExceededError, ) from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ( @@ -31,37 +31,33 @@ from libs.login import login_required from models.model import AppMode from services.app_generate_service import AppGenerateService +from services.errors.llm import InvokeRateLimitError # define completion message api for user class CompletionMessageApi(Resource): - @setup_required @login_required @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) def post(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, location='json', default='') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('model_config', type=dict, required=True, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, location="json", default="") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("model_config", type=dict, required=True, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") args = parser.parse_args() - streaming = args['response_mode'] != 'blocking' - args['auto_generate_name'] = False + streaming = args["response_mode"] != "blocking" + args["auto_generate_name"] = False account = flask_login.current_user try: response = AppGenerateService.generate( - app_model=app_model, - user=account, - args=args, - invoke_from=InvokeFrom.DEBUGGER, - streaming=streaming + app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming ) return helper.compact_generate_response(response) @@ -97,7 +93,7 @@ def post(self, app_model, task_id): AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ChatMessageApi(Resource): @@ -107,27 +103,24 @@ class ChatMessageApi(Resource): @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) def post(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, required=True, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('model_config', type=dict, required=True, location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, required=True, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("model_config", type=dict, required=True, location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") args = parser.parse_args() - streaming = args['response_mode'] != 'blocking' - args['auto_generate_name'] = False + streaming = args["response_mode"] != "blocking" + args["auto_generate_name"] = False account = flask_login.current_user try: response = AppGenerateService.generate( - app_model=app_model, - user=account, - args=args, - invoke_from=InvokeFrom.DEBUGGER, - streaming=streaming + app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming ) return helper.compact_generate_response(response) @@ -144,6 +137,8 @@ def post(self, app_model): raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() + except InvokeRateLimitError as ex: + raise InvokeRateLimitHttpError(ex.description) except InvokeError as e: raise CompletionRequestError(e.description) except (ValueError, AppInvokeQuotaExceededError) as e: @@ -163,10 +158,10 @@ def post(self, app_model, task_id): AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(CompletionMessageApi, '/apps//completion-messages') -api.add_resource(CompletionMessageStopApi, '/apps//completion-messages//stop') -api.add_resource(ChatMessageApi, '/apps//chat-messages') -api.add_resource(ChatMessageStopApi, '/apps//chat-messages//stop') +api.add_resource(CompletionMessageApi, "/apps//completion-messages") +api.add_resource(CompletionMessageStopApi, "/apps//completion-messages//stop") +api.add_resource(ChatMessageApi, "/apps//chat-messages") +api.add_resource(ChatMessageStopApi, "/apps//chat-messages//stop") diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index eb61c83d4626f0..7b78f622b9a72a 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -10,8 +10,7 @@ from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import ( @@ -20,13 +19,13 @@ conversation_pagination_fields, conversation_with_summary_pagination_fields, ) -from libs.helper import datetime_string +from libs.helper import DatetimeString from libs.login import login_required -from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation +from models import Conversation, EndUser, Message, MessageAnnotation +from models.model import AppMode class CompletionConversationApi(Resource): - @setup_required @login_required @account_initialization_required @@ -36,24 +35,23 @@ def get(self, app_model): if not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('keyword', type=str, location='args') - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('annotation_status', type=str, - choices=['annotated', 'not_annotated', 'all'], default='all', location='args') - parser.add_argument('page', type=int_range(1, 99999), default=1, location='args') - parser.add_argument('limit', type=int_range(1, 100), default=20, location='args') + parser.add_argument("keyword", type=str, location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument( + "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" + ) + parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") + parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") args = parser.parse_args() - query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == 'completion') + query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion") - if args['keyword']: - query = query.join( - Message, Message.conversation_id == Conversation.id - ).filter( + if args["keyword"]: + query = query.join(Message, Message.conversation_id == Conversation.id).filter( or_( - Message.query.ilike('%{}%'.format(args['keyword'])), - Message.answer.ilike('%{}%'.format(args['keyword'])) + Message.query.ilike("%{}%".format(args["keyword"])), + Message.answer.ilike("%{}%".format(args["keyword"])), ) ) @@ -61,8 +59,8 @@ def get(self, app_model): timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) @@ -70,8 +68,8 @@ def get(self, app_model): query = query.where(Conversation.created_at >= start_datetime_utc) - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=59) end_datetime_timezone = timezone.localize(end_datetime) @@ -79,29 +77,25 @@ def get(self, app_model): query = query.where(Conversation.created_at < end_datetime_utc) - if args['annotation_status'] == "annotated": + if args["annotation_status"] == "annotated": query = query.options(joinedload(Conversation.message_annotations)).join( MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id ) - elif args['annotation_status'] == "not_annotated": - query = query.outerjoin( - MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id - ).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0) + elif args["annotation_status"] == "not_annotated": + query = ( + query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) + .group_by(Conversation.id) + .having(func.count(MessageAnnotation.id) == 0) + ) query = query.order_by(Conversation.created_at.desc()) - conversations = db.paginate( - query, - page=args['page'], - per_page=args['limit'], - error_out=False - ) + conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) return conversations class CompletionConversationDetailApi(Resource): - @setup_required @login_required @account_initialization_required @@ -123,8 +117,11 @@ def delete(self, app_model, conversation_id): raise Forbidden() conversation_id = str(conversation_id) - conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) + .first() + ) if not conversation: raise NotFound("Conversation Not Exists.") @@ -132,11 +129,10 @@ def delete(self, app_model, conversation_id): conversation.is_deleted = True db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class ChatConversationApi(Resource): - @setup_required @login_required @account_initialization_required @@ -146,20 +142,28 @@ def get(self, app_model): if not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('keyword', type=str, location='args') - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('annotation_status', type=str, - choices=['annotated', 'not_annotated', 'all'], default='all', location='args') - parser.add_argument('message_count_gte', type=int_range(1, 99999), required=False, location='args') - parser.add_argument('page', type=int_range(1, 99999), required=False, default=1, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("keyword", type=str, location="args") + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument( + "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" + ) + parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args") + parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser.add_argument( + "sort_by", + type=str, + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + required=False, + default="-updated_at", + location="args", + ) args = parser.parse_args() subquery = ( db.session.query( - Conversation.id.label('conversation_id'), - EndUser.session_id.label('from_end_user_session_id') + Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id") ) .outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id) .subquery() @@ -167,78 +171,96 @@ def get(self, app_model): query = db.select(Conversation).where(Conversation.app_id == app_model.id) - if args['keyword']: - keyword_filter = '%{}%'.format(args['keyword']) - query = query.join( - Message, Message.conversation_id == Conversation.id, - ).join( - subquery, subquery.c.conversation_id == Conversation.id - ).filter( - or_( - Message.query.ilike(keyword_filter), - Message.answer.ilike(keyword_filter), - Conversation.name.ilike(keyword_filter), - Conversation.introduction.ilike(keyword_filter), - subquery.c.from_end_user_session_id.ilike(keyword_filter) - ), + if args["keyword"]: + keyword_filter = "%{}%".format(args["keyword"]) + query = ( + query.join( + Message, + Message.conversation_id == Conversation.id, + ) + .join(subquery, subquery.c.conversation_id == Conversation.id) + .filter( + or_( + Message.query.ilike(keyword_filter), + Message.answer.ilike(keyword_filter), + Conversation.name.ilike(keyword_filter), + Conversation.introduction.ilike(keyword_filter), + subquery.c.from_end_user_session_id.ilike(keyword_filter), + ), + ) + .group_by(Conversation.id) ) account = current_user timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - query = query.where(Conversation.created_at >= start_datetime_utc) + match args["sort_by"]: + case "updated_at" | "-updated_at": + query = query.where(Conversation.updated_at >= start_datetime_utc) + case "created_at" | "-created_at" | _: + query = query.where(Conversation.created_at >= start_datetime_utc) - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=59) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - query = query.where(Conversation.created_at < end_datetime_utc) + match args["sort_by"]: + case "updated_at" | "-updated_at": + query = query.where(Conversation.updated_at <= end_datetime_utc) + case "created_at" | "-created_at" | _: + query = query.where(Conversation.created_at <= end_datetime_utc) - if args['annotation_status'] == "annotated": + if args["annotation_status"] == "annotated": query = query.options(joinedload(Conversation.message_annotations)).join( MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id ) - elif args['annotation_status'] == "not_annotated": - query = query.outerjoin( - MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id - ).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0) + elif args["annotation_status"] == "not_annotated": + query = ( + query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) + .group_by(Conversation.id) + .having(func.count(MessageAnnotation.id) == 0) + ) - if args['message_count_gte'] and args['message_count_gte'] >= 1: + if args["message_count_gte"] and args["message_count_gte"] >= 1: query = ( query.options(joinedload(Conversation.messages)) .join(Message, Message.conversation_id == Conversation.id) .group_by(Conversation.id) - .having(func.count(Message.id) >= args['message_count_gte']) + .having(func.count(Message.id) >= args["message_count_gte"]) ) if app_model.mode == AppMode.ADVANCED_CHAT.value: query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value) - query = query.order_by(Conversation.created_at.desc()) + match args["sort_by"]: + case "created_at": + query = query.order_by(Conversation.created_at.asc()) + case "-created_at": + query = query.order_by(Conversation.created_at.desc()) + case "updated_at": + query = query.order_by(Conversation.updated_at.asc()) + case "-updated_at": + query = query.order_by(Conversation.updated_at.desc()) + case _: + query = query.order_by(Conversation.created_at.desc()) - conversations = db.paginate( - query, - page=args['page'], - per_page=args['limit'], - error_out=False - ) + conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) return conversations class ChatConversationDetailApi(Resource): - @setup_required @login_required @account_initialization_required @@ -260,8 +282,11 @@ def delete(self, app_model, conversation_id): raise Forbidden() conversation_id = str(conversation_id) - conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) + .first() + ) if not conversation: raise NotFound("Conversation Not Exists.") @@ -269,18 +294,21 @@ def delete(self, app_model, conversation_id): conversation.is_deleted = True db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 -api.add_resource(CompletionConversationApi, '/apps//completion-conversations') -api.add_resource(CompletionConversationDetailApi, '/apps//completion-conversations/') -api.add_resource(ChatConversationApi, '/apps//chat-conversations') -api.add_resource(ChatConversationDetailApi, '/apps//chat-conversations/') +api.add_resource(CompletionConversationApi, "/apps//completion-conversations") +api.add_resource(CompletionConversationDetailApi, "/apps//completion-conversations/") +api.add_resource(ChatConversationApi, "/apps//chat-conversations") +api.add_resource(ChatConversationDetailApi, "/apps//chat-conversations/") def _get_conversation(app_model, conversation_id): - conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) + .first() + ) if not conversation: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index aa0722ea355ca2..d49f433ba1f575 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -4,8 +4,7 @@ from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db from fields.conversation_variable_fields import paginated_conversation_variable_fields from libs.login import login_required @@ -21,7 +20,7 @@ class ConversationVariablesApi(Resource): @marshal_with(paginated_conversation_variable_fields) def get(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('conversation_id', type=str, location='args') + parser.add_argument("conversation_id", type=str, location="args") args = parser.parse_args() stmt = ( @@ -29,10 +28,10 @@ def get(self, app_model): .where(ConversationVariable.app_id == app_model.id) .order_by(ConversationVariable.created_at) ) - if args['conversation_id']: - stmt = stmt.where(ConversationVariable.conversation_id == args['conversation_id']) + if args["conversation_id"]: + stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"]) else: - raise ValueError('conversation_id is required') + raise ValueError("conversation_id is required") # NOTE: This is a temporary solution to avoid performance issues. page = 1 @@ -43,14 +42,14 @@ def get(self, app_model): rows = session.scalars(stmt).all() return { - 'page': page, - 'limit': page_size, - 'total': len(rows), - 'has_more': False, - 'data': [ + "page": page, + "limit": page_size, + "total": len(rows), + "has_more": False, + "data": [ { - 'created_at': row.created_at, - 'updated_at': row.updated_at, + "created_at": row.created_at, + "updated_at": row.updated_at, **row.to_variable().model_dump(), } for row in rows @@ -58,4 +57,4 @@ def get(self, app_model): } -api.add_resource(ConversationVariablesApi, '/apps//conversation-variables') +api.add_resource(ConversationVariablesApi, "/apps//conversation-variables") diff --git a/api/controllers/console/app/error.py b/api/controllers/console/app/error.py index f6feed12217a85..1559f82d6ea142 100644 --- a/api/controllers/console/app/error.py +++ b/api/controllers/console/app/error.py @@ -2,116 +2,128 @@ class AppNotFoundError(BaseHTTPException): - error_code = 'app_not_found' + error_code = "app_not_found" description = "App not found." code = 404 class ProviderNotInitializeError(BaseHTTPException): - error_code = 'provider_not_initialize' - description = "No valid model provider credentials found. " \ - "Please go to Settings -> Model Provider to complete your provider credentials." + error_code = "provider_not_initialize" + description = ( + "No valid model provider credentials found. " + "Please go to Settings -> Model Provider to complete your provider credentials." + ) code = 400 class ProviderQuotaExceededError(BaseHTTPException): - error_code = 'provider_quota_exceeded' - description = "Your quota for Dify Hosted Model Provider has been exhausted. " \ - "Please go to Settings -> Model Provider to complete your own provider credentials." + error_code = "provider_quota_exceeded" + description = ( + "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials." + ) code = 400 class ProviderModelCurrentlyNotSupportError(BaseHTTPException): - error_code = 'model_currently_not_support' + error_code = "model_currently_not_support" description = "Dify Hosted OpenAI trial currently not support the GPT-4 model." code = 400 class ConversationCompletedError(BaseHTTPException): - error_code = 'conversation_completed' + error_code = "conversation_completed" description = "The conversation has ended. Please start a new conversation." code = 400 class AppUnavailableError(BaseHTTPException): - error_code = 'app_unavailable' + error_code = "app_unavailable" description = "App unavailable, please check your app configurations." code = 400 class CompletionRequestError(BaseHTTPException): - error_code = 'completion_request_error' + error_code = "completion_request_error" description = "Completion request failed." code = 400 class AppMoreLikeThisDisabledError(BaseHTTPException): - error_code = 'app_more_like_this_disabled' + error_code = "app_more_like_this_disabled" description = "The 'More like this' feature is disabled. Please refresh your page." code = 403 class NoAudioUploadedError(BaseHTTPException): - error_code = 'no_audio_uploaded' + error_code = "no_audio_uploaded" description = "Please upload your audio." code = 400 class AudioTooLargeError(BaseHTTPException): - error_code = 'audio_too_large' + error_code = "audio_too_large" description = "Audio size exceeded. {message}" code = 413 class UnsupportedAudioTypeError(BaseHTTPException): - error_code = 'unsupported_audio_type' + error_code = "unsupported_audio_type" description = "Audio type not allowed." code = 415 class ProviderNotSupportSpeechToTextError(BaseHTTPException): - error_code = 'provider_not_support_speech_to_text' + error_code = "provider_not_support_speech_to_text" description = "Provider not support speech to text." code = 400 class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class DraftWorkflowNotExist(BaseHTTPException): - error_code = 'draft_workflow_not_exist' + error_code = "draft_workflow_not_exist" description = "Draft workflow need to be initialized." code = 400 class DraftWorkflowNotSync(BaseHTTPException): - error_code = 'draft_workflow_not_sync' + error_code = "draft_workflow_not_sync" description = "Workflow graph might have been modified, please refresh and resubmit." code = 400 class TracingConfigNotExist(BaseHTTPException): - error_code = 'trace_config_not_exist' + error_code = "trace_config_not_exist" description = "Trace config not exist." code = 400 class TracingConfigIsExist(BaseHTTPException): - error_code = 'trace_config_is_exist' + error_code = "trace_config_is_exist" description = "Trace config is exist." code = 400 class TracingConfigCheckError(BaseHTTPException): - error_code = 'trace_config_check_error' + error_code = "trace_config_check_error" description = "Invalid Credentials." code = 400 + + +class InvokeRateLimitError(BaseHTTPException): + """Raised when the Invoke returns rate limit error.""" + + error_code = "rate_limit_error" + description = "Rate Limit Error" + code = 429 diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 6803775e20dfb8..9c3cbe4e3e049e 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -10,8 +10,7 @@ ProviderNotInitializeError, ProviderQuotaExceededError, ) -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.llm_generator.llm_generator import LLMGenerator from core.model_runtime.errors.invoke import InvokeError @@ -24,21 +23,21 @@ class RuleGenerateApi(Resource): @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('instruction', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_config', type=dict, required=True, nullable=False, location='json') - parser.add_argument('no_variable', type=bool, required=True, default=False, location='json') + parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") + parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") + parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") args = parser.parse_args() account = current_user - PROMPT_GENERATION_MAX_TOKENS = int(os.getenv('PROMPT_GENERATION_MAX_TOKENS', '512')) + PROMPT_GENERATION_MAX_TOKENS = int(os.getenv("PROMPT_GENERATION_MAX_TOKENS", "512")) try: rules = LLMGenerator.generate_rule_config( tenant_id=account.current_tenant_id, - instruction=args['instruction'], - model_config=args['model_config'], - no_variable=args['no_variable'], - rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS + instruction=args["instruction"], + model_config=args["model_config"], + no_variable=args["no_variable"], + rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -52,4 +51,39 @@ def post(self): return rules -api.add_resource(RuleGenerateApi, '/rule-generate') +class RuleCodeGenerateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") + parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") + parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") + parser.add_argument("code_language", type=str, required=False, default="javascript", location="json") + args = parser.parse_args() + + account = current_user + CODE_GENERATION_MAX_TOKENS = int(os.getenv("CODE_GENERATION_MAX_TOKENS", "1024")) + try: + code_result = LLMGenerator.generate_code( + tenant_id=account.current_tenant_id, + instruction=args["instruction"], + model_config=args["model_config"], + code_language=args["code_language"], + max_tokens=CODE_GENERATION_MAX_TOKENS, + ) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + + return code_result + + +api.add_resource(RuleGenerateApi, "/rule-generate") +api.add_resource(RuleCodeGenerateApi, "/rule-code-generate") diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 056415f19a28c5..b7a4c31a156b80 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -14,8 +14,11 @@ ) from controllers.console.app.wraps import get_app_model from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_resource_check, + setup_required, +) from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError @@ -33,9 +36,9 @@ class ChatMessageListApi(Resource): message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_detail_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_detail_fields)), } @setup_required @@ -45,55 +48,69 @@ class ChatMessageListApi(Resource): @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') - parser.add_argument('first_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") + parser.add_argument("first_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - conversation = db.session.query(Conversation).filter( - Conversation.id == args['conversation_id'], - Conversation.app_id == app_model.id - ).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id) + .first() + ) if not conversation: raise NotFound("Conversation Not Exists.") - if args['first_id']: - first_message = db.session.query(Message) \ - .filter(Message.conversation_id == conversation.id, Message.id == args['first_id']).first() + if args["first_id"]: + first_message = ( + db.session.query(Message) + .filter(Message.conversation_id == conversation.id, Message.id == args["first_id"]) + .first() + ) if not first_message: raise NotFound("First message not found") - history_messages = db.session.query(Message).filter( - Message.conversation_id == conversation.id, - Message.created_at < first_message.created_at, - Message.id != first_message.id - ) \ - .order_by(Message.created_at.desc()).limit(args['limit']).all() + history_messages = ( + db.session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.created_at < first_message.created_at, + Message.id != first_message.id, + ) + .order_by(Message.created_at.desc()) + .limit(args["limit"]) + .all() + ) else: - history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \ - .order_by(Message.created_at.desc()).limit(args['limit']).all() + history_messages = ( + db.session.query(Message) + .filter(Message.conversation_id == conversation.id) + .order_by(Message.created_at.desc()) + .limit(args["limit"]) + .all() + ) has_more = False - if len(history_messages) == args['limit']: + if len(history_messages) == args["limit"]: current_page_first_message = history_messages[-1] - rest_count = db.session.query(Message).filter( - Message.conversation_id == conversation.id, - Message.created_at < current_page_first_message.created_at, - Message.id != current_page_first_message.id - ).count() + rest_count = ( + db.session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.created_at < current_page_first_message.created_at, + Message.id != current_page_first_message.id, + ) + .count() + ) if rest_count > 0: has_more = True history_messages = list(reversed(history_messages)) - return InfiniteScrollPagination( - data=history_messages, - limit=args['limit'], - has_more=has_more - ) + return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more) class MessageFeedbackApi(Resource): @@ -103,49 +120,46 @@ class MessageFeedbackApi(Resource): @get_app_model def post(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('message_id', required=True, type=uuid_value, location='json') - parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + parser.add_argument("message_id", required=True, type=uuid_value, location="json") + parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() - message_id = str(args['message_id']) + message_id = str(args["message_id"]) - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id - ).first() + message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() if not message: raise NotFound("Message Not Exists.") feedback = message.admin_feedback - if not args['rating'] and feedback: + if not args["rating"] and feedback: db.session.delete(feedback) - elif args['rating'] and feedback: - feedback.rating = args['rating'] - elif not args['rating'] and not feedback: - raise ValueError('rating cannot be None when feedback not exists') + elif args["rating"] and feedback: + feedback.rating = args["rating"] + elif not args["rating"] and not feedback: + raise ValueError("rating cannot be None when feedback not exists") else: feedback = MessageFeedback( app_id=app_model.id, conversation_id=message.conversation_id, message_id=message.id, - rating=args['rating'], - from_source='admin', - from_account_id=current_user.id + rating=args["rating"], + from_source="admin", + from_account_id=current_user.id, ) db.session.add(feedback) db.session.commit() - return {'result': 'success'} + return {"result": "success"} class MessageAnnotationApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") @get_app_model @marshal_with(annotation_fields) def post(self, app_model): @@ -153,10 +167,10 @@ def post(self, app_model): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('message_id', required=False, type=uuid_value, location='json') - parser.add_argument('question', required=True, type=str, location='json') - parser.add_argument('answer', required=True, type=str, location='json') - parser.add_argument('annotation_reply', required=False, type=dict, location='json') + parser.add_argument("message_id", required=False, type=uuid_value, location="json") + parser.add_argument("question", required=True, type=str, location="json") + parser.add_argument("answer", required=True, type=str, location="json") + parser.add_argument("annotation_reply", required=False, type=dict, location="json") args = parser.parse_args() annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id) @@ -169,11 +183,9 @@ class MessageAnnotationCountApi(Resource): @account_initialization_required @get_app_model def get(self, app_model): - count = db.session.query(MessageAnnotation).filter( - MessageAnnotation.app_id == app_model.id - ).count() + count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count() - return {'count': count} + return {"count": count} class MessageSuggestedQuestionApi(Resource): @@ -186,10 +198,7 @@ def get(self, app_model, message_id): try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, - message_id=message_id, - user=current_user, - invoke_from=InvokeFrom.DEBUGGER + app_model=app_model, message_id=message_id, user=current_user, invoke_from=InvokeFrom.DEBUGGER ) except MessageNotExistsError: raise NotFound("Message not found") @@ -209,7 +218,7 @@ def get(self, app_model, message_id): logging.exception("internal server error.") raise InternalServerError() - return {'data': questions} + return {"data": questions} class MessageApi(Resource): @@ -221,10 +230,7 @@ class MessageApi(Resource): def get(self, app_model, message_id): message_id = str(message_id) - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id - ).first() + message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() if not message: raise NotFound("Message Not Exists.") @@ -232,9 +238,9 @@ def get(self, app_model, message_id): return message -api.add_resource(MessageSuggestedQuestionApi, '/apps//chat-messages//suggested-questions') -api.add_resource(ChatMessageListApi, '/apps//chat-messages', endpoint='console_chat_messages') -api.add_resource(MessageFeedbackApi, '/apps//feedbacks') -api.add_resource(MessageAnnotationApi, '/apps//annotations') -api.add_resource(MessageAnnotationCountApi, '/apps//annotations/count') -api.add_resource(MessageApi, '/apps//messages/', endpoint='console_message') +api.add_resource(MessageSuggestedQuestionApi, "/apps//chat-messages//suggested-questions") +api.add_resource(ChatMessageListApi, "/apps//chat-messages", endpoint="console_chat_messages") +api.add_resource(MessageFeedbackApi, "/apps//feedbacks") +api.add_resource(MessageAnnotationApi, "/apps//annotations") +api.add_resource(MessageAnnotationCountApi, "/apps//annotations/count") +api.add_resource(MessageApi, "/apps//messages/", endpoint="console_message") diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index c8df879a29ca42..8ba195f5a51053 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -6,8 +6,7 @@ from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.agent.entities import AgentToolEntity from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager @@ -19,37 +18,35 @@ class ModelConfigResource(Resource): - @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) def post(self, app_model): - """Modify app model config""" # validate config model_configuration = AppModelConfigService.validate_configuration( - tenant_id=current_user.current_tenant_id, - config=request.json, - app_mode=AppMode.value_of(app_model.mode) + tenant_id=current_user.current_tenant_id, config=request.json, app_mode=AppMode.value_of(app_model.mode) ) new_app_model_config = AppModelConfig( app_id=app_model.id, + created_by=current_user.id, + updated_by=current_user.id, ) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: # get original app model config - original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter( - AppModelConfig.id == app_model.app_model_config_id - ).first() + original_app_model_config: AppModelConfig = ( + db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() + ) agent_mode = original_app_model_config.agent_mode_dict # decrypt agent tool parameters if it's secret-input parameter_map = {} masked_parameter_map = {} tool_map = {} - for tool in agent_mode.get('tools') or []: + for tool in agent_mode.get("tools") or []: if not isinstance(tool, dict) or len(tool.keys()) <= 3: continue @@ -66,7 +63,7 @@ def post(self, app_model): tool_runtime=tool_runtime, provider_name=agent_tool_entity.provider_id, provider_type=agent_tool_entity.provider_type, - identity_id=f'AGENT.{app_model.id}' + identity_id=f"AGENT.{app_model.id}", ) except Exception as e: continue @@ -79,18 +76,18 @@ def post(self, app_model): parameters = {} masked_parameter = {} - key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' + key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}" masked_parameter_map[key] = masked_parameter parameter_map[key] = parameters tool_map[key] = tool_runtime # encrypt agent tool parameters if it's secret-input agent_mode = new_app_model_config.agent_mode_dict - for tool in agent_mode.get('tools') or []: + for tool in agent_mode.get("tools") or []: agent_tool_entity = AgentToolEntity(**tool) # get tool - key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' + key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}" if key in tool_map: tool_runtime = tool_map[key] else: @@ -108,7 +105,7 @@ def post(self, app_model): tool_runtime=tool_runtime, provider_name=agent_tool_entity.provider_id, provider_type=agent_tool_entity.provider_type, - identity_id=f'AGENT.{app_model.id}' + identity_id=f"AGENT.{app_model.id}", ) manager.delete_tool_parameters_cache() @@ -116,15 +113,17 @@ def post(self, app_model): if agent_tool_entity.tool_parameters: if key not in masked_parameter_map: continue - + for masked_key, masked_value in masked_parameter_map[key].items(): - if masked_key in agent_tool_entity.tool_parameters and \ - agent_tool_entity.tool_parameters[masked_key] == masked_value: + if ( + masked_key in agent_tool_entity.tool_parameters + and agent_tool_entity.tool_parameters[masked_key] == masked_value + ): agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key) # encrypt parameters if agent_tool_entity.tool_parameters: - tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + tool["tool_parameters"] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) # update app model config new_app_model_config.agent_mode = json.dumps(agent_mode) @@ -135,12 +134,9 @@ def post(self, app_model): app_model.app_model_config_id = new_app_model_config.id db.session.commit() - app_model_config_was_updated.send( - app_model, - app_model_config=new_app_model_config - ) + app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config) - return {'result': 'success'} + return {"result": "success"} -api.add_resource(ModelConfigResource, '/apps//model-config') +api.add_resource(ModelConfigResource, "/apps//model-config") diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index c0cf7b9e33f32b..47b58396a1a303 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -2,8 +2,7 @@ from controllers.console import api from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required from services.ops_service import OpsService @@ -18,13 +17,11 @@ class TraceAppConfigApi(Resource): @account_initialization_required def get(self, app_id): parser = reqparse.RequestParser() - parser.add_argument('tracing_provider', type=str, required=True, location='args') + parser.add_argument("tracing_provider", type=str, required=True, location="args") args = parser.parse_args() try: - trace_config = OpsService.get_tracing_app_config( - app_id=app_id, tracing_provider=args['tracing_provider'] - ) + trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) if not trace_config: return {"has_not_configured": True} return trace_config @@ -37,19 +34,17 @@ def get(self, app_id): def post(self, app_id): """Create a new trace app configuration""" parser = reqparse.RequestParser() - parser.add_argument('tracing_provider', type=str, required=True, location='json') - parser.add_argument('tracing_config', type=dict, required=True, location='json') + parser.add_argument("tracing_provider", type=str, required=True, location="json") + parser.add_argument("tracing_config", type=dict, required=True, location="json") args = parser.parse_args() try: result = OpsService.create_tracing_app_config( - app_id=app_id, - tracing_provider=args['tracing_provider'], - tracing_config=args['tracing_config'] + app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] ) if not result: raise TracingConfigIsExist() - if result.get('error'): + if result.get("error"): raise TracingConfigCheckError() return result except Exception as e: @@ -61,15 +56,13 @@ def post(self, app_id): def patch(self, app_id): """Update an existing trace app configuration""" parser = reqparse.RequestParser() - parser.add_argument('tracing_provider', type=str, required=True, location='json') - parser.add_argument('tracing_config', type=dict, required=True, location='json') + parser.add_argument("tracing_provider", type=str, required=True, location="json") + parser.add_argument("tracing_config", type=dict, required=True, location="json") args = parser.parse_args() try: result = OpsService.update_tracing_app_config( - app_id=app_id, - tracing_provider=args['tracing_provider'], - tracing_config=args['tracing_config'] + app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] ) if not result: raise TracingConfigNotExist() @@ -83,14 +76,11 @@ def patch(self, app_id): def delete(self, app_id): """Delete an existing trace app configuration""" parser = reqparse.RequestParser() - parser.add_argument('tracing_provider', type=str, required=True, location='args') + parser.add_argument("tracing_provider", type=str, required=True, location="args") args = parser.parse_args() try: - result = OpsService.delete_tracing_app_config( - app_id=app_id, - tracing_provider=args['tracing_provider'] - ) + result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) if not result: raise TracingConfigNotExist() return {"result": "success"} @@ -98,4 +88,4 @@ def delete(self, app_id): raise e -api.add_resource(TraceAppConfigApi, '/apps//trace-config') +api.add_resource(TraceAppConfigApi, "/apps//trace-config") diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 6aa9f0b475f161..2f5645852fe277 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,3 +1,5 @@ +from datetime import datetime, timezone + from flask_login import current_user from flask_restful import Resource, marshal_with, reqparse from werkzeug.exceptions import Forbidden, NotFound @@ -5,32 +7,33 @@ from constants.languages import supported_language from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db from fields.app_fields import app_site_fields from libs.login import login_required -from models.model import Site +from models import Site def parse_app_site_args(): parser = reqparse.RequestParser() - parser.add_argument('title', type=str, required=False, location='json') - parser.add_argument('icon', type=str, required=False, location='json') - parser.add_argument('icon_background', type=str, required=False, location='json') - parser.add_argument('description', type=str, required=False, location='json') - parser.add_argument('default_language', type=supported_language, required=False, location='json') - parser.add_argument('chat_color_theme', type=str, required=False, location='json') - parser.add_argument('chat_color_theme_inverted', type=bool, required=False, location='json') - parser.add_argument('customize_domain', type=str, required=False, location='json') - parser.add_argument('copyright', type=str, required=False, location='json') - parser.add_argument('privacy_policy', type=str, required=False, location='json') - parser.add_argument('custom_disclaimer', type=str, required=False, location='json') - parser.add_argument('customize_token_strategy', type=str, choices=['must', 'allow', 'not_allow'], - required=False, - location='json') - parser.add_argument('prompt_public', type=bool, required=False, location='json') - parser.add_argument('show_workflow_steps', type=bool, required=False, location='json') + parser.add_argument("title", type=str, required=False, location="json") + parser.add_argument("icon_type", type=str, required=False, location="json") + parser.add_argument("icon", type=str, required=False, location="json") + parser.add_argument("icon_background", type=str, required=False, location="json") + parser.add_argument("description", type=str, required=False, location="json") + parser.add_argument("default_language", type=supported_language, required=False, location="json") + parser.add_argument("chat_color_theme", type=str, required=False, location="json") + parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json") + parser.add_argument("customize_domain", type=str, required=False, location="json") + parser.add_argument("copyright", type=str, required=False, location="json") + parser.add_argument("privacy_policy", type=str, required=False, location="json") + parser.add_argument("custom_disclaimer", type=str, required=False, location="json") + parser.add_argument( + "customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json" + ) + parser.add_argument("prompt_public", type=bool, required=False, location="json") + parser.add_argument("show_workflow_steps", type=bool, required=False, location="json") + parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json") return parser.parse_args() @@ -47,37 +50,38 @@ def post(self, app_model): if not current_user.is_editor: raise Forbidden() - site = db.session.query(Site). \ - filter(Site.app_id == app_model.id). \ - one_or_404() + site = db.session.query(Site).filter(Site.app_id == app_model.id).one_or_404() for attr_name in [ - 'title', - 'icon', - 'icon_background', - 'description', - 'default_language', - 'chat_color_theme', - 'chat_color_theme_inverted', - 'customize_domain', - 'copyright', - 'privacy_policy', - 'custom_disclaimer', - 'customize_token_strategy', - 'prompt_public', - 'show_workflow_steps' + "title", + "icon_type", + "icon", + "icon_background", + "description", + "default_language", + "chat_color_theme", + "chat_color_theme_inverted", + "customize_domain", + "copyright", + "privacy_policy", + "custom_disclaimer", + "customize_token_strategy", + "prompt_public", + "show_workflow_steps", + "use_icon_as_answer_icon", ]: value = args.get(attr_name) if value is not None: setattr(site, attr_name, value) + site.updated_by = current_user.id + site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() return site class AppSiteAccessTokenReset(Resource): - @setup_required @login_required @account_initialization_required @@ -94,10 +98,12 @@ def post(self, app_model): raise NotFound site.code = Site.generate_code(16) + site.updated_by = current_user.id + site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() return site -api.add_resource(AppSite, '/apps//site') -api.add_resource(AppSiteAccessTokenReset, '/apps//site/access-token-reset') +api.add_resource(AppSite, "/apps//site") +api.add_resource(AppSiteAccessTokenReset, "/apps//site/access-token-reset") diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index b882ffef34129e..db5e2824095ca0 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -8,16 +8,71 @@ from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db -from libs.helper import datetime_string +from libs.helper import DatetimeString from libs.login import login_required from models.model import AppMode -class DailyConversationStatistic(Resource): +class DailyMessageStatistic(Resource): + @setup_required + @login_required + @account_initialization_required + @get_app_model + def get(self, app_model): + account = current_user + + parser = reqparse.RequestParser() + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + args = parser.parse_args() + + sql_query = """SELECT + DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + COUNT(*) AS message_count +FROM + messages +WHERE + app_id = :app_id""" + arg_dict = {"tz": account.timezone, "app_id": app_model.id} + + timezone = pytz.timezone(account.timezone) + utc_timezone = pytz.utc + + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") + start_datetime = start_datetime.replace(second=0) + + start_datetime_timezone = timezone.localize(start_datetime) + start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) + + sql_query += " AND created_at >= :start" + arg_dict["start"] = start_datetime_utc + + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") + end_datetime = end_datetime.replace(second=0) + end_datetime_timezone = timezone.localize(end_datetime) + end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) + + sql_query += " AND created_at < :end" + arg_dict["end"] = end_datetime_utc + + sql_query += " GROUP BY date ORDER BY date" + + response_data = [] + + with db.engine.begin() as conn: + rs = conn.execute(db.text(sql_query), arg_dict) + for i in rs: + response_data.append({"date": str(i.date), "message_count": i.message_count}) + + return jsonify({"data": response_data}) + + +class DailyConversationStatistic(Resource): @setup_required @login_required @account_initialization_required @@ -26,58 +81,55 @@ def get(self, app_model): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' - SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count - FROM messages where app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + sql_query = """SELECT + DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + COUNT(DISTINCT messages.conversation_id) AS conversation_count +FROM + messages +WHERE + app_id = :app_id""" + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " AND created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " AND created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date ORDER BY date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'conversation_count': i.conversation_count - }) + response_data.append({"date": str(i.date), "conversation_count": i.conversation_count}) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class DailyTerminalsStatistic(Resource): - @setup_required @login_required @account_initialization_required @@ -86,54 +138,52 @@ def get(self, app_model): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' - SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count - FROM messages where app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + sql_query = """SELECT + DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + COUNT(DISTINCT messages.from_end_user_id) AS terminal_count +FROM + messages +WHERE + app_id = :app_id""" + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " AND created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " AND created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date ORDER BY date" response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'terminal_count': i.terminal_count - }) + response_data.append({"date": str(i.date), "terminal_count": i.terminal_count}) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class DailyTokenCostStatistic(Resource): @@ -145,58 +195,55 @@ def get(self, app_model): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' - SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, - (sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count, - sum(total_price) as total_price - FROM messages where app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + sql_query = """SELECT + DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + (SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count, + SUM(total_price) AS total_price +FROM + messages +WHERE + app_id = :app_id""" + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " AND created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " AND created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date ORDER BY date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'token_count': i.token_count, - 'total_price': i.total_price, - 'currency': 'USD' - }) + response_data.append( + {"date": str(i.date), "token_count": i.token_count, "total_price": i.total_price, "currency": "USD"} + ) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class AverageSessionInteractionStatistic(Resource): @@ -208,60 +255,72 @@ def get(self, app_model): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, -AVG(subquery.message_count) AS interactions -FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count - FROM conversations c - JOIN messages m ON c.id = m.conversation_id - WHERE c.override_model_configs IS NULL AND c.app_id = :app_id""" - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + sql_query = """SELECT + DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + AVG(subquery.message_count) AS interactions +FROM + ( + SELECT + m.conversation_id, + COUNT(m.id) AS message_count + FROM + conversations c + JOIN + messages m + ON c.id = m.conversation_id + WHERE + c.override_model_configs IS NULL + AND c.app_id = :app_id""" + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and c.created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " AND c.created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and c.created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " AND c.created_at < :end" + arg_dict["end"] = end_datetime_utc sql_query += """ - GROUP BY m.conversation_id) subquery -LEFT JOIN conversations c on c.id=subquery.conversation_id -GROUP BY date -ORDER BY date""" + GROUP BY m.conversation_id + ) subquery +LEFT JOIN + conversations c + ON c.id = subquery.conversation_id +GROUP BY + date +ORDER BY + date""" response_data = [] - + with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'interactions': float(i.interactions.quantize(Decimal('0.01'))) - }) + response_data.append( + {"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))} + ) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class UserSatisfactionRateStatistic(Resource): @@ -273,57 +332,61 @@ def get(self, app_model): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' - SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, - COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count - FROM messages m - LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like' - WHERE m.app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + sql_query = """SELECT + DATE(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + COUNT(m.id) AS message_count, + COUNT(mf.id) AS feedback_count +FROM + messages m +LEFT JOIN + message_feedbacks mf + ON mf.message_id=m.id AND mf.rating='like' +WHERE + m.app_id = :app_id""" + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and m.created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " AND m.created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and m.created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " AND m.created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date ORDER BY date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'rate': round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2), - }) + response_data.append( + { + "date": str(i.date), + "rate": round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2), + } + ) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class AverageResponseTimeStatistic(Resource): @@ -335,56 +398,52 @@ def get(self, app_model): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' - SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, - AVG(provider_response_latency) as latency - FROM messages - WHERE app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + sql_query = """SELECT + DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + AVG(provider_response_latency) AS latency +FROM + messages +WHERE + app_id = :app_id""" + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " AND created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " AND created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date ORDER BY date" response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'latency': round(i.latency * 1000, 4) - }) + response_data.append({"date": str(i.date), "latency": round(i.latency * 1000, 4)}) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class TokensPerSecondStatistic(Resource): @@ -396,63 +455,62 @@ def get(self, app_model): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = '''SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, - CASE + sql_query = """SELECT + DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + CASE WHEN SUM(provider_response_latency) = 0 THEN 0 ELSE (SUM(answer_tokens) / SUM(provider_response_latency)) END as tokens_per_second -FROM messages -WHERE app_id = :app_id''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} +FROM + messages +WHERE + app_id = :app_id""" + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " AND created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " AND created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date ORDER BY date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'tps': round(i.tokens_per_second, 4) - }) - - return jsonify({ - 'data': response_data - }) - - -api.add_resource(DailyConversationStatistic, '/apps//statistics/daily-conversations') -api.add_resource(DailyTerminalsStatistic, '/apps//statistics/daily-end-users') -api.add_resource(DailyTokenCostStatistic, '/apps//statistics/token-costs') -api.add_resource(AverageSessionInteractionStatistic, '/apps//statistics/average-session-interactions') -api.add_resource(UserSatisfactionRateStatistic, '/apps//statistics/user-satisfaction-rate') -api.add_resource(AverageResponseTimeStatistic, '/apps//statistics/average-response-time') -api.add_resource(TokensPerSecondStatistic, '/apps//statistics/tokens-per-second') + response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)}) + + return jsonify({"data": response_data}) + + +api.add_resource(DailyMessageStatistic, "/apps//statistics/daily-messages") +api.add_resource(DailyConversationStatistic, "/apps//statistics/daily-conversations") +api.add_resource(DailyTerminalsStatistic, "/apps//statistics/daily-end-users") +api.add_resource(DailyTokenCostStatistic, "/apps//statistics/token-costs") +api.add_resource(AverageSessionInteractionStatistic, "/apps//statistics/average-session-interactions") +api.add_resource(UserSatisfactionRateStatistic, "/apps//statistics/user-satisfaction-rate") +api.add_resource(AverageResponseTimeStatistic, "/apps//statistics/average-response-time") +api.add_resource(TokensPerSecondStatistic, "/apps//statistics/tokens-per-second") diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 6eb97b6c817916..f7027fb22669dd 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -9,18 +9,17 @@ from controllers.console import api from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.segments import factory -from core.errors.error import AppInvokeQuotaExceededError +from factories import variable_factory from fields.workflow_fields import workflow_fields from fields.workflow_run_fields import workflow_run_node_execution_fields from libs import helper from libs.helper import TimestampField, uuid_value from libs.login import current_user, login_required -from models.model import App, AppMode +from models import App +from models.model import AppMode from services.app_dsl_service import AppDslService from services.app_generate_service import AppGenerateService from services.errors.app import WorkflowHashNotEqualError @@ -64,51 +63,55 @@ def post(self, app_model: App): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - - content_type = request.headers.get('Content-Type', '') - if 'application/json' in content_type: + content_type = request.headers.get("Content-Type", "") + + if "application/json" in content_type: parser = reqparse.RequestParser() - parser.add_argument('graph', type=dict, required=True, nullable=False, location='json') - parser.add_argument('features', type=dict, required=True, nullable=False, location='json') - parser.add_argument('hash', type=str, required=False, location='json') + parser.add_argument("graph", type=dict, required=True, nullable=False, location="json") + parser.add_argument("features", type=dict, required=True, nullable=False, location="json") + parser.add_argument("hash", type=str, required=False, location="json") # TODO: set this to required=True after frontend is updated - parser.add_argument('environment_variables', type=list, required=False, location='json') - parser.add_argument('conversation_variables', type=list, required=False, location='json') + parser.add_argument("environment_variables", type=list, required=False, location="json") + parser.add_argument("conversation_variables", type=list, required=False, location="json") args = parser.parse_args() - elif 'text/plain' in content_type: + elif "text/plain" in content_type: try: - data = json.loads(request.data.decode('utf-8')) - if 'graph' not in data or 'features' not in data: - raise ValueError('graph or features not found in data') + data = json.loads(request.data.decode("utf-8")) + if "graph" not in data or "features" not in data: + raise ValueError("graph or features not found in data") - if not isinstance(data.get('graph'), dict) or not isinstance(data.get('features'), dict): - raise ValueError('graph or features is not a dict') + if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict): + raise ValueError("graph or features is not a dict") args = { - 'graph': data.get('graph'), - 'features': data.get('features'), - 'hash': data.get('hash'), - 'environment_variables': data.get('environment_variables'), - 'conversation_variables': data.get('conversation_variables'), + "graph": data.get("graph"), + "features": data.get("features"), + "hash": data.get("hash"), + "environment_variables": data.get("environment_variables"), + "conversation_variables": data.get("conversation_variables"), } except json.JSONDecodeError: - return {'message': 'Invalid JSON data'}, 400 + return {"message": "Invalid JSON data"}, 400 else: abort(415) workflow_service = WorkflowService() try: - environment_variables_list = args.get('environment_variables') or [] - environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] - conversation_variables_list = args.get('conversation_variables') or [] - conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] + environment_variables_list = args.get("environment_variables") or [] + environment_variables = [ + variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list + ] + conversation_variables_list = args.get("conversation_variables") or [] + conversation_variables = [ + variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list + ] workflow = workflow_service.sync_draft_workflow( app_model=app_model, - graph=args['graph'], - features=args['features'], - unique_hash=args.get('hash'), + graph=args["graph"], + features=args["features"], + unique_hash=args.get("hash"), account=current_user, environment_variables=environment_variables, conversation_variables=conversation_variables, @@ -119,7 +122,7 @@ def post(self, app_model: App): return { "result": "success", "hash": workflow.unique_hash, - "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at) + "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at), } @@ -138,13 +141,11 @@ def post(self, app_model: App): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('data', type=str, required=True, nullable=False, location='json') + parser.add_argument("data", type=str, required=True, nullable=False, location="json") args = parser.parse_args() workflow = AppDslService.import_and_overwrite_workflow( - app_model=app_model, - data=args['data'], - account=current_user + app_model=app_model, data=args["data"], account=current_user ) return workflow @@ -162,21 +163,19 @@ def post(self, app_model: App): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, location='json') - parser.add_argument('query', type=str, required=True, location='json', default='') - parser.add_argument('files', type=list, location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') + parser.add_argument("inputs", type=dict, location="json") + parser.add_argument("query", type=str, required=True, location="json", default="") + parser.add_argument("files", type=list, location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") + args = parser.parse_args() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.DEBUGGER, - streaming=True + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True ) return helper.compact_generate_response(response) @@ -190,6 +189,7 @@ def post(self, app_model: App): logging.exception("internal server error.") raise InternalServerError() + class AdvancedChatDraftRunIterationNodeApi(Resource): @setup_required @login_required @@ -202,18 +202,14 @@ def post(self, app_model: App, node_id: str): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, location='json') + parser.add_argument("inputs", type=dict, location="json") args = parser.parse_args() try: response = AppGenerateService.generate_single_iteration( - app_model=app_model, - user=current_user, - node_id=node_id, - args=args, - streaming=True + app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True ) return helper.compact_generate_response(response) @@ -227,6 +223,7 @@ def post(self, app_model: App, node_id: str): logging.exception("internal server error.") raise InternalServerError() + class WorkflowDraftRunIterationNodeApi(Resource): @setup_required @login_required @@ -239,18 +236,14 @@ def post(self, app_model: App, node_id: str): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, location='json') + parser.add_argument("inputs", type=dict, location="json") args = parser.parse_args() try: response = AppGenerateService.generate_single_iteration( - app_model=app_model, - user=current_user, - node_id=node_id, - args=args, - streaming=True + app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True ) return helper.compact_generate_response(response) @@ -264,6 +257,7 @@ def post(self, app_model: App, node_id: str): logging.exception("internal server error.") raise InternalServerError() + class DraftWorkflowRunApi(Resource): @setup_required @login_required @@ -276,27 +270,21 @@ def post(self, app_model: App): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') - parser.add_argument('files', type=list, required=False, location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() - try: - response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.DEBUGGER, - streaming=True - ) + response = AppGenerateService.generate( + app_model=app_model, + user=current_user, + args=args, + invoke_from=InvokeFrom.DEBUGGER, + streaming=True, + ) - return helper.compact_generate_response(response) - except (ValueError, AppInvokeQuotaExceededError) as e: - raise e - except Exception as e: - logging.exception("internal server error.") - raise InternalServerError() + return helper.compact_generate_response(response) class WorkflowTaskStopApi(Resource): @@ -311,12 +299,10 @@ def post(self, app_model: App, task_id: str): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) - return { - "result": "success" - } + return {"result": "success"} class DraftWorkflowNodeRunApi(Resource): @@ -332,24 +318,20 @@ def post(self, app_model: App, node_id: str): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() workflow_service = WorkflowService() workflow_node_execution = workflow_service.run_draft_workflow_node( - app_model=app_model, - node_id=node_id, - user_inputs=args.get('inputs'), - account=current_user + app_model=app_model, node_id=node_id, user_inputs=args.get("inputs"), account=current_user ) return workflow_node_execution class PublishedWorkflowApi(Resource): - @setup_required @login_required @account_initialization_required @@ -362,7 +344,7 @@ def get(self, app_model: App): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + # fetch published workflow by app_model workflow_service = WorkflowService() workflow = workflow_service.get_published_workflow(app_model=app_model) @@ -381,14 +363,11 @@ def post(self, app_model: App): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + workflow_service = WorkflowService() workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user) - return { - "result": "success", - "created_at": TimestampField().format(workflow.created_at) - } + return {"result": "success", "created_at": TimestampField().format(workflow.created_at)} class DefaultBlockConfigsApi(Resource): @@ -403,7 +382,7 @@ def get(self, app_model: App): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + # Get default block configs workflow_service = WorkflowService() return workflow_service.get_default_block_configs() @@ -421,24 +400,21 @@ def get(self, app_model: App, block_type: str): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('q', type=str, location='args') + parser.add_argument("q", type=str, location="args") args = parser.parse_args() filters = None - if args.get('q'): + if args.get("q"): try: - filters = json.loads(args.get('q')) + filters = json.loads(args.get("q")) except json.JSONDecodeError: - raise ValueError('Invalid filters') + raise ValueError("Invalid filters") # Get default block configs workflow_service = WorkflowService() - return workflow_service.get_default_block_config( - node_type=block_type, - filters=filters - ) + return workflow_service.get_default_block_config(node_type=block_type, filters=filters) class ConvertToWorkflowApi(Resource): @@ -455,40 +431,43 @@ def post(self, app_model: App): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + if request.data: parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, nullable=True, location='json') - parser.add_argument('icon', type=str, required=False, nullable=True, location='json') - parser.add_argument('icon_background', type=str, required=False, nullable=True, location='json') + parser.add_argument("name", type=str, required=False, nullable=True, location="json") + parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json") + parser.add_argument("icon", type=str, required=False, nullable=True, location="json") + parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") args = parser.parse_args() else: args = {} # convert to workflow mode workflow_service = WorkflowService() - new_app_model = workflow_service.convert_to_workflow( - app_model=app_model, - account=current_user, - args=args - ) + new_app_model = workflow_service.convert_to_workflow(app_model=app_model, account=current_user, args=args) # return app id return { - 'new_app_id': new_app_model.id, + "new_app_id": new_app_model.id, } -api.add_resource(DraftWorkflowApi, '/apps//workflows/draft') -api.add_resource(DraftWorkflowImportApi, '/apps//workflows/draft/import') -api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps//advanced-chat/workflows/draft/run') -api.add_resource(DraftWorkflowRunApi, '/apps//workflows/draft/run') -api.add_resource(WorkflowTaskStopApi, '/apps//workflow-runs/tasks//stop') -api.add_resource(DraftWorkflowNodeRunApi, '/apps//workflows/draft/nodes//run') -api.add_resource(AdvancedChatDraftRunIterationNodeApi, '/apps//advanced-chat/workflows/draft/iteration/nodes//run') -api.add_resource(WorkflowDraftRunIterationNodeApi, '/apps//workflows/draft/iteration/nodes//run') -api.add_resource(PublishedWorkflowApi, '/apps//workflows/publish') -api.add_resource(DefaultBlockConfigsApi, '/apps//workflows/default-workflow-block-configs') -api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs' - '/') -api.add_resource(ConvertToWorkflowApi, '/apps//convert-to-workflow') +api.add_resource(DraftWorkflowApi, "/apps//workflows/draft") +api.add_resource(DraftWorkflowImportApi, "/apps//workflows/draft/import") +api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps//advanced-chat/workflows/draft/run") +api.add_resource(DraftWorkflowRunApi, "/apps//workflows/draft/run") +api.add_resource(WorkflowTaskStopApi, "/apps//workflow-runs/tasks//stop") +api.add_resource(DraftWorkflowNodeRunApi, "/apps//workflows/draft/nodes//run") +api.add_resource( + AdvancedChatDraftRunIterationNodeApi, + "/apps//advanced-chat/workflows/draft/iteration/nodes//run", +) +api.add_resource( + WorkflowDraftRunIterationNodeApi, "/apps//workflows/draft/iteration/nodes//run" +) +api.add_resource(PublishedWorkflowApi, "/apps//workflows/publish") +api.add_resource(DefaultBlockConfigsApi, "/apps//workflows/default-workflow-block-configs") +api.add_resource( + DefaultBlockConfigApi, "/apps//workflows/default-workflow-block-configs/" +) +api.add_resource(ConvertToWorkflowApi, "/apps//convert-to-workflow") diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 6d1709ed8e65d9..2940556f84ef4e 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -3,11 +3,11 @@ from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from fields.workflow_app_log_fields import workflow_app_log_pagination_fields from libs.login import login_required -from models.model import App, AppMode +from models import App +from models.model import AppMode from services.workflow_app_service import WorkflowAppService @@ -22,20 +22,19 @@ def get(self, app_model: App): Get workflow app logs """ parser = reqparse.RequestParser() - parser.add_argument('keyword', type=str, location='args') - parser.add_argument('status', type=str, choices=['succeeded', 'failed', 'stopped'], location='args') - parser.add_argument('page', type=int_range(1, 99999), default=1, location='args') - parser.add_argument('limit', type=int_range(1, 100), default=20, location='args') + parser.add_argument("keyword", type=str, location="args") + parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") + parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") + parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") args = parser.parse_args() # get paginate workflow app logs workflow_app_service = WorkflowAppService() workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs( - app_model=app_model, - args=args + app_model=app_model, args=args ) return workflow_app_log_pagination -api.add_resource(WorkflowAppLogApi, '/apps//workflow-app-logs') +api.add_resource(WorkflowAppLogApi, "/apps//workflow-app-logs") diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 35d982e37ce4e3..08ab61bbb9c97e 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -3,8 +3,7 @@ from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from fields.workflow_run_fields import ( advanced_chat_workflow_run_pagination_fields, workflow_run_detail_fields, @@ -13,7 +12,8 @@ ) from libs.helper import uuid_value from libs.login import login_required -from models.model import App, AppMode +from models import App +from models.model import AppMode from services.workflow_run_service import WorkflowRunService @@ -28,15 +28,12 @@ def get(self, app_model: App): Get advanced chat app workflow run list """ parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() workflow_run_service = WorkflowRunService() - result = workflow_run_service.get_paginate_advanced_chat_workflow_runs( - app_model=app_model, - args=args - ) + result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args) return result @@ -52,15 +49,12 @@ def get(self, app_model: App): Get workflow run list """ parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() workflow_run_service = WorkflowRunService() - result = workflow_run_service.get_paginate_workflow_runs( - app_model=app_model, - args=args - ) + result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args) return result @@ -98,12 +92,10 @@ def get(self, app_model: App, run_id): workflow_run_service = WorkflowRunService() node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id) - return { - 'data': node_executions - } + return {"data": node_executions} -api.add_resource(AdvancedChatAppWorkflowRunListApi, '/apps//advanced-chat/workflow-runs') -api.add_resource(WorkflowRunListApi, '/apps//workflow-runs') -api.add_resource(WorkflowRunDetailApi, '/apps//workflow-runs/') -api.add_resource(WorkflowRunNodeExecutionListApi, '/apps//workflow-runs//node-executions') +api.add_resource(AdvancedChatAppWorkflowRunListApi, "/apps//advanced-chat/workflow-runs") +api.add_resource(WorkflowRunListApi, "/apps//workflow-runs") +api.add_resource(WorkflowRunDetailApi, "/apps//workflow-runs/") +api.add_resource(WorkflowRunNodeExecutionListApi, "/apps//workflow-runs//node-executions") diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index 1d7dc395ff3b18..6c7c73707bb204 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -8,13 +8,12 @@ from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db -from libs.helper import datetime_string +from libs.helper import DatetimeString from libs.login import login_required +from models.enums import WorkflowRunTriggeredFrom from models.model import AppMode -from models.workflow import WorkflowRunTriggeredFrom class WorkflowDailyRunsStatistic(Resource): @@ -26,56 +25,58 @@ def get(self, app_model): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' - SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs - FROM workflow_runs - WHERE app_id = :app_id - AND triggered_from = :triggered_from - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} + sql_query = """SELECT + DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + COUNT(id) AS runs +FROM + workflow_runs +WHERE + app_id = :app_id + AND triggered_from = :triggered_from""" + arg_dict = { + "tz": account.timezone, + "app_id": app_model.id, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + } timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " AND created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " AND created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date ORDER BY date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'runs': i.runs - }) + response_data.append({"date": str(i.date), "runs": i.runs}) + + return jsonify({"data": response_data}) - return jsonify({ - 'data': response_data - }) class WorkflowDailyTerminalsStatistic(Resource): @setup_required @@ -86,56 +87,58 @@ def get(self, app_model): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' - SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count - FROM workflow_runs - WHERE app_id = :app_id - AND triggered_from = :triggered_from - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} + sql_query = """SELECT + DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + COUNT(DISTINCT workflow_runs.created_by) AS terminal_count +FROM + workflow_runs +WHERE + app_id = :app_id + AND triggered_from = :triggered_from""" + arg_dict = { + "tz": account.timezone, + "app_id": app_model.id, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + } timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " AND created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " AND created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date ORDER BY date" response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'terminal_count': i.terminal_count - }) + response_data.append({"date": str(i.date), "terminal_count": i.terminal_count}) + + return jsonify({"data": response_data}) - return jsonify({ - 'data': response_data - }) class WorkflowDailyTokenCostStatistic(Resource): @setup_required @@ -146,58 +149,63 @@ def get(self, app_model): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' - SELECT - date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, - SUM(workflow_runs.total_tokens) as token_count - FROM workflow_runs - WHERE app_id = :app_id - AND triggered_from = :triggered_from - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} + sql_query = """SELECT + DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + SUM(workflow_runs.total_tokens) AS token_count +FROM + workflow_runs +WHERE + app_id = :app_id + AND triggered_from = :triggered_from""" + arg_dict = { + "tz": account.timezone, + "app_id": app_model.id, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + } timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " AND created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " AND created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date ORDER BY date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'token_count': i.token_count, - }) + response_data.append( + { + "date": str(i.date), + "token_count": i.token_count, + } + ) + + return jsonify({"data": response_data}) - return jsonify({ - 'data': response_data - }) class WorkflowAverageAppInteractionStatistic(Resource): @setup_required @@ -208,71 +216,79 @@ def get(self, app_model): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """ - SELECT - AVG(sub.interactions) as interactions, - sub.date - FROM - (SELECT - date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, - c.created_by, - COUNT(c.id) AS interactions - FROM workflow_runs c - WHERE c.app_id = :app_id - AND c.triggered_from = :triggered_from - {{start}} - {{end}} - GROUP BY date, c.created_by) sub - GROUP BY sub.date - """ - arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} + sql_query = """SELECT + AVG(sub.interactions) AS interactions, + sub.date +FROM + ( + SELECT + DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + c.created_by, + COUNT(c.id) AS interactions + FROM + workflow_runs c + WHERE + c.app_id = :app_id + AND c.triggered_from = :triggered_from + {{start}} + {{end}} + GROUP BY + date, c.created_by + ) sub +GROUP BY + sub.date""" + arg_dict = { + "tz": account.timezone, + "app_id": app_model.id, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + } timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query = sql_query.replace('{{start}}', ' AND c.created_at >= :start') - arg_dict['start'] = start_datetime_utc + sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start") + arg_dict["start"] = start_datetime_utc else: - sql_query = sql_query.replace('{{start}}', '') + sql_query = sql_query.replace("{{start}}", "") - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query = sql_query.replace('{{end}}', ' and c.created_at < :end') - arg_dict['end'] = end_datetime_utc + sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end") + arg_dict["end"] = end_datetime_utc else: - sql_query = sql_query.replace('{{end}}', '') + sql_query = sql_query.replace("{{end}}", "") response_data = [] - + with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'interactions': float(i.interactions.quantize(Decimal('0.01'))) - }) - - return jsonify({ - 'data': response_data - }) - -api.add_resource(WorkflowDailyRunsStatistic, '/apps//workflow/statistics/daily-conversations') -api.add_resource(WorkflowDailyTerminalsStatistic, '/apps//workflow/statistics/daily-terminals') -api.add_resource(WorkflowDailyTokenCostStatistic, '/apps//workflow/statistics/token-costs') -api.add_resource(WorkflowAverageAppInteractionStatistic, '/apps//workflow/statistics/average-app-interactions') + response_data.append( + {"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))} + ) + + return jsonify({"data": response_data}) + + +api.add_resource(WorkflowDailyRunsStatistic, "/apps//workflow/statistics/daily-conversations") +api.add_resource(WorkflowDailyTerminalsStatistic, "/apps//workflow/statistics/daily-terminals") +api.add_resource(WorkflowDailyTokenCostStatistic, "/apps//workflow/statistics/token-costs") +api.add_resource( + WorkflowAverageAppInteractionStatistic, "/apps//workflow/statistics/average-app-interactions" +) diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index d61ab6d6ae8f28..c71ee8e5dfea1d 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -5,27 +5,27 @@ from controllers.console.app.error import AppNotFoundError from extensions.ext_database import db from libs.login import current_user -from models.model import App, AppMode +from models import App +from models.model import AppMode -def get_app_model(view: Optional[Callable] = None, *, - mode: Union[AppMode, list[AppMode]] = None): +def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None): def decorator(view_func): @wraps(view_func) def decorated_view(*args, **kwargs): - if not kwargs.get('app_id'): - raise ValueError('missing app_id in path parameters') + if not kwargs.get("app_id"): + raise ValueError("missing app_id in path parameters") - app_id = kwargs.get('app_id') + app_id = kwargs.get("app_id") app_id = str(app_id) - del kwargs['app_id'] + del kwargs["app_id"] - app_model = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app_model = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app_model: raise AppNotFoundError() @@ -44,9 +44,10 @@ def decorated_view(*args, **kwargs): mode_values = {m.value for m in modes} raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}") - kwargs['app_model'] = app_model + kwargs["app_model"] = app_model return view_func(*args, **kwargs) + return decorated_view if view is None: diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index 8efb55cdb64edc..be353cefac1a19 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,76 +1,77 @@ -import base64 import datetime -import secrets +from flask import request from flask_restful import Resource, reqparse from constants.languages import supported_language from controllers.console import api from controllers.console.error import AlreadyActivateError from extensions.ext_database import db -from libs.helper import email, str_len, timezone -from libs.password import hash_password, valid_password -from models.account import AccountStatus -from services.account_service import RegisterService +from libs.helper import StrLen, email, extract_remote_ip, timezone +from models.account import AccountStatus, Tenant +from services.account_service import AccountService, RegisterService class ActivateCheckApi(Resource): def get(self): parser = reqparse.RequestParser() - parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='args') - parser.add_argument('email', type=email, required=False, nullable=True, location='args') - parser.add_argument('token', type=str, required=True, nullable=False, location='args') + parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="args") + parser.add_argument("email", type=email, required=False, nullable=True, location="args") + parser.add_argument("token", type=str, required=True, nullable=False, location="args") args = parser.parse_args() - workspaceId = args['workspace_id'] - reg_email = args['email'] - token = args['token'] + workspaceId = args["workspace_id"] + reg_email = args["email"] + token = args["token"] invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) - - return {'is_valid': invitation is not None, 'workspace_name': invitation['tenant'].name if invitation else None} + if invitation: + data = invitation.get("data", {}) + tenant: Tenant = invitation.get("tenant", None) + workspace_name = tenant.name if tenant else None + workspace_id = tenant.id if tenant else None + invitee_email = data.get("email") if data else None + return { + "is_valid": invitation is not None, + "data": {"workspace_name": workspace_name, "workspace_id": workspace_id, "email": invitee_email}, + } + else: + return {"is_valid": False} class ActivateApi(Resource): def post(self): parser = reqparse.RequestParser() - parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='json') - parser.add_argument('email', type=email, required=False, nullable=True, location='json') - parser.add_argument('token', type=str, required=True, nullable=False, location='json') - parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json') - parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json') - parser.add_argument('interface_language', type=supported_language, required=True, nullable=False, - location='json') - parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json') + parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") + parser.add_argument("email", type=email, required=False, nullable=True, location="json") + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") + parser.add_argument( + "interface_language", type=supported_language, required=True, nullable=False, location="json" + ) + parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json") args = parser.parse_args() - invitation = RegisterService.get_invitation_if_token_valid(args['workspace_id'], args['email'], args['token']) + invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"]) if invitation is None: raise AlreadyActivateError() - RegisterService.revoke_token(args['workspace_id'], args['email'], args['token']) - - account = invitation['account'] - account.name = args['name'] + RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"]) - # generate password salt - salt = secrets.token_bytes(16) - base64_salt = base64.b64encode(salt).decode() + account = invitation["account"] + account.name = args["name"] - # encrypt password with salt - password_hashed = hash_password(args['password'], salt) - base64_password_hashed = base64.b64encode(password_hashed).decode() - account.password = base64_password_hashed - account.password_salt = base64_salt - account.interface_language = args['interface_language'] - account.timezone = args['timezone'] - account.interface_theme = 'light' + account.interface_language = args["interface_language"] + account.timezone = args["timezone"] + account.interface_theme = "light" account.status = AccountStatus.ACTIVE.value account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() - return {'result': 'success'} + token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) + + return {"result": "success", "data": token_pair.model_dump()} -api.add_resource(ActivateCheckApi, '/activate/check') -api.add_resource(ActivateApi, '/activate') +api.add_resource(ActivateCheckApi, "/activate/check") +api.add_resource(ActivateApi, "/activate") diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index f79b93b74f6df3..465c44e9b6dc2f 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -7,8 +7,7 @@ from libs.login import login_required from services.auth.api_key_auth_service import ApiKeyAuthService -from ..setup import setup_required -from ..wraps import account_initialization_required +from ..wraps import account_initialization_required, setup_required class ApiKeyAuthDataSource(Resource): @@ -19,18 +18,19 @@ def get(self): data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id) if data_source_api_key_bindings: return { - 'sources': [{ - 'id': data_source_api_key_binding.id, - 'category': data_source_api_key_binding.category, - 'provider': data_source_api_key_binding.provider, - 'disabled': data_source_api_key_binding.disabled, - 'created_at': int(data_source_api_key_binding.created_at.timestamp()), - 'updated_at': int(data_source_api_key_binding.updated_at.timestamp()), - } - for data_source_api_key_binding in - data_source_api_key_bindings] + "sources": [ + { + "id": data_source_api_key_binding.id, + "category": data_source_api_key_binding.category, + "provider": data_source_api_key_binding.provider, + "disabled": data_source_api_key_binding.disabled, + "created_at": int(data_source_api_key_binding.created_at.timestamp()), + "updated_at": int(data_source_api_key_binding.updated_at.timestamp()), + } + for data_source_api_key_binding in data_source_api_key_bindings + ] } - return {'sources': []} + return {"sources": []} class ApiKeyAuthDataSourceBinding(Resource): @@ -42,16 +42,16 @@ def post(self): if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('category', type=str, required=True, nullable=False, location='json') - parser.add_argument('provider', type=str, required=True, nullable=False, location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("category", type=str, required=True, nullable=False, location="json") + parser.add_argument("provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() ApiKeyAuthService.validate_api_key_auth_args(args) try: ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args) except Exception as e: raise ApiKeyAuthFailedError(str(e)) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ApiKeyAuthDataSourceBindingDelete(Resource): @@ -65,9 +65,9 @@ def delete(self, binding_id): ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(ApiKeyAuthDataSource, '/api-key-auth/data-source') -api.add_resource(ApiKeyAuthDataSourceBinding, '/api-key-auth/data-source/binding') -api.add_resource(ApiKeyAuthDataSourceBindingDelete, '/api-key-auth/data-source/') +api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source") +api.add_resource(ApiKeyAuthDataSourceBinding, "/api-key-auth/data-source/binding") +api.add_resource(ApiKeyAuthDataSourceBindingDelete, "/api-key-auth/data-source/") diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 45cfa9d7ebcb1b..3c3f45260a54b3 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -11,19 +11,18 @@ from libs.login import login_required from libs.oauth_data_source import NotionOAuth -from ..setup import setup_required -from ..wraps import account_initialization_required +from ..wraps import account_initialization_required, setup_required def get_oauth_providers(): with current_app.app_context(): - notion_oauth = NotionOAuth(client_id=dify_config.NOTION_CLIENT_ID, - client_secret=dify_config.NOTION_CLIENT_SECRET, - redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/data-source/callback/notion') + notion_oauth = NotionOAuth( + client_id=dify_config.NOTION_CLIENT_ID, + client_secret=dify_config.NOTION_CLIENT_SECRET, + redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/data-source/callback/notion", + ) - OAUTH_PROVIDERS = { - 'notion': notion_oauth - } + OAUTH_PROVIDERS = {"notion": notion_oauth} return OAUTH_PROVIDERS @@ -37,18 +36,16 @@ def get(self, provider: str): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) print(vars(oauth_provider)) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 - if dify_config.NOTION_INTEGRATION_TYPE == 'internal': + return {"error": "Invalid provider"}, 400 + if dify_config.NOTION_INTEGRATION_TYPE == "internal": internal_secret = dify_config.NOTION_INTERNAL_SECRET if not internal_secret: - return {'error': 'Internal secret is not set'}, + return ({"error": "Internal secret is not set"},) oauth_provider.save_internal_access_token(internal_secret) - return { 'data': '' } + return {"data": ""} else: auth_url = oauth_provider.get_authorization_url() - return { 'data': auth_url }, 200 - - + return {"data": auth_url}, 200 class OAuthDataSourceCallback(Resource): @@ -57,18 +54,18 @@ def get(self, provider: str): with current_app.app_context(): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 - if 'code' in request.args: - code = request.args.get('code') + return {"error": "Invalid provider"}, 400 + if "code" in request.args: + code = request.args.get("code") - return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}') - elif 'error' in request.args: - error = request.args.get('error') + return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}") + elif "error" in request.args: + error = request.args.get("error") - return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}') + return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}") else: - return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied') - + return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied") + class OAuthDataSourceBinding(Resource): def get(self, provider: str): @@ -76,17 +73,18 @@ def get(self, provider: str): with current_app.app_context(): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 - if 'code' in request.args: - code = request.args.get('code') + return {"error": "Invalid provider"}, 400 + if "code" in request.args: + code = request.args.get("code") try: oauth_provider.get_access_token(code) except requests.exceptions.HTTPError as e: logging.exception( - f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") - return {'error': 'OAuth data source process failed'}, 400 + f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}" + ) + return {"error": "OAuth data source process failed"}, 400 - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class OAuthDataSourceSync(Resource): @@ -100,18 +98,17 @@ def get(self, provider, binding_id): with current_app.app_context(): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 + return {"error": "Invalid provider"}, 400 try: oauth_provider.sync_data_source(binding_id) except requests.exceptions.HTTPError as e: - logging.exception( - f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") - return {'error': 'OAuth data source process failed'}, 400 + logging.exception(f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") + return {"error": "OAuth data source process failed"}, 400 - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(OAuthDataSource, '/oauth/data-source/') -api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/') -api.add_resource(OAuthDataSourceBinding, '/oauth/data-source/binding/') -api.add_resource(OAuthDataSourceSync, '/oauth/data-source///sync') +api.add_resource(OAuthDataSource, "/oauth/data-source/") +api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/") +api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/") +api.add_resource(OAuthDataSourceSync, "/oauth/data-source///sync") diff --git a/api/controllers/console/auth/error.py b/api/controllers/console/auth/error.py index 53dab3298fbff3..e6e30c3c0b015f 100644 --- a/api/controllers/console/auth/error.py +++ b/api/controllers/console/auth/error.py @@ -2,31 +2,54 @@ class ApiKeyAuthFailedError(BaseHTTPException): - error_code = 'auth_failed' + error_code = "auth_failed" description = "{message}" code = 500 class InvalidEmailError(BaseHTTPException): - error_code = 'invalid_email' + error_code = "invalid_email" description = "The email address is not valid." code = 400 class PasswordMismatchError(BaseHTTPException): - error_code = 'password_mismatch' + error_code = "password_mismatch" description = "The passwords do not match." code = 400 class InvalidTokenError(BaseHTTPException): - error_code = 'invalid_or_expired_token' + error_code = "invalid_or_expired_token" description = "The token is invalid or has expired." code = 400 class PasswordResetRateLimitExceededError(BaseHTTPException): - error_code = 'password_reset_rate_limit_exceeded' - description = "Password reset rate limit exceeded. Try again later." + error_code = "password_reset_rate_limit_exceeded" + description = "Too many password reset emails have been sent. Please try again in 1 minutes." code = 429 + +class EmailCodeError(BaseHTTPException): + error_code = "email_code_error" + description = "Email code is invalid or expired." + code = 400 + + +class EmailOrPasswordMismatchError(BaseHTTPException): + error_code = "email_or_password_mismatch" + description = "The email or password is mismatched." + code = 400 + + +class EmailPasswordLoginLimitError(BaseHTTPException): + error_code = "email_code_login_limit" + description = "Too many incorrect password attempts. Please try again later." + code = 429 + + +class EmailCodeLoginRateLimitExceededError(BaseHTTPException): + error_code = "email_code_login_rate_limit_exceeded" + description = "Too many login emails have been sent. Please try again in 5 minutes." + code = 429 diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index d78be770abd094..735edae5f63038 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -1,86 +1,100 @@ import base64 -import logging import secrets +from flask import request from flask_restful import Resource, reqparse +from constants.languages import languages from controllers.console import api from controllers.console.auth.error import ( + EmailCodeError, InvalidEmailError, InvalidTokenError, PasswordMismatchError, - PasswordResetRateLimitExceededError, ) -from controllers.console.setup import setup_required +from controllers.console.error import EmailSendIpLimitError, NotAllowedRegister +from controllers.console.wraps import setup_required +from events.tenant_event import tenant_was_created from extensions.ext_database import db -from libs.helper import email as email_validate +from libs.helper import email, extract_remote_ip from libs.password import hash_password, valid_password from models.account import Account -from services.account_service import AccountService -from services.errors.account import RateLimitExceededError +from services.account_service import AccountService, TenantService +from services.errors.workspace import WorkSpaceNotAllowedCreateError +from services.feature_service import FeatureService class ForgotPasswordSendEmailApi(Resource): - @setup_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('email', type=str, required=True, location='json') + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("language", type=str, required=False, location="json") args = parser.parse_args() - email = args['email'] - - if not email_validate(email): - raise InvalidEmailError() - - account = Account.query.filter_by(email=email).first() + ip_address = extract_remote_ip(request) + if AccountService.is_email_send_ip_limit(ip_address): + raise EmailSendIpLimitError() - if account: - try: - AccountService.send_reset_password_email(account=account) - except RateLimitExceededError: - logging.warning(f"Rate limit exceeded for email: {account.email}") - raise PasswordResetRateLimitExceededError() + if args["language"] is not None and args["language"] == "zh-Hans": + language = "zh-Hans" else: - # Return success to avoid revealing email registration status - logging.warning(f"Attempt to reset password for unregistered email: {email}") + language = "en-US" + + account = Account.query.filter_by(email=args["email"]).first() + token = None + if account is None: + if FeatureService.get_system_features().is_allow_register: + token = AccountService.send_reset_password_email(email=args["email"], language=language) + return {"result": "fail", "data": token, "code": "account_not_found"} + else: + raise NotAllowedRegister() + else: + token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language) - return {"result": "success"} + return {"result": "success", "data": token} class ForgotPasswordCheckApi(Resource): - @setup_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('token', type=str, required=True, nullable=False, location='json') + parser.add_argument("email", type=str, required=True, location="json") + parser.add_argument("code", type=str, required=True, location="json") + parser.add_argument("token", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - token = args['token'] - reset_data = AccountService.get_reset_password_data(token) + user_email = args["email"] - if reset_data is None: - return {'is_valid': False, 'email': None} - return {'is_valid': True, 'email': reset_data.get('email')} + token_data = AccountService.get_reset_password_data(args["token"]) + if token_data is None: + raise InvalidTokenError() + if user_email != token_data.get("email"): + raise InvalidEmailError() + + if args["code"] != token_data.get("code"): + raise EmailCodeError() + + return {"is_valid": True, "email": token_data.get("email")} -class ForgotPasswordResetApi(Resource): +class ForgotPasswordResetApi(Resource): @setup_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('token', type=str, required=True, nullable=False, location='json') - parser.add_argument('new_password', type=valid_password, required=True, nullable=False, location='json') - parser.add_argument('password_confirm', type=valid_password, required=True, nullable=False, location='json') + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") + parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") args = parser.parse_args() - new_password = args['new_password'] - password_confirm = args['password_confirm'] + new_password = args["new_password"] + password_confirm = args["password_confirm"] if str(new_password).strip() != str(password_confirm).strip(): raise PasswordMismatchError() - token = args['token'] + token = args["token"] reset_data = AccountService.get_reset_password_data(token) if reset_data is None: @@ -94,14 +108,31 @@ def post(self): password_hashed = hash_password(new_password, salt) base64_password_hashed = base64.b64encode(password_hashed).decode() - account = Account.query.filter_by(email=reset_data.get('email')).first() - account.password = base64_password_hashed - account.password_salt = base64_salt - db.session.commit() + account = Account.query.filter_by(email=reset_data.get("email")).first() + if account: + account.password = base64_password_hashed + account.password_salt = base64_salt + db.session.commit() + tenant = TenantService.get_join_tenants(account) + if not tenant and not FeatureService.get_system_features().is_allow_create_workspace: + tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + TenantService.create_tenant_member(tenant, account, role="owner") + account.current_tenant = tenant + tenant_was_created.send(tenant) + else: + try: + account = AccountService.create_account_and_tenant( + email=reset_data.get("email"), + name=reset_data.get("email"), + password=password_confirm, + interface_language=languages[0], + ) + except WorkSpaceNotAllowedCreateError: + pass - return {'result': 'success'} + return {"result": "success"} -api.add_resource(ForgotPasswordSendEmailApi, '/forgot-password') -api.add_resource(ForgotPasswordCheckApi, '/forgot-password/validity') -api.add_resource(ForgotPasswordResetApi, '/forgot-password/resets') +api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password") +api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity") +api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets") diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index c135ece67ef86c..e2e8f849208171 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -5,12 +5,29 @@ from flask_restful import Resource, reqparse import services +from constants.languages import languages from controllers.console import api -from controllers.console.setup import setup_required -from libs.helper import email, get_remote_ip +from controllers.console.auth.error import ( + EmailCodeError, + EmailOrPasswordMismatchError, + EmailPasswordLoginLimitError, + InvalidEmailError, + InvalidTokenError, +) +from controllers.console.error import ( + AccountBannedError, + EmailSendIpLimitError, + NotAllowedCreateWorkspace, + NotAllowedRegister, +) +from controllers.console.wraps import setup_required +from events.tenant_event import tenant_was_created +from libs.helper import email, extract_remote_ip from libs.password import valid_password from models.account import Account -from services.account_service import AccountService, TenantService +from services.account_service import AccountService, RegisterService, TenantService +from services.errors.workspace import WorkSpaceNotAllowedCreateError +from services.feature_service import FeatureService class LoginApi(Resource): @@ -20,89 +37,186 @@ class LoginApi(Resource): def post(self): """Authenticate user and login.""" parser = reqparse.RequestParser() - parser.add_argument('email', type=email, required=True, location='json') - parser.add_argument('password', type=valid_password, required=True, location='json') - parser.add_argument('remember_me', type=bool, required=False, default=False, location='json') + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("password", type=valid_password, required=True, location="json") + parser.add_argument("remember_me", type=bool, required=False, default=False, location="json") + parser.add_argument("invite_token", type=str, required=False, default=None, location="json") + parser.add_argument("language", type=str, required=False, default="en-US", location="json") args = parser.parse_args() - # todo: Verify the recaptcha + is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"]) + if is_login_error_rate_limit: + raise EmailPasswordLoginLimitError() - try: - account = AccountService.authenticate(args['email'], args['password']) - except services.errors.account.AccountLoginError as e: - return {'code': 'unauthorized', 'message': str(e)}, 401 + invitation = args["invite_token"] + if invitation: + invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation) + + if args["language"] is not None and args["language"] == "zh-Hans": + language = "zh-Hans" + else: + language = "en-US" + try: + if invitation: + data = invitation.get("data", {}) + invitee_email = data.get("email") if data else None + if invitee_email != args["email"]: + raise InvalidEmailError() + account = AccountService.authenticate(args["email"], args["password"], args["invite_token"]) + else: + account = AccountService.authenticate(args["email"], args["password"]) + except services.errors.account.AccountLoginError: + raise AccountBannedError() + except services.errors.account.AccountPasswordError: + AccountService.add_login_error_rate_limit(args["email"]) + raise EmailOrPasswordMismatchError() + except services.errors.account.AccountNotFoundError: + if FeatureService.get_system_features().is_allow_register: + token = AccountService.send_reset_password_email(email=args["email"], language=language) + return {"result": "fail", "data": token, "code": "account_not_found"} + else: + raise NotAllowedRegister() # SELF_HOSTED only have one workspace tenants = TenantService.get_join_tenants(account) if len(tenants) == 0: - return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'} + return { + "result": "fail", + "data": "workspace not found, please contact system admin to invite you to join in a workspace", + } - token = AccountService.login(account, ip_address=get_remote_ip(request)) - - return {'result': 'success', 'data': token} + token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) + AccountService.reset_login_error_rate_limit(args["email"]) + return {"result": "success", "data": token_pair.model_dump()} class LogoutApi(Resource): - @setup_required def get(self): account = cast(Account, flask_login.current_user) - token = request.headers.get('Authorization', '').split(' ')[1] - AccountService.logout(account=account, token=token) + if isinstance(account, flask_login.AnonymousUserMixin): + return {"result": "success"} + AccountService.logout(account=account) flask_login.logout_user() - return {'result': 'success'} + return {"result": "success"} -class ResetPasswordApi(Resource): +class ResetPasswordSendEmailApi(Resource): @setup_required - def get(self): - # parser = reqparse.RequestParser() - # parser.add_argument('email', type=email, required=True, location='json') - # args = parser.parse_args() - - # import mailchimp_transactional as MailchimpTransactional - # from mailchimp_transactional.api_client import ApiClientError - - # account = {'email': args['email']} - # account = AccountService.get_by_email(args['email']) - # if account is None: - # raise ValueError('Email not found') - # new_password = AccountService.generate_password() - # AccountService.update_password(account, new_password) - - # todo: Send email - # MAILCHIMP_API_KEY = dify_config.MAILCHIMP_TRANSACTIONAL_API_KEY - # mailchimp = MailchimpTransactional(MAILCHIMP_API_KEY) - - # message = { - # 'from_email': 'noreply@example.com', - # 'to': [{'email': account['email']}], - # 'subject': 'Reset your Dify password', - # 'html': """ - #

Dear User,

- #

The Dify team has generated a new password for you, details as follows:

- #

{new_password}

- #

Please change your password to log in as soon as possible.

- #

Regards,

- #

The Dify Team

- # """ - # } - - # response = mailchimp.messages.send({ - # 'message': message, - # # required for transactional email - # ' settings': { - # 'sandbox_mode': dify_config.MAILCHIMP_SANDBOX_MODE, - # }, - # }) - - # Check if MSG was sent - # if response.status_code != 200: - # # handle error - # pass - - return {'result': 'success'} - - -api.add_resource(LoginApi, '/login') -api.add_resource(LogoutApi, '/logout') + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("language", type=str, required=False, location="json") + args = parser.parse_args() + + if args["language"] is not None and args["language"] == "zh-Hans": + language = "zh-Hans" + else: + language = "en-US" + + account = AccountService.get_user_through_email(args["email"]) + if account is None: + if FeatureService.get_system_features().is_allow_register: + token = AccountService.send_reset_password_email(email=args["email"], language=language) + else: + raise NotAllowedRegister() + else: + token = AccountService.send_reset_password_email(account=account, language=language) + + return {"result": "success", "data": token} + + +class EmailCodeLoginSendEmailApi(Resource): + @setup_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("language", type=str, required=False, location="json") + args = parser.parse_args() + + ip_address = extract_remote_ip(request) + if AccountService.is_email_send_ip_limit(ip_address): + raise EmailSendIpLimitError() + + if args["language"] is not None and args["language"] == "zh-Hans": + language = "zh-Hans" + else: + language = "en-US" + + account = AccountService.get_user_through_email(args["email"]) + if account is None: + if FeatureService.get_system_features().is_allow_register: + token = AccountService.send_email_code_login_email(email=args["email"], language=language) + else: + raise NotAllowedRegister() + else: + token = AccountService.send_email_code_login_email(account=account, language=language) + + return {"result": "success", "data": token} + + +class EmailCodeLoginApi(Resource): + @setup_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=str, required=True, location="json") + parser.add_argument("code", type=str, required=True, location="json") + parser.add_argument("token", type=str, required=True, location="json") + args = parser.parse_args() + + user_email = args["email"] + + token_data = AccountService.get_email_code_login_data(args["token"]) + if token_data is None: + raise InvalidTokenError() + + if token_data["email"] != args["email"]: + raise InvalidEmailError() + + if token_data["code"] != args["code"]: + raise EmailCodeError() + + AccountService.revoke_email_code_login_token(args["token"]) + account = AccountService.get_user_through_email(user_email) + if account: + tenant = TenantService.get_join_tenants(account) + if not tenant: + if not FeatureService.get_system_features().is_allow_create_workspace: + raise NotAllowedCreateWorkspace() + else: + tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + TenantService.create_tenant_member(tenant, account, role="owner") + account.current_tenant = tenant + tenant_was_created.send(tenant) + + if account is None: + try: + account = AccountService.create_account_and_tenant( + email=user_email, name=user_email, interface_language=languages[0] + ) + except WorkSpaceNotAllowedCreateError: + return NotAllowedCreateWorkspace() + token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) + AccountService.reset_login_error_rate_limit(args["email"]) + return {"result": "success", "data": token_pair.model_dump()} + + +class RefreshTokenApi(Resource): + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("refresh_token", type=str, required=True, location="json") + args = parser.parse_args() + + try: + new_token_pair = AccountService.refresh_token(args["refresh_token"]) + return {"result": "success", "data": new_token_pair.model_dump()} + except Exception as e: + return {"result": "fail", "data": str(e)}, 401 + + +api.add_resource(LoginApi, "/login") +api.add_resource(LogoutApi, "/logout") +api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login") +api.add_resource(EmailCodeLoginApi, "/email-code-login/validity") +api.add_resource(ResetPasswordSendEmailApi, "/reset-password") +api.add_resource(RefreshTokenApi, "/refresh-token") diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 4a651bfe7b009e..d27e3353c90165 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -5,14 +5,20 @@ import requests from flask import current_app, redirect, request from flask_restful import Resource +from werkzeug.exceptions import Unauthorized from configs import dify_config from constants.languages import languages +from events.tenant_event import tenant_was_created from extensions.ext_database import db -from libs.helper import get_remote_ip +from libs.helper import extract_remote_ip from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo -from models.account import Account, AccountStatus +from models import Account +from models.account import AccountStatus from services.account_service import AccountService, RegisterService, TenantService +from services.errors.account import AccountNotFoundError +from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError +from services.feature_service import FeatureService from .. import api @@ -25,7 +31,7 @@ def get_oauth_providers(): github_oauth = GitHubOAuth( client_id=dify_config.GITHUB_CLIENT_ID, client_secret=dify_config.GITHUB_CLIENT_SECRET, - redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/github', + redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/github", ) if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET: google_oauth = None @@ -33,23 +39,24 @@ def get_oauth_providers(): google_oauth = GoogleOAuth( client_id=dify_config.GOOGLE_CLIENT_ID, client_secret=dify_config.GOOGLE_CLIENT_SECRET, - redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/google', + redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google", ) - OAUTH_PROVIDERS = {'github': github_oauth, 'google': google_oauth} + OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth} return OAUTH_PROVIDERS class OAuthLogin(Resource): def get(self, provider: str): + invite_token = request.args.get("invite_token") or None OAUTH_PROVIDERS = get_oauth_providers() with current_app.app_context(): oauth_provider = OAUTH_PROVIDERS.get(provider) print(vars(oauth_provider)) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 + return {"error": "Invalid provider"}, 400 - auth_url = oauth_provider.get_authorization_url() + auth_url = oauth_provider.get_authorization_url(invite_token=invite_token) return redirect(auth_url) @@ -59,31 +66,67 @@ def get(self, provider: str): with current_app.app_context(): oauth_provider = OAUTH_PROVIDERS.get(provider) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 + return {"error": "Invalid provider"}, 400 + + code = request.args.get("code") + state = request.args.get("state") + invite_token = None + if state: + invite_token = state - code = request.args.get('code') try: token = oauth_provider.get_access_token(code) user_info = oauth_provider.get_user_info(token) except requests.exceptions.HTTPError as e: - logging.exception(f'An error occurred during the OAuth process with {provider}: {e.response.text}') - return {'error': 'OAuth process failed'}, 400 + logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}") + return {"error": "OAuth process failed"}, 400 + + if invite_token and RegisterService.is_valid_invite_token(invite_token): + invitation = RegisterService._get_invitation_by_token(token=invite_token) + if invitation: + invitation_email = invitation.get("email", None) + if invitation_email != user_info.email: + return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.") + + return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}") + + try: + account = _generate_account(provider, user_info) + except AccountNotFoundError: + return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.") + except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError): + return redirect( + f"{dify_config.CONSOLE_WEB_URL}/signin" + "?message=Workspace not found, please contact system admin to invite you to join in a workspace." + ) - account = _generate_account(provider, user_info) # Check account status - if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: - return {'error': 'Account is banned or closed.'}, 403 + if account.status == AccountStatus.BANNED.value: + return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.") if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() - TenantService.create_owner_tenant_if_not_exist(account) + try: + TenantService.create_owner_tenant_if_not_exist(account) + except Unauthorized: + return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.") + except WorkSpaceNotAllowedCreateError: + return redirect( + f"{dify_config.CONSOLE_WEB_URL}/signin" + "?message=Workspace not found, please contact system admin to invite you to join in a workspace." + ) - token = AccountService.login(account, ip_address=get_remote_ip(request)) + token_pair = AccountService.login( + account=account, + ip_address=extract_remote_ip(request), + ) - return redirect(f'{dify_config.CONSOLE_WEB_URL}?console_token={token}') + return redirect( + f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}" + ) def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: @@ -99,9 +142,21 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): # Get account by openid or email. account = _get_account_by_openid_or_email(provider, user_info) + if account: + tenant = TenantService.get_join_tenants(account) + if not tenant: + if not FeatureService.get_system_features().is_allow_create_workspace: + raise WorkSpaceNotAllowedCreateError() + else: + tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + TenantService.create_tenant_member(tenant, account, role="owner") + account.current_tenant = tenant + tenant_was_created.send(tenant) + if not account: - # Create account - account_name = user_info.name if user_info.name else 'Dify' + if not FeatureService.get_system_features().is_allow_register: + raise AccountNotFoundError() + account_name = user_info.name or "Dify" account = RegisterService.register( email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider ) @@ -121,5 +176,5 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): return account -api.add_resource(OAuthLogin, '/oauth/login/') -api.add_resource(OAuthCallback, '/oauth/authorize/') +api.add_resource(OAuthLogin, "/oauth/login/") +api.add_resource(OAuthCallback, "/oauth/authorize/") diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 72a6129efa3e4d..4b0c82ae6c90c2 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -2,35 +2,30 @@ from flask_restful import Resource, reqparse from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, only_edition_cloud +from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required from libs.login import login_required from services.billing_service import BillingService class Subscription(Resource): - @setup_required @login_required @account_initialization_required @only_edition_cloud def get(self): - parser = reqparse.RequestParser() - parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team']) - parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year']) + parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"]) + parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) args = parser.parse_args() BillingService.is_tenant_owner_or_admin(current_user) - return BillingService.get_subscription(args['plan'], - args['interval'], - current_user.email, - current_user.current_tenant_id) + return BillingService.get_subscription( + args["plan"], args["interval"], current_user.email, current_user.current_tenant_id + ) class Invoices(Resource): - @setup_required @login_required @account_initialization_required @@ -40,5 +35,5 @@ def get(self): return BillingService.get_invoices(current_user.email, current_user.current_tenant_id) -api.add_resource(Subscription, '/billing/subscription') -api.add_resource(Invoices, '/billing/invoices') +api.add_resource(Subscription, "/billing/subscription") +api.add_resource(Invoices, "/billing/invoices") diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 0ca0f0a85653dc..ef1e87905a1b38 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -7,34 +7,35 @@ from werkzeug.exceptions import NotFound from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.indexing_runner import IndexingRunner from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.notion_extractor import NotionExtractor from extensions.ext_database import db from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields from libs.login import login_required -from models.dataset import Document -from models.source import DataSourceOauthBinding +from models import DataSourceOauthBinding, Document from services.dataset_service import DatasetService, DocumentService from tasks.document_indexing_sync_task import document_indexing_sync_task class DataSourceApi(Resource): - @setup_required @login_required @account_initialization_required @marshal_with(integrate_list_fields) def get(self): # get workspace data source integrates - data_source_integrates = db.session.query(DataSourceOauthBinding).filter( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.disabled == False - ).all() + data_source_integrates = ( + db.session.query(DataSourceOauthBinding) + .filter( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.disabled == False, + ) + .all() + ) - base_url = request.url_root.rstrip('/') + base_url = request.url_root.rstrip("/") data_source_oauth_base_path = "/console/api/oauth/data-source" providers = ["notion"] @@ -44,26 +45,30 @@ def get(self): existing_integrates = filter(lambda item: item.provider == provider, data_source_integrates) if existing_integrates: for existing_integrate in list(existing_integrates): - integrate_data.append({ - 'id': existing_integrate.id, - 'provider': provider, - 'created_at': existing_integrate.created_at, - 'is_bound': True, - 'disabled': existing_integrate.disabled, - 'source_info': existing_integrate.source_info, - 'link': f'{base_url}{data_source_oauth_base_path}/{provider}' - }) + integrate_data.append( + { + "id": existing_integrate.id, + "provider": provider, + "created_at": existing_integrate.created_at, + "is_bound": True, + "disabled": existing_integrate.disabled, + "source_info": existing_integrate.source_info, + "link": f"{base_url}{data_source_oauth_base_path}/{provider}", + } + ) else: - integrate_data.append({ - 'id': None, - 'provider': provider, - 'created_at': None, - 'source_info': None, - 'is_bound': False, - 'disabled': None, - 'link': f'{base_url}{data_source_oauth_base_path}/{provider}' - }) - return {'data': integrate_data}, 200 + integrate_data.append( + { + "id": None, + "provider": provider, + "created_at": None, + "source_info": None, + "is_bound": False, + "disabled": None, + "link": f"{base_url}{data_source_oauth_base_path}/{provider}", + } + ) + return {"data": integrate_data}, 200 @setup_required @login_required @@ -71,92 +76,82 @@ def get(self): def patch(self, binding_id, action): binding_id = str(binding_id) action = str(action) - data_source_binding = DataSourceOauthBinding.query.filter_by( - id=binding_id - ).first() + data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first() if data_source_binding is None: - raise NotFound('Data source binding not found.') + raise NotFound("Data source binding not found.") # enable binding - if action == 'enable': + if action == "enable": if data_source_binding.disabled: data_source_binding.disabled = False data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(data_source_binding) db.session.commit() else: - raise ValueError('Data source is not disabled.') + raise ValueError("Data source is not disabled.") # disable binding - if action == 'disable': + if action == "disable": if not data_source_binding.disabled: data_source_binding.disabled = True data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(data_source_binding) db.session.commit() else: - raise ValueError('Data source is disabled.') - return {'result': 'success'}, 200 + raise ValueError("Data source is disabled.") + return {"result": "success"}, 200 class DataSourceNotionListApi(Resource): - @setup_required @login_required @account_initialization_required @marshal_with(integrate_notion_info_list_fields) def get(self): - dataset_id = request.args.get('dataset_id', default=None, type=str) + dataset_id = request.args.get("dataset_id", default=None, type=str) exist_page_ids = [] # import notion in the exist dataset if dataset_id: dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') - if dataset.data_source_type != 'notion_import': - raise ValueError('Dataset is not notion type.') + raise NotFound("Dataset not found.") + if dataset.data_source_type != "notion_import": + raise ValueError("Dataset is not notion type.") documents = Document.query.filter_by( dataset_id=dataset_id, tenant_id=current_user.current_tenant_id, - data_source_type='notion_import', - enabled=True + data_source_type="notion_import", + enabled=True, ).all() if documents: for document in documents: data_source_info = json.loads(document.data_source_info) - exist_page_ids.append(data_source_info['notion_page_id']) + exist_page_ids.append(data_source_info["notion_page_id"]) # get all authorized pages data_source_bindings = DataSourceOauthBinding.query.filter_by( - tenant_id=current_user.current_tenant_id, - provider='notion', - disabled=False + tenant_id=current_user.current_tenant_id, provider="notion", disabled=False ).all() if not data_source_bindings: - return { - 'notion_info': [] - }, 200 + return {"notion_info": []}, 200 pre_import_info_list = [] for data_source_binding in data_source_bindings: source_info = data_source_binding.source_info - pages = source_info['pages'] + pages = source_info["pages"] # Filter out already bound pages for page in pages: - if page['page_id'] in exist_page_ids: - page['is_bound'] = True + if page["page_id"] in exist_page_ids: + page["is_bound"] = True else: - page['is_bound'] = False + page["is_bound"] = False pre_import_info = { - 'workspace_name': source_info['workspace_name'], - 'workspace_icon': source_info['workspace_icon'], - 'workspace_id': source_info['workspace_id'], - 'pages': pages, + "workspace_name": source_info["workspace_name"], + "workspace_icon": source_info["workspace_icon"], + "workspace_id": source_info["workspace_id"], + "pages": pages, } pre_import_info_list.append(pre_import_info) - return { - 'notion_info': pre_import_info_list - }, 200 + return {"notion_info": pre_import_info_list}, 200 class DataSourceNotionApi(Resource): - @setup_required @login_required @account_initialization_required @@ -166,64 +161,67 @@ def get(self, workspace_id, page_id, page_type): data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', ) ).first() if not data_source_binding: - raise NotFound('Data source binding not found.') + raise NotFound("Data source binding not found.") extractor = NotionExtractor( notion_workspace_id=workspace_id, notion_obj_id=page_id, notion_page_type=page_type, notion_access_token=data_source_binding.access_token, - tenant_id=current_user.current_tenant_id + tenant_id=current_user.current_tenant_id, ) text_docs = extractor.extract() - return { - 'content': "\n".join([doc.page_content for doc in text_docs]) - }, 200 + return {"content": "\n".join([doc.page_content for doc in text_docs])}, 200 @setup_required @login_required @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json') - parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') + parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json") + parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) args = parser.parse_args() # validate args DocumentService.estimate_args_validate(args) - notion_info_list = args['notion_info_list'] + notion_info_list = args["notion_info_list"] extract_settings = [] for notion_info in notion_info_list: - workspace_id = notion_info['workspace_id'] - for page in notion_info['pages']: + workspace_id = notion_info["workspace_id"] + for page in notion_info["pages"]: extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ "notion_workspace_id": workspace_id, - "notion_obj_id": page['page_id'], - "notion_page_type": page['type'], - "tenant_id": current_user.current_tenant_id + "notion_obj_id": page["page_id"], + "notion_page_type": page["type"], + "tenant_id": current_user.current_tenant_id, }, - document_model=args['doc_form'] + document_model=args["doc_form"], ) extract_settings.append(extract_setting) indexing_runner = IndexingRunner() - response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, - args['process_rule'], args['doc_form'], - args['doc_language']) + response = indexing_runner.indexing_estimate( + current_user.current_tenant_id, + extract_settings, + args["process_rule"], + args["doc_form"], + args["doc_language"], + ) return response, 200 class DataSourceNotionDatasetSyncApi(Resource): - @setup_required @login_required @account_initialization_required @@ -240,7 +238,6 @@ def get(self, dataset_id): class DataSourceNotionDocumentSyncApi(Resource): - @setup_required @login_required @account_initialization_required @@ -258,10 +255,14 @@ def get(self, dataset_id, document_id): return 200 -api.add_resource(DataSourceApi, '/data-source/integrates', '/data-source/integrates//') -api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages') -api.add_resource(DataSourceNotionApi, - '/notion/workspaces//pages///preview', - '/datasets/notion-indexing-estimate') -api.add_resource(DataSourceNotionDatasetSyncApi, '/datasets//notion/sync') -api.add_resource(DataSourceNotionDocumentSyncApi, '/datasets//documents//notion/sync') +api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates//") +api.add_resource(DataSourceNotionListApi, "/notion/pre-import/pages") +api.add_resource( + DataSourceNotionApi, + "/notion/workspaces//pages///preview", + "/datasets/notion-indexing-estimate", +) +api.add_resource(DataSourceNotionDatasetSyncApi, "/datasets//notion/sync") +api.add_resource( + DataSourceNotionDocumentSyncApi, "/datasets//documents//notion/sync" +) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index a5bc2dd86a905d..82163a32eebd4d 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -10,66 +10,60 @@ from controllers.console.apikey import api_key_fields, api_key_list from controllers.console.app.error import ProviderNotInitializeError from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.indexing_runner import IndexingRunner from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager from core.rag.datasource.vdb.vector_type import VectorType from core.rag.extractor.entity.extract_setting import ExtractSetting -from core.rag.retrieval.retrival_methods import RetrievalMethod +from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from fields.app_fields import related_app_list from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields from fields.document_fields import document_status_fields from libs.login import login_required -from models.dataset import Dataset, Document, DocumentSegment -from models.model import ApiToken, UploadFile +from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile +from models.dataset import DatasetPermissionEnum from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService def _validate_name(name): if not name or len(name) < 1 or len(name) > 40: - raise ValueError('Name must be between 1 to 40 characters.') + raise ValueError("Name must be between 1 to 40 characters.") return name def _validate_description_length(description): if len(description) > 400: - raise ValueError('Description cannot exceed 400 characters.') + raise ValueError("Description cannot exceed 400 characters.") return description class DatasetListApi(Resource): - @setup_required @login_required @account_initialization_required def get(self): - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - ids = request.args.getlist('ids') - provider = request.args.get('provider', default="vendor") - search = request.args.get('keyword', default=None, type=str) - tag_ids = request.args.getlist('tag_ids') + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + ids = request.args.getlist("ids") + # provider = request.args.get("provider", default="vendor") + search = request.args.get("keyword", default=None, type=str) + tag_ids = request.args.getlist("tag_ids") if ids: datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id) else: - datasets, total = DatasetService.get_datasets(page, limit, provider, - current_user.current_tenant_id, current_user, search, tag_ids) + datasets, total = DatasetService.get_datasets( + page, limit, current_user.current_tenant_id, current_user, search, tag_ids + ) # check embedding setting provider_manager = ProviderManager() - configurations = provider_manager.get_configurations( - tenant_id=current_user.current_tenant_id - ) + configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) - embedding_models = configurations.get_models( - model_type=ModelType.TEXT_EMBEDDING, - only_active=True - ) + embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) model_names = [] for embedding_model in embedding_models: @@ -77,28 +71,22 @@ def get(self): data = marshal(datasets, dataset_detail_fields) for item in data: - if item['indexing_technique'] == 'high_quality': + if item["indexing_technique"] == "high_quality": item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" if item_model in model_names: - item['embedding_available'] = True + item["embedding_available"] = True else: - item['embedding_available'] = False + item["embedding_available"] = False else: - item['embedding_available'] = True + item["embedding_available"] = True - if item.get('permission') == 'partial_members': - part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item['id']) - item.update({'partial_member_list': part_users_list}) + if item.get("permission") == "partial_members": + part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item["id"]) + item.update({"partial_member_list": part_users_list}) else: - item.update({'partial_member_list': []}) + item.update({"partial_member_list": []}) - response = { - 'data': data, - 'has_more': len(datasets) == limit, - 'limit': limit, - 'total': total, - 'page': page - } + response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} return response, 200 @setup_required @@ -106,13 +94,48 @@ def get(self): @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, required=True, - help='type is required. Name must be between 1 to 40 characters.', - type=_validate_name) - parser.add_argument('indexing_technique', type=str, location='json', - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help='Invalid indexing technique.') + parser.add_argument( + "name", + nullable=False, + required=True, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument( + "description", + type=str, + nullable=True, + required=False, + default="", + ) + parser.add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + help="Invalid indexing technique.", + ) + parser.add_argument( + "external_knowledge_api_id", + type=str, + nullable=True, + required=False, + ) + parser.add_argument( + "provider", + type=str, + nullable=True, + choices=Dataset.PROVIDER_LIST, + required=False, + default="vendor", + ) + parser.add_argument( + "external_knowledge_id", + type=str, + nullable=True, + required=False, + ) args = parser.parse_args() # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator @@ -122,9 +145,14 @@ def post(self): try: dataset = DatasetService.create_empty_dataset( tenant_id=current_user.current_tenant_id, - name=args['name'], - indexing_technique=args['indexing_technique'], - account=current_user + name=args["name"], + description=args["description"], + indexing_technique=args["indexing_technique"], + account=current_user, + permission=DatasetPermissionEnum.ONLY_ME, + provider=args["provider"], + external_knowledge_api_id=args["external_knowledge_api_id"], + external_knowledge_id=args["external_knowledge_id"], ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() @@ -142,42 +170,36 @@ def get(self, dataset_id): if dataset is None: raise NotFound("Dataset not found.") try: - DatasetService.check_dataset_permission( - dataset, current_user) + DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) data = marshal(dataset, dataset_detail_fields) - if data.get('permission') == 'partial_members': + if data.get("permission") == "partial_members": part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) - data.update({'partial_member_list': part_users_list}) + data.update({"partial_member_list": part_users_list}) # check embedding setting provider_manager = ProviderManager() - configurations = provider_manager.get_configurations( - tenant_id=current_user.current_tenant_id - ) + configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) - embedding_models = configurations.get_models( - model_type=ModelType.TEXT_EMBEDDING, - only_active=True - ) + embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) model_names = [] for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - if data['indexing_technique'] == 'high_quality': + if data["indexing_technique"] == "high_quality": item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" if item_model in model_names: - data['embedding_available'] = True + data["embedding_available"] = True else: - data['embedding_available'] = False + data["embedding_available"] = False else: - data['embedding_available'] = True + data["embedding_available"] = True - if data.get('permission') == 'partial_members': + if data.get("permission") == "partial_members": part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) - data.update({'partial_member_list': part_users_list}) + data.update({"partial_member_list": part_users_list}) return data, 200 @@ -191,42 +213,76 @@ def patch(self, dataset_id): raise NotFound("Dataset not found.") parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, - help='type is required. Name must be between 1 to 40 characters.', - type=_validate_name) - parser.add_argument('description', - location='json', store_missing=False, - type=_validate_description_length) - parser.add_argument('indexing_technique', type=str, location='json', - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help='Invalid indexing technique.') - parser.add_argument('permission', type=str, location='json', choices=( - 'only_me', 'all_team_members', 'partial_members'), help='Invalid permission.' - ) - parser.add_argument('embedding_model', type=str, - location='json', help='Invalid embedding model.') - parser.add_argument('embedding_model_provider', type=str, - location='json', help='Invalid embedding model provider.') - parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.') - parser.add_argument('partial_member_list', type=list, location='json', help='Invalid parent user list.') + parser.add_argument( + "name", + nullable=False, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length) + parser.add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + help="Invalid indexing technique.", + ) + parser.add_argument( + "permission", + type=str, + location="json", + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + help="Invalid permission.", + ) + parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") + parser.add_argument( + "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider." + ) + parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") + parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") + + parser.add_argument( + "external_retrieval_model", + type=dict, + required=False, + nullable=True, + location="json", + help="Invalid external retrieval model.", + ) + + parser.add_argument( + "external_knowledge_id", + type=str, + required=False, + nullable=True, + location="json", + help="Invalid external knowledge id.", + ) + + parser.add_argument( + "external_knowledge_api_id", + type=str, + required=False, + nullable=True, + location="json", + help="Invalid external knowledge api id.", + ) args = parser.parse_args() data = request.get_json() # check embedding model setting - if data.get('indexing_technique') == 'high_quality': - DatasetService.check_embedding_model_setting(dataset.tenant_id, - data.get('embedding_model_provider'), - data.get('embedding_model') - ) + if data.get("indexing_technique") == "high_quality": + DatasetService.check_embedding_model_setting( + dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") + ) # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator DatasetPermissionService.check_permission( - current_user, dataset, data.get('permission'), data.get('partial_member_list') + current_user, dataset, data.get("permission"), data.get("partial_member_list") ) - dataset = DatasetService.update_dataset( - dataset_id_str, args, current_user) + dataset = DatasetService.update_dataset(dataset_id_str, args, current_user) if dataset is None: raise NotFound("Dataset not found.") @@ -234,16 +290,19 @@ def patch(self, dataset_id): result_data = marshal(dataset, dataset_detail_fields) tenant_id = current_user.current_tenant_id - if data.get('partial_member_list') and data.get('permission') == 'partial_members': + if data.get("partial_member_list") and data.get("permission") == "partial_members": DatasetPermissionService.update_partial_member_list( - tenant_id, dataset_id_str, data.get('partial_member_list') + tenant_id, dataset_id_str, data.get("partial_member_list") ) # clear partial member list when permission is only_me or all_team_members - elif data.get('permission') == 'only_me' or data.get('permission') == 'all_team_members': + elif ( + data.get("permission") == DatasetPermissionEnum.ONLY_ME + or data.get("permission") == DatasetPermissionEnum.ALL_TEAM + ): DatasetPermissionService.clear_partial_member_list(dataset_id_str) partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) - result_data.update({'partial_member_list': partial_member_list}) + result_data.update({"partial_member_list": partial_member_list}) return result_data, 200 @@ -260,12 +319,13 @@ def delete(self, dataset_id): try: if DatasetService.delete_dataset(dataset_id_str, current_user): DatasetPermissionService.clear_partial_member_list(dataset_id_str) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 else: raise NotFound("Dataset not found.") except services.errors.dataset.DatasetInUseError: raise DatasetInUseError() + class DatasetUseCheckApi(Resource): @setup_required @login_required @@ -274,10 +334,10 @@ def get(self, dataset_id): dataset_id_str = str(dataset_id) dataset_is_using = DatasetService.dataset_use_check(dataset_id_str) - return {'is_using': dataset_is_using}, 200 + return {"is_using": dataset_is_using}, 200 -class DatasetQueryApi(Resource): +class DatasetQueryApi(Resource): @setup_required @login_required @account_initialization_required @@ -292,51 +352,53 @@ def get(self, dataset_id): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) - dataset_queries, total = DatasetService.get_dataset_queries( - dataset_id=dataset.id, - page=page, - per_page=limit - ) + dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit) response = { - 'data': marshal(dataset_queries, dataset_query_detail_fields), - 'has_more': len(dataset_queries) == limit, - 'limit': limit, - 'total': total, - 'page': page + "data": marshal(dataset_queries, dataset_query_detail_fields), + "has_more": len(dataset_queries) == limit, + "limit": limit, + "total": total, + "page": page, } return response, 200 class DatasetIndexingEstimateApi(Resource): - @setup_required @login_required @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') - parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') - parser.add_argument('indexing_technique', type=str, required=True, - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') + parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json") + parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") + parser.add_argument( + "indexing_technique", + type=str, + required=True, + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + location="json", + ) + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) args = parser.parse_args() # validate args DocumentService.estimate_args_validate(args) extract_settings = [] - if args['info_list']['data_source_type'] == 'upload_file': - file_ids = args['info_list']['file_info_list']['file_ids'] - file_details = db.session.query(UploadFile).filter( - UploadFile.tenant_id == current_user.current_tenant_id, - UploadFile.id.in_(file_ids) - ).all() + if args["info_list"]["data_source_type"] == "upload_file": + file_ids = args["info_list"]["file_info_list"]["file_ids"] + file_details = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)) + .all() + ) if file_details is None: raise NotFound("File not found.") @@ -344,55 +406,58 @@ def post(self): if file_details: for file_detail in file_details: extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=file_detail, - document_model=args['doc_form'] + datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"] ) extract_settings.append(extract_setting) - elif args['info_list']['data_source_type'] == 'notion_import': - notion_info_list = args['info_list']['notion_info_list'] + elif args["info_list"]["data_source_type"] == "notion_import": + notion_info_list = args["info_list"]["notion_info_list"] for notion_info in notion_info_list: - workspace_id = notion_info['workspace_id'] - for page in notion_info['pages']: + workspace_id = notion_info["workspace_id"] + for page in notion_info["pages"]: extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ "notion_workspace_id": workspace_id, - "notion_obj_id": page['page_id'], - "notion_page_type": page['type'], - "tenant_id": current_user.current_tenant_id + "notion_obj_id": page["page_id"], + "notion_page_type": page["type"], + "tenant_id": current_user.current_tenant_id, }, - document_model=args['doc_form'] + document_model=args["doc_form"], ) extract_settings.append(extract_setting) - elif args['info_list']['data_source_type'] == 'website_crawl': - website_info_list = args['info_list']['website_info_list'] - for url in website_info_list['urls']: + elif args["info_list"]["data_source_type"] == "website_crawl": + website_info_list = args["info_list"]["website_info_list"] + for url in website_info_list["urls"]: extract_setting = ExtractSetting( datasource_type="website_crawl", website_info={ - "provider": website_info_list['provider'], - "job_id": website_info_list['job_id'], + "provider": website_info_list["provider"], + "job_id": website_info_list["job_id"], "url": url, "tenant_id": current_user.current_tenant_id, - "mode": 'crawl', - "only_main_content": website_info_list['only_main_content'] + "mode": "crawl", + "only_main_content": website_info_list["only_main_content"], }, - document_model=args['doc_form'] + document_model=args["doc_form"], ) extract_settings.append(extract_setting) else: - raise ValueError('Data source type not support') + raise ValueError("Data source type not support") indexing_runner = IndexingRunner() try: - response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, - args['process_rule'], args['doc_form'], - args['doc_language'], args['dataset_id'], - args['indexing_technique']) + response = indexing_runner.indexing_estimate( + current_user.current_tenant_id, + extract_settings, + args["process_rule"], + args["doc_form"], + args["doc_language"], + args["dataset_id"], + args["indexing_technique"], + ) except LLMBadRequestError: raise ProviderNotInitializeError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except Exception as e: @@ -402,7 +467,6 @@ def post(self): class DatasetRelatedAppListApi(Resource): - @setup_required @login_required @account_initialization_required @@ -426,52 +490,52 @@ def get(self, dataset_id): if app_model: related_apps.append(app_model) - return { - 'data': related_apps, - 'total': len(related_apps) - }, 200 + return {"data": related_apps, "total": len(related_apps)}, 200 class DatasetIndexingStatusApi(Resource): - @setup_required @login_required @account_initialization_required def get(self, dataset_id): dataset_id = str(dataset_id) - documents = db.session.query(Document).filter( - Document.dataset_id == dataset_id, - Document.tenant_id == current_user.current_tenant_id - ).all() + documents = ( + db.session.query(Document) + .filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id) + .all() + ) documents_status = [] for document in documents: - completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() - total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments documents_status.append(marshal(document, document_status_fields)) - data = { - 'data': documents_status - } + data = {"data": documents_status} return data class DatasetApiKeyApi(Resource): max_keys = 10 - token_prefix = 'dataset-' - resource_type = 'dataset' + token_prefix = "dataset-" + resource_type = "dataset" @setup_required @login_required @account_initialization_required @marshal_with(api_key_list) def get(self): - keys = db.session.query(ApiToken). \ - filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \ - all() + keys = ( + db.session.query(ApiToken) + .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) + .all() + ) return {"items": keys} @setup_required @@ -483,15 +547,17 @@ def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - current_key_count = db.session.query(ApiToken). \ - filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \ - count() + current_key_count = ( + db.session.query(ApiToken) + .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) + .count() + ) if current_key_count >= self.max_keys: flask_restful.abort( 400, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", - code='max_keys_exceeded' + code="max_keys_exceeded", ) key = ApiToken.generate_api_key(self.token_prefix, 24) @@ -505,7 +571,7 @@ def post(self): class DatasetApiDeleteApi(Resource): - resource_type = 'dataset' + resource_type = "dataset" @setup_required @login_required @@ -517,18 +583,23 @@ def delete(self, api_key_id): if not current_user.is_admin_or_owner: raise Forbidden() - key = db.session.query(ApiToken). \ - filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type, - ApiToken.id == api_key_id). \ - first() + key = ( + db.session.query(ApiToken) + .filter( + ApiToken.tenant_id == current_user.current_tenant_id, + ApiToken.type == self.resource_type, + ApiToken.id == api_key_id, + ) + .first() + ) if key is None: - flask_restful.abort(404, message='API key not found') + flask_restful.abort(404, message="API key not found") db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DatasetApiBaseUrlApi(Resource): @@ -536,10 +607,7 @@ class DatasetApiBaseUrlApi(Resource): @login_required @account_initialization_required def get(self): - return { - 'api_base_url': (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL - else request.host_url.rstrip('/')) + '/v1' - } + return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"} class DatasetRetrievalSettingApi(Resource): @@ -549,15 +617,35 @@ class DatasetRetrievalSettingApi(Resource): def get(self): vector_type = dify_config.VECTOR_STORE match vector_type: - case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT: - return { - 'retrieval_method': [ - RetrievalMethod.SEMANTIC_SEARCH.value - ] - } - case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH: + case ( + VectorType.MILVUS + | VectorType.RELYT + | VectorType.PGVECTOR + | VectorType.TIDB_VECTOR + | VectorType.CHROMA + | VectorType.TENCENT + | VectorType.PGVECTO_RS + | VectorType.BAIDU + | VectorType.VIKINGDB + | VectorType.UPSTASH + | VectorType.OCEANBASE + ): + return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} + case ( + VectorType.QDRANT + | VectorType.WEAVIATE + | VectorType.OPENSEARCH + | VectorType.ANALYTICDB + | VectorType.MYSCALE + | VectorType.ORACLE + | VectorType.ELASTICSEARCH + | VectorType.PGVECTOR + | VectorType.TIDB_ON_QDRANT + | VectorType.LINDORM + | VectorType.COUCHBASE + ): return { - 'retrieval_method': [ + "retrieval_method": [ RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value, @@ -573,15 +661,33 @@ class DatasetRetrievalSettingMockApi(Resource): @account_initialization_required def get(self, vector_type): match vector_type: - case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT: - return { - 'retrieval_method': [ - RetrievalMethod.SEMANTIC_SEARCH.value - ] - } - case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH: + case ( + VectorType.MILVUS + | VectorType.RELYT + | VectorType.TIDB_VECTOR + | VectorType.CHROMA + | VectorType.TENCENT + | VectorType.PGVECTO_RS + | VectorType.BAIDU + | VectorType.VIKINGDB + | VectorType.UPSTASH + | VectorType.OCEANBASE + ): + return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} + case ( + VectorType.QDRANT + | VectorType.WEAVIATE + | VectorType.OPENSEARCH + | VectorType.ANALYTICDB + | VectorType.MYSCALE + | VectorType.ORACLE + | VectorType.ELASTICSEARCH + | VectorType.COUCHBASE + | VectorType.PGVECTOR + | VectorType.LINDORM + ): return { - 'retrieval_method': [ + "retrieval_method": [ RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value, @@ -591,7 +697,6 @@ def get(self, vector_type): raise ValueError(f"Unsupported vector db type {vector_type}.") - class DatasetErrorDocs(Resource): @setup_required @login_required @@ -603,10 +708,7 @@ def get(self, dataset_id): raise NotFound("Dataset not found.") results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str) - return { - 'data': [marshal(item, document_status_fields) for item in results], - 'total': len(results) - }, 200 + return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200 class DatasetPermissionUserListApi(Resource): @@ -626,21 +728,21 @@ def get(self, dataset_id): partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) return { - 'data': partial_members_list, + "data": partial_members_list, }, 200 -api.add_resource(DatasetListApi, '/datasets') -api.add_resource(DatasetApi, '/datasets/') -api.add_resource(DatasetUseCheckApi, '/datasets//use-check') -api.add_resource(DatasetQueryApi, '/datasets//queries') -api.add_resource(DatasetErrorDocs, '/datasets//error-docs') -api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate') -api.add_resource(DatasetRelatedAppListApi, '/datasets//related-apps') -api.add_resource(DatasetIndexingStatusApi, '/datasets//indexing-status') -api.add_resource(DatasetApiKeyApi, '/datasets/api-keys') -api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/') -api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') -api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') -api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/') -api.add_resource(DatasetPermissionUserListApi, '/datasets//permission-part-users') +api.add_resource(DatasetListApi, "/datasets") +api.add_resource(DatasetApi, "/datasets/") +api.add_resource(DatasetUseCheckApi, "/datasets//use-check") +api.add_resource(DatasetQueryApi, "/datasets//queries") +api.add_resource(DatasetErrorDocs, "/datasets//error-docs") +api.add_resource(DatasetIndexingEstimateApi, "/datasets/indexing-estimate") +api.add_resource(DatasetRelatedAppListApi, "/datasets//related-apps") +api.add_resource(DatasetIndexingStatusApi, "/datasets//indexing-status") +api.add_resource(DatasetApiKeyApi, "/datasets/api-keys") +api.add_resource(DatasetApiDeleteApi, "/datasets/api-keys/") +api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info") +api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting") +api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/") +api.add_resource(DatasetPermissionUserListApi, "/datasets//permission-part-users") diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 976b97660ae293..60848039c5282c 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -24,8 +24,11 @@ InvalidActionError, InvalidMetadataError, ) -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_resource_check, + setup_required, +) from core.errors.error import ( LLMBadRequestError, ModelCurrentlyNotSupportError, @@ -46,8 +49,7 @@ document_with_segments_fields, ) from libs.login import login_required -from models.dataset import Dataset, DatasetProcessRule, Document, DocumentSegment -from models.model import UploadFile +from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from services.dataset_service import DatasetService, DocumentService from tasks.add_document_to_index_task import add_document_to_index_task from tasks.remove_document_from_index_task import remove_document_from_index_task @@ -57,7 +59,7 @@ class DocumentResource(Resource): def get_document(self, dataset_id: str, document_id: str) -> Document: dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) @@ -67,17 +69,17 @@ def get_document(self, dataset_id: str, document_id: str) -> Document: document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") if document.tenant_id != current_user.current_tenant_id: - raise Forbidden('No permission.') + raise Forbidden("No permission.") return document def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]: dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) @@ -87,7 +89,7 @@ def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]: documents = DocumentService.get_batch_documents(dataset_id, batch) if not documents: - raise NotFound('Documents not found.') + raise NotFound("Documents not found.") return documents @@ -99,11 +101,11 @@ class GetProcessRuleApi(Resource): def get(self): req_data = request.args - document_id = req_data.get('document_id') + document_id = req_data.get("document_id") # get default rules - mode = DocumentService.DEFAULT_RULES['mode'] - rules = DocumentService.DEFAULT_RULES['rules'] + mode = DocumentService.DEFAULT_RULES["mode"] + rules = DocumentService.DEFAULT_RULES["rules"] if document_id: # get the latest process rule document = Document.query.get_or_404(document_id) @@ -111,7 +113,7 @@ def get(self): dataset = DatasetService.get_dataset(document.dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) @@ -119,19 +121,18 @@ def get(self): raise Forbidden(str(e)) # get the latest process rule - dataset_process_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.dataset_id == document.dataset_id). \ - order_by(DatasetProcessRule.created_at.desc()). \ - limit(1). \ - one_or_none() + dataset_process_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.dataset_id == document.dataset_id) + .order_by(DatasetProcessRule.created_at.desc()) + .limit(1) + .one_or_none() + ) if dataset_process_rule: mode = dataset_process_rule.mode rules = dataset_process_rule.rules_dict - return { - 'mode': mode, - 'rules': rules - } + return {"mode": mode, "rules": rules} class DatasetDocumentListApi(Resource): @@ -140,49 +141,48 @@ class DatasetDocumentListApi(Resource): @account_initialization_required def get(self, dataset_id): dataset_id = str(dataset_id) - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - search = request.args.get('keyword', default=None, type=str) - sort = request.args.get('sort', default='-created_at', type=str) + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + search = request.args.get("keyword", default=None, type=str) + sort = request.args.get("sort", default="-created_at", type=str) # "yes", "true", "t", "y", "1" convert to True, while others convert to False. try: - fetch = string_to_bool(request.args.get('fetch', default='false')) + fetch = string_to_bool(request.args.get("fetch", default="false")) except (ArgumentTypeError, ValueError, Exception) as e: fetch = False dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - query = Document.query.filter_by( - dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) + query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) if search: - search = f'%{search}%' + search = f"%{search}%" query = query.filter(Document.name.like(search)) - if sort.startswith('-'): + if sort.startswith("-"): sort_logic = desc sort = sort[1:] else: sort_logic = asc - if sort == 'hit_count': - sub_query = db.select(DocumentSegment.document_id, - db.func.sum(DocumentSegment.hit_count).label("total_hit_count")) \ - .group_by(DocumentSegment.document_id) \ + if sort == "hit_count": + sub_query = ( + db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count")) + .group_by(DocumentSegment.document_id) .subquery() + ) - query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id) \ - .order_by( - sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)), - sort_logic(Document.position), - ) - elif sort == 'created_at': + query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by( + sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)), + sort_logic(Document.position), + ) + elif sort == "created_at": query = query.order_by( sort_logic(Document.created_at), sort_logic(Document.position), @@ -193,48 +193,47 @@ def get(self, dataset_id): desc(Document.position), ) - paginated_documents = query.paginate( - page=page, per_page=limit, max_per_page=100, error_out=False) + paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items if fetch: for document in documents: - completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() - total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments data = marshal(documents, document_with_segments_fields) else: data = marshal(documents, document_fields) response = { - 'data': data, - 'has_more': len(documents) == limit, - 'limit': limit, - 'total': paginated_documents.total, - 'page': page + "data": data, + "has_more": len(documents) == limit, + "limit": limit, + "total": paginated_documents.total, + "page": page, } return response - documents_and_batch_fields = { - 'documents': fields.List(fields.Nested(document_fields)), - 'batch': fields.String - } + documents_and_batch_fields = {"documents": fields.List(fields.Nested(document_fields)), "batch": fields.String} @setup_required @login_required @account_initialization_required @marshal_with(documents_and_batch_fields) - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def post(self, dataset_id): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_dataset_editor: @@ -246,21 +245,22 @@ def post(self, dataset_id): raise Forbidden(str(e)) parser = reqparse.RequestParser() - parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, - location='json') - parser.add_argument('data_source', type=dict, required=False, location='json') - parser.add_argument('process_rule', type=dict, required=False, location='json') - parser.add_argument('duplicate', type=bool, default=True, nullable=False, location='json') - parser.add_argument('original_document_id', type=str, required=False, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') - parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, - location='json') + parser.add_argument( + "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" + ) + parser.add_argument("data_source", type=dict, required=False, location="json") + parser.add_argument("process_rule", type=dict, required=False, location="json") + parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json") + parser.add_argument("original_document_id", type=str, required=False, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() - if not dataset.indexing_technique and not args['indexing_technique']: - raise ValueError('indexing_technique is required.') + if not dataset.indexing_technique and not args["indexing_technique"]: + raise ValueError("indexing_technique is required.") # validate args DocumentService.document_create_args_validate(args) @@ -274,51 +274,60 @@ def post(self, dataset_id): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - return { - 'documents': documents, - 'batch': batch - } + return {"documents": documents, "batch": batch} class DatasetInitApi(Resource): - @setup_required @login_required @account_initialization_required @marshal_with(dataset_and_document_fields) - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def post(self): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, required=True, - nullable=False, location='json') - parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json') - parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') - parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, - location='json') + parser.add_argument( + "indexing_technique", + type=str, + choices=Dataset.INDEXING_TECHNIQUE_LIST, + required=True, + nullable=False, + location="json", + ) + parser.add_argument("data_source", type=dict, required=True, nullable=True, location="json") + parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") + parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") + parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") args = parser.parse_args() # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: raise Forbidden() - if args['indexing_technique'] == 'high_quality': + if args["indexing_technique"] == "high_quality": + if args["embedding_model"] is None or args["embedding_model_provider"] is None: + raise ValueError("embedding model and embedding model provider are required for high quality indexing.") try: model_manager = ModelManager() - model_manager.get_default_model_instance( + model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, - model_type=ModelType.TEXT_EMBEDDING + provider=args["embedding_model_provider"], + model_type=ModelType.TEXT_EMBEDDING, + model=args["embedding_model"], ) except InvokeAuthorizationError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -327,9 +336,7 @@ def post(self): try: dataset, documents, batch = DocumentService.save_document_without_dataset_id( - tenant_id=current_user.current_tenant_id, - document_data=args, - account=current_user + tenant_id=current_user.current_tenant_id, document_data=args, account=current_user ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -338,17 +345,12 @@ def post(self): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - response = { - 'dataset': dataset, - 'documents': documents, - 'batch': batch - } + response = {"dataset": dataset, "documents": documents, "batch": batch} return response class DocumentIndexingEstimateApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -357,50 +359,49 @@ def get(self, dataset_id, document_id): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - if document.indexing_status in ['completed', 'error']: + if document.indexing_status in {"completed", "error"}: raise DocumentAlreadyFinishedError() data_process_rule = document.dataset_process_rule data_process_rule_dict = data_process_rule.to_dict() - response = { - "tokens": 0, - "total_price": 0, - "currency": "USD", - "total_segments": 0, - "preview": [] - } + response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} - if document.data_source_type == 'upload_file': + if document.data_source_type == "upload_file": data_source_info = document.data_source_info_dict - if data_source_info and 'upload_file_id' in data_source_info: - file_id = data_source_info['upload_file_id'] + if data_source_info and "upload_file_id" in data_source_info: + file_id = data_source_info["upload_file_id"] - file = db.session.query(UploadFile).filter( - UploadFile.tenant_id == document.tenant_id, - UploadFile.id == file_id - ).first() + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) + .first() + ) # raise error if file not found if not file: - raise NotFound('File not found.') + raise NotFound("File not found.") extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=file, - document_model=document.doc_form + datasource_type="upload_file", upload_file=file, document_model=document.doc_form ) indexing_runner = IndexingRunner() try: - response = indexing_runner.indexing_estimate(current_user.current_tenant_id, [extract_setting], - data_process_rule_dict, document.doc_form, - 'English', dataset_id) + response = indexing_runner.indexing_estimate( + current_user.current_tenant_id, + [extract_setting], + data_process_rule_dict, + document.doc_form, + "English", + dataset_id, + ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except Exception as e: @@ -410,7 +411,6 @@ def get(self, dataset_id, document_id): class DocumentBatchIndexingEstimateApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -418,13 +418,7 @@ def get(self, dataset_id, batch): dataset_id = str(dataset_id) batch = str(batch) documents = self.get_batch_documents(dataset_id, batch) - response = { - "tokens": 0, - "total_price": 0, - "currency": "USD", - "total_segments": 0, - "preview": [] - } + response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} if not documents: return response data_process_rule = documents[0].dataset_process_rule @@ -432,82 +426,83 @@ def get(self, dataset_id, batch): info_list = [] extract_settings = [] for document in documents: - if document.indexing_status in ['completed', 'error']: + if document.indexing_status in {"completed", "error"}: raise DocumentAlreadyFinishedError() data_source_info = document.data_source_info_dict # format document files info - if data_source_info and 'upload_file_id' in data_source_info: - file_id = data_source_info['upload_file_id'] + if data_source_info and "upload_file_id" in data_source_info: + file_id = data_source_info["upload_file_id"] info_list.append(file_id) # format document notion info - elif data_source_info and 'notion_workspace_id' in data_source_info and 'notion_page_id' in data_source_info: + elif ( + data_source_info and "notion_workspace_id" in data_source_info and "notion_page_id" in data_source_info + ): pages = [] - page = { - 'page_id': data_source_info['notion_page_id'], - 'type': data_source_info['type'] - } + page = {"page_id": data_source_info["notion_page_id"], "type": data_source_info["type"]} pages.append(page) - notion_info = { - 'workspace_id': data_source_info['notion_workspace_id'], - 'pages': pages - } + notion_info = {"workspace_id": data_source_info["notion_workspace_id"], "pages": pages} info_list.append(notion_info) - if document.data_source_type == 'upload_file': - file_id = data_source_info['upload_file_id'] - file_detail = db.session.query(UploadFile).filter( - UploadFile.tenant_id == current_user.current_tenant_id, - UploadFile.id == file_id - ).first() + if document.data_source_type == "upload_file": + file_id = data_source_info["upload_file_id"] + file_detail = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id) + .first() + ) if file_detail is None: raise NotFound("File not found.") extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=file_detail, - document_model=document.doc_form + datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form ) extract_settings.append(extract_setting) - elif document.data_source_type == 'notion_import': + elif document.data_source_type == "notion_import": extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ - "notion_workspace_id": data_source_info['notion_workspace_id'], - "notion_obj_id": data_source_info['notion_page_id'], - "notion_page_type": data_source_info['type'], - "tenant_id": current_user.current_tenant_id + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], + "tenant_id": current_user.current_tenant_id, }, - document_model=document.doc_form + document_model=document.doc_form, ) extract_settings.append(extract_setting) - elif document.data_source_type == 'website_crawl': + elif document.data_source_type == "website_crawl": extract_setting = ExtractSetting( datasource_type="website_crawl", website_info={ - "provider": data_source_info['provider'], - "job_id": data_source_info['job_id'], - "url": data_source_info['url'], + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], + "url": data_source_info["url"], "tenant_id": current_user.current_tenant_id, - "mode": data_source_info['mode'], - "only_main_content": data_source_info['only_main_content'] + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], }, - document_model=document.doc_form + document_model=document.doc_form, ) extract_settings.append(extract_setting) else: - raise ValueError('Data source type not support') + raise ValueError("Data source type not support") indexing_runner = IndexingRunner() try: - response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, - data_process_rule_dict, document.doc_form, - 'English', dataset_id) + response = indexing_runner.indexing_estimate( + current_user.current_tenant_id, + extract_settings, + data_process_rule_dict, + document.doc_form, + "English", + dataset_id, + ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except Exception as e: @@ -516,7 +511,6 @@ def get(self, dataset_id, batch): class DocumentBatchIndexingStatusApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -526,24 +520,24 @@ def get(self, dataset_id, batch): documents = self.get_batch_documents(dataset_id, batch) documents_status = [] for document in documents: - completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() - total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments if document.is_paused: - document.indexing_status = 'paused' + document.indexing_status = "paused" documents_status.append(marshal(document, document_status_fields)) - data = { - 'data': documents_status - } + data = {"data": documents_status} return data class DocumentIndexingStatusApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -552,25 +546,24 @@ def get(self, dataset_id, document_id): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - completed_segments = DocumentSegment.query \ - .filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document_id), - DocumentSegment.status != 're_segment') \ - .count() - total_segments = DocumentSegment.query \ - .filter(DocumentSegment.document_id == str(document_id), - DocumentSegment.status != 're_segment') \ - .count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document_id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments if document.is_paused: - document.indexing_status = 'paused' + document.indexing_status = "paused" return marshal(document, document_status_fields) class DocumentDetailApi(DocumentResource): - METADATA_CHOICES = {'all', 'only', 'without'} + METADATA_CHOICES = {"all", "only", "without"} @setup_required @login_required @@ -580,77 +573,75 @@ def get(self, dataset_id, document_id): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - metadata = request.args.get('metadata', 'all') + metadata = request.args.get("metadata", "all") if metadata not in self.METADATA_CHOICES: - raise InvalidMetadataError(f'Invalid metadata value: {metadata}') + raise InvalidMetadataError(f"Invalid metadata value: {metadata}") - if metadata == 'only': - response = { - 'id': document.id, - 'doc_type': document.doc_type, - 'doc_metadata': document.doc_metadata - } - elif metadata == 'without': + if metadata == "only": + response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata} + elif metadata == "without": process_rules = DatasetService.get_process_rules(dataset_id) data_source_info = document.data_source_detail_dict response = { - 'id': document.id, - 'position': document.position, - 'data_source_type': document.data_source_type, - 'data_source_info': data_source_info, - 'dataset_process_rule_id': document.dataset_process_rule_id, - 'dataset_process_rule': process_rules, - 'name': document.name, - 'created_from': document.created_from, - 'created_by': document.created_by, - 'created_at': document.created_at.timestamp(), - 'tokens': document.tokens, - 'indexing_status': document.indexing_status, - 'completed_at': int(document.completed_at.timestamp()) if document.completed_at else None, - 'updated_at': int(document.updated_at.timestamp()) if document.updated_at else None, - 'indexing_latency': document.indexing_latency, - 'error': document.error, - 'enabled': document.enabled, - 'disabled_at': int(document.disabled_at.timestamp()) if document.disabled_at else None, - 'disabled_by': document.disabled_by, - 'archived': document.archived, - 'segment_count': document.segment_count, - 'average_segment_length': document.average_segment_length, - 'hit_count': document.hit_count, - 'display_status': document.display_status, - 'doc_form': document.doc_form + "id": document.id, + "position": document.position, + "data_source_type": document.data_source_type, + "data_source_info": data_source_info, + "dataset_process_rule_id": document.dataset_process_rule_id, + "dataset_process_rule": process_rules, + "name": document.name, + "created_from": document.created_from, + "created_by": document.created_by, + "created_at": document.created_at.timestamp(), + "tokens": document.tokens, + "indexing_status": document.indexing_status, + "completed_at": int(document.completed_at.timestamp()) if document.completed_at else None, + "updated_at": int(document.updated_at.timestamp()) if document.updated_at else None, + "indexing_latency": document.indexing_latency, + "error": document.error, + "enabled": document.enabled, + "disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None, + "disabled_by": document.disabled_by, + "archived": document.archived, + "segment_count": document.segment_count, + "average_segment_length": document.average_segment_length, + "hit_count": document.hit_count, + "display_status": document.display_status, + "doc_form": document.doc_form, + "doc_language": document.doc_language, } else: process_rules = DatasetService.get_process_rules(dataset_id) data_source_info = document.data_source_detail_dict response = { - 'id': document.id, - 'position': document.position, - 'data_source_type': document.data_source_type, - 'data_source_info': data_source_info, - 'dataset_process_rule_id': document.dataset_process_rule_id, - 'dataset_process_rule': process_rules, - 'name': document.name, - 'created_from': document.created_from, - 'created_by': document.created_by, - 'created_at': document.created_at.timestamp(), - 'tokens': document.tokens, - 'indexing_status': document.indexing_status, - 'completed_at': int(document.completed_at.timestamp()) if document.completed_at else None, - 'updated_at': int(document.updated_at.timestamp()) if document.updated_at else None, - 'indexing_latency': document.indexing_latency, - 'error': document.error, - 'enabled': document.enabled, - 'disabled_at': int(document.disabled_at.timestamp()) if document.disabled_at else None, - 'disabled_by': document.disabled_by, - 'archived': document.archived, - 'doc_type': document.doc_type, - 'doc_metadata': document.doc_metadata, - 'segment_count': document.segment_count, - 'average_segment_length': document.average_segment_length, - 'hit_count': document.hit_count, - 'display_status': document.display_status, - 'doc_form': document.doc_form + "id": document.id, + "position": document.position, + "data_source_type": document.data_source_type, + "data_source_info": data_source_info, + "dataset_process_rule_id": document.dataset_process_rule_id, + "dataset_process_rule": process_rules, + "name": document.name, + "created_from": document.created_from, + "created_by": document.created_by, + "created_at": document.created_at.timestamp(), + "tokens": document.tokens, + "indexing_status": document.indexing_status, + "completed_at": int(document.completed_at.timestamp()) if document.completed_at else None, + "updated_at": int(document.updated_at.timestamp()) if document.updated_at else None, + "indexing_latency": document.indexing_latency, + "error": document.error, + "enabled": document.enabled, + "disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None, + "disabled_by": document.disabled_by, + "archived": document.archived, + "doc_type": document.doc_type, + "doc_metadata": document.doc_metadata, + "segment_count": document.segment_count, + "average_segment_length": document.average_segment_length, + "hit_count": document.hit_count, + "display_status": document.display_status, + "doc_form": document.doc_form, + "doc_language": document.doc_language, } return response, 200 @@ -671,7 +662,7 @@ def patch(self, dataset_id, document_id, action): if action == "pause": if document.indexing_status != "indexing": - raise InvalidActionError('Document not in indexing state.') + raise InvalidActionError("Document not in indexing state.") document.paused_by = current_user.id document.paused_at = datetime.now(timezone.utc).replace(tzinfo=None) @@ -679,8 +670,8 @@ def patch(self, dataset_id, document_id, action): db.session.commit() elif action == "resume": - if document.indexing_status not in ["paused", "error"]: - raise InvalidActionError('Document not in paused or error state.') + if document.indexing_status not in {"paused", "error"}: + raise InvalidActionError("Document not in paused or error state.") document.paused_by = None document.paused_at = None @@ -689,7 +680,7 @@ def patch(self, dataset_id, document_id, action): else: raise InvalidActionError() - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class DocumentDeleteApi(DocumentResource): @@ -710,9 +701,9 @@ def delete(self, dataset_id, document_id): try: DocumentService.delete_document(document) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Cannot delete document during indexing.') + raise DocumentIndexingError("Cannot delete document during indexing.") - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DocumentMetadataApi(DocumentResource): @@ -726,26 +717,26 @@ def put(self, dataset_id, document_id): req_data = request.get_json() - doc_type = req_data.get('doc_type') - doc_metadata = req_data.get('doc_metadata') + doc_type = req_data.get("doc_type") + doc_metadata = req_data.get("doc_metadata") # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() if doc_type is None or doc_metadata is None: - raise ValueError('Both doc_type and doc_metadata must be provided.') + raise ValueError("Both doc_type and doc_metadata must be provided.") if doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA: - raise ValueError('Invalid doc_type.') + raise ValueError("Invalid doc_type.") if not isinstance(doc_metadata, dict): - raise ValueError('doc_metadata must be a dictionary.') + raise ValueError("doc_metadata must be a dictionary.") metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] document.doc_metadata = {} - if doc_type == 'others': + if doc_type == "others": document.doc_metadata = doc_metadata else: for key, value_type in metadata_schema.items(): @@ -757,14 +748,14 @@ def put(self, dataset_id, document_id): document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() - return {'result': 'success', 'message': 'Document metadata updated.'}, 200 + return {"result": "success", "message": "Document metadata updated."}, 200 class DocumentStatusApi(DocumentResource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def patch(self, dataset_id, document_id, action): dataset_id = str(dataset_id) document_id = str(document_id) @@ -784,14 +775,14 @@ def patch(self, dataset_id, document_id, action): document = self.get_document(dataset_id, document_id) - indexing_cache_key = 'document_{}_indexing'.format(document.id) + indexing_cache_key = "document_{}_indexing".format(document.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise InvalidActionError("Document is being indexed, please try again later") if action == "enable": if document.enabled: - raise InvalidActionError('Document already enabled.') + raise InvalidActionError("Document already enabled.") document.enabled = True document.disabled_at = None @@ -804,13 +795,13 @@ def patch(self, dataset_id, document_id, action): add_document_to_index_task.delay(document_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 elif action == "disable": - if not document.completed_at or document.indexing_status != 'completed': - raise InvalidActionError('Document is not completed.') + if not document.completed_at or document.indexing_status != "completed": + raise InvalidActionError("Document is not completed.") if not document.enabled: - raise InvalidActionError('Document already disabled.') + raise InvalidActionError("Document already disabled.") document.enabled = False document.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None) @@ -823,11 +814,11 @@ def patch(self, dataset_id, document_id, action): remove_document_from_index_task.delay(document_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 elif action == "archive": if document.archived: - raise InvalidActionError('Document already archived.') + raise InvalidActionError("Document already archived.") document.archived = True document.archived_at = datetime.now(timezone.utc).replace(tzinfo=None) @@ -841,10 +832,10 @@ def patch(self, dataset_id, document_id, action): remove_document_from_index_task.delay(document_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 elif action == "un_archive": if not document.archived: - raise InvalidActionError('Document is not archived.') + raise InvalidActionError("Document is not archived.") document.archived = False document.archived_at = None @@ -857,13 +848,12 @@ def patch(self, dataset_id, document_id, action): add_document_to_index_task.delay(document_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 else: raise InvalidActionError() class DocumentPauseApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -874,7 +864,7 @@ def patch(self, dataset_id, document_id): dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") document = DocumentService.get_document(dataset.id, document_id) @@ -890,9 +880,9 @@ def patch(self, dataset_id, document_id): # pause document DocumentService.pause_document(document) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Cannot pause completed document.') + raise DocumentIndexingError("Cannot pause completed document.") - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DocumentRecoverApi(DocumentResource): @@ -905,7 +895,7 @@ def patch(self, dataset_id, document_id): document_id = str(document_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") document = DocumentService.get_document(dataset.id, document_id) # 404 if document not found @@ -919,9 +909,9 @@ def patch(self, dataset_id, document_id): # pause document DocumentService.recover_document(document) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Document is not in paused status.') + raise DocumentIndexingError("Document is not in paused status.") - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DocumentRetryApi(DocumentResource): @@ -932,15 +922,14 @@ def post(self, dataset_id): """retry document.""" parser = reqparse.RequestParser() - parser.add_argument('document_ids', type=list, required=True, nullable=False, - location='json') + parser.add_argument("document_ids", type=list, required=True, nullable=False, location="json") args = parser.parse_args() dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) retry_documents = [] if not dataset: - raise NotFound('Dataset not found.') - for document_id in args['document_ids']: + raise NotFound("Dataset not found.") + for document_id in args["document_ids"]: try: document_id = str(document_id) @@ -955,16 +944,16 @@ def post(self, dataset_id): raise ArchivedDocumentImmutableError() # 400 if document is completed - if document.indexing_status == 'completed': + if document.indexing_status == "completed": raise DocumentAlreadyFinishedError() retry_documents.append(document) except Exception as e: - logging.error(f"Document {document_id} retry failed: {str(e)}") + logging.exception(f"Document {document_id} retry failed: {str(e)}") continue # retry document DocumentService.retry_document(dataset_id, retry_documents) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DocumentRenameApi(DocumentResource): @@ -979,13 +968,13 @@ def post(self, dataset_id, document_id): dataset = DatasetService.get_dataset(dataset_id) DatasetService.check_dataset_operator_permission(current_user, dataset) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, nullable=False, location='json') + parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() try: - document = DocumentService.rename_document(dataset_id, document_id, args['name']) + document = DocumentService.rename_document(dataset_id, document_id, args["name"]) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Cannot delete document during indexing.') + raise DocumentIndexingError("Cannot delete document during indexing.") return document @@ -999,51 +988,43 @@ def get(self, dataset_id, document_id): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") if document.tenant_id != current_user.current_tenant_id: - raise Forbidden('No permission.') - if document.data_source_type != 'website_crawl': - raise ValueError('Document is not a website document.') + raise Forbidden("No permission.") + if document.data_source_type != "website_crawl": + raise ValueError("Document is not a website document.") # 403 if document is archived if DocumentService.check_archived(document): raise ArchivedDocumentImmutableError() # sync document DocumentService.sync_website_document(dataset_id, document) - return {'result': 'success'}, 200 - - -api.add_resource(GetProcessRuleApi, '/datasets/process-rule') -api.add_resource(DatasetDocumentListApi, - '/datasets//documents') -api.add_resource(DatasetInitApi, - '/datasets/init') -api.add_resource(DocumentIndexingEstimateApi, - '/datasets//documents//indexing-estimate') -api.add_resource(DocumentBatchIndexingEstimateApi, - '/datasets//batch//indexing-estimate') -api.add_resource(DocumentBatchIndexingStatusApi, - '/datasets//batch//indexing-status') -api.add_resource(DocumentIndexingStatusApi, - '/datasets//documents//indexing-status') -api.add_resource(DocumentDetailApi, - '/datasets//documents/') -api.add_resource(DocumentProcessingApi, - '/datasets//documents//processing/') -api.add_resource(DocumentDeleteApi, - '/datasets//documents/') -api.add_resource(DocumentMetadataApi, - '/datasets//documents//metadata') -api.add_resource(DocumentStatusApi, - '/datasets//documents//status/') -api.add_resource(DocumentPauseApi, '/datasets//documents//processing/pause') -api.add_resource(DocumentRecoverApi, '/datasets//documents//processing/resume') -api.add_resource(DocumentRetryApi, '/datasets//retry') -api.add_resource(DocumentRenameApi, - '/datasets//documents//rename') - -api.add_resource(WebsiteDocumentSyncApi, '/datasets//documents//website-sync') + return {"result": "success"}, 200 + + +api.add_resource(GetProcessRuleApi, "/datasets/process-rule") +api.add_resource(DatasetDocumentListApi, "/datasets//documents") +api.add_resource(DatasetInitApi, "/datasets/init") +api.add_resource( + DocumentIndexingEstimateApi, "/datasets//documents//indexing-estimate" +) +api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets//batch//indexing-estimate") +api.add_resource(DocumentBatchIndexingStatusApi, "/datasets//batch//indexing-status") +api.add_resource(DocumentIndexingStatusApi, "/datasets//documents//indexing-status") +api.add_resource(DocumentDetailApi, "/datasets//documents/") +api.add_resource( + DocumentProcessingApi, "/datasets//documents//processing/" +) +api.add_resource(DocumentDeleteApi, "/datasets//documents/") +api.add_resource(DocumentMetadataApi, "/datasets//documents//metadata") +api.add_resource(DocumentStatusApi, "/datasets//documents//status/") +api.add_resource(DocumentPauseApi, "/datasets//documents//processing/pause") +api.add_resource(DocumentRecoverApi, "/datasets//documents//processing/resume") +api.add_resource(DocumentRetryApi, "/datasets//retry") +api.add_resource(DocumentRenameApi, "/datasets//documents//rename") + +api.add_resource(WebsiteDocumentSyncApi, "/datasets//documents//website-sync") diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index a4210d5a0c26ff..5d8d664e414b94 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -11,11 +11,11 @@ from controllers.console import api from controllers.console.app.error import ProviderNotInitializeError from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError -from controllers.console.setup import setup_required from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_knowledge_limit_check, cloud_edition_billing_resource_check, + setup_required, ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager @@ -24,7 +24,7 @@ from extensions.ext_redis import redis_client from fields.segment_fields import segment_fields from libs.login import login_required -from models.dataset import DocumentSegment +from models import DocumentSegment from services.dataset_service import DatasetService, DocumentService, SegmentService from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task from tasks.disable_segment_from_index_task import disable_segment_from_index_task @@ -40,7 +40,7 @@ def get(self, dataset_id, document_id): document_id = str(document_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) @@ -50,37 +50,33 @@ def get(self, dataset_id, document_id): document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") parser = reqparse.RequestParser() - parser.add_argument('last_id', type=str, default=None, location='args') - parser.add_argument('limit', type=int, default=20, location='args') - parser.add_argument('status', type=str, - action='append', default=[], location='args') - parser.add_argument('hit_count_gte', type=int, - default=None, location='args') - parser.add_argument('enabled', type=str, default='all', location='args') - parser.add_argument('keyword', type=str, default=None, location='args') + parser.add_argument("last_id", type=str, default=None, location="args") + parser.add_argument("limit", type=int, default=20, location="args") + parser.add_argument("status", type=str, action="append", default=[], location="args") + parser.add_argument("hit_count_gte", type=int, default=None, location="args") + parser.add_argument("enabled", type=str, default="all", location="args") + parser.add_argument("keyword", type=str, default=None, location="args") args = parser.parse_args() - last_id = args['last_id'] - limit = min(args['limit'], 100) - status_list = args['status'] - hit_count_gte = args['hit_count_gte'] - keyword = args['keyword'] + last_id = args["last_id"] + limit = min(args["limit"], 100) + status_list = args["status"] + hit_count_gte = args["hit_count_gte"] + keyword = args["keyword"] query = DocumentSegment.query.filter( - DocumentSegment.document_id == str(document_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id ) if last_id is not None: last_segment = db.session.get(DocumentSegment, str(last_id)) if last_segment: - query = query.filter( - DocumentSegment.position > last_segment.position) + query = query.filter(DocumentSegment.position > last_segment.position) else: - return {'data': [], 'has_more': False, 'limit': limit}, 200 + return {"data": [], "has_more": False, "limit": limit}, 200 if status_list: query = query.filter(DocumentSegment.status.in_(status_list)) @@ -89,12 +85,12 @@ def get(self, dataset_id, document_id): query = query.filter(DocumentSegment.hit_count >= hit_count_gte) if keyword: - query = query.where(DocumentSegment.content.ilike(f'%{keyword}%')) + query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) - if args['enabled'].lower() != 'all': - if args['enabled'].lower() == 'true': + if args["enabled"].lower() != "all": + if args["enabled"].lower() == "true": query = query.filter(DocumentSegment.enabled == True) - elif args['enabled'].lower() == 'false': + elif args["enabled"].lower() == "false": query = query.filter(DocumentSegment.enabled == False) total = query.count() @@ -106,11 +102,11 @@ def get(self, dataset_id, document_id): segments = segments[:-1] return { - 'data': marshal(segments, segment_fields), - 'doc_form': document.doc_form, - 'has_more': has_more, - 'limit': limit, - 'total': total + "data": marshal(segments, segment_fields), + "doc_form": document.doc_form, + "has_more": has_more, + "limit": limit, + "total": total, }, 200 @@ -118,12 +114,12 @@ class DatasetDocumentSegmentApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def patch(self, dataset_id, segment_id, action): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # The role of the current user in the ta table must be admin, owner, or editor @@ -134,7 +130,7 @@ def patch(self, dataset_id, segment_id, action): DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": # check embedding model setting try: model_manager = ModelManager() @@ -142,32 +138,32 @@ def patch(self, dataset_id, segment_id, action): tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") - if segment.status != 'completed': - raise NotFound('Segment is not completed, enable or disable function is not allowed') + if segment.status != "completed": + raise NotFound("Segment is not completed, enable or disable function is not allowed") - document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id) + document_indexing_cache_key = "document_{}_indexing".format(segment.document_id) cache_result = redis_client.get(document_indexing_cache_key) if cache_result is not None: raise InvalidActionError("Document is being indexed, please try again later") - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_indexing".format(segment.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise InvalidActionError("Segment is being indexed, please try again later") @@ -186,7 +182,7 @@ def patch(self, dataset_id, segment_id, action): enable_segment_to_index_task.delay(segment.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 elif action == "disable": if not segment.enabled: raise InvalidActionError("Segment is already disabled.") @@ -201,7 +197,7 @@ def patch(self, dataset_id, segment_id, action): disable_segment_from_index_task.delay(segment.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 else: raise InvalidActionError() @@ -210,35 +206,36 @@ class DatasetDocumentSegmentAddApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') - @cloud_edition_billing_knowledge_limit_check('add_segment') + @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_knowledge_limit_check("add_segment") def post(self, dataset_id, document_id): # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") if not current_user.is_editor: raise Forbidden() # check embedding model setting - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) try: @@ -247,37 +244,34 @@ def post(self, dataset_id, document_id): raise Forbidden(str(e)) # validate args parser = reqparse.RequestParser() - parser.add_argument('content', type=str, required=True, nullable=False, location='json') - parser.add_argument('answer', type=str, required=False, nullable=True, location='json') - parser.add_argument('keywords', type=list, required=False, nullable=True, location='json') + parser.add_argument("content", type=str, required=True, nullable=False, location="json") + parser.add_argument("answer", type=str, required=False, nullable=True, location="json") + parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") args = parser.parse_args() SegmentService.segment_create_args_validate(args, document) segment = SegmentService.create_segment(args, document, dataset) - return { - 'data': marshal(segment, segment_fields), - 'doc_form': document.doc_form - }, 200 + return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 class DatasetDocumentSegmentUpdateApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def patch(self, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') - if dataset.indexing_technique == 'high_quality': + raise NotFound("Document not found.") + if dataset.indexing_technique == "high_quality": # check embedding model setting try: model_manager = ModelManager() @@ -285,22 +279,22 @@ def patch(self, dataset_id, document_id, segment_id): tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # check segment segment_id = str(segment_id) segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() @@ -310,16 +304,13 @@ def patch(self, dataset_id, document_id, segment_id): raise Forbidden(str(e)) # validate args parser = reqparse.RequestParser() - parser.add_argument('content', type=str, required=True, nullable=False, location='json') - parser.add_argument('answer', type=str, required=False, nullable=True, location='json') - parser.add_argument('keywords', type=list, required=False, nullable=True, location='json') + parser.add_argument("content", type=str, required=True, nullable=False, location="json") + parser.add_argument("answer", type=str, required=False, nullable=True, location="json") + parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") args = parser.parse_args() SegmentService.segment_create_args_validate(args, document) segment = SegmentService.update_segment(args, segment, document, dataset) - return { - 'data': marshal(segment, segment_fields), - 'doc_form': document.doc_form - }, 200 + return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 @setup_required @login_required @@ -329,22 +320,21 @@ def delete(self, dataset_id, document_id, segment_id): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") # check segment segment_id = str(segment_id) segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") # The role of the current user in the ta table must be admin or owner if not current_user.is_editor: raise Forbidden() @@ -353,36 +343,36 @@ def delete(self, dataset_id, document_id, segment_id): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) SegmentService.delete_segment(segment, document, dataset) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class DatasetDocumentSegmentBatchImportApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') - @cloud_edition_billing_knowledge_limit_check('add_segment') + @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_knowledge_limit_check("add_segment") def post(self, dataset_id, document_id): # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") # get file from request - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() # check file type - if not file.filename.endswith('.csv'): + if not file.filename.endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") try: @@ -390,51 +380,47 @@ def post(self, dataset_id, document_id): df = pd.read_csv(file) result = [] for index, row in df.iterrows(): - if document.doc_form == 'qa_model': - data = {'content': row[0], 'answer': row[1]} + if document.doc_form == "qa_model": + data = {"content": row[0], "answer": row[1]} else: - data = {'content': row[0]} + data = {"content": row[0]} result.append(data) if len(result) == 0: raise ValueError("The CSV file is empty.") # async job job_id = str(uuid.uuid4()) - indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id)) + indexing_cache_key = "segment_batch_import_{}".format(str(job_id)) # send batch add segments task - redis_client.setnx(indexing_cache_key, 'waiting') - batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id, - current_user.current_tenant_id, current_user.id) + redis_client.setnx(indexing_cache_key, "waiting") + batch_create_segment_to_index_task.delay( + str(job_id), result, dataset_id, document_id, current_user.current_tenant_id, current_user.id + ) except Exception as e: - return {'error': str(e)}, 500 - return { - 'job_id': job_id, - 'job_status': 'waiting' - }, 200 + return {"error": str(e)}, 500 + return {"job_id": job_id, "job_status": "waiting"}, 200 @setup_required @login_required @account_initialization_required def get(self, job_id): job_id = str(job_id) - indexing_cache_key = 'segment_batch_import_{}'.format(job_id) + indexing_cache_key = "segment_batch_import_{}".format(job_id) cache_result = redis_client.get(indexing_cache_key) if cache_result is None: raise ValueError("The job is not exist.") - return { - 'job_id': job_id, - 'job_status': cache_result.decode() - }, 200 + return {"job_id": job_id, "job_status": cache_result.decode()}, 200 -api.add_resource(DatasetDocumentSegmentListApi, - '/datasets//documents//segments') -api.add_resource(DatasetDocumentSegmentApi, - '/datasets//segments//') -api.add_resource(DatasetDocumentSegmentAddApi, - '/datasets//documents//segment') -api.add_resource(DatasetDocumentSegmentUpdateApi, - '/datasets//documents//segments/') -api.add_resource(DatasetDocumentSegmentBatchImportApi, - '/datasets//documents//segments/batch_import', - '/datasets/batch_import_status/') +api.add_resource(DatasetDocumentSegmentListApi, "/datasets//documents//segments") +api.add_resource(DatasetDocumentSegmentApi, "/datasets//segments//") +api.add_resource(DatasetDocumentSegmentAddApi, "/datasets//documents//segment") +api.add_resource( + DatasetDocumentSegmentUpdateApi, + "/datasets//documents//segments/", +) +api.add_resource( + DatasetDocumentSegmentBatchImportApi, + "/datasets//documents//segments/batch_import", + "/datasets/batch_import_status/", +) diff --git a/api/controllers/console/datasets/error.py b/api/controllers/console/datasets/error.py index 9270b610c28e8c..6a7a3971a8b33f 100644 --- a/api/controllers/console/datasets/error.py +++ b/api/controllers/console/datasets/error.py @@ -2,90 +2,90 @@ class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class FileTooLargeError(BaseHTTPException): - error_code = 'file_too_large' + error_code = "file_too_large" description = "File size exceeded. {message}" code = 413 class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 class HighQualityDatasetOnlyError(BaseHTTPException): - error_code = 'high_quality_dataset_only' + error_code = "high_quality_dataset_only" description = "Current operation only supports 'high-quality' datasets." code = 400 class DatasetNotInitializedError(BaseHTTPException): - error_code = 'dataset_not_initialized' + error_code = "dataset_not_initialized" description = "The dataset is still being initialized or indexing. Please wait a moment." code = 400 class ArchivedDocumentImmutableError(BaseHTTPException): - error_code = 'archived_document_immutable' + error_code = "archived_document_immutable" description = "The archived document is not editable." code = 403 class DatasetNameDuplicateError(BaseHTTPException): - error_code = 'dataset_name_duplicate' + error_code = "dataset_name_duplicate" description = "The dataset name already exists. Please modify your dataset name." code = 409 class InvalidActionError(BaseHTTPException): - error_code = 'invalid_action' + error_code = "invalid_action" description = "Invalid action." code = 400 class DocumentAlreadyFinishedError(BaseHTTPException): - error_code = 'document_already_finished' + error_code = "document_already_finished" description = "The document has been processed. Please refresh the page or go to the document details." code = 400 class DocumentIndexingError(BaseHTTPException): - error_code = 'document_indexing' + error_code = "document_indexing" description = "The document is being processed and cannot be edited." code = 400 class InvalidMetadataError(BaseHTTPException): - error_code = 'invalid_metadata' + error_code = "invalid_metadata" description = "The metadata content is incorrect. Please check and verify." code = 400 class WebsiteCrawlError(BaseHTTPException): - error_code = 'crawl_failed' + error_code = "crawl_failed" description = "{message}" code = 500 class DatasetInUseError(BaseHTTPException): - error_code = 'dataset_in_use' + error_code = "dataset_in_use" description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it." code = 409 class IndexingEstimateError(BaseHTTPException): - error_code = 'indexing_estimate_error' + error_code = "indexing_estimate_error" description = "Knowledge indexing estimate failed: {message}" code = 500 diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py new file mode 100644 index 00000000000000..bc6e3687c1c99d --- /dev/null +++ b/api/controllers/console/datasets/external.py @@ -0,0 +1,262 @@ +from flask import request +from flask_login import current_user +from flask_restful import Resource, marshal, reqparse +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound + +import services +from controllers.console import api +from controllers.console.datasets.error import DatasetNameDuplicateError +from controllers.console.wraps import account_initialization_required, setup_required +from fields.dataset_fields import dataset_detail_fields +from libs.login import login_required +from services.dataset_service import DatasetService +from services.external_knowledge_service import ExternalDatasetService +from services.hit_testing_service import HitTestingService +from services.knowledge_service import ExternalDatasetTestService + + +def _validate_name(name): + if not name or len(name) < 1 or len(name) > 100: + raise ValueError("Name must be between 1 to 100 characters.") + return name + + +def _validate_description_length(description): + if description and len(description) > 400: + raise ValueError("Description cannot exceed 400 characters.") + return description + + +class ExternalApiTemplateListApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + search = request.args.get("keyword", default=None, type=str) + + external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis( + page, limit, current_user.current_tenant_id, search + ) + response = { + "data": [item.to_dict() for item in external_knowledge_apis], + "has_more": len(external_knowledge_apis) == limit, + "limit": limit, + "total": total, + "page": page, + } + return response, 200 + + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument( + "name", + nullable=False, + required=True, + help="Name is required. Name must be between 1 to 100 characters.", + type=_validate_name, + ) + parser.add_argument( + "settings", + type=dict, + location="json", + nullable=False, + required=True, + ) + args = parser.parse_args() + + ExternalDatasetService.validate_api_list(args["settings"]) + + # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator + if not current_user.is_dataset_editor: + raise Forbidden() + + try: + external_knowledge_api = ExternalDatasetService.create_external_knowledge_api( + tenant_id=current_user.current_tenant_id, user_id=current_user.id, args=args + ) + except services.errors.dataset.DatasetNameDuplicateError: + raise DatasetNameDuplicateError() + + return external_knowledge_api.to_dict(), 201 + + +class ExternalApiTemplateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, external_knowledge_api_id): + external_knowledge_api_id = str(external_knowledge_api_id) + external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id) + if external_knowledge_api is None: + raise NotFound("API template not found.") + + return external_knowledge_api.to_dict(), 200 + + @setup_required + @login_required + @account_initialization_required + def patch(self, external_knowledge_api_id): + external_knowledge_api_id = str(external_knowledge_api_id) + + parser = reqparse.RequestParser() + parser.add_argument( + "name", + nullable=False, + required=True, + help="type is required. Name must be between 1 to 100 characters.", + type=_validate_name, + ) + parser.add_argument( + "settings", + type=dict, + location="json", + nullable=False, + required=True, + ) + args = parser.parse_args() + ExternalDatasetService.validate_api_list(args["settings"]) + + external_knowledge_api = ExternalDatasetService.update_external_knowledge_api( + tenant_id=current_user.current_tenant_id, + user_id=current_user.id, + external_knowledge_api_id=external_knowledge_api_id, + args=args, + ) + + return external_knowledge_api.to_dict(), 200 + + @setup_required + @login_required + @account_initialization_required + def delete(self, external_knowledge_api_id): + external_knowledge_api_id = str(external_knowledge_api_id) + + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor or current_user.is_dataset_operator: + raise Forbidden() + + ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id) + return {"result": "success"}, 200 + + +class ExternalApiUseCheckApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, external_knowledge_api_id): + external_knowledge_api_id = str(external_knowledge_api_id) + + external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check( + external_knowledge_api_id + ) + return {"is_using": external_knowledge_api_is_using, "count": count}, 200 + + +class ExternalDatasetCreateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json") + parser.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "name", + nullable=False, + required=True, + help="name is required. Name must be between 1 to 100 characters.", + type=_validate_name, + ) + parser.add_argument("description", type=str, required=False, nullable=True, location="json") + parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") + + args = parser.parse_args() + + # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator + if not current_user.is_dataset_editor: + raise Forbidden() + + try: + dataset = ExternalDatasetService.create_external_dataset( + tenant_id=current_user.current_tenant_id, + user_id=current_user.id, + args=args, + ) + except services.errors.dataset.DatasetNameDuplicateError: + raise DatasetNameDuplicateError() + + return marshal(dataset, dataset_detail_fields), 201 + + +class ExternalKnowledgeHitTestingApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, dataset_id): + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + parser = reqparse.RequestParser() + parser.add_argument("query", type=str, location="json") + parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") + args = parser.parse_args() + + HitTestingService.hit_testing_args_check(args) + + try: + response = HitTestingService.external_retrieve( + dataset=dataset, + query=args["query"], + account=current_user, + external_retrieval_model=args["external_retrieval_model"], + ) + + return response + except Exception as e: + raise InternalServerError(str(e)) + + +class BedrockRetrievalApi(Resource): + # this api is only for internal testing + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json") + parser.add_argument( + "query", + nullable=False, + required=True, + type=str, + ) + parser.add_argument("knowledge_id", nullable=False, required=True, type=str) + args = parser.parse_args() + + # Call the knowledge retrieval service + result = ExternalDatasetTestService.knowledge_retrieval( + args["retrieval_setting"], args["query"], args["knowledge_id"] + ) + return result, 200 + + +api.add_resource(ExternalKnowledgeHitTestingApi, "/datasets//external-hit-testing") +api.add_resource(ExternalDatasetCreateApi, "/datasets/external") +api.add_resource(ExternalApiTemplateListApi, "/datasets/external-knowledge-api") +api.add_resource(ExternalApiTemplateApi, "/datasets/external-knowledge-api/") +api.add_resource(ExternalApiUseCheckApi, "/datasets/external-knowledge-api//use-check") +# this api is only for internal test +api.add_resource(BedrockRetrievalApi, "/test/retrieval") diff --git a/api/controllers/console/datasets/file.py b/api/controllers/console/datasets/file.py deleted file mode 100644 index 3b2083bcc3351c..00000000000000 --- a/api/controllers/console/datasets/file.py +++ /dev/null @@ -1,87 +0,0 @@ -from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal_with - -import services -from configs import dify_config -from controllers.console import api -from controllers.console.datasets.error import ( - FileTooLargeError, - NoFileUploadedError, - TooManyFilesError, - UnsupportedFileTypeError, -) -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from fields.file_fields import file_fields, upload_config_fields -from libs.login import login_required -from services.file_service import ALLOWED_EXTENSIONS, UNSTRUCTURED_ALLOWED_EXTENSIONS, FileService - -PREVIEW_WORDS_LIMIT = 3000 - - -class FileApi(Resource): - - @setup_required - @login_required - @account_initialization_required - @marshal_with(upload_config_fields) - def get(self): - file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT - batch_count_limit = dify_config.UPLOAD_FILE_BATCH_LIMIT - image_file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT - return { - 'file_size_limit': file_size_limit, - 'batch_count_limit': batch_count_limit, - 'image_file_size_limit': image_file_size_limit - }, 200 - - @setup_required - @login_required - @account_initialization_required - @marshal_with(file_fields) - @cloud_edition_billing_resource_check(resource='documents') - def post(self): - - # get file from request - file = request.files['file'] - - # check file - if 'file' not in request.files: - raise NoFileUploadedError() - - if len(request.files) > 1: - raise TooManyFilesError() - try: - upload_file = FileService.upload_file(file, current_user) - except services.errors.file.FileTooLargeError as file_too_large_error: - raise FileTooLargeError(file_too_large_error.description) - except services.errors.file.UnsupportedFileTypeError: - raise UnsupportedFileTypeError() - - return upload_file, 201 - - -class FilePreviewApi(Resource): - @setup_required - @login_required - @account_initialization_required - def get(self, file_id): - file_id = str(file_id) - text = FileService.get_file_preview(file_id) - return {'content': text} - - -class FileSupportTypeApi(Resource): - @setup_required - @login_required - @account_initialization_required - def get(self): - etl_type = dify_config.ETL_TYPE - allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS - return {'allowed_extensions': allowed_extensions} - - -api.add_resource(FileApi, '/files/upload') -api.add_resource(FilePreviewApi, '/files//preview') -api.add_resource(FileSupportTypeApi, '/files/support-type') diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 8771bf909ed650..495f511275b4b9 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,86 +1,23 @@ -import logging +from flask_restful import Resource -from flask_login import current_user -from flask_restful import Resource, marshal, reqparse -from werkzeug.exceptions import Forbidden, InternalServerError, NotFound - -import services from controllers.console import api -from controllers.console.app.error import ( - CompletionRequestError, - ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, - ProviderQuotaExceededError, -) -from controllers.console.datasets.error import DatasetNotInitializedError -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required -from core.errors.error import ( - LLMBadRequestError, - ModelCurrentlyNotSupportError, - ProviderTokenNotInitError, - QuotaExceededError, -) -from core.model_runtime.errors.invoke import InvokeError -from fields.hit_testing_fields import hit_testing_record_fields +from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase +from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required -from services.dataset_service import DatasetService -from services.hit_testing_service import HitTestingService - -class HitTestingApi(Resource): +class HitTestingApi(Resource, DatasetsHitTestingBase): @setup_required @login_required @account_initialization_required def post(self, dataset_id): dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) - if dataset is None: - raise NotFound("Dataset not found.") - - try: - DatasetService.check_dataset_permission(dataset, current_user) - except services.errors.account.NoPermissionError as e: - raise Forbidden(str(e)) - - parser = reqparse.RequestParser() - parser.add_argument('query', type=str, location='json') - parser.add_argument('retrieval_model', type=dict, required=False, location='json') - args = parser.parse_args() - - HitTestingService.hit_testing_args_check(args) - - try: - response = HitTestingService.retrieve( - dataset=dataset, - query=args['query'], - account=current_user, - retrieval_model=args['retrieval_model'], - limit=10 - ) + dataset = self.get_and_validate_dataset(dataset_id_str) + args = self.parse_args() + self.hit_testing_args_check(args) - return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)} - except services.errors.index.IndexNotInitializedError: - raise DatasetNotInitializedError() - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - except QuotaExceededError: - raise ProviderQuotaExceededError() - except ModelCurrentlyNotSupportError: - raise ProviderModelCurrentlyNotSupportError() - except LLMBadRequestError: - raise ProviderNotInitializeError( - "No Embedding Model or Reranking Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") - except InvokeError as e: - raise CompletionRequestError(e.description) - except ValueError as e: - raise ValueError(str(e)) - except Exception as e: - logging.exception("Hit testing failed.") - raise InternalServerError(str(e)) + return self.perform_hit_testing(dataset, args) -api.add_resource(HitTestingApi, '/datasets//hit-testing') +api.add_resource(HitTestingApi, "/datasets//hit-testing") diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py new file mode 100644 index 00000000000000..3b4c07686361d0 --- /dev/null +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -0,0 +1,85 @@ +import logging + +from flask_login import current_user +from flask_restful import marshal, reqparse +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound + +import services.dataset_service +from controllers.console.app.error import ( + CompletionRequestError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from controllers.console.datasets.error import DatasetNotInitializedError +from core.errors.error import ( + LLMBadRequestError, + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) +from core.model_runtime.errors.invoke import InvokeError +from fields.hit_testing_fields import hit_testing_record_fields +from services.dataset_service import DatasetService +from services.hit_testing_service import HitTestingService + + +class DatasetsHitTestingBase: + @staticmethod + def get_and_validate_dataset(dataset_id: str): + dataset = DatasetService.get_dataset(dataset_id) + if dataset is None: + raise NotFound("Dataset not found.") + + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + return dataset + + @staticmethod + def hit_testing_args_check(args): + HitTestingService.hit_testing_args_check(args) + + @staticmethod + def parse_args(): + parser = reqparse.RequestParser() + + parser.add_argument("query", type=str, location="json") + parser.add_argument("retrieval_model", type=dict, required=False, location="json") + parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") + return parser.parse_args() + + @staticmethod + def perform_hit_testing(dataset, args): + try: + response = HitTestingService.retrieve( + dataset=dataset, + query=args["query"], + account=current_user, + retrieval_model=args["retrieval_model"], + external_retrieval_model=args["external_retrieval_model"], + limit=10, + ) + return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} + except services.errors.index.IndexNotInitializedError: + raise DatasetNotInitializedError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except LLMBadRequestError: + raise ProviderNotInitializeError( + "No Embedding Model or Reranking Model available. Please configure a valid provider " + "in the Settings -> Model Provider." + ) + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError as e: + raise ValueError(str(e)) + except Exception as e: + logging.exception("Hit testing failed.") + raise InternalServerError(str(e)) diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py index bbd91256f1c29c..9127c8af455f6c 100644 --- a/api/controllers/console/datasets/website.py +++ b/api/controllers/console/datasets/website.py @@ -2,23 +2,22 @@ from controllers.console import api from controllers.console.datasets.error import WebsiteCrawlError -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required from services.website_service import WebsiteService class WebsiteCrawlApi(Resource): - @setup_required @login_required @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, choices=['firecrawl'], - required=True, nullable=True, location='json') - parser.add_argument('url', type=str, required=True, nullable=True, location='json') - parser.add_argument('options', type=dict, required=True, nullable=True, location='json') + parser.add_argument( + "provider", type=str, choices=["firecrawl", "jinareader"], required=True, nullable=True, location="json" + ) + parser.add_argument("url", type=str, required=True, nullable=True, location="json") + parser.add_argument("options", type=dict, required=True, nullable=True, location="json") args = parser.parse_args() WebsiteService.document_create_args_validate(args) # crawl url @@ -35,15 +34,15 @@ class WebsiteCrawlStatusApi(Resource): @account_initialization_required def get(self, job_id: str): parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, choices=['firecrawl'], required=True, location='args') + parser.add_argument("provider", type=str, choices=["firecrawl", "jinareader"], required=True, location="args") args = parser.parse_args() # get crawl status try: - result = WebsiteService.get_crawl_status(job_id, args['provider']) + result = WebsiteService.get_crawl_status(job_id, args["provider"]) except Exception as e: raise WebsiteCrawlError(str(e)) return result, 200 -api.add_resource(WebsiteCrawlApi, '/website/crawl') -api.add_resource(WebsiteCrawlStatusApi, '/website/crawl/status/') +api.add_resource(WebsiteCrawlApi, "/website/crawl") +api.add_resource(WebsiteCrawlStatusApi, "/website/crawl/status/") diff --git a/api/controllers/console/error.py b/api/controllers/console/error.py index 888dad83ccda84..e0630ca66cc7c0 100644 --- a/api/controllers/console/error.py +++ b/api/controllers/console/error.py @@ -2,35 +2,87 @@ class AlreadySetupError(BaseHTTPException): - error_code = 'already_setup' + error_code = "already_setup" description = "Dify has been successfully installed. Please refresh the page or return to the dashboard homepage." code = 403 class NotSetupError(BaseHTTPException): - error_code = 'not_setup' - description = "Dify has not been initialized and installed yet. " \ - "Please proceed with the initialization and installation process first." + error_code = "not_setup" + description = ( + "Dify has not been initialized and installed yet. " + "Please proceed with the initialization and installation process first." + ) code = 401 + class NotInitValidateError(BaseHTTPException): - error_code = 'not_init_validated' - description = "Init validation has not been completed yet. " \ - "Please proceed with the init validation process first." + error_code = "not_init_validated" + description = "Init validation has not been completed yet. Please proceed with the init validation process first." code = 401 + class InitValidateFailedError(BaseHTTPException): - error_code = 'init_validate_failed' + error_code = "init_validate_failed" description = "Init validation failed. Please check the password and try again." code = 401 + class AccountNotLinkTenantError(BaseHTTPException): - error_code = 'account_not_link_tenant' + error_code = "account_not_link_tenant" description = "Account not link tenant." code = 403 class AlreadyActivateError(BaseHTTPException): - error_code = 'already_activate' + error_code = "already_activate" description = "Auth Token is invalid or account already activated, please check again." code = 403 + + +class NotAllowedCreateWorkspace(BaseHTTPException): + error_code = "not_allowed_create_workspace" + description = "Workspace not found, please contact system admin to invite you to join in a workspace." + code = 400 + + +class AccountBannedError(BaseHTTPException): + error_code = "account_banned" + description = "Account is banned." + code = 400 + + +class NotAllowedRegister(BaseHTTPException): + error_code = "unauthorized" + description = "Account not found." + code = 400 + + +class EmailSendIpLimitError(BaseHTTPException): + error_code = "email_send_ip_limit" + description = "Too many emails have been sent from this IP address recently. Please try again later." + code = 429 + + +class FileTooLargeError(BaseHTTPException): + error_code = "file_too_large" + description = "File size exceeded. {message}" + code = 413 + + +class UnsupportedFileTypeError(BaseHTTPException): + error_code = "unsupported_file_type" + description = "File type not allowed." + code = 415 + + +class TooManyFilesError(BaseHTTPException): + error_code = "too_many_files" + description = "Only one file is allowed." + code = 400 + + +class NoFileUploadedError(BaseHTTPException): + error_code = "no_file_uploaded" + description = "Please upload your file." + code = 400 diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 27cc83042a7822..9690677f61b1c2 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -33,14 +33,10 @@ class ChatAudioApi(InstalledAppResource): def post(self, installed_app): app_model = installed_app.app - file = request.files['file'] + file = request.files["file"] try: - response = AudioService.transcript_asr( - app_model=app_model, - file=file, - end_user=None - ) + response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None) return response except services.errors.app_model_config.AppModelConfigBrokenError: @@ -76,30 +72,27 @@ def post(self, installed_app): app_model = installed_app.app try: parser = reqparse.RequestParser() - parser.add_argument('message_id', type=str, required=False, location='json') - parser.add_argument('voice', type=str, location='json') - parser.add_argument('text', type=str, location='json') - parser.add_argument('streaming', type=bool, location='json') + parser.add_argument("message_id", type=str, required=False, location="json") + parser.add_argument("voice", type=str, location="json") + parser.add_argument("text", type=str, location="json") + parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get('message_id', None) - text = args.get('text', None) - if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] - and app_model.workflow - and app_model.workflow.features_dict): - text_to_speech = app_model.workflow.features_dict.get('text_to_speech') - voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + message_id = args.get("message_id", None) + text = args.get("text", None) + if ( + app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} + and app_model.workflow + and app_model.workflow.features_dict + ): + text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + voice = args.get("voice") or text_to_speech.get("voice") else: try: - voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice') + voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") except Exception: voice = None - response = AudioService.transcript_tts( - app_model=app_model, - message_id=message_id, - voice=voice, - text=text - ) + response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text) return response except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") @@ -127,7 +120,7 @@ def post(self, installed_app): raise InternalServerError() -api.add_resource(ChatAudioApi, '/installed-apps//audio-to-text', endpoint='installed_app_audio') -api.add_resource(ChatTextApi, '/installed-apps//text-to-audio', endpoint='installed_app_text') +api.add_resource(ChatAudioApi, "/installed-apps//audio-to-text", endpoint="installed_app_audio") +api.add_resource(ChatTextApi, "/installed-apps//text-to-audio", endpoint="installed_app_text") # api.add_resource(ChatTextApiWithMessageId, '/installed-apps//text-to-audio/message-id', # endpoint='installed_app_text_with_message_id') diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 869b56e13bf939..125bc1af8c41ec 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -30,33 +30,28 @@ # define completion api for user class CompletionApi(InstalledAppResource): - def post(self, installed_app): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, location='json', default='') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, location="json", default="") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' - args['auto_generate_name'] = False + streaming = args["response_mode"] == "streaming" + args["auto_generate_name"] = False installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.EXPLORE, - streaming=streaming + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming ) return helper.compact_generate_response(response) @@ -85,41 +80,38 @@ def post(self, installed_app): class CompletionStopApi(InstalledAppResource): def post(self, installed_app, task_id): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ChatApi(InstalledAppResource): def post(self, installed_app): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, required=True, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') - parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, required=True, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") + parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") args = parser.parse_args() - args['auto_generate_name'] = False + args["auto_generate_name"] = False installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.EXPLORE, - streaming=True + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True ) return helper.compact_generate_response(response) @@ -149,15 +141,27 @@ class ChatStopApi(InstalledAppResource): def post(self, installed_app, task_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(CompletionApi, '/installed-apps//completion-messages', endpoint='installed_app_completion') -api.add_resource(CompletionStopApi, '/installed-apps//completion-messages//stop', endpoint='installed_app_stop_completion') -api.add_resource(ChatApi, '/installed-apps//chat-messages', endpoint='installed_app_chat_completion') -api.add_resource(ChatStopApi, '/installed-apps//chat-messages//stop', endpoint='installed_app_stop_chat_completion') +api.add_resource( + CompletionApi, "/installed-apps//completion-messages", endpoint="installed_app_completion" +) +api.add_resource( + CompletionStopApi, + "/installed-apps//completion-messages//stop", + endpoint="installed_app_stop_completion", +) +api.add_resource( + ChatApi, "/installed-apps//chat-messages", endpoint="installed_app_chat_completion" +) +api.add_resource( + ChatStopApi, + "/installed-apps//chat-messages//stop", + endpoint="installed_app_stop_chat_completion", +) diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index ea0fa4e17e9d44..6f9d7769b942ce 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -16,30 +16,29 @@ class ConversationListApi(InstalledAppResource): - @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, installed_app): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args") args = parser.parse_args() pinned = None - if 'pinned' in args and args['pinned'] is not None: - pinned = True if args['pinned'] == 'true' else False + if "pinned" in args and args["pinned"] is not None: + pinned = True if args["pinned"] == "true" else False try: return WebConversationService.pagination_by_last_id( app_model=app_model, user=current_user, - last_id=args['last_id'], - limit=args['limit'], + last_id=args["last_id"], + limit=args["limit"], invoke_from=InvokeFrom.EXPLORE, pinned=pinned, ) @@ -51,7 +50,7 @@ class ConversationApi(InstalledAppResource): def delete(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -65,39 +64,33 @@ def delete(self, installed_app, c_id): class ConversationRenameApi(InstalledAppResource): - @marshal_with(simple_conversation_fields) def post(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, location='json') - parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') + parser.add_argument("name", type=str, required=False, location="json") + parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") args = parser.parse_args() try: return ConversationService.rename( - app_model, - conversation_id, - current_user, - args['name'], - args['auto_generate'] + app_model, conversation_id, current_user, args["name"], args["auto_generate"] ) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") class ConversationPinApi(InstalledAppResource): - def patch(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -114,7 +107,7 @@ class ConversationUnPinApi(InstalledAppResource): def patch(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -123,8 +116,26 @@ def patch(self, installed_app, c_id): return {"result": "success"} -api.add_resource(ConversationRenameApi, '/installed-apps//conversations//name', endpoint='installed_app_conversation_rename') -api.add_resource(ConversationListApi, '/installed-apps//conversations', endpoint='installed_app_conversations') -api.add_resource(ConversationApi, '/installed-apps//conversations/', endpoint='installed_app_conversation') -api.add_resource(ConversationPinApi, '/installed-apps//conversations//pin', endpoint='installed_app_conversation_pin') -api.add_resource(ConversationUnPinApi, '/installed-apps//conversations//unpin', endpoint='installed_app_conversation_unpin') +api.add_resource( + ConversationRenameApi, + "/installed-apps//conversations//name", + endpoint="installed_app_conversation_rename", +) +api.add_resource( + ConversationListApi, "/installed-apps//conversations", endpoint="installed_app_conversations" +) +api.add_resource( + ConversationApi, + "/installed-apps//conversations/", + endpoint="installed_app_conversation", +) +api.add_resource( + ConversationPinApi, + "/installed-apps//conversations//pin", + endpoint="installed_app_conversation_pin", +) +api.add_resource( + ConversationUnPinApi, + "/installed-apps//conversations//unpin", + endpoint="installed_app_conversation_unpin", +) diff --git a/api/controllers/console/explore/error.py b/api/controllers/console/explore/error.py index 9c3216ecc8c178..18221b7797cdb0 100644 --- a/api/controllers/console/explore/error.py +++ b/api/controllers/console/explore/error.py @@ -2,24 +2,24 @@ class NotCompletionAppError(BaseHTTPException): - error_code = 'not_completion_app' + error_code = "not_completion_app" description = "Not Completion App" code = 400 class NotChatAppError(BaseHTTPException): - error_code = 'not_chat_app' + error_code = "not_chat_app" description = "App mode is invalid." code = 400 class NotWorkflowAppError(BaseHTTPException): - error_code = 'not_workflow_app' + error_code = "not_workflow_app" description = "Only support workflow app." code = 400 class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException): - error_code = 'app_suggested_questions_after_answer_disabled' + error_code = "app_suggested_questions_after_answer_disabled" description = "Function Suggested questions after answer disabled." code = 403 diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index ec7bbed3074ad6..d72715a38c5b04 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -11,7 +11,7 @@ from extensions.ext_database import db from fields.installed_app_fields import installed_app_list_fields from libs.login import login_required -from models.model import App, InstalledApp, RecommendedApp +from models import App, InstalledApp, RecommendedApp from services.account_service import TenantService @@ -21,72 +21,72 @@ class InstalledAppsListApi(Resource): @marshal_with(installed_app_list_fields) def get(self): current_tenant_id = current_user.current_tenant_id - installed_apps = db.session.query(InstalledApp).filter( - InstalledApp.tenant_id == current_tenant_id - ).all() + installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all() current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) installed_apps = [ { - 'id': installed_app.id, - 'app': installed_app.app, - 'app_owner_tenant_id': installed_app.app_owner_tenant_id, - 'is_pinned': installed_app.is_pinned, - 'last_used_at': installed_app.last_used_at, - 'editable': current_user.role in ["owner", "admin"], - 'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id + "id": installed_app.id, + "app": installed_app.app, + "app_owner_tenant_id": installed_app.app_owner_tenant_id, + "is_pinned": installed_app.is_pinned, + "last_used_at": installed_app.last_used_at, + "editable": current_user.role in {"owner", "admin"}, + "uninstallable": current_tenant_id == installed_app.app_owner_tenant_id, } for installed_app in installed_apps + if installed_app.app is not None ] - installed_apps.sort(key=lambda app: (-app['is_pinned'], - app['last_used_at'] is None, - -app['last_used_at'].timestamp() if app['last_used_at'] is not None else 0)) + installed_apps.sort( + key=lambda app: ( + -app["is_pinned"], + app["last_used_at"] is None, + -app["last_used_at"].timestamp() if app["last_used_at"] is not None else 0, + ) + ) - return {'installed_apps': installed_apps} + return {"installed_apps": installed_apps} @login_required @account_initialization_required - @cloud_edition_billing_resource_check('apps') + @cloud_edition_billing_resource_check("apps") def post(self): parser = reqparse.RequestParser() - parser.add_argument('app_id', type=str, required=True, help='Invalid app_id') + parser.add_argument("app_id", type=str, required=True, help="Invalid app_id") args = parser.parse_args() - recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first() + recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() if recommended_app is None: - raise NotFound('App not found') + raise NotFound("App not found") current_tenant_id = current_user.current_tenant_id - app = db.session.query(App).filter( - App.id == args['app_id'] - ).first() + app = db.session.query(App).filter(App.id == args["app_id"]).first() if app is None: - raise NotFound('App not found') + raise NotFound("App not found") if not app.is_public: - raise Forbidden('You can\'t install a non-public app') + raise Forbidden("You can't install a non-public app") - installed_app = InstalledApp.query.filter(and_( - InstalledApp.app_id == args['app_id'], - InstalledApp.tenant_id == current_tenant_id - )).first() + installed_app = InstalledApp.query.filter( + and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id) + ).first() if installed_app is None: # todo: position recommended_app.install_count += 1 new_installed_app = InstalledApp( - app_id=args['app_id'], + app_id=args["app_id"], tenant_id=current_tenant_id, app_owner_tenant_id=app.tenant_id, is_pinned=False, - last_used_at=datetime.now(timezone.utc).replace(tzinfo=None) + last_used_at=datetime.now(timezone.utc).replace(tzinfo=None), ) db.session.add(new_installed_app) db.session.commit() - return {'message': 'App installed successfully'} + return {"message": "App installed successfully"} class InstalledAppApi(InstalledAppResource): @@ -94,30 +94,31 @@ class InstalledAppApi(InstalledAppResource): update and delete an installed app use InstalledAppResource to apply default decorators and get installed_app """ + def delete(self, installed_app): if installed_app.app_owner_tenant_id == current_user.current_tenant_id: - raise BadRequest('You can\'t uninstall an app owned by the current tenant') + raise BadRequest("You can't uninstall an app owned by the current tenant") db.session.delete(installed_app) db.session.commit() - return {'result': 'success', 'message': 'App uninstalled successfully'} + return {"result": "success", "message": "App uninstalled successfully"} def patch(self, installed_app): parser = reqparse.RequestParser() - parser.add_argument('is_pinned', type=inputs.boolean) + parser.add_argument("is_pinned", type=inputs.boolean) args = parser.parse_args() commit_args = False - if 'is_pinned' in args: - installed_app.is_pinned = args['is_pinned'] + if "is_pinned" in args: + installed_app.is_pinned = args["is_pinned"] commit_args = True if commit_args: db.session.commit() - return {'result': 'success', 'message': 'App info updated successfully'} + return {"result": "success", "message": "App info updated successfully"} -api.add_resource(InstalledAppsListApi, '/installed-apps') -api.add_resource(InstalledAppApi, '/installed-apps/') +api.add_resource(InstalledAppsListApi, "/installed-apps") +api.add_resource(InstalledAppApi, "/installed-apps/") diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 3523a869003a32..3d221ff30a6599 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -40,23 +40,25 @@ def get(self, installed_app): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') - parser.add_argument('first_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") + parser.add_argument("first_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() try: - return MessageService.pagination_by_first_id(app_model, current_user, - args['conversation_id'], args['first_id'], args['limit']) + return MessageService.pagination_by_first_id( + app_model, current_user, args["conversation_id"], args["first_id"], args["limit"], "desc" + ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.message.FirstMessageNotExistsError: raise NotFound("First Message Not Exists.") + class MessageFeedbackApi(InstalledAppResource): def post(self, installed_app, message_id): app_model = installed_app.app @@ -64,30 +66,32 @@ def post(self, installed_app, message_id): message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, current_user, args['rating']) + MessageService.create_feedback(app_model, message_id, current_user, args["rating"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class MessageMoreLikeThisApi(InstalledAppResource): def get(self, installed_app, message_id): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') + parser.add_argument( + "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" + ) args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' + streaming = args["response_mode"] == "streaming" try: response = AppGenerateService.generate_more_like_this( @@ -95,7 +99,7 @@ def get(self, installed_app, message_id): user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE, - streaming=streaming + streaming=streaming, ) return helper.compact_generate_response(response) except MessageNotExistsError: @@ -121,17 +125,14 @@ class MessageSuggestedQuestionApi(InstalledAppResource): def get(self, installed_app, message_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() message_id = str(message_id) try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, - user=current_user, - message_id=message_id, - invoke_from=InvokeFrom.EXPLORE + app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE ) except MessageNotExistsError: raise NotFound("Message not found") @@ -151,10 +152,22 @@ def get(self, installed_app, message_id): logging.exception("internal server error.") raise InternalServerError() - return {'data': questions} + return {"data": questions} -api.add_resource(MessageListApi, '/installed-apps//messages', endpoint='installed_app_messages') -api.add_resource(MessageFeedbackApi, '/installed-apps//messages//feedbacks', endpoint='installed_app_message_feedback') -api.add_resource(MessageMoreLikeThisApi, '/installed-apps//messages//more-like-this', endpoint='installed_app_more_like_this') -api.add_resource(MessageSuggestedQuestionApi, '/installed-apps//messages//suggested-questions', endpoint='installed_app_suggested_question') +api.add_resource(MessageListApi, "/installed-apps//messages", endpoint="installed_app_messages") +api.add_resource( + MessageFeedbackApi, + "/installed-apps//messages//feedbacks", + endpoint="installed_app_message_feedback", +) +api.add_resource( + MessageMoreLikeThisApi, + "/installed-apps//messages//more-like-this", + endpoint="installed_app_more_like_this", +) +api.add_resource( + MessageSuggestedQuestionApi, + "/installed-apps//messages//suggested-questions", + endpoint="installed_app_suggested_question", +) diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 0a168d6306c83e..fee52248a698e0 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,7 +1,7 @@ +from flask_restful import marshal_with -from flask_restful import fields, marshal_with - -from configs import dify_config +from controllers.common import fields +from controllers.common import helpers as controller_helpers from controllers.console import api from controllers.console.app.error import AppUnavailableError from controllers.console.explore.wraps import InstalledAppResource @@ -11,41 +11,16 @@ class AppParameterApi(InstalledAppResource): """Resource for app variables.""" - variable_fields = { - 'key': fields.String, - 'name': fields.String, - 'description': fields.String, - 'type': fields.String, - 'default': fields.String, - 'max_length': fields.Integer, - 'options': fields.List(fields.String) - } - - system_parameters_fields = { - 'image_file_size_limit': fields.String - } - - parameters_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw, - 'suggested_questions_after_answer': fields.Raw, - 'speech_to_text': fields.Raw, - 'text_to_speech': fields.Raw, - 'retriever_resource': fields.Raw, - 'annotation_reply': fields.Raw, - 'more_like_this': fields.Raw, - 'user_input_form': fields.Raw, - 'sensitive_word_avoidance': fields.Raw, - 'file_upload': fields.Raw, - 'system_parameters': fields.Nested(system_parameters_fields) - } - @marshal_with(parameters_fields) + @marshal_with(fields.parameters_fields) def get(self, installed_app: InstalledApp): """Retrieve app parameters.""" app_model = installed_app.app - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model is None: + raise AppUnavailableError() + + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() @@ -54,33 +29,16 @@ def get(self, installed_app: InstalledApp): user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config + if app_model_config is None: + raise AppUnavailableError() + features_dict = app_model_config.to_dict() - user_input_form = features_dict.get('user_input_form', []) + user_input_form = features_dict.get("user_input_form", []) - return { - 'opening_statement': features_dict.get('opening_statement'), - 'suggested_questions': features_dict.get('suggested_questions', []), - 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', - {"enabled": False}), - 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), - 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), - 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), - 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), - 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), - 'user_input_form': user_input_form, - 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', - {"enabled": False, "type": "", "configs": []}), - 'file_upload': features_dict.get('file_upload', {"image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"] - }}), - 'system_parameters': { - 'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT - } - } + return controller_helpers.get_parameters_from_feature_dict( + features_dict=features_dict, user_input_form=user_input_form + ) class ExploreAppMetaApi(InstalledAppResource): @@ -90,6 +48,7 @@ def get(self, installed_app: InstalledApp): return AppService().get_app_meta(app_model) -api.add_resource(AppParameterApi, '/installed-apps//parameters', - endpoint='installed_app_parameters') -api.add_resource(ExploreAppMetaApi, '/installed-apps//meta', endpoint='installed_app_meta') +api.add_resource( + AppParameterApi, "/installed-apps//parameters", endpoint="installed_app_parameters" +) +api.add_resource(ExploreAppMetaApi, "/installed-apps//meta", endpoint="installed_app_meta") diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 6e10e2ec92671e..5daaa1e7c38ba8 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -8,28 +8,28 @@ from services.recommended_app_service import RecommendedAppService app_fields = { - 'id': fields.String, - 'name': fields.String, - 'mode': fields.String, - 'icon': fields.String, - 'icon_background': fields.String + "id": fields.String, + "name": fields.String, + "mode": fields.String, + "icon": fields.String, + "icon_background": fields.String, } recommended_app_fields = { - 'app': fields.Nested(app_fields, attribute='app'), - 'app_id': fields.String, - 'description': fields.String(attribute='description'), - 'copyright': fields.String, - 'privacy_policy': fields.String, - 'custom_disclaimer': fields.String, - 'category': fields.String, - 'position': fields.Integer, - 'is_listed': fields.Boolean + "app": fields.Nested(app_fields, attribute="app"), + "app_id": fields.String, + "description": fields.String(attribute="description"), + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "category": fields.String, + "position": fields.Integer, + "is_listed": fields.Boolean, } recommended_app_list_fields = { - 'recommended_apps': fields.List(fields.Nested(recommended_app_fields)), - 'categories': fields.List(fields.String) + "recommended_apps": fields.List(fields.Nested(recommended_app_fields)), + "categories": fields.List(fields.String), } @@ -40,11 +40,11 @@ class RecommendedAppListApi(Resource): def get(self): # language args parser = reqparse.RequestParser() - parser.add_argument('language', type=str, location='args') + parser.add_argument("language", type=str, location="args") args = parser.parse_args() - if args.get('language') and args.get('language') in languages: - language_prefix = args.get('language') + if args.get("language") and args.get("language") in languages: + language_prefix = args.get("language") elif current_user and current_user.interface_language: language_prefix = current_user.interface_language else: @@ -61,5 +61,5 @@ def get(self, app_id): return RecommendedAppService.get_recommend_app_detail(app_id) -api.add_resource(RecommendedAppListApi, '/explore/apps') -api.add_resource(RecommendedAppApi, '/explore/apps/') +api.add_resource(RecommendedAppListApi, "/explore/apps") +api.add_resource(RecommendedAppApi, "/explore/apps/") diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index cf86b2fee1cead..0fc963747981e1 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -11,56 +11,54 @@ from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService -feedback_fields = { - 'rating': fields.String -} +feedback_fields = {"rating": fields.String} message_fields = { - 'id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String, - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'created_at': TimestampField + "id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "answer": fields.String, + "message_files": fields.List(fields.Nested(message_file_fields)), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "created_at": TimestampField, } class SavedMessageListApi(InstalledAppResource): saved_message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } @marshal_with(saved_message_infinite_scroll_pagination_fields) def get(self, installed_app): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - return SavedMessageService.pagination_by_last_id(app_model, current_user, args['last_id'], args['limit']) + return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"]) def post(self, installed_app): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('message_id', type=uuid_value, required=True, location='json') + parser.add_argument("message_id", type=uuid_value, required=True, location="json") args = parser.parse_args() try: - SavedMessageService.save(app_model, current_user, args['message_id']) + SavedMessageService.save(app_model, current_user, args["message_id"]) except MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class SavedMessageApi(InstalledAppResource): @@ -69,13 +67,21 @@ def delete(self, installed_app, message_id): message_id = str(message_id) - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() SavedMessageService.delete(app_model, current_user, message_id) - return {'result': 'success'} + return {"result": "success"} -api.add_resource(SavedMessageListApi, '/installed-apps//saved-messages', endpoint='installed_app_saved_messages') -api.add_resource(SavedMessageApi, '/installed-apps//saved-messages/', endpoint='installed_app_saved_message') +api.add_resource( + SavedMessageListApi, + "/installed-apps//saved-messages", + endpoint="installed_app_saved_messages", +) +api.add_resource( + SavedMessageApi, + "/installed-apps//saved-messages/", + endpoint="installed_app_saved_message", +) diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 7c5e211d4788bc..45f99b1db9fa9e 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -35,17 +35,13 @@ def post(self, installed_app: InstalledApp): raise NotWorkflowAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') - parser.add_argument('files', type=list, required=False, location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.EXPLORE, - streaming=True + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True ) return helper.compact_generate_response(response) @@ -76,10 +72,10 @@ def post(self, installed_app: InstalledApp, task_id: str): AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) - return { - "result": "success" - } + return {"result": "success"} -api.add_resource(InstalledAppWorkflowRunApi, '/installed-apps//workflows/run') -api.add_resource(InstalledAppWorkflowTaskStopApi, '/installed-apps//workflows/tasks//stop') +api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps//workflows/run") +api.add_resource( + InstalledAppWorkflowTaskStopApi, "/installed-apps//workflows/tasks//stop" +) diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 84890f1b46471a..49ea81a8a0f86d 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -7,36 +7,40 @@ from controllers.console.wraps import account_initialization_required from extensions.ext_database import db from libs.login import login_required -from models.model import InstalledApp +from models import InstalledApp def installed_app_required(view=None): def decorator(view): @wraps(view) def decorated(*args, **kwargs): - if not kwargs.get('installed_app_id'): - raise ValueError('missing installed_app_id in path parameters') + if not kwargs.get("installed_app_id"): + raise ValueError("missing installed_app_id in path parameters") - installed_app_id = kwargs.get('installed_app_id') + installed_app_id = kwargs.get("installed_app_id") installed_app_id = str(installed_app_id) - del kwargs['installed_app_id'] + del kwargs["installed_app_id"] - installed_app = db.session.query(InstalledApp).filter( - InstalledApp.id == str(installed_app_id), - InstalledApp.tenant_id == current_user.current_tenant_id - ).first() + installed_app = ( + db.session.query(InstalledApp) + .filter( + InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id + ) + .first() + ) if installed_app is None: - raise NotFound('Installed app not found') + raise NotFound("Installed app not found") if not installed_app.app: db.session.delete(installed_app) db.session.commit() - raise NotFound('Installed app not found') + raise NotFound("Installed app not found") return view(installed_app, *args, **kwargs) + return decorated if view: diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index fe73bcb98572a5..4ac0aa497e0866 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -3,8 +3,7 @@ from constants import HIDDEN_VALUE from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from fields.api_based_extension_fields import api_based_extension_fields from libs.login import login_required from models.api_based_extension import APIBasedExtension @@ -13,23 +12,18 @@ class CodeBasedExtensionAPI(Resource): - @setup_required @login_required @account_initialization_required def get(self): parser = reqparse.RequestParser() - parser.add_argument('module', type=str, required=True, location='args') + parser.add_argument("module", type=str, required=True, location="args") args = parser.parse_args() - return { - 'module': args['module'], - 'data': CodeBasedExtensionService.get_code_based_extension(args['module']) - } + return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])} class APIBasedExtensionAPI(Resource): - @setup_required @login_required @account_initialization_required @@ -44,23 +38,22 @@ def get(self): @marshal_with(api_based_extension_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('api_endpoint', type=str, required=True, location='json') - parser.add_argument('api_key', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") + parser.add_argument("api_endpoint", type=str, required=True, location="json") + parser.add_argument("api_key", type=str, required=True, location="json") args = parser.parse_args() extension_data = APIBasedExtension( tenant_id=current_user.current_tenant_id, - name=args['name'], - api_endpoint=args['api_endpoint'], - api_key=args['api_key'] + name=args["name"], + api_endpoint=args["api_endpoint"], + api_key=args["api_key"], ) return APIBasedExtensionService.save(extension_data) class APIBasedExtensionDetailAPI(Resource): - @setup_required @login_required @account_initialization_required @@ -82,16 +75,16 @@ def post(self, id): extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('api_endpoint', type=str, required=True, location='json') - parser.add_argument('api_key', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") + parser.add_argument("api_endpoint", type=str, required=True, location="json") + parser.add_argument("api_key", type=str, required=True, location="json") args = parser.parse_args() - extension_data_from_db.name = args['name'] - extension_data_from_db.api_endpoint = args['api_endpoint'] + extension_data_from_db.name = args["name"] + extension_data_from_db.api_endpoint = args["api_endpoint"] - if args['api_key'] != HIDDEN_VALUE: - extension_data_from_db.api_key = args['api_key'] + if args["api_key"] != HIDDEN_VALUE: + extension_data_from_db.api_key = args["api_key"] return APIBasedExtensionService.save(extension_data_from_db) @@ -106,10 +99,10 @@ def delete(self, id): APIBasedExtensionService.delete(extension_data_from_db) - return {'result': 'success'} + return {"result": "success"} -api.add_resource(CodeBasedExtensionAPI, '/code-based-extension') +api.add_resource(CodeBasedExtensionAPI, "/code-based-extension") -api.add_resource(APIBasedExtensionAPI, '/api-based-extension') -api.add_resource(APIBasedExtensionDetailAPI, '/api-based-extension/') +api.add_resource(APIBasedExtensionAPI, "/api-based-extension") +api.add_resource(APIBasedExtensionDetailAPI, "/api-based-extension/") diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 8475cd848822a7..70ab4ff865cb48 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -5,12 +5,10 @@ from services.feature_service import FeatureService from . import api -from .setup import setup_required -from .wraps import account_initialization_required, cloud_utm_record +from .wraps import account_initialization_required, cloud_utm_record, setup_required class FeatureApi(Resource): - @setup_required @login_required @account_initialization_required @@ -24,5 +22,5 @@ def get(self): return FeatureService.get_system_features().model_dump() -api.add_resource(FeatureApi, '/features') -api.add_resource(SystemFeatureApi, '/system-features') +api.add_resource(FeatureApi, "/features") +api.add_resource(SystemFeatureApi, "/system-features") diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py new file mode 100644 index 00000000000000..946d3db37f587b --- /dev/null +++ b/api/controllers/console/files.py @@ -0,0 +1,95 @@ +from flask import request +from flask_login import current_user +from flask_restful import Resource, marshal_with + +import services +from configs import dify_config +from constants import DOCUMENT_EXTENSIONS +from controllers.common.errors import FilenameNotExistsError +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_resource_check, + setup_required, +) +from fields.file_fields import file_fields, upload_config_fields +from libs.login import login_required +from services.file_service import FileService + +from .error import ( + FileTooLargeError, + NoFileUploadedError, + TooManyFilesError, + UnsupportedFileTypeError, +) + +PREVIEW_WORDS_LIMIT = 3000 + + +class FileApi(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(upload_config_fields) + def get(self): + return { + "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, + "batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT, + "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, + "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, + "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, + "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, + }, 200 + + @setup_required + @login_required + @account_initialization_required + @marshal_with(file_fields) + @cloud_edition_billing_resource_check("documents") + def post(self): + file = request.files["file"] + source = request.form.get("source") + + if "file" not in request.files: + raise NoFileUploadedError() + + if len(request.files) > 1: + raise TooManyFilesError() + + if not file.filename: + raise FilenameNotExistsError + + if source not in ("datasets", None): + source = None + + try: + upload_file = FileService.upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=current_user, + source=source, + ) + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() + + return upload_file, 201 + + +class FilePreviewApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, file_id): + file_id = str(file_id) + text = FileService.get_file_preview(file_id) + return {"content": text} + + +class FileSupportTypeApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + return {"allowed_extensions": DOCUMENT_EXTENSIONS} diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index 6feb1003a975f6..ae759bb752a30e 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -4,7 +4,7 @@ from flask_restful import Resource, reqparse from configs import dify_config -from libs.helper import str_len +from libs.helper import StrLen from models.model import DifySetup from services.account_service import TenantService @@ -14,12 +14,11 @@ class InitValidateAPI(Resource): - def get(self): init_status = get_init_validate_status() if init_status: - return { 'status': 'finished' } - return {'status': 'not_started' } + return {"status": "finished"} + return {"status": "not_started"} @only_edition_self_hosted def post(self): @@ -29,22 +28,23 @@ def post(self): raise AlreadySetupError() parser = reqparse.RequestParser() - parser.add_argument('password', type=str_len(30), - required=True, location='json') - input_password = parser.parse_args()['password'] + parser.add_argument("password", type=StrLen(30), required=True, location="json") + input_password = parser.parse_args()["password"] - if input_password != os.environ.get('INIT_PASSWORD'): - session['is_init_validated'] = False + if input_password != os.environ.get("INIT_PASSWORD"): + session["is_init_validated"] = False raise InitValidateFailedError() - - session['is_init_validated'] = True - return {'result': 'success'}, 201 + + session["is_init_validated"] = True + return {"result": "success"}, 201 + def get_init_validate_status(): - if dify_config.EDITION == 'SELF_HOSTED': - if os.environ.get('INIT_PASSWORD'): - return session.get('is_init_validated') or DifySetup.query.first() - + if dify_config.EDITION == "SELF_HOSTED": + if os.environ.get("INIT_PASSWORD"): + return session.get("is_init_validated") or DifySetup.query.first() + return True -api.add_resource(InitValidateAPI, '/init') + +api.add_resource(InitValidateAPI, "/init") diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py index 7664ba8c165db8..cd28cc946ee288 100644 --- a/api/controllers/console/ping.py +++ b/api/controllers/console/ping.py @@ -4,14 +4,11 @@ class PingApi(Resource): - def get(self): """ For connection health check """ - return { - "result": "pong" - } + return {"result": "pong"} -api.add_resource(PingApi, '/ping') +api.add_resource(PingApi, "/ping") diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py new file mode 100644 index 00000000000000..9b899bef641e8f --- /dev/null +++ b/api/controllers/console/remote_files.py @@ -0,0 +1,81 @@ +import urllib.parse +from typing import cast + +import httpx +from flask_login import current_user +from flask_restful import Resource, marshal_with, reqparse + +import services +from controllers.common import helpers +from core.file import helpers as file_helpers +from core.helper import ssrf_proxy +from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields +from models.account import Account +from services.file_service import FileService + +from .error import ( + FileTooLargeError, + UnsupportedFileTypeError, +) + + +class RemoteFileInfoApi(Resource): + @marshal_with(remote_file_info_fields) + def get(self, url): + decoded_url = urllib.parse.unquote(url) + resp = ssrf_proxy.head(decoded_url) + if resp.status_code != httpx.codes.OK: + # failed back to get method + resp = ssrf_proxy.get(decoded_url, timeout=3) + resp.raise_for_status() + return { + "file_type": resp.headers.get("Content-Type", "application/octet-stream"), + "file_length": int(resp.headers.get("Content-Length", 0)), + } + + +class RemoteFileUploadApi(Resource): + @marshal_with(file_fields_with_signed_url) + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("url", type=str, required=True, help="URL is required") + args = parser.parse_args() + + url = args["url"] + + resp = ssrf_proxy.head(url=url) + if resp.status_code != httpx.codes.OK: + resp = ssrf_proxy.get(url=url, timeout=3) + resp.raise_for_status() + + file_info = helpers.guess_file_info_from_response(resp) + + if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size): + raise FileTooLargeError + + content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content + + try: + user = cast(Account, current_user) + upload_file = FileService.upload_file( + filename=file_info.filename, + content=content, + mimetype=file_info.mimetype, + user=user, + source_url=url, + ) + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() + + return { + "id": upload_file.id, + "name": upload_file.name, + "size": upload_file.size, + "extension": upload_file.extension, + "url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id), + "mime_type": upload_file.mime_type, + "created_by": upload_file.created_by, + "created_at": upload_file.created_at, + }, 201 diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index ef7cc6bc03ce3a..e0b728d97739d3 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,32 +1,26 @@ -from functools import wraps - from flask import request from flask_restful import Resource, reqparse from configs import dify_config -from libs.helper import email, get_remote_ip, str_len +from libs.helper import StrLen, email, extract_remote_ip from libs.password import valid_password from models.model import DifySetup from services.account_service import RegisterService, TenantService from . import api -from .error import AlreadySetupError, NotInitValidateError, NotSetupError +from .error import AlreadySetupError, NotInitValidateError from .init_validate import get_init_validate_status from .wraps import only_edition_self_hosted class SetupApi(Resource): - def get(self): - if dify_config.EDITION == 'SELF_HOSTED': + if dify_config.EDITION == "SELF_HOSTED": setup_status = get_setup_status() if setup_status: - return { - 'step': 'finished', - 'setup_at': setup_status.setup_at.isoformat() - } - return {'step': 'not_started'} - return {'step': 'finished'} + return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()} + return {"step": "not_started"} + return {"step": "finished"} @only_edition_self_hosted def post(self): @@ -38,49 +32,28 @@ def post(self): tenant_count = TenantService.get_tenant_count() if tenant_count > 0: raise AlreadySetupError() - + if not get_init_validate_status(): raise NotInitValidateError() parser = reqparse.RequestParser() - parser.add_argument('email', type=email, - required=True, location='json') - parser.add_argument('name', type=str_len( - 30), required=True, location='json') - parser.add_argument('password', type=valid_password, - required=True, location='json') + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("name", type=StrLen(30), required=True, location="json") + parser.add_argument("password", type=valid_password, required=True, location="json") args = parser.parse_args() # setup RegisterService.setup( - email=args['email'], - name=args['name'], - password=args['password'], - ip_address=get_remote_ip(request) + email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request) ) - return {'result': 'success'}, 201 - - -def setup_required(view): - @wraps(view) - def decorated(*args, **kwargs): - # check setup - if not get_init_validate_status(): - raise NotInitValidateError() - - elif not get_setup_status(): - raise NotSetupError() - - return view(*args, **kwargs) - - return decorated + return {"result": "success"}, 201 def get_setup_status(): - if dify_config.EDITION == 'SELF_HOSTED': + if dify_config.EDITION == "SELF_HOSTED": return DifySetup.query.first() - else: - return True + return True + -api.add_resource(SetupApi, '/setup') +api.add_resource(SetupApi, "/setup") diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 004afaa531a086..ccd3293a6266fc 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -4,8 +4,7 @@ from werkzeug.exceptions import Forbidden from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from fields.tag_fields import tag_fields from libs.login import login_required from models.model import Tag @@ -13,20 +12,19 @@ def _validate_name(name): - if not name or len(name) < 1 or len(name) > 40: - raise ValueError('Name must be between 1 to 50 characters.') + if not name or len(name) < 1 or len(name) > 50: + raise ValueError("Name must be between 1 to 50 characters.") return name class TagListApi(Resource): - @setup_required @login_required @account_initialization_required @marshal_with(tag_fields) def get(self): - tag_type = request.args.get('type', type=str) - keyword = request.args.get('keyword', default=None, type=str) + tag_type = request.args.get("type", type=str) + keyword = request.args.get("keyword", default=None, type=str) tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword) return tags, 200 @@ -40,28 +38,21 @@ def post(self): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, required=True, - help='Name must be between 1 to 50 characters.', - type=_validate_name) - parser.add_argument('type', type=str, location='json', - choices=Tag.TAG_TYPE_LIST, - nullable=True, - help='Invalid tag type.') + parser.add_argument( + "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name + ) + parser.add_argument( + "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." + ) args = parser.parse_args() tag = TagService.save_tags(args) - response = { - 'id': tag.id, - 'name': tag.name, - 'type': tag.type, - 'binding_count': 0 - } + response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} return response, 200 class TagUpdateDeleteApi(Resource): - @setup_required @login_required @account_initialization_required @@ -72,20 +63,15 @@ def patch(self, tag_id): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, required=True, - help='Name must be between 1 to 50 characters.', - type=_validate_name) + parser.add_argument( + "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name + ) args = parser.parse_args() tag = TagService.update_tags(args, tag_id) binding_count = TagService.get_tag_binding_count(tag_id) - response = { - 'id': tag.id, - 'name': tag.name, - 'type': tag.type, - 'binding_count': binding_count - } + response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} return response, 200 @@ -104,7 +90,6 @@ def delete(self, tag_id): class TagBindingCreateApi(Resource): - @setup_required @login_required @account_initialization_required @@ -114,14 +99,15 @@ def post(self): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('tag_ids', type=list, nullable=False, required=True, location='json', - help='Tag IDs is required.') - parser.add_argument('target_id', type=str, nullable=False, required=True, location='json', - help='Target ID is required.') - parser.add_argument('type', type=str, location='json', - choices=Tag.TAG_TYPE_LIST, - nullable=True, - help='Invalid tag type.') + parser.add_argument( + "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." + ) + parser.add_argument( + "target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required." + ) + parser.add_argument( + "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." + ) args = parser.parse_args() TagService.save_tag_binding(args) @@ -129,7 +115,6 @@ def post(self): class TagBindingDeleteApi(Resource): - @setup_required @login_required @account_initialization_required @@ -139,21 +124,18 @@ def post(self): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('tag_id', type=str, nullable=False, required=True, - help='Tag ID is required.') - parser.add_argument('target_id', type=str, nullable=False, required=True, - help='Target ID is required.') - parser.add_argument('type', type=str, location='json', - choices=Tag.TAG_TYPE_LIST, - nullable=True, - help='Invalid tag type.') + parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") + parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") + parser.add_argument( + "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." + ) args = parser.parse_args() TagService.delete_tag_binding(args) return 200 -api.add_resource(TagListApi, '/tags') -api.add_resource(TagUpdateDeleteApi, '/tags/') -api.add_resource(TagBindingCreateApi, '/tag-bindings/create') -api.add_resource(TagBindingDeleteApi, '/tag-bindings/remove') +api.add_resource(TagListApi, "/tags") +api.add_resource(TagUpdateDeleteApi, "/tags/") +api.add_resource(TagBindingCreateApi, "/tag-bindings/create") +api.add_resource(TagBindingDeleteApi, "/tag-bindings/remove") diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 1fcf4bdc00e5b1..7dea8e554edd7a 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -1,9 +1,9 @@ - import json import logging import requests from flask_restful import Resource, reqparse +from packaging import version from configs import dify_config @@ -11,42 +11,52 @@ class VersionApi(Resource): - def get(self): parser = reqparse.RequestParser() - parser.add_argument('current_version', type=str, required=True, location='args') + parser.add_argument("current_version", type=str, required=True, location="args") args = parser.parse_args() check_update_url = dify_config.CHECK_UPDATE_URL result = { - 'version': dify_config.CURRENT_VERSION, - 'release_date': '', - 'release_notes': '', - 'can_auto_update': False, - 'features': { - 'can_replace_logo': dify_config.CAN_REPLACE_LOGO, - 'model_load_balancing_enabled': dify_config.MODEL_LB_ENABLED - } + "version": dify_config.CURRENT_VERSION, + "release_date": "", + "release_notes": "", + "can_auto_update": False, + "features": { + "can_replace_logo": dify_config.CAN_REPLACE_LOGO, + "model_load_balancing_enabled": dify_config.MODEL_LB_ENABLED, + }, } if not check_update_url: return result try: - response = requests.get(check_update_url, { - 'current_version': args.get('current_version') - }) + response = requests.get(check_update_url, {"current_version": args.get("current_version")}) except Exception as error: logging.warning("Check update version error: {}.".format(str(error))) - result['version'] = args.get('current_version') + result["version"] = args.get("current_version") return result content = json.loads(response.content) - result['version'] = content['version'] - result['release_date'] = content['releaseDate'] - result['release_notes'] = content['releaseNotes'] - result['can_auto_update'] = content['canAutoUpdate'] + if _has_new_version(latest_version=content["version"], current_version=f"{args.get('current_version')}"): + result["version"] = content["version"] + result["release_date"] = content["releaseDate"] + result["release_notes"] = content["releaseNotes"] + result["can_auto_update"] = content["canAutoUpdate"] return result -api.add_resource(VersionApi, '/version') +def _has_new_version(*, latest_version: str, current_version: str) -> bool: + try: + latest = version.parse(latest_version) + current = version.parse(current_version) + + # Compare versions + return latest > current + except version.InvalidVersion: + logging.warning(f"Invalid version format: latest={latest_version}, current={current_version}") + return False + + +api.add_resource(VersionApi, "/version") diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 1056d5eb62faf2..aabc4177595d67 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -8,70 +8,70 @@ from configs import dify_config from constants.languages import supported_language from controllers.console import api -from controllers.console.setup import setup_required from controllers.console.workspace.error import ( AccountAlreadyInitedError, CurrentPasswordIncorrectError, InvalidInvitationCodeError, RepeatPasswordNotMatchError, ) -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db from fields.member_fields import account_fields from libs.helper import TimestampField, timezone from libs.login import login_required -from models.account import AccountIntegrate, InvitationCode +from models import AccountIntegrate, InvitationCode from services.account_service import AccountService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError class AccountInitApi(Resource): - @setup_required @login_required def post(self): account = current_user - if account.status == 'active': + if account.status == "active": raise AccountAlreadyInitedError() parser = reqparse.RequestParser() - if dify_config.EDITION == 'CLOUD': - parser.add_argument('invitation_code', type=str, location='json') + if dify_config.EDITION == "CLOUD": + parser.add_argument("invitation_code", type=str, location="json") - parser.add_argument( - 'interface_language', type=supported_language, required=True, location='json') - parser.add_argument('timezone', type=timezone, - required=True, location='json') + parser.add_argument("interface_language", type=supported_language, required=True, location="json") + parser.add_argument("timezone", type=timezone, required=True, location="json") args = parser.parse_args() - if dify_config.EDITION == 'CLOUD': - if not args['invitation_code']: - raise ValueError('invitation_code is required') + if dify_config.EDITION == "CLOUD": + if not args["invitation_code"]: + raise ValueError("invitation_code is required") # check invitation code - invitation_code = db.session.query(InvitationCode).filter( - InvitationCode.code == args['invitation_code'], - InvitationCode.status == 'unused', - ).first() + invitation_code = ( + db.session.query(InvitationCode) + .filter( + InvitationCode.code == args["invitation_code"], + InvitationCode.status == "unused", + ) + .first() + ) if not invitation_code: raise InvalidInvitationCodeError() - invitation_code.status = 'used' + invitation_code.status = "used" invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_account_id = account.id - account.interface_language = args['interface_language'] - account.timezone = args['timezone'] - account.interface_theme = 'light' - account.status = 'active' + account.interface_language = args["interface_language"] + account.timezone = args["timezone"] + account.interface_theme = "light" + account.status = "active" account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() - return {'result': 'success'} + return {"result": "success"} class AccountProfileApi(Resource): @@ -90,15 +90,14 @@ class AccountNameApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() # Validate account name length - if len(args['name']) < 3 or len(args['name']) > 30: - raise ValueError( - "Account name must be between 3 and 30 characters.") + if len(args["name"]) < 3 or len(args["name"]) > 30: + raise ValueError("Account name must be between 3 and 30 characters.") - updated_account = AccountService.update_account(current_user, name=args['name']) + updated_account = AccountService.update_account(current_user, name=args["name"]) return updated_account @@ -110,10 +109,10 @@ class AccountAvatarApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('avatar', type=str, required=True, location='json') + parser.add_argument("avatar", type=str, required=True, location="json") args = parser.parse_args() - updated_account = AccountService.update_account(current_user, avatar=args['avatar']) + updated_account = AccountService.update_account(current_user, avatar=args["avatar"]) return updated_account @@ -125,11 +124,10 @@ class AccountInterfaceLanguageApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument( - 'interface_language', type=supported_language, required=True, location='json') + parser.add_argument("interface_language", type=supported_language, required=True, location="json") args = parser.parse_args() - updated_account = AccountService.update_account(current_user, interface_language=args['interface_language']) + updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"]) return updated_account @@ -141,11 +139,10 @@ class AccountInterfaceThemeApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('interface_theme', type=str, choices=[ - 'light', 'dark'], required=True, location='json') + parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json") args = parser.parse_args() - updated_account = AccountService.update_account(current_user, interface_theme=args['interface_theme']) + updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"]) return updated_account @@ -157,15 +154,14 @@ class AccountTimezoneApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('timezone', type=str, - required=True, location='json') + parser.add_argument("timezone", type=str, required=True, location="json") args = parser.parse_args() # Validate timezone string, e.g. America/New_York, Asia/Shanghai - if args['timezone'] not in pytz.all_timezones: + if args["timezone"] not in pytz.all_timezones: raise ValueError("Invalid timezone string.") - updated_account = AccountService.update_account(current_user, timezone=args['timezone']) + updated_account = AccountService.update_account(current_user, timezone=args["timezone"]) return updated_account @@ -177,20 +173,16 @@ class AccountPasswordApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('password', type=str, - required=False, location='json') - parser.add_argument('new_password', type=str, - required=True, location='json') - parser.add_argument('repeat_new_password', type=str, - required=True, location='json') + parser.add_argument("password", type=str, required=False, location="json") + parser.add_argument("new_password", type=str, required=True, location="json") + parser.add_argument("repeat_new_password", type=str, required=True, location="json") args = parser.parse_args() - if args['new_password'] != args['repeat_new_password']: + if args["new_password"] != args["repeat_new_password"]: raise RepeatPasswordNotMatchError() try: - AccountService.update_account_password( - current_user, args['password'], args['new_password']) + AccountService.update_account_password(current_user, args["password"], args["new_password"]) except ServiceCurrentPasswordIncorrectError: raise CurrentPasswordIncorrectError() @@ -199,14 +191,14 @@ def post(self): class AccountIntegrateApi(Resource): integrate_fields = { - 'provider': fields.String, - 'created_at': TimestampField, - 'is_bound': fields.Boolean, - 'link': fields.String + "provider": fields.String, + "created_at": TimestampField, + "is_bound": fields.Boolean, + "link": fields.String, } integrate_list_fields = { - 'data': fields.List(fields.Nested(integrate_fields)), + "data": fields.List(fields.Nested(integrate_fields)), } @setup_required @@ -216,10 +208,9 @@ class AccountIntegrateApi(Resource): def get(self): account = current_user - account_integrates = db.session.query(AccountIntegrate).filter( - AccountIntegrate.account_id == account.id).all() + account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all() - base_url = request.url_root.rstrip('/') + base_url = request.url_root.rstrip("/") oauth_base_path = "/console/api/oauth/login" providers = ["github", "google"] @@ -227,36 +218,38 @@ def get(self): for provider in providers: existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None) if existing_integrate: - integrate_data.append({ - 'id': existing_integrate.id, - 'provider': provider, - 'created_at': existing_integrate.created_at, - 'is_bound': True, - 'link': None - }) + integrate_data.append( + { + "id": existing_integrate.id, + "provider": provider, + "created_at": existing_integrate.created_at, + "is_bound": True, + "link": None, + } + ) else: - integrate_data.append({ - 'id': None, - 'provider': provider, - 'created_at': None, - 'is_bound': False, - 'link': f'{base_url}{oauth_base_path}/{provider}' - }) - - return {'data': integrate_data} - + integrate_data.append( + { + "id": None, + "provider": provider, + "created_at": None, + "is_bound": False, + "link": f"{base_url}{oauth_base_path}/{provider}", + } + ) + return {"data": integrate_data} # Register API resources -api.add_resource(AccountInitApi, '/account/init') -api.add_resource(AccountProfileApi, '/account/profile') -api.add_resource(AccountNameApi, '/account/name') -api.add_resource(AccountAvatarApi, '/account/avatar') -api.add_resource(AccountInterfaceLanguageApi, '/account/interface-language') -api.add_resource(AccountInterfaceThemeApi, '/account/interface-theme') -api.add_resource(AccountTimezoneApi, '/account/timezone') -api.add_resource(AccountPasswordApi, '/account/password') -api.add_resource(AccountIntegrateApi, '/account/integrates') +api.add_resource(AccountInitApi, "/account/init") +api.add_resource(AccountProfileApi, "/account/profile") +api.add_resource(AccountNameApi, "/account/name") +api.add_resource(AccountAvatarApi, "/account/avatar") +api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language") +api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme") +api.add_resource(AccountTimezoneApi, "/account/timezone") +api.add_resource(AccountPasswordApi, "/account/password") +api.add_resource(AccountIntegrateApi, "/account/integrates") # api.add_resource(AccountEmailApi, '/account/email') # api.add_resource(AccountEmailVerifyApi, '/account/email-verify') diff --git a/api/controllers/console/workspace/error.py b/api/controllers/console/workspace/error.py index 99f55835bc6b97..9e13c7b9241ff1 100644 --- a/api/controllers/console/workspace/error.py +++ b/api/controllers/console/workspace/error.py @@ -2,36 +2,36 @@ class RepeatPasswordNotMatchError(BaseHTTPException): - error_code = 'repeat_password_not_match' + error_code = "repeat_password_not_match" description = "New password and repeat password does not match." code = 400 class CurrentPasswordIncorrectError(BaseHTTPException): - error_code = 'current_password_incorrect' + error_code = "current_password_incorrect" description = "Current password is incorrect." code = 400 class ProviderRequestFailedError(BaseHTTPException): - error_code = 'provider_request_failed' + error_code = "provider_request_failed" description = None code = 400 class InvalidInvitationCodeError(BaseHTTPException): - error_code = 'invalid_invitation_code' + error_code = "invalid_invitation_code" description = "Invalid invitation code." code = 400 class AccountAlreadyInitedError(BaseHTTPException): - error_code = 'account_already_inited' + error_code = "account_already_inited" description = "The account has been initialized. Please refresh the page." code = 400 class AccountNotInitializedError(BaseHTTPException): - error_code = 'account_not_initialized' + error_code = "account_not_initialized" description = "The account has not been initialized yet. Please proceed with the initialization process first." code = 400 diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 50514e39f6aad1..d2b2092b75a9ff 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -2,8 +2,7 @@ from werkzeug.exceptions import Forbidden from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from libs.login import current_user, login_required @@ -22,10 +21,16 @@ def post(self, provider: str): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() # validate model load balancing credentials @@ -38,18 +43,18 @@ def post(self, provider: str): model_load_balancing_service.validate_load_balancing_credentials( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - credentials=args['credentials'] + model=args["model"], + model_type=args["model_type"], + credentials=args["credentials"], ) except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {'result': 'success' if result else 'error'} + response = {"result": "success" if result else "error"} if not result: - response['error'] = error + response["error"] = error return response @@ -65,10 +70,16 @@ def post(self, provider: str, config_id: str): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() # validate model load balancing config credentials @@ -81,26 +92,30 @@ def post(self, provider: str, config_id: str): model_load_balancing_service.validate_load_balancing_credentials( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - credentials=args['credentials'], + model=args["model"], + model_type=args["model_type"], + credentials=args["credentials"], config_id=config_id, ) except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {'result': 'success' if result else 'error'} + response = {"result": "success" if result else "error"} if not result: - response['error'] = error + response["error"] = error return response # Load Balancing Config -api.add_resource(LoadBalancingCredentialsValidateApi, - '/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate') - -api.add_resource(LoadBalancingConfigCredentialsValidateApi, - '/workspaces/current/model-providers//models/load-balancing-configs//credentials-validate') +api.add_resource( + LoadBalancingCredentialsValidateApi, + "/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate", +) + +api.add_resource( + LoadBalancingConfigCredentialsValidateApi, + "/workspaces/current/model-providers//models/load-balancing-configs//credentials-validate", +) diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 34e9da384106a7..8f694c65e0ddfd 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -4,8 +4,11 @@ import services from configs import dify_config from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_resource_check, + setup_required, +) from extensions.ext_database import db from fields.member_fields import account_with_role_list_fields from libs.login import login_required @@ -23,7 +26,7 @@ class MemberListApi(Resource): @marshal_with(account_with_role_list_fields) def get(self): members = TenantService.get_tenant_members(current_user.current_tenant) - return {'result': 'success', 'accounts': members}, 200 + return {"result": "success", "accounts": members}, 200 class MemberInviteEmailApi(Resource): @@ -32,48 +35,46 @@ class MemberInviteEmailApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('members') + @cloud_edition_billing_resource_check("members") def post(self): parser = reqparse.RequestParser() - parser.add_argument('emails', type=str, required=True, location='json', action='append') - parser.add_argument('role', type=str, required=True, default='admin', location='json') - parser.add_argument('language', type=str, required=False, location='json') + parser.add_argument("emails", type=str, required=True, location="json", action="append") + parser.add_argument("role", type=str, required=True, default="admin", location="json") + parser.add_argument("language", type=str, required=False, location="json") args = parser.parse_args() - invitee_emails = args['emails'] - invitee_role = args['role'] - interface_language = args['language'] + invitee_emails = args["emails"] + invitee_role = args["role"] + interface_language = args["language"] if not TenantAccountRole.is_non_owner_role(invitee_role): - return {'code': 'invalid-role', 'message': 'Invalid role'}, 400 + return {"code": "invalid-role", "message": "Invalid role"}, 400 inviter = current_user invitation_results = [] console_web_url = dify_config.CONSOLE_WEB_URL for invitee_email in invitee_emails: try: - token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter) - invitation_results.append({ - 'status': 'success', - 'email': invitee_email, - 'url': f'{console_web_url}/activate?email={invitee_email}&token={token}' - }) + token = RegisterService.invite_new_member( + inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter + ) + invitation_results.append( + { + "status": "success", + "email": invitee_email, + "url": f"{console_web_url}/activate?email={invitee_email}&token={token}", + } + ) except AccountAlreadyInTenantError: - invitation_results.append({ - 'status': 'success', - 'email': invitee_email, - 'url': f'{console_web_url}/signin' - }) + invitation_results.append( + {"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"} + ) break except Exception as e: - invitation_results.append({ - 'status': 'failed', - 'email': invitee_email, - 'message': str(e) - }) + invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)}) return { - 'result': 'success', - 'invitation_results': invitation_results, + "result": "success", + "invitation_results": invitation_results, }, 201 @@ -91,15 +92,15 @@ def delete(self, member_id): try: TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user) except services.errors.account.CannotOperateSelfError as e: - return {'code': 'cannot-operate-self', 'message': str(e)}, 400 + return {"code": "cannot-operate-self", "message": str(e)}, 400 except services.errors.account.NoPermissionError as e: - return {'code': 'forbidden', 'message': str(e)}, 403 + return {"code": "forbidden", "message": str(e)}, 403 except services.errors.account.MemberNotInTenantError as e: - return {'code': 'member-not-found', 'message': str(e)}, 404 + return {"code": "member-not-found", "message": str(e)}, 404 except Exception as e: raise ValueError(str(e)) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class MemberUpdateRoleApi(Resource): @@ -110,12 +111,12 @@ class MemberUpdateRoleApi(Resource): @account_initialization_required def put(self, member_id): parser = reqparse.RequestParser() - parser.add_argument('role', type=str, required=True, location='json') + parser.add_argument("role", type=str, required=True, location="json") args = parser.parse_args() - new_role = args['role'] + new_role = args["role"] if not TenantAccountRole.is_valid_role(new_role): - return {'code': 'invalid-role', 'message': 'Invalid role'}, 400 + return {"code": "invalid-role", "message": "Invalid role"}, 400 member = db.session.get(Account, str(member_id)) if not member: @@ -128,7 +129,7 @@ def put(self, member_id): # todo: 403 - return {'result': 'success'} + return {"result": "success"} class DatasetOperatorMemberListApi(Resource): @@ -140,11 +141,11 @@ class DatasetOperatorMemberListApi(Resource): @marshal_with(account_with_role_list_fields) def get(self): members = TenantService.get_dataset_operator_members(current_user.current_tenant) - return {'result': 'success', 'accounts': members}, 200 + return {"result": "success", "accounts": members}, 200 -api.add_resource(MemberListApi, '/workspaces/current/members') -api.add_resource(MemberInviteEmailApi, '/workspaces/current/members/invite-email') -api.add_resource(MemberCancelInviteApi, '/workspaces/current/members/') -api.add_resource(MemberUpdateRoleApi, '/workspaces/current/members//update-role') -api.add_resource(DatasetOperatorMemberListApi, '/workspaces/current/dataset-operators') +api.add_resource(MemberListApi, "/workspaces/current/members") +api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email") +api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/") +api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members//update-role") +api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators") diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index c888159f839719..0e54126063be75 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -6,8 +6,7 @@ from werkzeug.exceptions import Forbidden from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder @@ -17,7 +16,6 @@ class ModelProviderListApi(Resource): - @setup_required @login_required @account_initialization_required @@ -25,21 +23,23 @@ def get(self): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model_type', type=str, required=False, nullable=True, - choices=[mt.value for mt in ModelType], location='args') + parser.add_argument( + "model_type", + type=str, + required=False, + nullable=True, + choices=[mt.value for mt in ModelType], + location="args", + ) args = parser.parse_args() model_provider_service = ModelProviderService() - provider_list = model_provider_service.get_provider_list( - tenant_id=tenant_id, - model_type=args.get('model_type') - ) + provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type")) return jsonable_encoder({"data": provider_list}) class ModelProviderCredentialApi(Resource): - @setup_required @login_required @account_initialization_required @@ -47,25 +47,18 @@ def get(self, provider: str): tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() - credentials = model_provider_service.get_provider_credentials( - tenant_id=tenant_id, - provider=provider - ) + credentials = model_provider_service.get_provider_credentials(tenant_id=tenant_id, provider=provider) - return { - "credentials": credentials - } + return {"credentials": credentials} class ModelProviderValidateApi(Resource): - @setup_required @login_required @account_initialization_required def post(self, provider: str): - parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() tenant_id = current_user.current_tenant_id @@ -77,24 +70,21 @@ def post(self, provider: str): try: model_provider_service.provider_credentials_validate( - tenant_id=tenant_id, - provider=provider, - credentials=args['credentials'] + tenant_id=tenant_id, provider=provider, credentials=args["credentials"] ) except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {'result': 'success' if result else 'error'} + response = {"result": "success" if result else "error"} if not result: - response['error'] = error + response["error"] = error return response class ModelProviderApi(Resource): - @setup_required @login_required @account_initialization_required @@ -103,21 +93,19 @@ def post(self, provider: str): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() model_provider_service = ModelProviderService() try: model_provider_service.save_provider_credentials( - tenant_id=current_user.current_tenant_id, - provider=provider, - credentials=args['credentials'] + tenant_id=current_user.current_tenant_id, provider=provider, credentials=args["credentials"] ) except CredentialsValidateFailedError as ex: raise ValueError(str(ex)) - return {'result': 'success'}, 201 + return {"result": "success"}, 201 @setup_required @login_required @@ -127,12 +115,9 @@ def delete(self, provider: str): raise Forbidden() model_provider_service = ModelProviderService() - model_provider_service.remove_provider_credentials( - tenant_id=current_user.current_tenant_id, - provider=provider - ) + model_provider_service.remove_provider_credentials(tenant_id=current_user.current_tenant_id, provider=provider) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class ModelProviderIconApi(Resource): @@ -140,22 +125,18 @@ class ModelProviderIconApi(Resource): Get model provider icon """ - @setup_required - @login_required - @account_initialization_required def get(self, provider: str, icon_type: str, lang: str): model_provider_service = ModelProviderService() icon, mimetype = model_provider_service.get_model_provider_icon( provider=provider, icon_type=icon_type, - lang=lang + lang=lang, ) return send_file(io.BytesIO(icon), mimetype=mimetype) class PreferredProviderTypeUpdateApi(Resource): - @setup_required @login_required @account_initialization_required @@ -166,18 +147,22 @@ def post(self, provider: str): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False, - choices=['system', 'custom'], location='json') + parser.add_argument( + "preferred_provider_type", + type=str, + required=True, + nullable=False, + choices=["system", "custom"], + location="json", + ) args = parser.parse_args() model_provider_service = ModelProviderService() model_provider_service.switch_preferred_provider( - tenant_id=tenant_id, - provider=provider, - preferred_provider_type=args['preferred_provider_type'] + tenant_id=tenant_id, provider=provider, preferred_provider_type=args["preferred_provider_type"] ) - return {'result': 'success'} + return {"result": "success"} class ModelProviderPaymentCheckoutUrlApi(Resource): @@ -185,13 +170,15 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): @login_required @account_initialization_required def get(self, provider: str): - if provider != 'anthropic': - raise ValueError(f'provider name {provider} is invalid') + if provider != "anthropic": + raise ValueError(f"provider name {provider} is invalid") BillingService.is_tenant_owner_or_admin(current_user) - data = BillingService.get_model_provider_payment_link(provider_name=provider, - tenant_id=current_user.current_tenant_id, - account_id=current_user.id, - prefilled_email=current_user.email) + data = BillingService.get_model_provider_payment_link( + provider_name=provider, + tenant_id=current_user.current_tenant_id, + account_id=current_user.id, + prefilled_email=current_user.email, + ) return data @@ -201,10 +188,7 @@ class ModelProviderFreeQuotaSubmitApi(Resource): @account_initialization_required def post(self, provider: str): model_provider_service = ModelProviderService() - result = model_provider_service.free_quota_submit( - tenant_id=current_user.current_tenant_id, - provider=provider - ) + result = model_provider_service.free_quota_submit(tenant_id=current_user.current_tenant_id, provider=provider) return result @@ -215,32 +199,36 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource): @account_initialization_required def get(self, provider: str): parser = reqparse.RequestParser() - parser.add_argument('token', type=str, required=False, nullable=True, location='args') + parser.add_argument("token", type=str, required=False, nullable=True, location="args") args = parser.parse_args() model_provider_service = ModelProviderService() result = model_provider_service.free_quota_qualification_verify( - tenant_id=current_user.current_tenant_id, - provider=provider, - token=args['token'] + tenant_id=current_user.current_tenant_id, provider=provider, token=args["token"] ) return result -api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers') - -api.add_resource(ModelProviderCredentialApi, '/workspaces/current/model-providers//credentials') -api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers//credentials/validate') -api.add_resource(ModelProviderApi, '/workspaces/current/model-providers/') -api.add_resource(ModelProviderIconApi, '/workspaces/current/model-providers//' - '/') - -api.add_resource(PreferredProviderTypeUpdateApi, - '/workspaces/current/model-providers//preferred-provider-type') -api.add_resource(ModelProviderPaymentCheckoutUrlApi, - '/workspaces/current/model-providers//checkout-url') -api.add_resource(ModelProviderFreeQuotaSubmitApi, - '/workspaces/current/model-providers//free-quota-submit') -api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi, - '/workspaces/current/model-providers//free-quota-qualification-verify') +api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers") + +api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers//credentials") +api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers//credentials/validate") +api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/") +api.add_resource( + ModelProviderIconApi, "/workspaces/current/model-providers///" +) + +api.add_resource( + PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers//preferred-provider-type" +) +api.add_resource( + ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers//checkout-url" +) +api.add_resource( + ModelProviderFreeQuotaSubmitApi, "/workspaces/current/model-providers//free-quota-submit" +) +api.add_resource( + ModelProviderFreeQuotaQualificationVerifyApi, + "/workspaces/current/model-providers//free-quota-qualification-verify", +) diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 69f2253e97ea40..57443cc3b350d0 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -5,8 +5,7 @@ from werkzeug.exceptions import Forbidden from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder @@ -16,27 +15,29 @@ class DefaultModelApi(Resource): - @setup_required @login_required @account_initialization_required def get(self): parser = reqparse.RequestParser() - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='args') + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="args", + ) args = parser.parse_args() tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() default_model_entity = model_provider_service.get_default_model_of_model_type( - tenant_id=tenant_id, - model_type=args['model_type'] + tenant_id=tenant_id, model_type=args["model_type"] ) - return jsonable_encoder({ - "data": default_model_entity - }) + return jsonable_encoder({"data": default_model_entity}) @setup_required @login_required @@ -44,40 +45,40 @@ def get(self): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json') + parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json") args = parser.parse_args() tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() - model_settings = args['model_settings'] + model_settings = args["model_settings"] for model_setting in model_settings: - if 'model_type' not in model_setting or model_setting['model_type'] not in [mt.value for mt in ModelType]: - raise ValueError('invalid model type') + if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]: + raise ValueError("invalid model type") - if 'provider' not in model_setting: + if "provider" not in model_setting: continue - if 'model' not in model_setting: - raise ValueError('invalid model') + if "model" not in model_setting: + raise ValueError("invalid model") try: model_provider_service.update_default_model_of_model_type( tenant_id=tenant_id, - model_type=model_setting['model_type'], - provider=model_setting['provider'], - model=model_setting['model'] + model_type=model_setting["model_type"], + provider=model_setting["provider"], + model=model_setting["model"], ) - except Exception: - logging.warning(f"{model_setting['model_type']} save error") + except Exception as ex: + logging.exception(f"{model_setting['model_type']} save error: {ex}") + raise ex - return {'result': 'success'} + return {"result": "success"} class ModelProviderModelApi(Resource): - @setup_required @login_required @account_initialization_required @@ -85,14 +86,9 @@ def get(self, provider): tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() - models = model_provider_service.get_models_by_provider( - tenant_id=tenant_id, - provider=provider - ) + models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider) - return jsonable_encoder({ - "data": models - }) + return jsonable_encoder({"data": models}) @setup_required @login_required @@ -104,61 +100,66 @@ def post(self, provider: str): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') - parser.add_argument('credentials', type=dict, required=False, nullable=True, location='json') - parser.add_argument('load_balancing', type=dict, required=False, nullable=True, location='json') - parser.add_argument('config_from', type=str, required=False, nullable=True, location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") + parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json") + parser.add_argument("config_from", type=str, required=False, nullable=True, location="json") args = parser.parse_args() model_load_balancing_service = ModelLoadBalancingService() - if ('load_balancing' in args and args['load_balancing'] and - 'enabled' in args['load_balancing'] and args['load_balancing']['enabled']): - if 'configs' not in args['load_balancing']: - raise ValueError('invalid load balancing configs') + if ( + "load_balancing" in args + and args["load_balancing"] + and "enabled" in args["load_balancing"] + and args["load_balancing"]["enabled"] + ): + if "configs" not in args["load_balancing"]: + raise ValueError("invalid load balancing configs") # save load balancing configs model_load_balancing_service.update_load_balancing_configs( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - configs=args['load_balancing']['configs'] + model=args["model"], + model_type=args["model_type"], + configs=args["load_balancing"]["configs"], ) # enable load balancing model_load_balancing_service.enable_model_load_balancing( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) else: # disable load balancing model_load_balancing_service.disable_model_load_balancing( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - if args.get('config_from', '') != 'predefined-model': + if args.get("config_from", "") != "predefined-model": model_provider_service = ModelProviderService() try: model_provider_service.save_model_credentials( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - credentials=args['credentials'] + model=args["model"], + model_type=args["model_type"], + credentials=args["credentials"], ) except CredentialsValidateFailedError as ex: + logging.exception(f"save model credentials error: {ex}") raise ValueError(str(ex)) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 @setup_required @login_required @@ -170,24 +171,26 @@ def delete(self, provider: str): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) args = parser.parse_args() model_provider_service = ModelProviderService() model_provider_service.remove_model_credentials( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class ModelProviderModelCredentialApi(Resource): - @setup_required @login_required @account_initialization_required @@ -195,38 +198,34 @@ def get(self, provider: str): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='args') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='args') + parser.add_argument("model", type=str, required=True, nullable=False, location="args") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="args", + ) args = parser.parse_args() model_provider_service = ModelProviderService() credentials = model_provider_service.get_model_credentials( - tenant_id=tenant_id, - provider=provider, - model_type=args['model_type'], - model=args['model'] + tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"] ) model_load_balancing_service = ModelLoadBalancingService() is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) return { "credentials": credentials, - "load_balancing": { - "enabled": is_load_balancing_enabled, - "configs": load_balancing_configs - } + "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs}, } class ModelProviderModelEnableApi(Resource): - @setup_required @login_required @account_initialization_required @@ -234,24 +233,26 @@ def patch(self, provider: str): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) args = parser.parse_args() model_provider_service = ModelProviderService() model_provider_service.enable_model( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - return {'result': 'success'} + return {"result": "success"} class ModelProviderModelDisableApi(Resource): - @setup_required @login_required @account_initialization_required @@ -259,24 +260,26 @@ def patch(self, provider: str): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) args = parser.parse_args() model_provider_service = ModelProviderService() model_provider_service.disable_model( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - return {'result': 'success'} + return {"result": "success"} class ModelProviderModelValidateApi(Resource): - @setup_required @login_required @account_initialization_required @@ -284,10 +287,16 @@ def post(self, provider: str): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() model_provider_service = ModelProviderService() @@ -299,48 +308,42 @@ def post(self, provider: str): model_provider_service.model_credentials_validate( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - credentials=args['credentials'] + model=args["model"], + model_type=args["model_type"], + credentials=args["credentials"], ) except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {'result': 'success' if result else 'error'} + response = {"result": "success" if result else "error"} if not result: - response['error'] = error + response["error"] = error return response class ModelProviderModelParameterRuleApi(Resource): - @setup_required @login_required @account_initialization_required def get(self, provider: str): parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='args') + parser.add_argument("model", type=str, required=True, nullable=False, location="args") args = parser.parse_args() tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() parameter_rules = model_provider_service.get_model_parameter_rules( - tenant_id=tenant_id, - provider=provider, - model=args['model'] + tenant_id=tenant_id, provider=provider, model=args["model"] ) - return jsonable_encoder({ - "data": parameter_rules - }) + return jsonable_encoder({"data": parameter_rules}) class ModelProviderAvailableModelApi(Resource): - @setup_required @login_required @account_initialization_required @@ -348,27 +351,31 @@ def get(self, model_type): tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() - models = model_provider_service.get_models_by_model_type( - tenant_id=tenant_id, - model_type=model_type - ) - - return jsonable_encoder({ - "data": models - }) - - -api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers//models') -api.add_resource(ModelProviderModelEnableApi, '/workspaces/current/model-providers//models/enable', - endpoint='model-provider-model-enable') -api.add_resource(ModelProviderModelDisableApi, '/workspaces/current/model-providers//models/disable', - endpoint='model-provider-model-disable') -api.add_resource(ModelProviderModelCredentialApi, - '/workspaces/current/model-providers//models/credentials') -api.add_resource(ModelProviderModelValidateApi, - '/workspaces/current/model-providers//models/credentials/validate') - -api.add_resource(ModelProviderModelParameterRuleApi, - '/workspaces/current/model-providers//models/parameter-rules') -api.add_resource(ModelProviderAvailableModelApi, '/workspaces/current/models/model-types/') -api.add_resource(DefaultModelApi, '/workspaces/current/default-model') + models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) + + return jsonable_encoder({"data": models}) + + +api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers//models") +api.add_resource( + ModelProviderModelEnableApi, + "/workspaces/current/model-providers//models/enable", + endpoint="model-provider-model-enable", +) +api.add_resource( + ModelProviderModelDisableApi, + "/workspaces/current/model-providers//models/disable", + endpoint="model-provider-model-disable", +) +api.add_resource( + ModelProviderModelCredentialApi, "/workspaces/current/model-providers//models/credentials" +) +api.add_resource( + ModelProviderModelValidateApi, "/workspaces/current/model-providers//models/credentials/validate" +) + +api.add_resource( + ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers//models/parameter-rules" +) +api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/") +api.add_resource(DefaultModelApi, "/workspaces/current/default-model") diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index bafeabb08ae2c7..daadb85d84e2fa 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -7,8 +7,7 @@ from configs import dify_config from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder from libs.helper import alphanumeric, uuid_value from libs.login import login_required @@ -28,10 +27,18 @@ def get(self): tenant_id = current_user.current_tenant_id req = reqparse.RequestParser() - req.add_argument('type', type=str, choices=['builtin', 'model', 'api', 'workflow'], required=False, nullable=True, location='args') + req.add_argument( + "type", + type=str, + choices=["builtin", "model", "api", "workflow"], + required=False, + nullable=True, + location="args", + ) args = req.parse_args() - return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get('type', None)) + return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None)) + class ToolBuiltinProviderListToolsApi(Resource): @setup_required @@ -41,11 +48,14 @@ def get(self, provider): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder(BuiltinToolManageService.list_builtin_tool_provider_tools( - user_id, - tenant_id, - provider, - )) + return jsonable_encoder( + BuiltinToolManageService.list_builtin_tool_provider_tools( + user_id, + tenant_id, + provider, + ) + ) + class ToolBuiltinProviderDeleteApi(Resource): @setup_required @@ -54,7 +64,7 @@ class ToolBuiltinProviderDeleteApi(Resource): def post(self, provider): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id @@ -63,7 +73,8 @@ def post(self, provider): tenant_id, provider, ) - + + class ToolBuiltinProviderUpdateApi(Resource): @setup_required @login_required @@ -71,12 +82,12 @@ class ToolBuiltinProviderUpdateApi(Resource): def post(self, provider): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() @@ -84,9 +95,10 @@ def post(self, provider): user_id, tenant_id, provider, - args['credentials'], + args["credentials"], ) - + + class ToolBuiltinProviderGetCredentialsApi(Resource): @setup_required @login_required @@ -101,6 +113,7 @@ def get(self, provider): provider, ) + class ToolBuiltinProviderIconApi(Resource): @setup_required def get(self, provider): @@ -108,6 +121,7 @@ def get(self, provider): icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) + class ToolApiProviderAddApi(Resource): @setup_required @login_required @@ -115,35 +129,36 @@ class ToolApiProviderAddApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') - parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json') - parser.add_argument('schema', type=str, required=True, nullable=False, location='json') - parser.add_argument('provider', type=str, required=True, nullable=False, location='json') - parser.add_argument('icon', type=dict, required=True, nullable=False, location='json') - parser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json') - parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json', default=[]) - parser.add_argument('custom_disclaimer', type=str, required=False, nullable=True, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") + parser.add_argument("schema", type=str, required=True, nullable=False, location="json") + parser.add_argument("provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("icon", type=dict, required=True, nullable=False, location="json") + parser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json") + parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[]) + parser.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json") args = parser.parse_args() return ApiToolManageService.create_api_tool_provider( user_id, tenant_id, - args['provider'], - args['icon'], - args['credentials'], - args['schema_type'], - args['schema'], - args.get('privacy_policy', ''), - args.get('custom_disclaimer', ''), - args.get('labels', []), + args["provider"], + args["icon"], + args["credentials"], + args["schema_type"], + args["schema"], + args.get("privacy_policy", ""), + args.get("custom_disclaimer", ""), + args.get("labels", []), ) + class ToolApiProviderGetRemoteSchemaApi(Resource): @setup_required @login_required @@ -151,16 +166,17 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): def get(self): parser = reqparse.RequestParser() - parser.add_argument('url', type=str, required=True, nullable=False, location='args') + parser.add_argument("url", type=str, required=True, nullable=False, location="args") args = parser.parse_args() return ApiToolManageService.get_api_tool_provider_remote_schema( current_user.id, current_user.current_tenant_id, - args['url'], + args["url"], ) - + + class ToolApiProviderListToolsApi(Resource): @setup_required @login_required @@ -171,15 +187,18 @@ def get(self): parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, required=True, nullable=False, location='args') + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") args = parser.parse_args() - return jsonable_encoder(ApiToolManageService.list_api_tool_provider_tools( - user_id, - tenant_id, - args['provider'], - )) + return jsonable_encoder( + ApiToolManageService.list_api_tool_provider_tools( + user_id, + tenant_id, + args["provider"], + ) + ) + class ToolApiProviderUpdateApi(Resource): @setup_required @@ -188,37 +207,38 @@ class ToolApiProviderUpdateApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') - parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json') - parser.add_argument('schema', type=str, required=True, nullable=False, location='json') - parser.add_argument('provider', type=str, required=True, nullable=False, location='json') - parser.add_argument('original_provider', type=str, required=True, nullable=False, location='json') - parser.add_argument('icon', type=dict, required=True, nullable=False, location='json') - parser.add_argument('privacy_policy', type=str, required=True, nullable=True, location='json') - parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') - parser.add_argument('custom_disclaimer', type=str, required=True, nullable=True, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") + parser.add_argument("schema", type=str, required=True, nullable=False, location="json") + parser.add_argument("provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("original_provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("icon", type=dict, required=True, nullable=False, location="json") + parser.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json") + parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") + parser.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json") args = parser.parse_args() return ApiToolManageService.update_api_tool_provider( user_id, tenant_id, - args['provider'], - args['original_provider'], - args['icon'], - args['credentials'], - args['schema_type'], - args['schema'], - args['privacy_policy'], - args['custom_disclaimer'], - args.get('labels', []), + args["provider"], + args["original_provider"], + args["icon"], + args["credentials"], + args["schema_type"], + args["schema"], + args["privacy_policy"], + args["custom_disclaimer"], + args.get("labels", []), ) + class ToolApiProviderDeleteApi(Resource): @setup_required @login_required @@ -226,22 +246,23 @@ class ToolApiProviderDeleteApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, required=True, nullable=False, location='json') + parser.add_argument("provider", type=str, required=True, nullable=False, location="json") args = parser.parse_args() return ApiToolManageService.delete_api_tool_provider( user_id, tenant_id, - args['provider'], + args["provider"], ) + class ToolApiProviderGetApi(Resource): @setup_required @login_required @@ -252,16 +273,17 @@ def get(self): parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, required=True, nullable=False, location='args') + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") args = parser.parse_args() return ApiToolManageService.get_api_tool_provider( user_id, tenant_id, - args['provider'], + args["provider"], ) + class ToolBuiltinProviderCredentialsSchemaApi(Resource): @setup_required @login_required @@ -269,6 +291,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): def get(self, provider): return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider) + class ToolApiProviderSchemaApi(Resource): @setup_required @login_required @@ -276,14 +299,15 @@ class ToolApiProviderSchemaApi(Resource): def post(self): parser = reqparse.RequestParser() - parser.add_argument('schema', type=str, required=True, nullable=False, location='json') + parser.add_argument("schema", type=str, required=True, nullable=False, location="json") args = parser.parse_args() return ApiToolManageService.parser_api_schema( - schema=args['schema'], + schema=args["schema"], ) + class ToolApiProviderPreviousTestApi(Resource): @setup_required @login_required @@ -291,25 +315,26 @@ class ToolApiProviderPreviousTestApi(Resource): def post(self): parser = reqparse.RequestParser() - parser.add_argument('tool_name', type=str, required=True, nullable=False, location='json') - parser.add_argument('provider_name', type=str, required=False, nullable=False, location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') - parser.add_argument('parameters', type=dict, required=True, nullable=False, location='json') - parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json') - parser.add_argument('schema', type=str, required=True, nullable=False, location='json') + parser.add_argument("tool_name", type=str, required=True, nullable=False, location="json") + parser.add_argument("provider_name", type=str, required=False, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("parameters", type=dict, required=True, nullable=False, location="json") + parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") + parser.add_argument("schema", type=str, required=True, nullable=False, location="json") args = parser.parse_args() return ApiToolManageService.test_api_tool_preview( current_user.current_tenant_id, - args['provider_name'] if args['provider_name'] else '', - args['tool_name'], - args['credentials'], - args['parameters'], - args['schema_type'], - args['schema'], + args["provider_name"] or "", + args["tool_name"], + args["credentials"], + args["parameters"], + args["schema_type"], + args["schema"], ) + class ToolWorkflowProviderCreateApi(Resource): @setup_required @login_required @@ -317,35 +342,35 @@ class ToolWorkflowProviderCreateApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id reqparser = reqparse.RequestParser() - reqparser.add_argument('workflow_app_id', type=uuid_value, required=True, nullable=False, location='json') - reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json') - reqparser.add_argument('label', type=str, required=True, nullable=False, location='json') - reqparser.add_argument('description', type=str, required=True, nullable=False, location='json') - reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json') - reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json') - reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='') - reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') + reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") + reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") + reqparser.add_argument("label", type=str, required=True, nullable=False, location="json") + reqparser.add_argument("description", type=str, required=True, nullable=False, location="json") + reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json") + reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") + reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") + reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") args = reqparser.parse_args() return WorkflowToolManageService.create_workflow_tool( - user_id, - tenant_id, - args['workflow_app_id'], - args['name'], - args['label'], - args['icon'], - args['description'], - args['parameters'], - args['privacy_policy'], - args.get('labels', []), + user_id=user_id, + tenant_id=tenant_id, + workflow_app_id=args["workflow_app_id"], + name=args["name"], + label=args["label"], + icon=args["icon"], + description=args["description"], + parameters=args["parameters"], + privacy_policy=args["privacy_policy"], ) + class ToolWorkflowProviderUpdateApi(Resource): @setup_required @login_required @@ -353,38 +378,39 @@ class ToolWorkflowProviderUpdateApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id reqparser = reqparse.RequestParser() - reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json') - reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json') - reqparser.add_argument('label', type=str, required=True, nullable=False, location='json') - reqparser.add_argument('description', type=str, required=True, nullable=False, location='json') - reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json') - reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json') - reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='') - reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') - + reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") + reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") + reqparser.add_argument("label", type=str, required=True, nullable=False, location="json") + reqparser.add_argument("description", type=str, required=True, nullable=False, location="json") + reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json") + reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") + reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") + reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") + args = reqparser.parse_args() - if not args['workflow_tool_id']: - raise ValueError('incorrect workflow_tool_id') - + if not args["workflow_tool_id"]: + raise ValueError("incorrect workflow_tool_id") + return WorkflowToolManageService.update_workflow_tool( user_id, tenant_id, - args['workflow_tool_id'], - args['name'], - args['label'], - args['icon'], - args['description'], - args['parameters'], - args['privacy_policy'], - args.get('labels', []), + args["workflow_tool_id"], + args["name"], + args["label"], + args["icon"], + args["description"], + args["parameters"], + args["privacy_policy"], + args.get("labels", []), ) + class ToolWorkflowProviderDeleteApi(Resource): @setup_required @login_required @@ -392,21 +418,22 @@ class ToolWorkflowProviderDeleteApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id reqparser = reqparse.RequestParser() - reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json') + reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") args = reqparser.parse_args() return WorkflowToolManageService.delete_workflow_tool( user_id, tenant_id, - args['workflow_tool_id'], + args["workflow_tool_id"], ) - + + class ToolWorkflowProviderGetApi(Resource): @setup_required @login_required @@ -416,28 +443,29 @@ def get(self): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('workflow_tool_id', type=uuid_value, required=False, nullable=True, location='args') - parser.add_argument('workflow_app_id', type=uuid_value, required=False, nullable=True, location='args') + parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") + parser.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args") args = parser.parse_args() - if args.get('workflow_tool_id'): + if args.get("workflow_tool_id"): tool = WorkflowToolManageService.get_workflow_tool_by_tool_id( user_id, tenant_id, - args['workflow_tool_id'], + args["workflow_tool_id"], ) - elif args.get('workflow_app_id'): + elif args.get("workflow_app_id"): tool = WorkflowToolManageService.get_workflow_tool_by_app_id( user_id, tenant_id, - args['workflow_app_id'], + args["workflow_app_id"], ) else: - raise ValueError('incorrect workflow_tool_id or workflow_app_id') + raise ValueError("incorrect workflow_tool_id or workflow_app_id") return jsonable_encoder(tool) - + + class ToolWorkflowProviderListToolApi(Resource): @setup_required @login_required @@ -447,15 +475,18 @@ def get(self): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='args') + parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args") args = parser.parse_args() - return jsonable_encoder(WorkflowToolManageService.list_single_workflow_tools( - user_id, - tenant_id, - args['workflow_tool_id'], - )) + return jsonable_encoder( + WorkflowToolManageService.list_single_workflow_tools( + user_id, + tenant_id, + args["workflow_tool_id"], + ) + ) + class ToolBuiltinListApi(Resource): @setup_required @@ -465,11 +496,17 @@ def get(self): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder([provider.to_dict() for provider in BuiltinToolManageService.list_builtin_tools( - user_id, - tenant_id, - )]) - + return jsonable_encoder( + [ + provider.to_dict() + for provider in BuiltinToolManageService.list_builtin_tools( + user_id, + tenant_id, + ) + ] + ) + + class ToolApiListApi(Resource): @setup_required @login_required @@ -478,11 +515,17 @@ def get(self): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder([provider.to_dict() for provider in ApiToolManageService.list_api_tools( - user_id, - tenant_id, - )]) - + return jsonable_encoder( + [ + provider.to_dict() + for provider in ApiToolManageService.list_api_tools( + user_id, + tenant_id, + ) + ] + ) + + class ToolWorkflowListApi(Resource): @setup_required @login_required @@ -491,11 +534,17 @@ def get(self): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder([provider.to_dict() for provider in WorkflowToolManageService.list_tenant_workflow_tools( - user_id, - tenant_id, - )]) - + return jsonable_encoder( + [ + provider.to_dict() + for provider in WorkflowToolManageService.list_tenant_workflow_tools( + user_id, + tenant_id, + ) + ] + ) + + class ToolLabelsApi(Resource): @setup_required @login_required @@ -503,36 +552,41 @@ class ToolLabelsApi(Resource): def get(self): return jsonable_encoder(ToolLabelsService.list_tool_labels()) + # tool provider -api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers') +api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") # builtin tool provider -api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin//tools') -api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin//delete') -api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin//update') -api.add_resource(ToolBuiltinProviderGetCredentialsApi, '/workspaces/current/tool-provider/builtin//credentials') -api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin//credentials_schema') -api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin//icon') +api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin//tools") +api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin//delete") +api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin//update") +api.add_resource( + ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin//credentials" +) +api.add_resource( + ToolBuiltinProviderCredentialsSchemaApi, "/workspaces/current/tool-provider/builtin//credentials_schema" +) +api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin//icon") # api tool provider -api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add') -api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote') -api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools') -api.add_resource(ToolApiProviderUpdateApi, '/workspaces/current/tool-provider/api/update') -api.add_resource(ToolApiProviderDeleteApi, '/workspaces/current/tool-provider/api/delete') -api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/get') -api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema') -api.add_resource(ToolApiProviderPreviousTestApi, '/workspaces/current/tool-provider/api/test/pre') +api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add") +api.add_resource(ToolApiProviderGetRemoteSchemaApi, "/workspaces/current/tool-provider/api/remote") +api.add_resource(ToolApiProviderListToolsApi, "/workspaces/current/tool-provider/api/tools") +api.add_resource(ToolApiProviderUpdateApi, "/workspaces/current/tool-provider/api/update") +api.add_resource(ToolApiProviderDeleteApi, "/workspaces/current/tool-provider/api/delete") +api.add_resource(ToolApiProviderGetApi, "/workspaces/current/tool-provider/api/get") +api.add_resource(ToolApiProviderSchemaApi, "/workspaces/current/tool-provider/api/schema") +api.add_resource(ToolApiProviderPreviousTestApi, "/workspaces/current/tool-provider/api/test/pre") # workflow tool provider -api.add_resource(ToolWorkflowProviderCreateApi, '/workspaces/current/tool-provider/workflow/create') -api.add_resource(ToolWorkflowProviderUpdateApi, '/workspaces/current/tool-provider/workflow/update') -api.add_resource(ToolWorkflowProviderDeleteApi, '/workspaces/current/tool-provider/workflow/delete') -api.add_resource(ToolWorkflowProviderGetApi, '/workspaces/current/tool-provider/workflow/get') -api.add_resource(ToolWorkflowProviderListToolApi, '/workspaces/current/tool-provider/workflow/tools') +api.add_resource(ToolWorkflowProviderCreateApi, "/workspaces/current/tool-provider/workflow/create") +api.add_resource(ToolWorkflowProviderUpdateApi, "/workspaces/current/tool-provider/workflow/update") +api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provider/workflow/delete") +api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get") +api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools") -api.add_resource(ToolBuiltinListApi, '/workspaces/current/tools/builtin') -api.add_resource(ToolApiListApi, '/workspaces/current/tools/api') -api.add_resource(ToolWorkflowListApi, '/workspaces/current/tools/workflow') +api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin") +api.add_resource(ToolApiListApi, "/workspaces/current/tools/api") +api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow") -api.add_resource(ToolLabelsApi, '/workspaces/current/tool-labels') \ No newline at end of file +api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels") diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 7a11a45ae81900..76d76f6b58fc3c 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -6,6 +6,7 @@ from werkzeug.exceptions import Unauthorized import services +from controllers.common.errors import FilenameNotExistsError from controllers.console import api from controllers.console.admin import admin_required from controllers.console.datasets.error import ( @@ -15,8 +16,11 @@ UnsupportedFileTypeError, ) from controllers.console.error import AccountNotLinkTenantError -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_resource_check, + setup_required, +) from extensions.ext_database import db from libs.helper import TimestampField from libs.login import login_required @@ -26,39 +30,34 @@ from services.workspace_service import WorkspaceService provider_fields = { - 'provider_name': fields.String, - 'provider_type': fields.String, - 'is_valid': fields.Boolean, - 'token_is_set': fields.Boolean, + "provider_name": fields.String, + "provider_type": fields.String, + "is_valid": fields.Boolean, + "token_is_set": fields.Boolean, } tenant_fields = { - 'id': fields.String, - 'name': fields.String, - 'plan': fields.String, - 'status': fields.String, - 'created_at': TimestampField, - 'role': fields.String, - 'in_trial': fields.Boolean, - 'trial_end_reason': fields.String, - 'custom_config': fields.Raw(attribute='custom_config'), + "id": fields.String, + "name": fields.String, + "plan": fields.String, + "status": fields.String, + "created_at": TimestampField, + "role": fields.String, + "in_trial": fields.Boolean, + "trial_end_reason": fields.String, + "custom_config": fields.Raw(attribute="custom_config"), } tenants_fields = { - 'id': fields.String, - 'name': fields.String, - 'plan': fields.String, - 'status': fields.String, - 'created_at': TimestampField, - 'current': fields.Boolean + "id": fields.String, + "name": fields.String, + "plan": fields.String, + "status": fields.String, + "created_at": TimestampField, + "current": fields.Boolean, } -workspace_fields = { - 'id': fields.String, - 'name': fields.String, - 'status': fields.String, - 'created_at': TimestampField -} +workspace_fields = {"id": fields.String, "name": fields.String, "status": fields.String, "created_at": TimestampField} class TenantListApi(Resource): @@ -71,7 +70,7 @@ def get(self): for tenant in tenants: if tenant.id == current_user.current_tenant_id: tenant.current = True # Set current=True for current tenant - return {'workspaces': marshal(tenants, tenants_fields)}, 200 + return {"workspaces": marshal(tenants, tenants_fields)}, 200 class WorkspaceListApi(Resource): @@ -79,31 +78,37 @@ class WorkspaceListApi(Resource): @admin_required def get(self): parser = reqparse.RequestParser() - parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args') - parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") + parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc())\ - .paginate(page=args['page'], per_page=args['limit']) + tenants = ( + db.session.query(Tenant) + .order_by(Tenant.created_at.desc()) + .paginate(page=args["page"], per_page=args["limit"]) + ) has_more = False - if len(tenants.items) == args['limit']: + if len(tenants.items) == args["limit"]: current_page_first_tenant = tenants[-1] - rest_count = db.session.query(Tenant).filter( - Tenant.created_at < current_page_first_tenant.created_at, - Tenant.id != current_page_first_tenant.id - ).count() + rest_count = ( + db.session.query(Tenant) + .filter( + Tenant.created_at < current_page_first_tenant.created_at, Tenant.id != current_page_first_tenant.id + ) + .count() + ) if rest_count > 0: has_more = True total = db.session.query(Tenant).count() return { - 'data': marshal(tenants.items, workspace_fields), - 'has_more': has_more, - 'limit': args['limit'], - 'page': args['page'], - 'total': total - }, 200 + "data": marshal(tenants.items, workspace_fields), + "has_more": has_more, + "limit": args["limit"], + "page": args["page"], + "total": total, + }, 200 class TenantApi(Resource): @@ -112,8 +117,8 @@ class TenantApi(Resource): @account_initialization_required @marshal_with(tenant_fields) def get(self): - if request.path == '/info': - logging.warning('Deprecated URL /info was used.') + if request.path == "/info": + logging.warning("Deprecated URL /info was used.") tenant = current_user.current_tenant @@ -125,7 +130,7 @@ def get(self): tenant = tenants[0] # else, raise Unauthorized else: - raise Unauthorized('workspace is archived') + raise Unauthorized("workspace is archived") return WorkspaceService.get_tenant_info(tenant), 200 @@ -136,79 +141,89 @@ class SwitchWorkspaceApi(Resource): @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('tenant_id', type=str, required=True, location='json') + parser.add_argument("tenant_id", type=str, required=True, location="json") args = parser.parse_args() # check if tenant_id is valid, 403 if not try: - TenantService.switch_tenant(current_user, args['tenant_id']) + TenantService.switch_tenant(current_user, args["tenant_id"]) except Exception: raise AccountNotLinkTenantError("Account not link tenant") - new_tenant = db.session.query(Tenant).get(args['tenant_id']) # Get new tenant + new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant + + return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} - return {'result': 'success', 'new_tenant': marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} - class CustomConfigWorkspaceApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('workspace_custom') + @cloud_edition_billing_resource_check("workspace_custom") def post(self): parser = reqparse.RequestParser() - parser.add_argument('remove_webapp_brand', type=bool, location='json') - parser.add_argument('replace_webapp_logo', type=str, location='json') + parser.add_argument("remove_webapp_brand", type=bool, location="json") + parser.add_argument("replace_webapp_logo", type=str, location="json") args = parser.parse_args() tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404() custom_config_dict = { - 'remove_webapp_brand': args['remove_webapp_brand'], - 'replace_webapp_logo': args['replace_webapp_logo'] if args['replace_webapp_logo'] is not None else tenant.custom_config_dict.get('replace_webapp_logo') , + "remove_webapp_brand": args["remove_webapp_brand"], + "replace_webapp_logo": args["replace_webapp_logo"] + if args["replace_webapp_logo"] is not None + else tenant.custom_config_dict.get("replace_webapp_logo"), } tenant.custom_config_dict = custom_config_dict db.session.commit() - return {'result': 'success', 'tenant': marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} - + return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} + class WebappLogoWorkspaceApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('workspace_custom') + @cloud_edition_billing_resource_check("workspace_custom") def post(self): # get file from request - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() - extension = file.filename.split('.')[-1] - if extension.lower() not in ['svg', 'png']: + if not file.filename: + raise FilenameNotExistsError + + extension = file.filename.split(".")[-1] + if extension.lower() not in {"svg", "png"}: raise UnsupportedFileTypeError() try: - upload_file = FileService.upload_file(file, current_user, True) + upload_file = FileService.upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=current_user, + ) except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - - return { 'id': upload_file.id }, 201 + + return {"id": upload_file.id}, 201 -api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants -api.add_resource(WorkspaceListApi, '/all-workspaces') # GET for getting all tenants -api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info -api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated -api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant -api.add_resource(CustomConfigWorkspaceApi, '/workspaces/custom-config') -api.add_resource(WebappLogoWorkspaceApi, '/workspaces/custom-config/webapp-logo/upload') +api.add_resource(TenantListApi, "/workspaces") # GET for getting all tenants +api.add_resource(WorkspaceListApi, "/all-workspaces") # GET for getting all tenants +api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current") # GET for getting current tenant info +api.add_resource(TenantApi, "/info", endpoint="info") # Deprecated +api.add_resource(SwitchWorkspaceApi, "/workspaces/switch") # POST for switching tenant +api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config") +api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload") diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 3baf69acfd2d5f..9f294cb93c9bc0 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -1,4 +1,5 @@ import json +import os from functools import wraps from flask import abort, request @@ -6,9 +7,12 @@ from configs import dify_config from controllers.console.workspace.error import AccountNotInitializedError +from models.model import DifySetup from services.feature_service import FeatureService from services.operation_service import OperationService +from .error import NotInitValidateError, NotSetupError + def account_initialization_required(view): @wraps(view) @@ -16,7 +20,7 @@ def decorated(*args, **kwargs): # check account initialization account = current_user - if account.status == 'uninitialized': + if account.status == "uninitialized": raise AccountNotInitializedError() return view(*args, **kwargs) @@ -27,7 +31,7 @@ def decorated(*args, **kwargs): def only_edition_cloud(view): @wraps(view) def decorated(*args, **kwargs): - if dify_config.EDITION != 'CLOUD': + if dify_config.EDITION != "CLOUD": abort(404) return view(*args, **kwargs) @@ -38,7 +42,7 @@ def decorated(*args, **kwargs): def only_edition_self_hosted(view): @wraps(view) def decorated(*args, **kwargs): - if dify_config.EDITION != 'SELF_HOSTED': + if dify_config.EDITION != "SELF_HOSTED": abort(404) return view(*args, **kwargs) @@ -46,8 +50,7 @@ def decorated(*args, **kwargs): return decorated -def cloud_edition_billing_resource_check(resource: str, - error_msg: str = "You have reached the limit of your subscription."): +def cloud_edition_billing_resource_check(resource: str): def interceptor(view): @wraps(view) def decorated(*args, **kwargs): @@ -58,23 +61,24 @@ def decorated(*args, **kwargs): vector_space = features.vector_space documents_upload_quota = features.documents_upload_quota annotation_quota_limit = features.annotation_quota_limit - if resource == 'members' and 0 < members.limit <= members.size: - abort(403, error_msg) - elif resource == 'apps' and 0 < apps.limit <= apps.size: - abort(403, error_msg) - elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size: - abort(403, error_msg) - elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size: - # The api of file upload is used in the multiple places, so we need to check the source of the request from datasets - source = request.args.get('source') - if source == 'datasets': - abort(403, error_msg) + if resource == "members" and 0 < members.limit <= members.size: + abort(403, "The number of members has reached the limit of your subscription.") + elif resource == "apps" and 0 < apps.limit <= apps.size: + abort(403, "The number of apps has reached the limit of your subscription.") + elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: + abort(403, "The capacity of the vector space has reached the limit of your subscription.") + elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: + # The api of file upload is used in the multiple places, + # so we need to check the source of the request from datasets + source = request.args.get("source") + if source == "datasets": + abort(403, "The number of documents has reached the limit of your subscription.") else: return view(*args, **kwargs) - elif resource == 'workspace_custom' and not features.can_replace_logo: - abort(403, error_msg) - elif resource == 'annotation' and 0 < annotation_quota_limit.limit < annotation_quota_limit.size: - abort(403, error_msg) + elif resource == "workspace_custom" and not features.can_replace_logo: + abort(403, "The workspace custom feature has reached the limit of your subscription.") + elif resource == "annotation" and 0 < annotation_quota_limit.limit < annotation_quota_limit.size: + abort(403, "The annotation quota has reached the limit of your subscription.") else: return view(*args, **kwargs) @@ -85,16 +89,18 @@ def decorated(*args, **kwargs): return interceptor -def cloud_edition_billing_knowledge_limit_check(resource: str, - error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."): +def cloud_edition_billing_knowledge_limit_check(resource: str): def interceptor(view): @wraps(view) def decorated(*args, **kwargs): features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: - if resource == 'add_segment': - if features.billing.subscription.plan == 'sandbox': - abort(403, error_msg) + if resource == "add_segment": + if features.billing.subscription.plan == "sandbox": + abort( + 403, + "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.", + ) else: return view(*args, **kwargs) @@ -112,7 +118,7 @@ def decorated(*args, **kwargs): features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: - utm_info = request.cookies.get('utm_info') + utm_info = request.cookies.get("utm_info") if utm_info: utm_info = json.loads(utm_info) @@ -122,3 +128,17 @@ def decorated(*args, **kwargs): return view(*args, **kwargs) return decorated + + +def setup_required(view): + @wraps(view) + def decorated(*args, **kwargs): + # check setup + if dify_config.EDITION == "SELF_HOSTED" and os.environ.get("INIT_PASSWORD") and not DifySetup.query.first(): + raise NotInitValidateError() + elif dify_config.EDITION == "SELF_HOSTED" and not DifySetup.query.first(): + raise NotSetupError() + + return view(*args, **kwargs) + + return decorated diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py index 8d38ab9866a023..97d5c3f88fb522 100644 --- a/api/controllers/files/__init__.py +++ b/api/controllers/files/__init__.py @@ -2,7 +2,7 @@ from libs.external_api import ExternalApi -bp = Blueprint('files', __name__) +bp = Blueprint("files", __name__) api = ExternalApi(bp) diff --git a/api/controllers/files/error.py b/api/controllers/files/error.py new file mode 100644 index 00000000000000..a7ce4cd6f793e5 --- /dev/null +++ b/api/controllers/files/error.py @@ -0,0 +1,7 @@ +from libs.exception import BaseHTTPException + + +class UnsupportedFileTypeError(BaseHTTPException): + error_code = "unsupported_file_type" + description = "File type not allowed." + code = 415 diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index 247b5d45e190e5..6b3ac93cdf3d8f 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -1,47 +1,90 @@ from flask import Response, request -from flask_restful import Resource +from flask_restful import Resource, reqparse from werkzeug.exceptions import NotFound import services from controllers.files import api -from libs.exception import BaseHTTPException +from controllers.files.error import UnsupportedFileTypeError from services.account_service import TenantService from services.file_service import FileService class ImagePreviewApi(Resource): + """ + Deprecated + """ + def get(self, file_id): file_id = str(file_id) - timestamp = request.args.get('timestamp') - nonce = request.args.get('nonce') - sign = request.args.get('sign') + timestamp = request.args.get("timestamp") + nonce = request.args.get("nonce") + sign = request.args.get("sign") if not timestamp or not nonce or not sign: - return {'content': 'Invalid request.'}, 400 + return {"content": "Invalid request."}, 400 try: generator, mimetype = FileService.get_image_preview( - file_id, - timestamp, - nonce, - sign + file_id=file_id, + timestamp=timestamp, + nonce=nonce, + sign=sign, ) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() return Response(generator, mimetype=mimetype) - + + +class FilePreviewApi(Resource): + def get(self, file_id): + file_id = str(file_id) + + parser = reqparse.RequestParser() + parser.add_argument("timestamp", type=str, required=True, location="args") + parser.add_argument("nonce", type=str, required=True, location="args") + parser.add_argument("sign", type=str, required=True, location="args") + parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args") + + args = parser.parse_args() + + if not args["timestamp"] or not args["nonce"] or not args["sign"]: + return {"content": "Invalid request."}, 400 + + try: + generator, upload_file = FileService.get_file_generator_by_file_id( + file_id=file_id, + timestamp=args["timestamp"], + nonce=args["nonce"], + sign=args["sign"], + ) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() + + response = Response( + generator, + mimetype=upload_file.mime_type, + direct_passthrough=True, + headers={}, + ) + if upload_file.size > 0: + response.headers["Content-Length"] = str(upload_file.size) + if args["as_attachment"]: + response.headers["Content-Disposition"] = f"attachment; filename={upload_file.name}" + + return response + class WorkspaceWebappLogoApi(Resource): def get(self, workspace_id): workspace_id = str(workspace_id) custom_config = TenantService.get_custom_config(workspace_id) - webapp_logo_file_id = custom_config.get('replace_webapp_logo') if custom_config is not None else None + webapp_logo_file_id = custom_config.get("replace_webapp_logo") if custom_config is not None else None if not webapp_logo_file_id: - raise NotFound('webapp logo is not found') + raise NotFound("webapp logo is not found") try: generator, mimetype = FileService.get_public_image_preview( @@ -53,11 +96,6 @@ def get(self, workspace_id): return Response(generator, mimetype=mimetype) -api.add_resource(ImagePreviewApi, '/files//image-preview') -api.add_resource(WorkspaceWebappLogoApi, '/files/workspaces//webapp-logo') - - -class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' - description = "File type not allowed." - code = 415 +api.add_resource(ImagePreviewApi, "/files//image-preview") +api.add_resource(FilePreviewApi, "/files//file-preview") +api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces//webapp-logo") diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index 5a07ad2ea51800..a298701a2f8b11 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -3,8 +3,8 @@ from werkzeug.exceptions import Forbidden, NotFound from controllers.files import api +from controllers.files.error import UnsupportedFileTypeError from core.tools.tool_file_manager import ToolFileManager -from libs.exception import BaseHTTPException class ToolFilePreviewApi(Resource): @@ -13,36 +13,43 @@ def get(self, file_id, extension): parser = reqparse.RequestParser() - parser.add_argument('timestamp', type=str, required=True, location='args') - parser.add_argument('nonce', type=str, required=True, location='args') - parser.add_argument('sign', type=str, required=True, location='args') + parser.add_argument("timestamp", type=str, required=True, location="args") + parser.add_argument("nonce", type=str, required=True, location="args") + parser.add_argument("sign", type=str, required=True, location="args") + parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args") args = parser.parse_args() - if not ToolFileManager.verify_file(file_id=file_id, - timestamp=args['timestamp'], - nonce=args['nonce'], - sign=args['sign'], + if not ToolFileManager.verify_file( + file_id=file_id, + timestamp=args["timestamp"], + nonce=args["nonce"], + sign=args["sign"], ): - raise Forbidden('Invalid request.') - + raise Forbidden("Invalid request.") + try: - result = ToolFileManager.get_file_generator_by_tool_file_id( + stream, tool_file = ToolFileManager.get_file_generator_by_tool_file_id( file_id, ) - if not result: - raise NotFound('file is not found') - - generator, mimetype = result + if not stream or not tool_file: + raise NotFound("file is not found") except Exception: raise UnsupportedFileTypeError() - return Response(generator, mimetype=mimetype) + response = Response( + stream, + mimetype=tool_file.mimetype, + direct_passthrough=True, + headers={}, + ) + if tool_file.size > 0: + response.headers["Content-Length"] = str(tool_file.size) + if args["as_attachment"]: + response.headers["Content-Disposition"] = f"attachment; filename={tool_file.name}" + + return response -api.add_resource(ToolFilePreviewApi, '/files/tools/.') -class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' - description = "File type not allowed." - code = 415 +api.add_resource(ToolFilePreviewApi, "/files/tools/.") diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py index ad49a649caab66..9f124736a966ea 100644 --- a/api/controllers/inner_api/__init__.py +++ b/api/controllers/inner_api/__init__.py @@ -2,8 +2,7 @@ from libs.external_api import ExternalApi -bp = Blueprint('inner_api', __name__, url_prefix='/inner/api') +bp = Blueprint("inner_api", __name__, url_prefix="/inner/api") api = ExternalApi(bp) from .workspace import workspace - diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index 06610d89330837..99d32af593991f 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -1,6 +1,6 @@ from flask_restful import Resource, reqparse -from controllers.console.setup import setup_required +from controllers.console.wraps import setup_required from controllers.inner_api import api from controllers.inner_api.wraps import inner_api_only from events.tenant_event import tenant_was_created @@ -9,29 +9,24 @@ class EnterpriseWorkspace(Resource): - @setup_required @inner_api_only def post(self): parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('owner_email', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") + parser.add_argument("owner_email", type=str, required=True, location="json") args = parser.parse_args() - account = Account.query.filter_by(email=args['owner_email']).first() + account = Account.query.filter_by(email=args["owner_email"]).first() if account is None: - return { - 'message': 'owner account not found.' - }, 404 + return {"message": "owner account not found."}, 404 - tenant = TenantService.create_tenant(args['name']) - TenantService.create_tenant_member(tenant, account, role='owner') + tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True) + TenantService.create_tenant_member(tenant, account, role="owner") tenant_was_created.send(tenant) - return { - 'message': 'enterprise workspace created.' - } + return {"message": "enterprise workspace created."} -api.add_resource(EnterpriseWorkspace, '/enterprise/workspace') +api.add_resource(EnterpriseWorkspace, "/enterprise/workspace") diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index 5c37f5276f5562..51ffe683ff40ad 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -17,7 +17,7 @@ def decorated(*args, **kwargs): abort(404) # get header 'X-Inner-Api-Key' - inner_api_key = request.headers.get('X-Inner-Api-Key') + inner_api_key = request.headers.get("X-Inner-Api-Key") if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY: abort(401) @@ -33,29 +33,29 @@ def decorated(*args, **kwargs): return view(*args, **kwargs) # get header 'X-Inner-Api-Key' - authorization = request.headers.get('Authorization') + authorization = request.headers.get("Authorization") if not authorization: return view(*args, **kwargs) - parts = authorization.split(':') + parts = authorization.split(":") if len(parts) != 2: return view(*args, **kwargs) user_id, token = parts - if ' ' in user_id: - user_id = user_id.split(' ')[1] + if " " in user_id: + user_id = user_id.split(" ")[1] - inner_api_key = request.headers.get('X-Inner-Api-Key') + inner_api_key = request.headers.get("X-Inner-Api-Key") - data_to_sign = f'DIFY {user_id}' + data_to_sign = f"DIFY {user_id}" - signature = hmac_new(inner_api_key.encode('utf-8'), data_to_sign.encode('utf-8'), sha1) - signature = b64encode(signature.digest()).decode('utf-8') + signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1) + signature = b64encode(signature.digest()).decode("utf-8") if signature != token: return view(*args, **kwargs) - kwargs['user'] = db.session.query(EndUser).filter(EndUser.id == user_id).first() + kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first() return view(*args, **kwargs) diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index 082660a8915aa0..d6ab96c329f335 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -2,10 +2,9 @@ from libs.external_api import ExternalApi -bp = Blueprint('service_api', __name__, url_prefix='/v1') +bp = Blueprint("service_api", __name__, url_prefix="/v1") api = ExternalApi(bp) - from . import index from .app import app, audio, completion, conversation, file, message, workflow -from .dataset import dataset, document, segment +from .dataset import dataset, document, hit_testing, segment diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 3b3cf1b026d568..88b13faa52a69c 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,7 +1,7 @@ +from flask_restful import Resource, marshal_with -from flask_restful import Resource, fields, marshal_with - -from configs import dify_config +from controllers.common import fields +from controllers.common import helpers as controller_helpers from controllers.service_api import api from controllers.service_api.app.error import AppUnavailableError from controllers.service_api.wraps import validate_app_token @@ -12,40 +12,11 @@ class AppParameterApi(Resource): """Resource for app variables.""" - variable_fields = { - 'key': fields.String, - 'name': fields.String, - 'description': fields.String, - 'type': fields.String, - 'default': fields.String, - 'max_length': fields.Integer, - 'options': fields.List(fields.String) - } - - system_parameters_fields = { - 'image_file_size_limit': fields.String - } - - parameters_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw, - 'suggested_questions_after_answer': fields.Raw, - 'speech_to_text': fields.Raw, - 'text_to_speech': fields.Raw, - 'retriever_resource': fields.Raw, - 'annotation_reply': fields.Raw, - 'more_like_this': fields.Raw, - 'user_input_form': fields.Raw, - 'sensitive_word_avoidance': fields.Raw, - 'file_upload': fields.Raw, - 'system_parameters': fields.Nested(system_parameters_fields) - } - @validate_app_token - @marshal_with(parameters_fields) + @marshal_with(fields.parameters_fields) def get(self, app_model: App): """Retrieve app parameters.""" - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() @@ -54,33 +25,16 @@ def get(self, app_model: App): user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config + if app_model_config is None: + raise AppUnavailableError() + features_dict = app_model_config.to_dict() - user_input_form = features_dict.get('user_input_form', []) + user_input_form = features_dict.get("user_input_form", []) - return { - 'opening_statement': features_dict.get('opening_statement'), - 'suggested_questions': features_dict.get('suggested_questions', []), - 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', - {"enabled": False}), - 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), - 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), - 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), - 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), - 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), - 'user_input_form': user_input_form, - 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', - {"enabled": False, "type": "", "configs": []}), - 'file_upload': features_dict.get('file_upload', {"image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"] - }}), - 'system_parameters': { - 'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT - } - } + return controller_helpers.get_parameters_from_feature_dict( + features_dict=features_dict, user_input_form=user_input_form + ) class AppMetaApi(Resource): @@ -89,16 +43,14 @@ def get(self, app_model: App): """Get app meta""" return AppService().get_app_meta(app_model) + class AppInfoApi(Resource): @validate_app_token def get(self, app_model: App): - """Get app infomation""" - return { - 'name':app_model.name, - 'description':app_model.description - } + """Get app information""" + return {"name": app_model.name, "description": app_model.description} -api.add_resource(AppParameterApi, '/parameters') -api.add_resource(AppMetaApi, '/meta') -api.add_resource(AppInfoApi, '/info') +api.add_resource(AppParameterApi, "/parameters") +api.add_resource(AppMetaApi, "/meta") +api.add_resource(AppInfoApi, "/info") diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 3c009af343582d..5db41636471220 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -33,14 +33,10 @@ class AudioApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) def post(self, app_model: App, end_user: EndUser): - file = request.files['file'] + file = request.files["file"] try: - response = AudioService.transcript_asr( - app_model=app_model, - file=file, - end_user=end_user - ) + response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user) return response except services.errors.app_model_config.AppModelConfigBrokenError: @@ -74,30 +70,28 @@ class TextApi(Resource): def post(self, app_model: App, end_user: EndUser): try: parser = reqparse.RequestParser() - parser.add_argument('message_id', type=str, required=False, location='json') - parser.add_argument('voice', type=str, location='json') - parser.add_argument('text', type=str, location='json') - parser.add_argument('streaming', type=bool, location='json') + parser.add_argument("message_id", type=str, required=False, location="json") + parser.add_argument("voice", type=str, location="json") + parser.add_argument("text", type=str, location="json") + parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get('message_id', None) - text = args.get('text', None) - if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] - and app_model.workflow - and app_model.workflow.features_dict): - text_to_speech = app_model.workflow.features_dict.get('text_to_speech') - voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + message_id = args.get("message_id", None) + text = args.get("text", None) + if ( + app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} + and app_model.workflow + and app_model.workflow.features_dict + ): + text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + voice = args.get("voice") or text_to_speech.get("voice") else: try: - voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice') + voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") except Exception: voice = None response = AudioService.transcript_tts( - app_model=app_model, - message_id=message_id, - end_user=end_user.external_user_id, - voice=voice, - text=text + app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text ) return response @@ -127,5 +121,5 @@ def post(self, app_model: App, end_user: EndUser): raise InternalServerError() -api.add_resource(AudioApi, '/audio-to-text') -api.add_resource(TextApi, '/text-to-audio') +api.add_resource(AudioApi, "/audio-to-text") +api.add_resource(TextApi, "/text-to-audio") diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 2511f46bafacc6..8d8e356c4cb940 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -33,21 +33,21 @@ class CompletionApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise AppUnavailableError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, location='json', default='') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, location="json", default="") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' + streaming = args["response_mode"] == "streaming" - args['auto_generate_name'] = False + args["auto_generate_name"] = False try: response = AppGenerateService.generate( @@ -84,41 +84,37 @@ def post(self, app_model: App, end_user: EndUser): class CompletionStopApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser, task_id): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise AppUnavailableError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ChatApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, required=True, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') - parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') - parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, required=True, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") + parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' + streaming = args["response_mode"] == "streaming" try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.SERVICE_API, - streaming=streaming + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming ) return helper.compact_generate_response(response) @@ -148,15 +144,15 @@ class ChatStopApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser, task_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(CompletionApi, '/completion-messages') -api.add_resource(CompletionStopApi, '/completion-messages//stop') -api.add_resource(ChatApi, '/chat-messages') -api.add_resource(ChatStopApi, '/chat-messages//stop') +api.add_resource(CompletionApi, "/completion-messages") +api.add_resource(CompletionStopApi, "/completion-messages//stop") +api.add_resource(ChatApi, "/chat-messages") +api.add_resource(ChatStopApi, "/chat-messages//stop") diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 44bda8e771c97c..c62fd77d367aa6 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -7,33 +7,45 @@ from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.entities.app_invoke_entities import InvokeFrom -from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields +from fields.conversation_fields import ( + conversation_delete_fields, + conversation_infinite_scroll_pagination_fields, + simple_conversation_fields, +) from libs.helper import uuid_value from models.model import App, AppMode, EndUser from services.conversation_service import ConversationService class ConversationApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model: App, end_user: EndUser): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser.add_argument( + "sort_by", + type=str, + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + required=False, + default="-updated_at", + location="args", + ) args = parser.parse_args() try: return ConversationService.pagination_by_last_id( app_model=app_model, user=end_user, - last_id=args['last_id'], - limit=args['limit'], - invoke_from=InvokeFrom.SERVICE_API + last_id=args["last_id"], + limit=args["limit"], + invoke_from=InvokeFrom.SERVICE_API, + sort_by=args["sort_by"], ) except services.errors.conversation.LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -41,10 +53,10 @@ def get(self, app_model: App, end_user: EndUser): class ConversationDetailApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - @marshal_with(simple_conversation_fields) + @marshal_with(conversation_delete_fields) def delete(self, app_model: App, end_user: EndUser, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -53,37 +65,30 @@ def delete(self, app_model: App, end_user: EndUser, c_id): ConversationService.delete(app_model, conversation_id, end_user) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ConversationRenameApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) @marshal_with(simple_conversation_fields) def post(self, app_model: App, end_user: EndUser, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, location='json') - parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') + parser.add_argument("name", type=str, required=False, location="json") + parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") args = parser.parse_args() try: - return ConversationService.rename( - app_model, - conversation_id, - end_user, - args['name'], - args['auto_generate'] - ) + return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") -api.add_resource(ConversationRenameApi, '/conversations//name', endpoint='conversation_name') -api.add_resource(ConversationApi, '/conversations') -api.add_resource(ConversationDetailApi, '/conversations/', endpoint='conversation_detail') +api.add_resource(ConversationRenameApi, "/conversations//name", endpoint="conversation_name") +api.add_resource(ConversationApi, "/conversations") +api.add_resource(ConversationDetailApi, "/conversations/", endpoint="conversation_detail") diff --git a/api/controllers/service_api/app/error.py b/api/controllers/service_api/app/error.py index ac9edb1b4f6cbe..ca91da80c19f8e 100644 --- a/api/controllers/service_api/app/error.py +++ b/api/controllers/service_api/app/error.py @@ -2,104 +2,108 @@ class AppUnavailableError(BaseHTTPException): - error_code = 'app_unavailable' + error_code = "app_unavailable" description = "App unavailable, please check your app configurations." code = 400 class NotCompletionAppError(BaseHTTPException): - error_code = 'not_completion_app' + error_code = "not_completion_app" description = "Please check if your Completion app mode matches the right API route." code = 400 class NotChatAppError(BaseHTTPException): - error_code = 'not_chat_app' + error_code = "not_chat_app" description = "Please check if your app mode matches the right API route." code = 400 class NotWorkflowAppError(BaseHTTPException): - error_code = 'not_workflow_app' + error_code = "not_workflow_app" description = "Please check if your app mode matches the right API route." code = 400 class ConversationCompletedError(BaseHTTPException): - error_code = 'conversation_completed' + error_code = "conversation_completed" description = "The conversation has ended. Please start a new conversation." code = 400 class ProviderNotInitializeError(BaseHTTPException): - error_code = 'provider_not_initialize' - description = "No valid model provider credentials found. " \ - "Please go to Settings -> Model Provider to complete your provider credentials." + error_code = "provider_not_initialize" + description = ( + "No valid model provider credentials found. " + "Please go to Settings -> Model Provider to complete your provider credentials." + ) code = 400 class ProviderQuotaExceededError(BaseHTTPException): - error_code = 'provider_quota_exceeded' - description = "Your quota for Dify Hosted OpenAI has been exhausted. " \ - "Please go to Settings -> Model Provider to complete your own provider credentials." + error_code = "provider_quota_exceeded" + description = ( + "Your quota for Dify Hosted OpenAI has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials." + ) code = 400 class ProviderModelCurrentlyNotSupportError(BaseHTTPException): - error_code = 'model_currently_not_support' + error_code = "model_currently_not_support" description = "Dify Hosted OpenAI trial currently not support the GPT-4 model." code = 400 class CompletionRequestError(BaseHTTPException): - error_code = 'completion_request_error' + error_code = "completion_request_error" description = "Completion request failed." code = 400 class NoAudioUploadedError(BaseHTTPException): - error_code = 'no_audio_uploaded' + error_code = "no_audio_uploaded" description = "Please upload your audio." code = 400 class AudioTooLargeError(BaseHTTPException): - error_code = 'audio_too_large' + error_code = "audio_too_large" description = "Audio size exceeded. {message}" code = 413 class UnsupportedAudioTypeError(BaseHTTPException): - error_code = 'unsupported_audio_type' + error_code = "unsupported_audio_type" description = "Audio type not allowed." code = 415 class ProviderNotSupportSpeechToTextError(BaseHTTPException): - error_code = 'provider_not_support_speech_to_text' + error_code = "provider_not_support_speech_to_text" description = "Provider not support speech to text." code = 400 class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class FileTooLargeError(BaseHTTPException): - error_code = 'file_too_large' + error_code = "file_too_large" description = "File size exceeded. {message}" code = 413 class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index 5dbc1b1d1bbe0d..b0fd8e65ef97df 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -2,6 +2,7 @@ from flask_restful import Resource, marshal_with import services +from controllers.common.errors import FilenameNotExistsError from controllers.service_api import api from controllers.service_api.app.error import ( FileTooLargeError, @@ -16,15 +17,13 @@ class FileApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) @marshal_with(file_fields) def post(self, app_model: App, end_user: EndUser): - - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if not file.mimetype: @@ -33,8 +32,16 @@ def post(self, app_model: App, end_user: EndUser): if len(request.files) > 1: raise TooManyFilesError() + if not file.filename: + raise FilenameNotExistsError + try: - upload_file = FileService.upload_file(file, end_user) + upload_file = FileService.upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=end_user, + ) except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: @@ -43,4 +50,4 @@ def post(self, app_model: App, end_user: EndUser): return upload_file, 201 -api.add_resource(FileApi, '/files/upload') +api.add_resource(FileApi, "/files/upload") diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 875870e667c8d9..ada40ec9cb26bd 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -10,6 +10,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.entities.app_invoke_entities import InvokeFrom from fields.conversation_fields import message_file_fields +from fields.raws import FilesContainedField from libs.helper import TimestampField, uuid_value from models.model import App, AppMode, EndUser from services.errors.message import SuggestedQuestionsAfterAnswerDisabledError @@ -17,79 +18,79 @@ class MessageListApi(Resource): - feedback_fields = { - 'rating': fields.String - } + feedback_fields = {"rating": fields.String} retriever_resource_fields = { - 'id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'dataset_id': fields.String, - 'dataset_name': fields.String, - 'document_id': fields.String, - 'document_name': fields.String, - 'data_source_type': fields.String, - 'segment_id': fields.String, - 'score': fields.Float, - 'hit_count': fields.Integer, - 'word_count': fields.Integer, - 'segment_position': fields.Integer, - 'index_node_hash': fields.String, - 'content': fields.String, - 'created_at': TimestampField + "id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "dataset_id": fields.String, + "dataset_name": fields.String, + "document_id": fields.String, + "document_name": fields.String, + "data_source_type": fields.String, + "segment_id": fields.String, + "score": fields.Float, + "hit_count": fields.Integer, + "word_count": fields.Integer, + "segment_position": fields.Integer, + "index_node_hash": fields.String, + "content": fields.String, + "created_at": TimestampField, } agent_thought_fields = { - 'id': fields.String, - 'chain_id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'thought': fields.String, - 'tool': fields.String, - 'tool_labels': fields.Raw, - 'tool_input': fields.String, - 'created_at': TimestampField, - 'observation': fields.String, - 'message_files': fields.List(fields.String, attribute='files') + "id": fields.String, + "chain_id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "thought": fields.String, + "tool": fields.String, + "tool_labels": fields.Raw, + "tool_input": fields.String, + "created_at": TimestampField, + "observation": fields.String, + "message_files": fields.List(fields.Nested(message_file_fields)), } message_fields = { - 'id': fields.String, - 'conversation_id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String(attribute='re_sign_file_url_answer'), - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), - 'created_at': TimestampField, - 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), - 'status': fields.String, - 'error': fields.String, + "id": fields.String, + "conversation_id": fields.String, + "parent_message_id": fields.String, + "inputs": FilesContainedField, + "query": fields.String, + "answer": fields.String(attribute="re_sign_file_url_answer"), + "message_files": fields.List(fields.Nested(message_file_fields)), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), + "created_at": TimestampField, + "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "status": fields.String, + "error": fields.String, } message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model: App, end_user: EndUser): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') - parser.add_argument('first_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") + parser.add_argument("first_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() try: - return MessageService.pagination_by_first_id(app_model, end_user, - args['conversation_id'], args['first_id'], args['limit']) + return MessageService.pagination_by_first_id( + app_model, end_user, args["conversation_id"], args["first_id"], args["limit"] + ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.message.FirstMessageNotExistsError: @@ -102,15 +103,15 @@ def post(self, app_model: App, end_user: EndUser, message_id): message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args['rating']) + MessageService.create_feedback(app_model, message_id, end_user, args["rating"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class MessageSuggestedApi(Resource): @@ -118,15 +119,12 @@ class MessageSuggestedApi(Resource): def get(self, app_model: App, end_user: EndUser, message_id): message_id = str(message_id) app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, - user=end_user, - message_id=message_id, - invoke_from=InvokeFrom.SERVICE_API + app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API ) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -136,9 +134,9 @@ def get(self, app_model: App, end_user: EndUser, message_id): logging.exception("internal server error.") raise InternalServerError() - return {'result': 'success', 'data': questions} + return {"result": "success", "data": questions} -api.add_resource(MessageListApi, '/messages') -api.add_resource(MessageFeedbackApi, '/messages//feedbacks') -api.add_resource(MessageSuggestedApi, '/messages//suggested') +api.add_resource(MessageListApi, "/messages") +api.add_resource(MessageFeedbackApi, "/messages//feedbacks") +api.add_resource(MessageSuggestedApi, "/messages//suggested") diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 9446f9d5886fa2..96d1337632826a 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -1,6 +1,7 @@ import logging from flask_restful import Resource, fields, marshal_with, reqparse +from flask_restful.inputs import int_range from werkzeug.exceptions import InternalServerError from controllers.service_api import api @@ -22,27 +23,30 @@ ) from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db +from fields.workflow_app_log_fields import workflow_app_log_pagination_fields from libs import helper from models.model import App, AppMode, EndUser from models.workflow import WorkflowRun from services.app_generate_service import AppGenerateService +from services.workflow_app_service import WorkflowAppService logger = logging.getLogger(__name__) workflow_run_fields = { - 'id': fields.String, - 'workflow_id': fields.String, - 'status': fields.String, - 'inputs': fields.Raw, - 'outputs': fields.Raw, - 'error': fields.String, - 'total_steps': fields.Integer, - 'total_tokens': fields.Integer, - 'created_at': fields.DateTime, - 'finished_at': fields.DateTime, - 'elapsed_time': fields.Float, + "id": fields.String, + "workflow_id": fields.String, + "status": fields.String, + "inputs": fields.Raw, + "outputs": fields.Raw, + "error": fields.String, + "total_steps": fields.Integer, + "total_tokens": fields.Integer, + "created_at": fields.DateTime, + "finished_at": fields.DateTime, + "elapsed_time": fields.Float, } + class WorkflowRunDetailApi(Resource): @validate_app_token @marshal_with(workflow_run_fields) @@ -56,6 +60,8 @@ def get(self, app_model: App, workflow_id: str): workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_id).first() return workflow_run + + class WorkflowRunApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): @@ -67,20 +73,16 @@ def post(self, app_model: App, end_user: EndUser): raise NotWorkflowAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") args = parser.parse_args() - streaming = args.get('response_mode') == 'streaming' + streaming = args.get("response_mode") == "streaming" try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.SERVICE_API, - streaming=streaming + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming ) return helper.compact_generate_response(response) @@ -111,11 +113,33 @@ def post(self, app_model: App, end_user: EndUser, task_id: str): AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) - return { - "result": "success" - } + return {"result": "success"} + + +class WorkflowAppLogApi(Resource): + @validate_app_token + @marshal_with(workflow_app_log_pagination_fields) + def get(self, app_model: App): + """ + Get workflow app logs + """ + parser = reqparse.RequestParser() + parser.add_argument("keyword", type=str, location="args") + parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") + parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") + parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") + args = parser.parse_args() + + # get paginate workflow app logs + workflow_app_service = WorkflowAppService() + workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs( + app_model=app_model, args=args + ) + + return workflow_app_log_pagination -api.add_resource(WorkflowRunApi, '/workflows/run') -api.add_resource(WorkflowRunDetailApi, '/workflows/run/') -api.add_resource(WorkflowTaskStopApi, '/workflows/tasks//stop') +api.add_resource(WorkflowRunApi, "/workflows/run") +api.add_resource(WorkflowRunDetailApi, "/workflows/run/") +api.add_resource(WorkflowTaskStopApi, "/workflows/tasks//stop") +api.add_resource(WorkflowAppLogApi, "/workflows/logs") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 8dd16c0787cbca..799fccc228e21d 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -10,13 +10,13 @@ from core.provider_manager import ProviderManager from fields.dataset_fields import dataset_detail_fields from libs.login import current_user -from models.dataset import Dataset +from models.dataset import Dataset, DatasetPermissionEnum from services.dataset_service import DatasetService def _validate_name(name): if not name or len(name) < 1 or len(name) > 40: - raise ValueError('Name must be between 1 to 40 characters.') + raise ValueError("Name must be between 1 to 40 characters.") return name @@ -26,24 +26,18 @@ class DatasetListApi(DatasetApiResource): def get(self, tenant_id): """Resource for getting datasets.""" - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - provider = request.args.get('provider', default="vendor") - search = request.args.get('keyword', default=None, type=str) - tag_ids = request.args.getlist('tag_ids') + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + # provider = request.args.get("provider", default="vendor") + search = request.args.get("keyword", default=None, type=str) + tag_ids = request.args.getlist("tag_ids") - datasets, total = DatasetService.get_datasets(page, limit, provider, - tenant_id, current_user, search, tag_ids) + datasets, total = DatasetService.get_datasets(page, limit, tenant_id, current_user, search, tag_ids) # check embedding setting provider_manager = ProviderManager() - configurations = provider_manager.get_configurations( - tenant_id=current_user.current_tenant_id - ) + configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) - embedding_models = configurations.get_models( - model_type=ModelType.TEXT_EMBEDDING, - only_active=True - ) + embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) model_names = [] for embedding_model in embedding_models: @@ -51,47 +45,90 @@ def get(self, tenant_id): data = marshal(datasets, dataset_detail_fields) for item in data: - if item['indexing_technique'] == 'high_quality': + if item["indexing_technique"] == "high_quality": item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" if item_model in model_names: - item['embedding_available'] = True + item["embedding_available"] = True else: - item['embedding_available'] = False + item["embedding_available"] = False else: - item['embedding_available'] = True - response = { - 'data': data, - 'has_more': len(datasets) == limit, - 'limit': limit, - 'total': total, - 'page': page - } + item["embedding_available"] = True + response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} return response, 200 - def post(self, tenant_id): """Resource for creating datasets.""" parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, required=True, - help='type is required. Name must be between 1 to 40 characters.', - type=_validate_name) - parser.add_argument('indexing_technique', type=str, location='json', - choices=Dataset.INDEXING_TECHNIQUE_LIST, - help='Invalid indexing technique.') + parser.add_argument( + "name", + nullable=False, + required=True, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument( + "description", + type=str, + nullable=True, + required=False, + default="", + ) + parser.add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + help="Invalid indexing technique.", + ) + parser.add_argument( + "permission", + type=str, + location="json", + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + help="Invalid permission.", + required=False, + nullable=False, + ) + parser.add_argument( + "external_knowledge_api_id", + type=str, + nullable=True, + required=False, + default="_validate_name", + ) + parser.add_argument( + "provider", + type=str, + nullable=True, + required=False, + default="vendor", + ) + parser.add_argument( + "external_knowledge_id", + type=str, + nullable=True, + required=False, + ) args = parser.parse_args() try: dataset = DatasetService.create_empty_dataset( tenant_id=tenant_id, - name=args['name'], - indexing_technique=args['indexing_technique'], - account=current_user + name=args["name"], + description=args["description"], + indexing_technique=args["indexing_technique"], + account=current_user, + permission=args["permission"], + provider=args["provider"], + external_knowledge_api_id=args["external_knowledge_api_id"], + external_knowledge_id=args["external_knowledge_id"], ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() return marshal(dataset, dataset_detail_fields), 200 + class DatasetApi(DatasetApiResource): """Resource for dataset.""" @@ -103,7 +140,7 @@ def delete(self, _, dataset_id): dataset_id (UUID): The ID of the dataset to be deleted. Returns: - dict: A dictionary with a key 'result' and a value 'success' + dict: A dictionary with a key 'result' and a value 'success' if the dataset was successfully deleted. Omitted in HTTP response. int: HTTP status code 204 indicating that the operation was successful. @@ -115,11 +152,12 @@ def delete(self, _, dataset_id): try: if DatasetService.delete_dataset(dataset_id_str, current_user): - return {'result': 'success'}, 204 + return {"result": "success"}, 204 else: raise NotFound("Dataset not found.") except services.errors.dataset.DatasetInUseError: raise DatasetInUseError() -api.add_resource(DatasetListApi, '/datasets') -api.add_resource(DatasetApi, '/datasets/') + +api.add_resource(DatasetListApi, "/datasets") +api.add_resource(DatasetApi, "/datasets/") diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index ac1ea820a646ba..5c3fc7b241175a 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -6,6 +6,7 @@ from werkzeug.exceptions import NotFound import services.dataset_service +from controllers.common.errors import FilenameNotExistsError from controllers.service_api import api from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.dataset.error import ( @@ -27,47 +28,45 @@ class DocumentAddByTextApi(DatasetApiResource): """Resource for documents.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') - @cloud_edition_billing_resource_check('documents', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_resource_check("documents", "dataset") def post(self, tenant_id, dataset_id): """Create document by text.""" parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, nullable=False, location='json') - parser.add_argument('text', type=str, required=True, nullable=False, location='json') - parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json') - parser.add_argument('original_document_id', type=str, required=False, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') - parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, - location='json') - parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, - location='json') + parser.add_argument("name", type=str, required=True, nullable=False, location="json") + parser.add_argument("text", type=str, required=True, nullable=False, location="json") + parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") + parser.add_argument("original_document_id", type=str, required=False, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) + parser.add_argument( + "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" + ) + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') + raise ValueError("Dataset is not exist.") - if not dataset.indexing_technique and not args['indexing_technique']: - raise ValueError('indexing_technique is required.') + if not dataset.indexing_technique and not args["indexing_technique"]: + raise ValueError("indexing_technique is required.") - upload_file = FileService.upload_text(args.get('text'), args.get('name')) + text = args.get("text") + name = args.get("name") + if text is None or name is None: + raise ValueError("Both 'text' and 'name' must be non-null values.") + + upload_file = FileService.upload_text(text=str(text), text_name=str(name)) data_source = { - 'type': 'upload_file', - 'info_list': { - 'data_source_type': 'upload_file', - 'file_info_list': { - 'file_ids': [upload_file.id] - } - } + "type": "upload_file", + "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, } - args['data_source'] = data_source + args["data_source"] = data_source # validate args DocumentService.document_create_args_validate(args) @@ -76,60 +75,53 @@ def post(self, tenant_id, dataset_id): dataset=dataset, document_data=args, account=current_user, - dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, - created_from='api' + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) document = documents[0] - documents_and_batch_fields = { - 'document': marshal(document, document_fields), - 'batch': batch - } + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} return documents_and_batch_fields, 200 class DocumentUpdateByTextApi(DatasetApiResource): """Resource for update documents.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by text.""" parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, nullable=True, location='json') - parser.add_argument('text', type=str, required=False, nullable=True, location='json') - parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') - parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, - location='json') + parser.add_argument("name", type=str, required=False, nullable=True, location="json") + parser.add_argument("text", type=str, required=False, nullable=True, location="json") + parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') - - if args['text']: - upload_file = FileService.upload_text(args.get('text'), args.get('name')) + raise ValueError("Dataset is not exist.") + + if args["text"]: + text = args.get("text") + name = args.get("name") + if text is None or name is None: + raise ValueError("Both text and name must be strings.") + upload_file = FileService.upload_text(text=str(text), text_name=str(name)) data_source = { - 'type': 'upload_file', - 'info_list': { - 'data_source_type': 'upload_file', - 'file_info_list': { - 'file_ids': [upload_file.id] - } - } + "type": "upload_file", + "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, } - args['data_source'] = data_source + args["data_source"] = data_source # validate args - args['original_document_id'] = str(document_id) + args["original_document_id"] = str(document_id) DocumentService.document_create_args_validate(args) try: @@ -137,65 +129,62 @@ def post(self, tenant_id, dataset_id, document_id): dataset=dataset, document_data=args, account=current_user, - dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, - created_from='api' + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) document = documents[0] - documents_and_batch_fields = { - 'document': marshal(document, document_fields), - 'batch': batch - } + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} return documents_and_batch_fields, 200 class DocumentAddByFileApi(DatasetApiResource): """Resource for documents.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') - @cloud_edition_billing_resource_check('documents', 'dataset') + + @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_resource_check("documents", "dataset") def post(self, tenant_id, dataset_id): """Create document by upload file.""" args = {} - if 'data' in request.form: - args = json.loads(request.form['data']) - if 'doc_form' not in args: - args['doc_form'] = 'text_model' - if 'doc_language' not in args: - args['doc_language'] = 'English' + if "data" in request.form: + args = json.loads(request.form["data"]) + if "doc_form" not in args: + args["doc_form"] = "text_model" + if "doc_language" not in args: + args["doc_language"] = "English" # get dataset info dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') - if not dataset.indexing_technique and not args.get('indexing_technique'): - raise ValueError('indexing_technique is required.') + raise ValueError("Dataset is not exist.") + if not dataset.indexing_technique and not args.get("indexing_technique"): + raise ValueError("indexing_technique is required.") # save file info - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() - upload_file = FileService.upload_file(file, current_user) - data_source = { - 'type': 'upload_file', - 'info_list': { - 'file_info_list': { - 'file_ids': [upload_file.id] - } - } - } - args['data_source'] = data_source + if not file.filename: + raise FilenameNotExistsError + + upload_file = FileService.upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=current_user, + source="datasets", + ) + data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} + args["data_source"] = data_source # validate args DocumentService.document_create_args_validate(args) @@ -204,63 +193,58 @@ def post(self, tenant_id, dataset_id): dataset=dataset, document_data=args, account=dataset.created_by_account, - dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, - created_from='api' + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) document = documents[0] - documents_and_batch_fields = { - 'document': marshal(document, document_fields), - 'batch': batch - } + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} return documents_and_batch_fields, 200 class DocumentUpdateByFileApi(DatasetApiResource): """Resource for update documents.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by upload file.""" args = {} - if 'data' in request.form: - args = json.loads(request.form['data']) - if 'doc_form' not in args: - args['doc_form'] = 'text_model' - if 'doc_language' not in args: - args['doc_language'] = 'English' + if "data" in request.form: + args = json.loads(request.form["data"]) + if "doc_form" not in args: + args["doc_form"] = "text_model" + if "doc_language" not in args: + args["doc_language"] = "English" # get dataset info dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') - if 'file' in request.files: + raise ValueError("Dataset is not exist.") + if "file" in request.files: # save file info - file = request.files['file'] - + file = request.files["file"] if len(request.files) > 1: raise TooManyFilesError() - upload_file = FileService.upload_file(file, current_user) - data_source = { - 'type': 'upload_file', - 'info_list': { - 'file_info_list': { - 'file_ids': [upload_file.id] - } - } - } - args['data_source'] = data_source + if not file.filename: + raise FilenameNotExistsError + + upload_file = FileService.upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=current_user, + source="datasets", + ) + data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} + args["data_source"] = data_source # validate args - args['original_document_id'] = str(document_id) + args["original_document_id"] = str(document_id) DocumentService.document_create_args_validate(args) try: @@ -268,16 +252,13 @@ def post(self, tenant_id, dataset_id, document_id): dataset=dataset, document_data=args, account=dataset.created_by_account, - dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, - created_from='api' + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) document = documents[0] - documents_and_batch_fields = { - 'document': marshal(document, document_fields), - 'batch': batch - } + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch} return documents_and_batch_fields, 200 @@ -289,13 +270,10 @@ def delete(self, tenant_id, dataset_id, document_id): tenant_id = str(tenant_id) # get dataset info - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') + raise ValueError("Dataset is not exist.") document = DocumentService.get_document(dataset.id, document_id) @@ -311,44 +289,39 @@ def delete(self, tenant_id, dataset_id, document_id): # delete document DocumentService.delete_document(document) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Cannot delete document during indexing.') + raise DocumentIndexingError("Cannot delete document during indexing.") - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class DocumentListApi(DatasetApiResource): def get(self, tenant_id, dataset_id): dataset_id = str(dataset_id) tenant_id = str(tenant_id) - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - search = request.args.get('keyword', default=None, type=str) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + search = request.args.get("keyword", default=None, type=str) + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") - query = Document.query.filter_by( - dataset_id=str(dataset_id), tenant_id=tenant_id) + query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) if search: - search = f'%{search}%' + search = f"%{search}%" query = query.filter(Document.name.like(search)) query = query.order_by(desc(Document.created_at)) - paginated_documents = query.paginate( - page=page, per_page=limit, max_per_page=100, error_out=False) + paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items response = { - 'data': marshal(documents, document_fields), - 'has_more': len(documents) == limit, - 'limit': limit, - 'total': paginated_documents.total, - 'page': page + "data": marshal(documents, document_fields), + "has_more": len(documents) == limit, + "limit": limit, + "total": paginated_documents.total, + "page": page, } return response @@ -360,38 +333,52 @@ def get(self, tenant_id, dataset_id, batch): batch = str(batch) tenant_id = str(tenant_id) # get dataset - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # get documents documents = DocumentService.get_batch_documents(dataset_id, batch) if not documents: - raise NotFound('Documents not found.') + raise NotFound("Documents not found.") documents_status = [] for document in documents: - completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() - total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments if document.is_paused: - document.indexing_status = 'paused' + document.indexing_status = "paused" documents_status.append(marshal(document, document_status_fields)) - data = { - 'data': documents_status - } + data = {"data": documents_status} return data -api.add_resource(DocumentAddByTextApi, '/datasets//document/create_by_text') -api.add_resource(DocumentAddByFileApi, '/datasets//document/create_by_file') -api.add_resource(DocumentUpdateByTextApi, '/datasets//documents//update_by_text') -api.add_resource(DocumentUpdateByFileApi, '/datasets//documents//update_by_file') -api.add_resource(DocumentDeleteApi, '/datasets//documents/') -api.add_resource(DocumentListApi, '/datasets//documents') -api.add_resource(DocumentIndexingStatusApi, '/datasets//documents//indexing-status') +api.add_resource( + DocumentAddByTextApi, + "/datasets//document/create_by_text", + "/datasets//document/create-by-text", +) +api.add_resource( + DocumentAddByFileApi, + "/datasets//document/create_by_file", + "/datasets//document/create-by-file", +) +api.add_resource( + DocumentUpdateByTextApi, + "/datasets//documents//update_by_text", + "/datasets//documents//update-by-text", +) +api.add_resource( + DocumentUpdateByFileApi, + "/datasets//documents//update_by_file", + "/datasets//documents//update-by-file", +) +api.add_resource(DocumentDeleteApi, "/datasets//documents/") +api.add_resource(DocumentListApi, "/datasets//documents") +api.add_resource(DocumentIndexingStatusApi, "/datasets//documents//indexing-status") diff --git a/api/controllers/service_api/dataset/error.py b/api/controllers/service_api/dataset/error.py index e77693b6c9495c..5ff5e08c7245e1 100644 --- a/api/controllers/service_api/dataset/error.py +++ b/api/controllers/service_api/dataset/error.py @@ -2,78 +2,78 @@ class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class FileTooLargeError(BaseHTTPException): - error_code = 'file_too_large' + error_code = "file_too_large" description = "File size exceeded. {message}" code = 413 class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 class HighQualityDatasetOnlyError(BaseHTTPException): - error_code = 'high_quality_dataset_only' + error_code = "high_quality_dataset_only" description = "Current operation only supports 'high-quality' datasets." code = 400 class DatasetNotInitializedError(BaseHTTPException): - error_code = 'dataset_not_initialized' + error_code = "dataset_not_initialized" description = "The dataset is still being initialized or indexing. Please wait a moment." code = 400 class ArchivedDocumentImmutableError(BaseHTTPException): - error_code = 'archived_document_immutable' + error_code = "archived_document_immutable" description = "The archived document is not editable." code = 403 class DatasetNameDuplicateError(BaseHTTPException): - error_code = 'dataset_name_duplicate' + error_code = "dataset_name_duplicate" description = "The dataset name already exists. Please modify your dataset name." code = 409 class InvalidActionError(BaseHTTPException): - error_code = 'invalid_action' + error_code = "invalid_action" description = "Invalid action." code = 400 class DocumentAlreadyFinishedError(BaseHTTPException): - error_code = 'document_already_finished' + error_code = "document_already_finished" description = "The document has been processed. Please refresh the page or go to the document details." code = 400 class DocumentIndexingError(BaseHTTPException): - error_code = 'document_indexing' + error_code = "document_indexing" description = "The document is being processed and cannot be edited." code = 400 class InvalidMetadataError(BaseHTTPException): - error_code = 'invalid_metadata' + error_code = "invalid_metadata" description = "The metadata content is incorrect. Please check and verify." code = 400 class DatasetInUseError(BaseHTTPException): - error_code = 'dataset_in_use' + error_code = "dataset_in_use" description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it." code = 409 diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py new file mode 100644 index 00000000000000..465f71bf038eac --- /dev/null +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -0,0 +1,17 @@ +from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase +from controllers.service_api import api +from controllers.service_api.wraps import DatasetApiResource + + +class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): + def post(self, tenant_id, dataset_id): + dataset_id_str = str(dataset_id) + + dataset = self.get_and_validate_dataset(dataset_id_str) + args = self.parse_args() + self.hit_testing_args_check(args) + + return self.perform_hit_testing(dataset, args) + + +api.add_resource(HitTestingApi, "/datasets//hit-testing", "/datasets//retrieve") diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 0fa2aa65b26bdc..e68f6b4dc40a36 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -21,115 +21,106 @@ class SegmentApi(DatasetApiResource): """Resource for segments.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') - @cloud_edition_billing_knowledge_limit_check('add_segment', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") def post(self, tenant_id, dataset_id, document_id): """Create single segment.""" # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check document document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") + if document.indexing_status != "completed": + raise NotFound("Document is not completed.") + if not document.enabled: + raise NotFound("Document is disabled.") # check embedding model setting - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") - except ProviderTokenNotInitError as ex: + "in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # validate args parser = reqparse.RequestParser() - parser.add_argument('segments', type=list, required=False, nullable=True, location='json') + parser.add_argument("segments", type=list, required=False, nullable=True, location="json") args = parser.parse_args() - if args['segments'] is not None: - for args_item in args['segments']: + if args["segments"] is not None: + for args_item in args["segments"]: SegmentService.segment_create_args_validate(args_item, document) - segments = SegmentService.multi_create_segment(args['segments'], document, dataset) - return { - 'data': marshal(segments, segment_fields), - 'doc_form': document.doc_form - }, 200 + segments = SegmentService.multi_create_segment(args["segments"], document, dataset) + return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200 else: - return {"error": "Segemtns is required"}, 400 + return {"error": "Segments is required"}, 400 def get(self, tenant_id, dataset_id, document_id): """Create single segment.""" # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check document document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") # check embedding model setting - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) parser = reqparse.RequestParser() - parser.add_argument('status', type=str, - action='append', default=[], location='args') - parser.add_argument('keyword', type=str, default=None, location='args') + parser.add_argument("status", type=str, action="append", default=[], location="args") + parser.add_argument("keyword", type=str, default=None, location="args") args = parser.parse_args() - status_list = args['status'] - keyword = args['keyword'] + status_list = args["status"] + keyword = args["keyword"] query = DocumentSegment.query.filter( - DocumentSegment.document_id == str(document_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id ) if status_list: query = query.filter(DocumentSegment.status.in_(status_list)) if keyword: - query = query.where(DocumentSegment.content.ilike(f'%{keyword}%')) + query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) total = query.count() segments = query.order_by(DocumentSegment.position).all() - return { - 'data': marshal(segments, segment_fields), - 'doc_form': document.doc_form, - 'total': total - }, 200 + return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form, "total": total}, 200 class DatasetSegmentApi(DatasetApiResource): @@ -137,48 +128,41 @@ def delete(self, tenant_id, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") # check segment segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") SegmentService.delete_segment(segment, document, dataset) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 - @cloud_edition_billing_resource_check('vector_space', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") def post(self, tenant_id, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') - if dataset.indexing_technique == 'high_quality': + raise NotFound("Document not found.") + if dataset.indexing_technique == "high_quality": # check embedding model setting try: model_manager = ModelManager() @@ -186,35 +170,34 @@ def post(self, tenant_id, dataset_id, document_id, segment_id): tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # check segment segment_id = str(segment_id) segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") # validate args parser = reqparse.RequestParser() - parser.add_argument('segment', type=dict, required=False, nullable=True, location='json') + parser.add_argument("segment", type=dict, required=False, nullable=True, location="json") args = parser.parse_args() - SegmentService.segment_create_args_validate(args['segment'], document) - segment = SegmentService.update_segment(args['segment'], segment, document, dataset) - return { - 'data': marshal(segment, segment_fields), - 'doc_form': document.doc_form - }, 200 + SegmentService.segment_create_args_validate(args["segment"], document) + segment = SegmentService.update_segment(args["segment"], segment, document, dataset) + return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 -api.add_resource(SegmentApi, '/datasets//documents//segments') -api.add_resource(DatasetSegmentApi, '/datasets//documents//segments/') +api.add_resource(SegmentApi, "/datasets//documents//segments") +api.add_resource( + DatasetSegmentApi, "/datasets//documents//segments/" +) diff --git a/api/controllers/service_api/index.py b/api/controllers/service_api/index.py index c910063ebd83d1..d24c4597e210fb 100644 --- a/api/controllers/service_api/index.py +++ b/api/controllers/service_api/index.py @@ -13,4 +13,4 @@ def get(self): } -api.add_resource(IndexApi, '/') +api.add_resource(IndexApi, "/") diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 819512edf059bb..b935b23ed645c6 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -21,9 +21,10 @@ class WhereisUserArg(Enum): """ Enum for whereis_user_arg. """ - QUERY = 'query' - JSON = 'json' - FORM = 'form' + + QUERY = "query" + JSON = "json" + FORM = "form" class FetchUserArg(BaseModel): @@ -35,13 +36,13 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio def decorator(view_func): @wraps(view_func) def decorated_view(*args, **kwargs): - api_token = validate_and_get_api_token('app') + api_token = validate_and_get_api_token("app") app_model = db.session.query(App).filter(App.id == api_token.app_id).first() if not app_model: raise Forbidden("The app no longer exists.") - if app_model.status != 'normal': + if app_model.status != "normal": raise Forbidden("The app's status is abnormal.") if not app_model.enable_api: @@ -51,15 +52,15 @@ def decorated_view(*args, **kwargs): if tenant.status == TenantStatus.ARCHIVE: raise Forbidden("The workspace's status is archived.") - kwargs['app_model'] = app_model + kwargs["app_model"] = app_model if fetch_user_arg: if fetch_user_arg.fetch_from == WhereisUserArg.QUERY: - user_id = request.args.get('user') + user_id = request.args.get("user") elif fetch_user_arg.fetch_from == WhereisUserArg.JSON: - user_id = request.get_json().get('user') + user_id = request.get_json().get("user") elif fetch_user_arg.fetch_from == WhereisUserArg.FORM: - user_id = request.form.get('user') + user_id = request.form.get("user") else: # use default-user user_id = None @@ -70,9 +71,10 @@ def decorated_view(*args, **kwargs): if user_id: user_id = str(user_id) - kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id) + kwargs["end_user"] = create_or_update_end_user_for_user_id(app_model, user_id) return view_func(*args, **kwargs) + return decorated_view if view is None: @@ -81,9 +83,7 @@ def decorated_view(*args, **kwargs): return decorator(view) -def cloud_edition_billing_resource_check(resource: str, - api_token_type: str, - error_msg: str = "You have reached the limit of your subscription."): +def cloud_edition_billing_resource_check(resource: str, api_token_type: str): def interceptor(view): def decorated(*args, **kwargs): api_token = validate_and_get_api_token(api_token_type) @@ -95,34 +95,36 @@ def decorated(*args, **kwargs): vector_space = features.vector_space documents_upload_quota = features.documents_upload_quota - if resource == 'members' and 0 < members.limit <= members.size: - raise Forbidden(error_msg) - elif resource == 'apps' and 0 < apps.limit <= apps.size: - raise Forbidden(error_msg) - elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size: - raise Forbidden(error_msg) - elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size: - raise Forbidden(error_msg) + if resource == "members" and 0 < members.limit <= members.size: + raise Forbidden("The number of members has reached the limit of your subscription.") + elif resource == "apps" and 0 < apps.limit <= apps.size: + raise Forbidden("The number of apps has reached the limit of your subscription.") + elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: + raise Forbidden("The capacity of the vector space has reached the limit of your subscription.") + elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: + raise Forbidden("The number of documents has reached the limit of your subscription.") else: return view(*args, **kwargs) return view(*args, **kwargs) + return decorated + return interceptor -def cloud_edition_billing_knowledge_limit_check(resource: str, - api_token_type: str, - error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."): +def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str): def interceptor(view): @wraps(view) def decorated(*args, **kwargs): api_token = validate_and_get_api_token(api_token_type) features = FeatureService.get_features(api_token.tenant_id) if features.billing.enabled: - if resource == 'add_segment': - if features.billing.subscription.plan == 'sandbox': - raise Forbidden(error_msg) + if resource == "add_segment": + if features.billing.subscription.plan == "sandbox": + raise Forbidden( + "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan." + ) else: return view(*args, **kwargs) @@ -132,17 +134,20 @@ def decorated(*args, **kwargs): return interceptor + def validate_dataset_token(view=None): def decorator(view): @wraps(view) def decorated(*args, **kwargs): - api_token = validate_and_get_api_token('dataset') - tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \ - .filter(Tenant.id == api_token.tenant_id) \ - .filter(TenantAccountJoin.tenant_id == Tenant.id) \ - .filter(TenantAccountJoin.role.in_(['owner'])) \ - .filter(Tenant.status == TenantStatus.NORMAL) \ - .one_or_none() # TODO: only owner information is required, so only one is returned. + api_token = validate_and_get_api_token("dataset") + tenant_account_join = ( + db.session.query(Tenant, TenantAccountJoin) + .filter(Tenant.id == api_token.tenant_id) + .filter(TenantAccountJoin.tenant_id == Tenant.id) + .filter(TenantAccountJoin.role.in_(["owner"])) + .filter(Tenant.status == TenantStatus.NORMAL) + .one_or_none() + ) # TODO: only owner information is required, so only one is returned. if tenant_account_join: tenant, ta = tenant_account_join account = Account.query.filter_by(id=ta.account_id).first() @@ -156,6 +161,7 @@ def decorated(*args, **kwargs): else: raise Unauthorized("Tenant does not exist.") return view(api_token.tenant_id, *args, **kwargs) + return decorated if view: @@ -170,20 +176,24 @@ def validate_and_get_api_token(scope=None): """ Validate and get API token. """ - auth_header = request.headers.get('Authorization') - if auth_header is None or ' ' not in auth_header: + auth_header = request.headers.get("Authorization") + if auth_header is None or " " not in auth_header: raise Unauthorized("Authorization header must be provided and start with 'Bearer'") auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != 'bearer': + if auth_scheme != "bearer": raise Unauthorized("Authorization scheme must be 'Bearer'") - api_token = db.session.query(ApiToken).filter( - ApiToken.token == auth_token, - ApiToken.type == scope, - ).first() + api_token = ( + db.session.query(ApiToken) + .filter( + ApiToken.token == auth_token, + ApiToken.type == scope, + ) + .first() + ) if not api_token: raise Unauthorized("Access token is invalid") @@ -199,23 +209,26 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] Create or update session terminal based on user ID. """ if not user_id: - user_id = 'DEFAULT-USER' + user_id = "DEFAULT-USER" - end_user = db.session.query(EndUser) \ + end_user = ( + db.session.query(EndUser) .filter( - EndUser.tenant_id == app_model.tenant_id, - EndUser.app_id == app_model.id, - EndUser.session_id == user_id, - EndUser.type == 'service_api' - ).first() + EndUser.tenant_id == app_model.tenant_id, + EndUser.app_id == app_model.id, + EndUser.session_id == user_id, + EndUser.type == "service_api", + ) + .first() + ) if end_user is None: end_user = EndUser( tenant_id=app_model.tenant_id, app_id=app_model.id, - type='service_api', - is_anonymous=True if user_id == 'DEFAULT-USER' else False, - session_id=user_id + type="service_api", + is_anonymous=True if user_id == "DEFAULT-USER" else False, + session_id=user_id, ) db.session.add(end_user) db.session.commit() diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index aa19bdc0349fdc..50a04a625468e4 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -2,8 +2,17 @@ from libs.external_api import ExternalApi -bp = Blueprint('web', __name__, url_prefix='/api') +from .files import FileApi +from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi + +bp = Blueprint("web", __name__, url_prefix="/api") api = ExternalApi(bp) +# Files +api.add_resource(FileApi, "/files/upload") + +# Remote files +api.add_resource(RemoteFileInfoApi, "/remote-files/") +api.add_resource(RemoteFileUploadApi, "/remote-files/upload") -from . import app, audio, completion, conversation, feature, file, message, passport, saved_message, site, workflow +from . import app, audio, completion, conversation, feature, message, passport, saved_message, site, workflow diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index f4db82552c983a..cc8255ccf4e748 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,6 +1,7 @@ -from flask_restful import fields, marshal_with +from flask_restful import marshal_with -from configs import dify_config +from controllers.common import fields +from controllers.common import helpers as controller_helpers from controllers.web import api from controllers.web.error import AppUnavailableError from controllers.web.wraps import WebApiResource @@ -10,39 +11,11 @@ class AppParameterApi(WebApiResource): """Resource for app variables.""" - variable_fields = { - 'key': fields.String, - 'name': fields.String, - 'description': fields.String, - 'type': fields.String, - 'default': fields.String, - 'max_length': fields.Integer, - 'options': fields.List(fields.String) - } - system_parameters_fields = { - 'image_file_size_limit': fields.String - } - - parameters_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw, - 'suggested_questions_after_answer': fields.Raw, - 'speech_to_text': fields.Raw, - 'text_to_speech': fields.Raw, - 'retriever_resource': fields.Raw, - 'annotation_reply': fields.Raw, - 'more_like_this': fields.Raw, - 'user_input_form': fields.Raw, - 'sensitive_word_avoidance': fields.Raw, - 'file_upload': fields.Raw, - 'system_parameters': fields.Nested(system_parameters_fields) - } - - @marshal_with(parameters_fields) + @marshal_with(fields.parameters_fields) def get(self, app_model: App, end_user): """Retrieve app parameters.""" - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() @@ -51,33 +24,16 @@ def get(self, app_model: App, end_user): user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config + if app_model_config is None: + raise AppUnavailableError() + features_dict = app_model_config.to_dict() - user_input_form = features_dict.get('user_input_form', []) + user_input_form = features_dict.get("user_input_form", []) - return { - 'opening_statement': features_dict.get('opening_statement'), - 'suggested_questions': features_dict.get('suggested_questions', []), - 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', - {"enabled": False}), - 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), - 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), - 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), - 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), - 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), - 'user_input_form': user_input_form, - 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', - {"enabled": False, "type": "", "configs": []}), - 'file_upload': features_dict.get('file_upload', {"image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"] - }}), - 'system_parameters': { - 'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT - } - } + return controller_helpers.get_parameters_from_feature_dict( + features_dict=features_dict, user_input_form=user_input_form + ) class AppMeta(WebApiResource): @@ -86,5 +42,5 @@ def get(self, app_model: App, end_user): return AppService().get_app_meta(app_model) -api.add_resource(AppParameterApi, '/parameters') -api.add_resource(AppMeta, '/meta') \ No newline at end of file +api.add_resource(AppParameterApi, "/parameters") +api.add_resource(AppMeta, "/meta") diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 0e905f905a7a2f..23550efe2e2768 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -31,14 +31,10 @@ class AudioApi(WebApiResource): def post(self, app_model: App, end_user): - file = request.files['file'] + file = request.files["file"] try: - response = AudioService.transcript_asr( - app_model=app_model, - file=file, - end_user=end_user - ) + response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user) return response except services.errors.app_model_config.AppModelConfigBrokenError: @@ -70,34 +66,32 @@ def post(self, app_model: App, end_user): class TextApi(WebApiResource): def post(self, app_model: App, end_user): from flask_restful import reqparse + try: parser = reqparse.RequestParser() - parser.add_argument('message_id', type=str, required=False, location='json') - parser.add_argument('voice', type=str, location='json') - parser.add_argument('text', type=str, location='json') - parser.add_argument('streaming', type=bool, location='json') + parser.add_argument("message_id", type=str, required=False, location="json") + parser.add_argument("voice", type=str, location="json") + parser.add_argument("text", type=str, location="json") + parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get('message_id', None) - text = args.get('text', None) - if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] - and app_model.workflow - and app_model.workflow.features_dict): - text_to_speech = app_model.workflow.features_dict.get('text_to_speech') - voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + message_id = args.get("message_id", None) + text = args.get("text", None) + if ( + app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} + and app_model.workflow + and app_model.workflow.features_dict + ): + text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + voice = args.get("voice") or text_to_speech.get("voice") else: try: - voice = args.get('voice') if args.get( - 'voice') else app_model.app_model_config.text_to_speech_dict.get('voice') + voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") except Exception: voice = None response = AudioService.transcript_tts( - app_model=app_model, - message_id=message_id, - end_user=end_user.external_user_id, - voice=voice, - text=text + app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text ) return response @@ -127,5 +121,5 @@ def post(self, app_model: App, end_user): raise InternalServerError() -api.add_resource(AudioApi, '/audio-to-text') -api.add_resource(TextApi, '/text-to-audio') +api.add_resource(AudioApi, "/audio-to-text") +api.add_resource(TextApi, "/text-to-audio") diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 948d5fabb5328d..45b890dfc4899d 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -15,6 +15,7 @@ ProviderNotInitializeError, ProviderQuotaExceededError, ) +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.wraps import WebApiResource from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom @@ -24,34 +25,30 @@ from libs.helper import uuid_value from models.model import AppMode from services.app_generate_service import AppGenerateService +from services.errors.llm import InvokeRateLimitError # define completion api for user class CompletionApi(WebApiResource): - def post(self, app_model, end_user): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, location='json', default='') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, location="json", default="") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' - args['auto_generate_name'] = False + streaming = args["response_mode"] == "streaming" + args["auto_generate_name"] = False try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.WEB_APP, - streaming=streaming + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=streaming ) return helper.compact_generate_response(response) @@ -79,40 +76,37 @@ def post(self, app_model, end_user): class CompletionStopApi(WebApiResource): def post(self, app_model, end_user, task_id): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ChatApi(WebApiResource): def post(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, required=True, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') - parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, required=True, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json") + parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' - args['auto_generate_name'] = False + streaming = args["response_mode"] == "streaming" + args["auto_generate_name"] = False try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.WEB_APP, - streaming=streaming + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=streaming ) return helper.compact_generate_response(response) @@ -129,6 +123,8 @@ def post(self, app_model, end_user): raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() + except InvokeRateLimitError as ex: + raise InvokeRateLimitHttpError(ex.description) except InvokeError as e: raise CompletionRequestError(e.description) except ValueError as e: @@ -141,15 +137,15 @@ def post(self, app_model, end_user): class ChatStopApi(WebApiResource): def post(self, app_model, end_user, task_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(CompletionApi, '/completion-messages') -api.add_resource(CompletionStopApi, '/completion-messages//stop') -api.add_resource(ChatApi, '/chat-messages') -api.add_resource(ChatStopApi, '/chat-messages//stop') +api.add_resource(CompletionApi, "/completion-messages") +api.add_resource(CompletionStopApi, "/completion-messages//stop") +api.add_resource(ChatApi, "/chat-messages") +api.add_resource(ChatStopApi, "/chat-messages//stop") diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index b83ea3a52596b1..c3b0cd4f44b2ac 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -15,31 +15,39 @@ class ConversationListApi(WebApiResource): - @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args") + parser.add_argument( + "sort_by", + type=str, + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + required=False, + default="-updated_at", + location="args", + ) args = parser.parse_args() pinned = None - if 'pinned' in args and args['pinned'] is not None: - pinned = True if args['pinned'] == 'true' else False + if "pinned" in args and args["pinned"] is not None: + pinned = True if args["pinned"] == "true" else False try: return WebConversationService.pagination_by_last_id( app_model=app_model, user=end_user, - last_id=args['last_id'], - limit=args['limit'], + last_id=args["last_id"], + limit=args["limit"], invoke_from=InvokeFrom.WEB_APP, pinned=pinned, + sort_by=args["sort_by"], ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -48,7 +56,7 @@ def get(self, app_model, end_user): class ConversationApi(WebApiResource): def delete(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -62,37 +70,29 @@ def delete(self, app_model, end_user, c_id): class ConversationRenameApi(WebApiResource): - @marshal_with(simple_conversation_fields) def post(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, location='json') - parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') + parser.add_argument("name", type=str, required=False, location="json") + parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") args = parser.parse_args() try: - return ConversationService.rename( - app_model, - conversation_id, - end_user, - args['name'], - args['auto_generate'] - ) + return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") class ConversationPinApi(WebApiResource): - def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -108,7 +108,7 @@ def patch(self, app_model, end_user, c_id): class ConversationUnPinApi(WebApiResource): def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -117,8 +117,8 @@ def patch(self, app_model, end_user, c_id): return {"result": "success"} -api.add_resource(ConversationRenameApi, '/conversations//name', endpoint='web_conversation_name') -api.add_resource(ConversationListApi, '/conversations') -api.add_resource(ConversationApi, '/conversations/') -api.add_resource(ConversationPinApi, '/conversations//pin') -api.add_resource(ConversationUnPinApi, '/conversations//unpin') +api.add_resource(ConversationRenameApi, "/conversations//name", endpoint="web_conversation_name") +api.add_resource(ConversationListApi, "/conversations") +api.add_resource(ConversationApi, "/conversations/") +api.add_resource(ConversationPinApi, "/conversations//pin") +api.add_resource(ConversationUnPinApi, "/conversations//unpin") diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index bc87f510512a83..9fe5d08d54a12d 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -2,122 +2,134 @@ class AppUnavailableError(BaseHTTPException): - error_code = 'app_unavailable' + error_code = "app_unavailable" description = "App unavailable, please check your app configurations." code = 400 class NotCompletionAppError(BaseHTTPException): - error_code = 'not_completion_app' + error_code = "not_completion_app" description = "Please check if your Completion app mode matches the right API route." code = 400 class NotChatAppError(BaseHTTPException): - error_code = 'not_chat_app' + error_code = "not_chat_app" description = "Please check if your app mode matches the right API route." code = 400 class NotWorkflowAppError(BaseHTTPException): - error_code = 'not_workflow_app' + error_code = "not_workflow_app" description = "Please check if your Workflow app mode matches the right API route." code = 400 class ConversationCompletedError(BaseHTTPException): - error_code = 'conversation_completed' + error_code = "conversation_completed" description = "The conversation has ended. Please start a new conversation." code = 400 class ProviderNotInitializeError(BaseHTTPException): - error_code = 'provider_not_initialize' - description = "No valid model provider credentials found. " \ - "Please go to Settings -> Model Provider to complete your provider credentials." + error_code = "provider_not_initialize" + description = ( + "No valid model provider credentials found. " + "Please go to Settings -> Model Provider to complete your provider credentials." + ) code = 400 class ProviderQuotaExceededError(BaseHTTPException): - error_code = 'provider_quota_exceeded' - description = "Your quota for Dify Hosted OpenAI has been exhausted. " \ - "Please go to Settings -> Model Provider to complete your own provider credentials." + error_code = "provider_quota_exceeded" + description = ( + "Your quota for Dify Hosted OpenAI has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials." + ) code = 400 class ProviderModelCurrentlyNotSupportError(BaseHTTPException): - error_code = 'model_currently_not_support' + error_code = "model_currently_not_support" description = "Dify Hosted OpenAI trial currently not support the GPT-4 model." code = 400 class CompletionRequestError(BaseHTTPException): - error_code = 'completion_request_error' + error_code = "completion_request_error" description = "Completion request failed." code = 400 class AppMoreLikeThisDisabledError(BaseHTTPException): - error_code = 'app_more_like_this_disabled' + error_code = "app_more_like_this_disabled" description = "The 'More like this' feature is disabled. Please refresh your page." code = 403 class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException): - error_code = 'app_suggested_questions_after_answer_disabled' + error_code = "app_suggested_questions_after_answer_disabled" description = "The 'Suggested Questions After Answer' feature is disabled. Please refresh your page." code = 403 class NoAudioUploadedError(BaseHTTPException): - error_code = 'no_audio_uploaded' + error_code = "no_audio_uploaded" description = "Please upload your audio." code = 400 class AudioTooLargeError(BaseHTTPException): - error_code = 'audio_too_large' + error_code = "audio_too_large" description = "Audio size exceeded. {message}" code = 413 class UnsupportedAudioTypeError(BaseHTTPException): - error_code = 'unsupported_audio_type' + error_code = "unsupported_audio_type" description = "Audio type not allowed." code = 415 class ProviderNotSupportSpeechToTextError(BaseHTTPException): - error_code = 'provider_not_support_speech_to_text' + error_code = "provider_not_support_speech_to_text" description = "Provider not support speech to text." code = 400 class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class FileTooLargeError(BaseHTTPException): - error_code = 'file_too_large' + error_code = "file_too_large" description = "File size exceeded. {message}" code = 413 class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 class WebSSOAuthRequiredError(BaseHTTPException): - error_code = 'web_sso_auth_required' + error_code = "web_sso_auth_required" description = "Web SSO authentication required." code = 401 + + +class InvokeRateLimitError(BaseHTTPException): + """Raised when the Invoke returns rate limit error.""" + + error_code = "rate_limit_error" + description = "Rate Limit Error" + code = 429 diff --git a/api/controllers/web/feature.py b/api/controllers/web/feature.py index 69b38faaf655e8..0563ed22382e6b 100644 --- a/api/controllers/web/feature.py +++ b/api/controllers/web/feature.py @@ -9,4 +9,4 @@ def get(self): return FeatureService.get_system_features().model_dump() -api.add_resource(SystemFeatureApi, '/system-features') +api.add_resource(SystemFeatureApi, "/system-features") diff --git a/api/controllers/web/file.py b/api/controllers/web/file.py deleted file mode 100644 index ca83f6037a3cfa..00000000000000 --- a/api/controllers/web/file.py +++ /dev/null @@ -1,35 +0,0 @@ -from flask import request -from flask_restful import marshal_with - -import services -from controllers.web import api -from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError -from controllers.web.wraps import WebApiResource -from fields.file_fields import file_fields -from services.file_service import FileService - - -class FileApi(WebApiResource): - - @marshal_with(file_fields) - def post(self, app_model, end_user): - # get file from request - file = request.files['file'] - - # check file - if 'file' not in request.files: - raise NoFileUploadedError() - - if len(request.files) > 1: - raise TooManyFilesError() - try: - upload_file = FileService.upload_file(file, end_user) - except services.errors.file.FileTooLargeError as file_too_large_error: - raise FileTooLargeError(file_too_large_error.description) - except services.errors.file.UnsupportedFileTypeError: - raise UnsupportedFileTypeError() - - return upload_file, 201 - - -api.add_resource(FileApi, '/files/upload') diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py new file mode 100644 index 00000000000000..a282fc63a8b056 --- /dev/null +++ b/api/controllers/web/files.py @@ -0,0 +1,43 @@ +from flask import request +from flask_restful import marshal_with + +import services +from controllers.common.errors import FilenameNotExistsError +from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError +from controllers.web.wraps import WebApiResource +from fields.file_fields import file_fields +from services.file_service import FileService + + +class FileApi(WebApiResource): + @marshal_with(file_fields) + def post(self, app_model, end_user): + file = request.files["file"] + source = request.form.get("source") + + if "file" not in request.files: + raise NoFileUploadedError() + + if len(request.files) > 1: + raise TooManyFilesError() + + if not file.filename: + raise FilenameNotExistsError + + if source not in ("datasets", None): + source = None + + try: + upload_file = FileService.upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=end_user, + source=source, + ) + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() + + return upload_file, 201 diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 865d2270ad8d10..98891f5d00d7e0 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -22,6 +22,7 @@ from core.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import message_file_fields from fields.message_fields import agent_thought_fields +from fields.raws import FilesContainedField from libs import helper from libs.helper import TimestampField, uuid_value from models.model import AppMode @@ -33,65 +34,65 @@ class MessageListApi(WebApiResource): - feedback_fields = { - 'rating': fields.String - } + feedback_fields = {"rating": fields.String} retriever_resource_fields = { - 'id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'dataset_id': fields.String, - 'dataset_name': fields.String, - 'document_id': fields.String, - 'document_name': fields.String, - 'data_source_type': fields.String, - 'segment_id': fields.String, - 'score': fields.Float, - 'hit_count': fields.Integer, - 'word_count': fields.Integer, - 'segment_position': fields.Integer, - 'index_node_hash': fields.String, - 'content': fields.String, - 'created_at': TimestampField + "id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "dataset_id": fields.String, + "dataset_name": fields.String, + "document_id": fields.String, + "document_name": fields.String, + "data_source_type": fields.String, + "segment_id": fields.String, + "score": fields.Float, + "hit_count": fields.Integer, + "word_count": fields.Integer, + "segment_position": fields.Integer, + "index_node_hash": fields.String, + "content": fields.String, + "created_at": TimestampField, } message_fields = { - 'id': fields.String, - 'conversation_id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String(attribute='re_sign_file_url_answer'), - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), - 'created_at': TimestampField, - 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), - 'status': fields.String, - 'error': fields.String, + "id": fields.String, + "conversation_id": fields.String, + "parent_message_id": fields.String, + "inputs": FilesContainedField, + "query": fields.String, + "answer": fields.String(attribute="re_sign_file_url_answer"), + "message_files": fields.List(fields.Nested(message_file_fields)), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), + "created_at": TimestampField, + "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "status": fields.String, + "error": fields.String, } message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') - parser.add_argument('first_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") + parser.add_argument("first_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() try: - return MessageService.pagination_by_first_id(app_model, end_user, - args['conversation_id'], args['first_id'], args['limit']) + return MessageService.pagination_by_first_id( + app_model, end_user, args["conversation_id"], args["first_id"], args["limit"], "desc" + ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.message.FirstMessageNotExistsError: @@ -103,29 +104,31 @@ def post(self, app_model, end_user, message_id): message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args['rating']) + MessageService.create_feedback(app_model, message_id, end_user, args["rating"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class MessageMoreLikeThisApi(WebApiResource): def get(self, app_model, end_user, message_id): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') + parser.add_argument( + "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" + ) args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' + streaming = args["response_mode"] == "streaming" try: response = AppGenerateService.generate_more_like_this( @@ -133,7 +136,7 @@ def get(self, app_model, end_user, message_id): user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP, - streaming=streaming + streaming=streaming, ) return helper.compact_generate_response(response) @@ -159,17 +162,14 @@ def get(self, app_model, end_user, message_id): class MessageSuggestedQuestionApi(WebApiResource): def get(self, app_model, end_user, message_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotCompletionAppError() message_id = str(message_id) try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, - user=end_user, - message_id=message_id, - invoke_from=InvokeFrom.WEB_APP + app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP ) except MessageNotExistsError: raise NotFound("Message not found") @@ -189,10 +189,10 @@ def get(self, app_model, end_user, message_id): logging.exception("internal server error.") raise InternalServerError() - return {'data': questions} + return {"data": questions} -api.add_resource(MessageListApi, '/messages') -api.add_resource(MessageFeedbackApi, '/messages//feedbacks') -api.add_resource(MessageMoreLikeThisApi, '/messages//more-like-this') -api.add_resource(MessageSuggestedQuestionApi, '/messages//suggested-questions') +api.add_resource(MessageListApi, "/messages") +api.add_resource(MessageFeedbackApi, "/messages//feedbacks") +api.add_resource(MessageMoreLikeThisApi, "/messages//more-like-this") +api.add_resource(MessageSuggestedQuestionApi, "/messages//suggested-questions") diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index ccc8683a799dc5..a01ffd861230a5 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -9,37 +9,37 @@ from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site +from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService class PassportResource(Resource): """Base resource for passport.""" - def get(self): + def get(self): system_features = FeatureService.get_system_features() - if system_features.sso_enforced_for_web: - raise WebSSOAuthRequiredError() - - app_code = request.headers.get('X-App-Code') + app_code = request.headers.get("X-App-Code") if app_code is None: - raise Unauthorized('X-App-Code header is missing.') + raise Unauthorized("X-App-Code header is missing.") + + if system_features.sso_enforced_for_web: + app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) + if app_web_sso_enabled: + raise WebSSOAuthRequiredError() # get site from db and check if it is normal - site = db.session.query(Site).filter( - Site.code == app_code, - Site.status == 'normal' - ).first() + site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() if not site: raise NotFound() # get app from db and check if it is normal and enable_site app_model = db.session.query(App).filter(App.id == site.app_id).first() - if not app_model or app_model.status != 'normal' or not app_model.enable_site: + if not app_model or app_model.status != "normal" or not app_model.enable_site: raise NotFound() end_user = EndUser( tenant_id=app_model.tenant_id, app_id=app_model.id, - type='browser', + type="browser", is_anonymous=True, session_id=generate_session_id(), ) @@ -49,20 +49,20 @@ def get(self): payload = { "iss": site.app_id, - 'sub': 'Web API Passport', - 'app_id': site.app_id, - 'app_code': app_code, - 'end_user_id': end_user.id, + "sub": "Web API Passport", + "app_id": site.app_id, + "app_code": app_code, + "end_user_id": end_user.id, } tk = PassportService().issue(payload) return { - 'access_token': tk, + "access_token": tk, } -api.add_resource(PassportResource, '/passport') +api.add_resource(PassportResource, "/passport") def generate_session_id(): @@ -71,7 +71,6 @@ def generate_session_id(): """ while True: session_id = str(uuid.uuid4()) - existing_count = db.session.query(EndUser) \ - .filter(EndUser.session_id == session_id).count() + existing_count = db.session.query(EndUser).filter(EndUser.session_id == session_id).count() if existing_count == 0: return session_id diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py new file mode 100644 index 00000000000000..d6b8eb2855725c --- /dev/null +++ b/api/controllers/web/remote_files.py @@ -0,0 +1,75 @@ +import urllib.parse + +import httpx +from flask_restful import marshal_with, reqparse + +import services +from controllers.common import helpers +from controllers.web.wraps import WebApiResource +from core.file import helpers as file_helpers +from core.helper import ssrf_proxy +from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields +from services.file_service import FileService + +from .error import FileTooLargeError, UnsupportedFileTypeError + + +class RemoteFileInfoApi(WebApiResource): + @marshal_with(remote_file_info_fields) + def get(self, app_model, end_user, url): + decoded_url = urllib.parse.unquote(url) + resp = ssrf_proxy.head(decoded_url) + if resp.status_code != httpx.codes.OK: + # failed back to get method + resp = ssrf_proxy.get(decoded_url, timeout=3) + resp.raise_for_status() + return { + "file_type": resp.headers.get("Content-Type", "application/octet-stream"), + "file_length": int(resp.headers.get("Content-Length", -1)), + } + + +class RemoteFileUploadApi(WebApiResource): + @marshal_with(file_fields_with_signed_url) + def post(self, app_model, end_user): # Add app_model and end_user parameters + parser = reqparse.RequestParser() + parser.add_argument("url", type=str, required=True, help="URL is required") + args = parser.parse_args() + + url = args["url"] + + resp = ssrf_proxy.head(url=url) + if resp.status_code != httpx.codes.OK: + resp = ssrf_proxy.get(url=url, timeout=3) + resp.raise_for_status() + + file_info = helpers.guess_file_info_from_response(resp) + + if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size): + raise FileTooLargeError + + content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content + + try: + upload_file = FileService.upload_file( + filename=file_info.filename, + content=content, + mimetype=file_info.mimetype, + user=end_user, + source_url=url, + ) + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError + + return { + "id": upload_file.id, + "name": upload_file.name, + "size": upload_file.size, + "extension": upload_file.extension, + "url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id), + "mime_type": upload_file.mime_type, + "created_by": upload_file.created_by, + "created_at": upload_file.created_at, + }, 201 diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index e17869ffdbf8e8..b0492e6b6f0d31 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -10,67 +10,65 @@ from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService -feedback_fields = { - 'rating': fields.String -} +feedback_fields = {"rating": fields.String} message_fields = { - 'id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String, - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'created_at': TimestampField + "id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "answer": fields.String, + "message_files": fields.List(fields.Nested(message_file_fields)), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "created_at": TimestampField, } class SavedMessageListApi(WebApiResource): saved_message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } @marshal_with(saved_message_infinite_scroll_pagination_fields) def get(self, app_model, end_user): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - return SavedMessageService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit']) + return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"]) def post(self, app_model, end_user): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('message_id', type=uuid_value, required=True, location='json') + parser.add_argument("message_id", type=uuid_value, required=True, location="json") args = parser.parse_args() try: - SavedMessageService.save(app_model, end_user, args['message_id']) + SavedMessageService.save(app_model, end_user, args["message_id"]) except MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class SavedMessageApi(WebApiResource): def delete(self, app_model, end_user, message_id): message_id = str(message_id) - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() SavedMessageService.delete(app_model, end_user, message_id) - return {'result': 'success'} + return {"result": "success"} -api.add_resource(SavedMessageListApi, '/saved-messages') -api.add_resource(SavedMessageApi, '/saved-messages/') +api.add_resource(SavedMessageListApi, "/saved-messages") +api.add_resource(SavedMessageApi, "/saved-messages/") diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 99ec86e935e333..0564b15ea39855 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,4 +1,3 @@ - from flask_restful import fields, marshal_with from werkzeug.exceptions import Forbidden @@ -6,6 +5,7 @@ from controllers.web import api from controllers.web.wraps import WebApiResource from extensions.ext_database import db +from libs.helper import AppIconUrlField from models.account import TenantStatus from models.model import Site from services.feature_service import FeatureService @@ -15,39 +15,42 @@ class AppSiteApi(WebApiResource): """Resource for app sites.""" model_config_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw(attribute='suggested_questions_list'), - 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), - 'more_like_this': fields.Raw(attribute='more_like_this_dict'), - 'model': fields.Raw(attribute='model_dict'), - 'user_input_form': fields.Raw(attribute='user_input_form_list'), - 'pre_prompt': fields.String, + "opening_statement": fields.String, + "suggested_questions": fields.Raw(attribute="suggested_questions_list"), + "suggested_questions_after_answer": fields.Raw(attribute="suggested_questions_after_answer_dict"), + "more_like_this": fields.Raw(attribute="more_like_this_dict"), + "model": fields.Raw(attribute="model_dict"), + "user_input_form": fields.Raw(attribute="user_input_form_list"), + "pre_prompt": fields.String, } site_fields = { - 'title': fields.String, - 'chat_color_theme': fields.String, - 'chat_color_theme_inverted': fields.Boolean, - 'icon': fields.String, - 'icon_background': fields.String, - 'description': fields.String, - 'copyright': fields.String, - 'privacy_policy': fields.String, - 'custom_disclaimer': fields.String, - 'default_language': fields.String, - 'prompt_public': fields.Boolean, - 'show_workflow_steps': fields.Boolean, + "title": fields.String, + "chat_color_theme": fields.String, + "chat_color_theme_inverted": fields.Boolean, + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "description": fields.String, + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "default_language": fields.String, + "prompt_public": fields.Boolean, + "show_workflow_steps": fields.Boolean, + "use_icon_as_answer_icon": fields.Boolean, } app_fields = { - 'app_id': fields.String, - 'end_user_id': fields.String, - 'enable_site': fields.Boolean, - 'site': fields.Nested(site_fields), - 'model_config': fields.Nested(model_config_fields, allow_null=True), - 'plan': fields.String, - 'can_replace_logo': fields.Boolean, - 'custom_config': fields.Raw(attribute='custom_config'), + "app_id": fields.String, + "end_user_id": fields.String, + "enable_site": fields.Boolean, + "site": fields.Nested(site_fields), + "model_config": fields.Nested(model_config_fields, allow_null=True), + "plan": fields.String, + "can_replace_logo": fields.Boolean, + "custom_config": fields.Raw(attribute="custom_config"), } @marshal_with(app_fields) @@ -67,7 +70,7 @@ def get(self, app_model, end_user): return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo) -api.add_resource(AppSiteApi, '/site') +api.add_resource(AppSiteApi, "/site") class AppSiteInfo: @@ -85,9 +88,13 @@ def __init__(self, tenant, app, site, end_user, can_replace_logo): if can_replace_logo: base_url = dify_config.FILES_URL - remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False) - replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None + remove_webapp_brand = tenant.custom_config_dict.get("remove_webapp_brand", False) + replace_webapp_logo = ( + f"{base_url}/files/workspaces/{tenant.id}/webapp-logo" + if tenant.custom_config_dict.get("replace_webapp_logo") + else None + ) self.custom_config = { - 'remove_webapp_brand': remove_webapp_brand, - 'replace_webapp_logo': replace_webapp_logo, + "remove_webapp_brand": remove_webapp_brand, + "replace_webapp_logo": replace_webapp_logo, } diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 77c468e4179334..55b0c3e2ab34c5 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -33,17 +33,13 @@ def post(self, app_model: App, end_user: EndUser): raise NotWorkflowAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') - parser.add_argument('files', type=list, required=False, location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.WEB_APP, - streaming=True + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=True ) return helper.compact_generate_response(response) @@ -73,10 +69,8 @@ def post(self, app_model: App, end_user: EndUser, task_id: str): AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) - return { - "result": "success" - } + return {"result": "success"} -api.add_resource(WorkflowRunApi, '/workflows/run') -api.add_resource(WorkflowTaskStopApi, '/workflows/tasks//stop') +api.add_resource(WorkflowRunApi, "/workflows/run") +api.add_resource(WorkflowTaskStopApi, "/workflows/tasks//stop") diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index f5ab49d7e17290..c327c3df18526c 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -8,6 +8,7 @@ from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site +from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService @@ -18,7 +19,9 @@ def decorated(*args, **kwargs): app_model, end_user = decode_jwt_token() return view(app_model, end_user, *args, **kwargs) + return decorated + if view: return decorator(view) return decorator @@ -26,56 +29,63 @@ def decorated(*args, **kwargs): def decode_jwt_token(): system_features = FeatureService.get_system_features() - + app_code = request.headers.get("X-App-Code") try: - auth_header = request.headers.get('Authorization') + auth_header = request.headers.get("Authorization") if auth_header is None: - raise Unauthorized('Authorization header is missing.') + raise Unauthorized("Authorization header is missing.") - if ' ' not in auth_header: - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") auth_scheme, tk = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != 'bearer': - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") decoded = PassportService().verify(tk) - app_code = decoded.get('app_code') - app_model = db.session.query(App).filter(App.id == decoded['app_id']).first() + app_code = decoded.get("app_code") + app_model = db.session.query(App).filter(App.id == decoded["app_id"]).first() site = db.session.query(Site).filter(Site.code == app_code).first() if not app_model: raise NotFound() if not app_code or not site: - raise BadRequest('Site URL is no longer valid.') + raise BadRequest("Site URL is no longer valid.") if app_model.enable_site is False: - raise BadRequest('Site is disabled.') - end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first() + raise BadRequest("Site is disabled.") + end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first() if not end_user: raise NotFound() - _validate_web_sso_token(decoded, system_features) + _validate_web_sso_token(decoded, system_features, app_code) return app_model, end_user except Unauthorized as e: if system_features.sso_enforced_for_web: - raise WebSSOAuthRequiredError() + app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) + if app_web_sso_enabled: + raise WebSSOAuthRequiredError() raise Unauthorized(e.description) -def _validate_web_sso_token(decoded, system_features): +def _validate_web_sso_token(decoded, system_features, app_code): + app_web_sso_enabled = False + # Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login if system_features.sso_enforced_for_web: - source = decoded.get('token_source') - if not source or source != 'sso': - raise WebSSOAuthRequiredError() - - # Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login - if not system_features.sso_enforced_for_web: - source = decoded.get('token_source') - if source and source == 'sso': - raise Unauthorized('sso token expired.') + app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) + if app_web_sso_enabled: + source = decoded.get("token_source") + if not source or source != "sso": + raise WebSSOAuthRequiredError() + + # Check if SSO is not enforced for web, and if the token source is SSO, + # raise an error and redirect to normal passport login + if not system_features.sso_enforced_for_web or not app_web_sso_enabled: + source = decoded.get("token_source") + if source and source == "sso": + raise Unauthorized("sso token expired.") class WebApiResource(Resource): diff --git a/api/core/__init__.py b/api/core/__init__.py index 8c986fc8bd8afa..6eaea7b1c8419f 100644 --- a/api/core/__init__.py +++ b/api/core/__init__.py @@ -1 +1 @@ -import core.moderation.base \ No newline at end of file +import core.moderation.base diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index e9fa1f01618e01..860ec5de0c8ece 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -1,6 +1,7 @@ import json import logging import uuid +from collections.abc import Mapping, Sequence from datetime import datetime, timezone from typing import Optional, Union, cast @@ -15,22 +16,25 @@ ) from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.file.message_file_parser import MessageFileParser +from core.file import file_manager from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.entities.message_entities import ( +from core.model_runtime.entities import ( AssistantPromptMessage, + LLMUsage, PromptMessage, + PromptMessageContent, PromptMessageTool, SystemPromptMessage, TextPromptMessageContent, ToolPromptMessage, UserPromptMessage, ) +from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder +from core.prompt.utils.extract_thread_messages import extract_thread_messages from core.tools.entities.tool_entities import ( ToolParameter, ToolRuntimeVariablePool, @@ -38,46 +42,32 @@ from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tools.tool.tool import Tool from core.tools.tool_manager import ToolManager -from core.tools.utils.tool_parameter_converter import ToolParameterConverter from extensions.ext_database import db -from models.model import Conversation, Message, MessageAgentThought +from factories import file_factory +from models.model import Conversation, Message, MessageAgentThought, MessageFile from models.tools import ToolConversationVariables logger = logging.getLogger(__name__) + class BaseAgentRunner(AppRunner): - def __init__(self, tenant_id: str, - application_generate_entity: AgentChatAppGenerateEntity, - conversation: Conversation, - app_config: AgentChatAppConfig, - model_config: ModelConfigWithCredentialsEntity, - config: AgentEntity, - queue_manager: AppQueueManager, - message: Message, - user_id: str, - memory: Optional[TokenBufferMemory] = None, - prompt_messages: Optional[list[PromptMessage]] = None, - variables_pool: Optional[ToolRuntimeVariablePool] = None, - db_variables: Optional[ToolConversationVariables] = None, - model_instance: ModelInstance = None - ) -> None: - """ - Agent runner - :param tenant_id: tenant id - :param application_generate_entity: application generate entity - :param conversation: conversation - :param app_config: app generate entity - :param model_config: model config - :param config: dataset config - :param queue_manager: queue manager - :param message: message - :param user_id: user id - :param memory: memory - :param prompt_messages: prompt messages - :param variables_pool: variables pool - :param db_variables: db variables - :param model_instance: model instance - """ + def __init__( + self, + tenant_id: str, + application_generate_entity: AgentChatAppGenerateEntity, + conversation: Conversation, + app_config: AgentChatAppConfig, + model_config: ModelConfigWithCredentialsEntity, + config: AgentEntity, + queue_manager: AppQueueManager, + message: Message, + user_id: str, + memory: Optional[TokenBufferMemory] = None, + prompt_messages: Optional[list[PromptMessage]] = None, + variables_pool: Optional[ToolRuntimeVariablePool] = None, + db_variables: Optional[ToolConversationVariables] = None, + model_instance: ModelInstance | None = None, + ) -> None: self.tenant_id = tenant_id self.application_generate_entity = application_generate_entity self.conversation = conversation @@ -88,9 +78,7 @@ def __init__(self, tenant_id: str, self.message = message self.user_id = user_id self.memory = memory - self.history_prompt_messages = self.organize_agent_history( - prompt_messages=prompt_messages or [] - ) + self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or []) self.variables_pool = variables_pool self.db_variables_pool = db_variables self.model_instance = model_instance @@ -111,12 +99,16 @@ def __init__(self, tenant_id: str, retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None, return_resource=app_config.additional_features.show_retrieve_source, invoke_from=application_generate_entity.invoke_from, - hit_callback=hit_callback + hit_callback=hit_callback, ) # get how many agent thoughts have been created - self.agent_thought_count = db.session.query(MessageAgentThought).filter( - MessageAgentThought.message_id == self.message.id, - ).count() + self.agent_thought_count = ( + db.session.query(MessageAgentThought) + .filter( + MessageAgentThought.message_id == self.message.id, + ) + .count() + ) db.session.close() # check if model supports stream tool call @@ -135,25 +127,26 @@ def __init__(self, tenant_id: str, self.query = None self._current_thoughts: list[PromptMessage] = [] - def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \ - -> AgentChatAppGenerateEntity: + def _repack_app_generate_entity( + self, app_generate_entity: AgentChatAppGenerateEntity + ) -> AgentChatAppGenerateEntity: """ Repack app generate entity """ if app_generate_entity.app_config.prompt_template.simple_prompt_template is None: - app_generate_entity.app_config.prompt_template.simple_prompt_template = '' + app_generate_entity.app_config.prompt_template.simple_prompt_template = "" return app_generate_entity - + def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]: """ - convert tool to prompt message tool + convert tool to prompt message tool """ tool_entity = ToolManager.get_agent_tool_runtime( tenant_id=self.tenant_id, app_id=self.app_config.app_id, agent_tool=tool, - invoke_from=self.application_generate_entity.invoke_from + invoke_from=self.application_generate_entity.invoke_from, ) tool_entity.load_variables(self.variables_pool) @@ -164,7 +157,7 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P "type": "object", "properties": {}, "required": [], - } + }, ) parameters = tool_entity.get_all_runtime_parameters() @@ -172,24 +165,30 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P if parameter.form != ToolParameter.ToolParameterForm.LLM: continue - parameter_type = ToolParameterConverter.get_parameter_type(parameter.type) + parameter_type = parameter.type.as_normal_type() + if parameter.type in { + ToolParameter.ToolParameterType.SYSTEM_FILES, + ToolParameter.ToolParameterType.FILE, + ToolParameter.ToolParameterType.FILES, + }: + continue enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: enum = [option.value for option in parameter.options] - message_tool.parameters['properties'][parameter.name] = { + message_tool.parameters["properties"][parameter.name] = { "type": parameter_type, - "description": parameter.llm_description or '', + "description": parameter.llm_description or "", } if len(enum) > 0: - message_tool.parameters['properties'][parameter.name]['enum'] = enum + message_tool.parameters["properties"][parameter.name]["enum"] = enum if parameter.required: - message_tool.parameters['required'].append(parameter.name) + message_tool.parameters["required"].append(parameter.name) return message_tool, tool_entity - + def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool: """ convert dataset retriever tool to prompt message tool @@ -201,24 +200,24 @@ def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRe "type": "object", "properties": {}, "required": [], - } + }, ) for parameter in tool.get_runtime_parameters(): - parameter_type = 'string' - - prompt_tool.parameters['properties'][parameter.name] = { + parameter_type = "string" + + prompt_tool.parameters["properties"][parameter.name] = { "type": parameter_type, - "description": parameter.llm_description or '', + "description": parameter.llm_description or "", } if parameter.required: - if parameter.name not in prompt_tool.parameters['required']: - prompt_tool.parameters['required'].append(parameter.name) + if parameter.name not in prompt_tool.parameters["required"]: + prompt_tool.parameters["required"].append(parameter.name) return prompt_tool - - def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]: + + def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]: """ Init tools """ @@ -257,55 +256,61 @@ def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) if parameter.form != ToolParameter.ToolParameterForm.LLM: continue - parameter_type = ToolParameterConverter.get_parameter_type(parameter.type) + parameter_type = parameter.type.as_normal_type() + if parameter.type in { + ToolParameter.ToolParameterType.SYSTEM_FILES, + ToolParameter.ToolParameterType.FILE, + ToolParameter.ToolParameterType.FILES, + }: + continue enum = [] if parameter.type == ToolParameter.ToolParameterType.SELECT: enum = [option.value for option in parameter.options] - - prompt_tool.parameters['properties'][parameter.name] = { + + prompt_tool.parameters["properties"][parameter.name] = { "type": parameter_type, - "description": parameter.llm_description or '', + "description": parameter.llm_description or "", } if len(enum) > 0: - prompt_tool.parameters['properties'][parameter.name]['enum'] = enum + prompt_tool.parameters["properties"][parameter.name]["enum"] = enum if parameter.required: - if parameter.name not in prompt_tool.parameters['required']: - prompt_tool.parameters['required'].append(parameter.name) + if parameter.name not in prompt_tool.parameters["required"]: + prompt_tool.parameters["required"].append(parameter.name) return prompt_tool - - def create_agent_thought(self, message_id: str, message: str, - tool_name: str, tool_input: str, messages_ids: list[str] - ) -> MessageAgentThought: + + def create_agent_thought( + self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str] + ) -> MessageAgentThought: """ Create agent thought """ thought = MessageAgentThought( message_id=message_id, message_chain_id=None, - thought='', + thought="", tool=tool_name, - tool_labels_str='{}', - tool_meta_str='{}', + tool_labels_str="{}", + tool_meta_str="{}", tool_input=tool_input, message=message, message_token=0, message_unit_price=0, message_price_unit=0, - message_files=json.dumps(messages_ids) if messages_ids else '', - answer='', - observation='', + message_files=json.dumps(messages_ids) if messages_ids else "", + answer="", + observation="", answer_token=0, answer_unit_price=0, answer_price_unit=0, tokens=0, total_price=0, position=self.agent_thought_count + 1, - currency='USD', + currency="USD", latency=0, - created_by_role='account', + created_by_role="account", created_by=self.user_id, ) @@ -318,22 +323,22 @@ def create_agent_thought(self, message_id: str, message: str, return thought - def save_agent_thought(self, - agent_thought: MessageAgentThought, - tool_name: str, - tool_input: Union[str, dict], - thought: str, - observation: Union[str, dict], - tool_invoke_meta: Union[str, dict], - answer: str, - messages_ids: list[str], - llm_usage: LLMUsage = None) -> MessageAgentThought: + def save_agent_thought( + self, + agent_thought: MessageAgentThought, + tool_name: str, + tool_input: Union[str, dict], + thought: str, + observation: Union[str, dict], + tool_invoke_meta: Union[str, dict], + answer: str, + messages_ids: list[str], + llm_usage: LLMUsage = None, + ) -> MessageAgentThought: """ Save agent thought """ - agent_thought = db.session.query(MessageAgentThought).filter( - MessageAgentThought.id == agent_thought.id - ).first() + agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() if thought is not None: agent_thought.thought = thought @@ -356,7 +361,7 @@ def save_agent_thought(self, observation = json.dumps(observation, ensure_ascii=False) except Exception as e: observation = json.dumps(observation) - + agent_thought.observation = observation if answer is not None: @@ -364,7 +369,7 @@ def save_agent_thought(self, if messages_ids is not None and len(messages_ids) > 0: agent_thought.message_files = json.dumps(messages_ids) - + if llm_usage: agent_thought.message_token = llm_usage.prompt_tokens agent_thought.message_price_unit = llm_usage.prompt_price_unit @@ -377,7 +382,7 @@ def save_agent_thought(self, # check if tool labels is not empty labels = agent_thought.tool_labels or {} - tools = agent_thought.tool.split(';') if agent_thought.tool else [] + tools = agent_thought.tool.split(";") if agent_thought.tool else [] for tool in tools: if not tool: continue @@ -386,7 +391,7 @@ def save_agent_thought(self, if tool_label: labels[tool] = tool_label.to_dict() else: - labels[tool] = {'en_US': tool, 'zh_Hans': tool} + labels[tool] = {"en_US": tool, "zh_Hans": tool} agent_thought.tool_labels_str = json.dumps(labels) @@ -401,14 +406,18 @@ def save_agent_thought(self, db.session.commit() db.session.close() - + def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables): """ convert tool variables to db variables """ - db_variables = db.session.query(ToolConversationVariables).filter( - ToolConversationVariables.conversation_id == self.message.conversation_id, - ).first() + db_variables = ( + db.session.query(ToolConversationVariables) + .filter( + ToolConversationVariables.conversation_id == self.message.conversation_id, + ) + .first() + ) db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) @@ -425,9 +434,16 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P if isinstance(prompt_message, SystemPromptMessage): result.append(prompt_message) - messages: list[Message] = db.session.query(Message).filter( - Message.conversation_id == self.message.conversation_id, - ).order_by(Message.created_at.asc()).all() + messages: list[Message] = ( + db.session.query(Message) + .filter( + Message.conversation_id == self.message.conversation_id, + ) + .order_by(Message.created_at.desc()) + .all() + ) + + messages = list(reversed(extract_thread_messages(messages))) for message in messages: if message.id == self.message.id: @@ -439,42 +455,48 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P for agent_thought in agent_thoughts: tools = agent_thought.tool if tools: - tools = tools.split(';') + tools = tools.split(";") tool_calls: list[AssistantPromptMessage.ToolCall] = [] tool_call_response: list[ToolPromptMessage] = [] try: tool_inputs = json.loads(agent_thought.tool_input) except Exception as e: - tool_inputs = { tool: {} for tool in tools } + tool_inputs = {tool: {} for tool in tools} try: tool_responses = json.loads(agent_thought.observation) except Exception as e: - tool_responses = { tool: agent_thought.observation for tool in tools } + tool_responses = dict.fromkeys(tools, agent_thought.observation) for tool in tools: # generate a uuid for tool call tool_call_id = str(uuid.uuid4()) - tool_calls.append(AssistantPromptMessage.ToolCall( - id=tool_call_id, - type='function', - function=AssistantPromptMessage.ToolCall.ToolCallFunction( + tool_calls.append( + AssistantPromptMessage.ToolCall( + id=tool_call_id, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=tool, + arguments=json.dumps(tool_inputs.get(tool, {})), + ), + ) + ) + tool_call_response.append( + ToolPromptMessage( + content=tool_responses.get(tool, agent_thought.observation), name=tool, - arguments=json.dumps(tool_inputs.get(tool, {})), + tool_call_id=tool_call_id, ) - )) - tool_call_response.append(ToolPromptMessage( - content=tool_responses.get(tool, agent_thought.observation), - name=tool, - tool_call_id=tool_call_id, - )) - - result.extend([ - AssistantPromptMessage( - content=agent_thought.thought, - tool_calls=tool_calls, - ), - *tool_call_response - ]) + ) + + result.extend( + [ + AssistantPromptMessage( + content=agent_thought.thought, + tool_calls=tool_calls, + ), + *tool_call_response, + ] + ) if not tools: result.append(AssistantPromptMessage(content=agent_thought.thought)) else: @@ -486,30 +508,28 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P return result def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: - message_file_parser = MessageFileParser( - tenant_id=self.tenant_id, - app_id=self.app_config.app_id, - ) - - files = message.message_files - if files: - file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) + files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() + if not files: + return UserPromptMessage(content=message.query) + file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) + if not file_extra_config: + return UserPromptMessage(content=message.query) - if file_extra_config: - file_objs = message_file_parser.transform_message_files( - files, - file_extra_config - ) - else: - file_objs = [] + image_detail_config = file_extra_config.image_config.detail if file_extra_config.image_config else None + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - if not file_objs: - return UserPromptMessage(content=message.query) - else: - prompt_message_contents = [TextPromptMessageContent(data=message.query)] - for file_obj in file_objs: - prompt_message_contents.append(file_obj.prompt_message_content) - - return UserPromptMessage(content=prompt_message_contents) - else: + file_objs = file_factory.build_from_message_files( + message_files=files, tenant_id=self.tenant_id, config=file_extra_config + ) + if not file_objs: return UserPromptMessage(content=message.query) + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=message.query)) + for file in file_objs: + prompt_message_contents.append( + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config, + ) + ) + return UserPromptMessage(content=prompt_message_contents) diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 89c948d2e29f5f..d98ba5a3fad846 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -25,17 +25,19 @@ class CotAgentRunner(BaseAgentRunner, ABC): _is_first_iteration = True - _ignore_observation_providers = ['wenxin'] + _ignore_observation_providers = ["wenxin"] _historic_prompt_messages: list[PromptMessage] = None _agent_scratchpad: list[AgentScratchpadUnit] = None _instruction: str = None _query: str = None _prompt_messages_tools: list[PromptMessage] = None - def run(self, message: Message, - query: str, - inputs: dict[str, str], - ) -> Union[Generator, LLMResult]: + def run( + self, + message: Message, + query: str, + inputs: dict[str, str], + ) -> Union[Generator, LLMResult]: """ Run Cot agent application """ @@ -46,17 +48,16 @@ def run(self, message: Message, trace_manager = app_generate_entity.trace_manager # check model mode - if 'Observation' not in app_generate_entity.model_conf.stop: + if "Observation" not in app_generate_entity.model_conf.stop: if app_generate_entity.model_conf.provider not in self._ignore_observation_providers: - app_generate_entity.model_conf.stop.append('Observation') + app_generate_entity.model_conf.stop.append("Observation") app_config = self.app_config # init instruction inputs = inputs or {} instruction = app_config.prompt_template.simple_prompt_template - self._instruction = self._fill_in_inputs_from_external_data_tools( - instruction, inputs) + self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) iteration_step = 1 max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 @@ -65,16 +66,14 @@ def run(self, message: Message, tool_instances, self._prompt_messages_tools = self._init_prompt_tools() function_call_state = True - llm_usage = { - 'usage': None - } - final_answer = '' + llm_usage = {"usage": None} + final_answer = "" def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): - if not final_llm_usage_dict['usage']: - final_llm_usage_dict['usage'] = usage + if not final_llm_usage_dict["usage"]: + final_llm_usage_dict["usage"] = usage else: - llm_usage = final_llm_usage_dict['usage'] + llm_usage = final_llm_usage_dict["usage"] llm_usage.prompt_tokens += usage.prompt_tokens llm_usage.completion_tokens += usage.completion_tokens llm_usage.prompt_price += usage.prompt_price @@ -94,17 +93,13 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): message_file_ids = [] agent_thought = self.create_agent_thought( - message_id=message.id, - message='', - tool_name='', - tool_input='', - messages_ids=message_file_ids + message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids ) if iteration_step > 1: - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) # recalc llm max tokens prompt_messages = self._organize_prompt_messages() @@ -125,21 +120,20 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): raise ValueError("failed to invoke llm") usage_dict = {} - react_chunks = CotAgentOutputParser.handle_react_stream_output( - chunks, usage_dict) + react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) scratchpad = AgentScratchpadUnit( - agent_response='', - thought='', - action_str='', - observation='', + agent_response="", + thought="", + action_str="", + observation="", action=None, ) # publish agent thought if it's first iteration if iteration_step == 1: - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) for chunk in react_chunks: if isinstance(chunk, AgentScratchpadUnit.Action): @@ -154,61 +148,51 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): yield LLMResultChunk( model=self.model_config.model, prompt_messages=prompt_messages, - system_fingerprint='', - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage( - content=chunk - ), - usage=None - ) + system_fingerprint="", + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None), ) - scratchpad.thought = scratchpad.thought.strip( - ) or 'I am thinking about how to help you' + scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" self._agent_scratchpad.append(scratchpad) # get llm usage - if 'usage' in usage_dict: - increase_usage(llm_usage, usage_dict['usage']) + if "usage" in usage_dict: + increase_usage(llm_usage, usage_dict["usage"]) else: - usage_dict['usage'] = LLMUsage.empty_usage() + usage_dict["usage"] = LLMUsage.empty_usage() self.save_agent_thought( agent_thought=agent_thought, - tool_name=scratchpad.action.action_name if scratchpad.action else '', - tool_input={ - scratchpad.action.action_name: scratchpad.action.action_input - } if scratchpad.action else {}, + tool_name=scratchpad.action.action_name if scratchpad.action else "", + tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {}, tool_invoke_meta={}, thought=scratchpad.thought, - observation='', + observation="", answer=scratchpad.agent_response, messages_ids=[], - llm_usage=usage_dict['usage'] + llm_usage=usage_dict["usage"], ) if not scratchpad.is_final(): - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) if not scratchpad.action: # failed to extract action, return final answer directly - final_answer = '' + final_answer = "" else: if scratchpad.action.action_name.lower() == "final answer": # action is final answer, return final answer directly try: if isinstance(scratchpad.action.action_input, dict): - final_answer = json.dumps( - scratchpad.action.action_input) + final_answer = json.dumps(scratchpad.action.action_input) elif isinstance(scratchpad.action.action_input, str): final_answer = scratchpad.action.action_input else: - final_answer = f'{scratchpad.action.action_input}' + final_answer = f"{scratchpad.action.action_input}" except json.JSONDecodeError: - final_answer = f'{scratchpad.action.action_input}' + final_answer = f"{scratchpad.action.action_input}" else: function_call_state = True # action is tool call, invoke tool @@ -224,21 +208,18 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): self.save_agent_thought( agent_thought=agent_thought, tool_name=scratchpad.action.action_name, - tool_input={ - scratchpad.action.action_name: scratchpad.action.action_input}, + tool_input={scratchpad.action.action_name: scratchpad.action.action_input}, thought=scratchpad.thought, - observation={ - scratchpad.action.action_name: tool_invoke_response}, - tool_invoke_meta={ - scratchpad.action.action_name: tool_invoke_meta.to_dict()}, + observation={scratchpad.action.action_name: tool_invoke_response}, + tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()}, answer=scratchpad.agent_response, messages_ids=message_file_ids, - llm_usage=usage_dict['usage'] + llm_usage=usage_dict["usage"], ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) # update prompt tool message for prompt_tool in self._prompt_messages_tools: @@ -250,44 +231,45 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): model=model_instance.model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage( - content=final_answer - ), - usage=llm_usage['usage'] + index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"] ), - system_fingerprint='' + system_fingerprint="", ) # save agent thought self.save_agent_thought( agent_thought=agent_thought, - tool_name='', + tool_name="", tool_input={}, tool_invoke_meta={}, thought=final_answer, observation={}, answer=final_answer, - messages_ids=[] + messages_ids=[], ) self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event - self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( - model=model_instance.model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=final_answer + self.queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=model_instance.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=final_answer), + usage=llm_usage["usage"] or LLMUsage.empty_usage(), + system_fingerprint="", + ) ), - usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(), - system_fingerprint='' - )), PublishFrom.APPLICATION_MANAGER) - - def _handle_invoke_action(self, action: AgentScratchpadUnit.Action, - tool_instances: dict[str, Tool], - message_file_ids: list[str], - trace_manager: Optional[TraceQueueManager] = None - ) -> tuple[str, ToolInvokeMeta]: + PublishFrom.APPLICATION_MANAGER, + ) + + def _handle_invoke_action( + self, + action: AgentScratchpadUnit.Action, + tool_instances: dict[str, Tool], + message_file_ids: list[str], + trace_manager: Optional[TraceQueueManager] = None, + ) -> tuple[str, ToolInvokeMeta]: """ handle invoke action :param action: action @@ -326,13 +308,12 @@ def _handle_invoke_action(self, action: AgentScratchpadUnit.Action, # publish files for message_file_id, save_as in message_files: if save_as: - self.variables_pool.set_file( - tool_name=tool_call_name, value=message_file_id, name=save_as) + self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) # publish message file - self.queue_manager.publish(QueueMessageFileEvent( - message_file_id=message_file_id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER + ) # add message file ids message_file_ids.append(message_file_id) @@ -342,10 +323,7 @@ def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action: """ convert dict to action """ - return AgentScratchpadUnit.Action( - action_name=action['action'], - action_input=action['action_input'] - ) + return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"]) def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str: """ @@ -353,7 +331,7 @@ def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dic """ for key, value in inputs.items(): try: - instruction = instruction.replace(f'{{{{{key}}}}}', str(value)) + instruction = instruction.replace(f"{{{{{key}}}}}", str(value)) except Exception as e: continue @@ -370,14 +348,14 @@ def _init_react_state(self, query) -> None: @abstractmethod def _organize_prompt_messages(self) -> list[PromptMessage]: """ - organize prompt messages + organize prompt messages """ def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str: """ - format assistant message + format assistant message """ - message = '' + message = "" for scratchpad in agent_scratchpad: if scratchpad.is_final(): message += f"Final Answer: {scratchpad.agent_response}" @@ -390,9 +368,11 @@ def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) return message - def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]: + def _organize_historic_prompt_messages( + self, current_session_messages: Optional[list[PromptMessage]] = None + ) -> list[PromptMessage]: """ - organize historic prompt messages + organize historic prompt messages """ result: list[PromptMessage] = [] scratchpads: list[AgentScratchpadUnit] = [] @@ -403,8 +383,8 @@ def _organize_historic_prompt_messages(self, current_session_messages: list[Prom if not current_scratchpad: current_scratchpad = AgentScratchpadUnit( agent_response=message.content, - thought=message.content or 'I am thinking about how to help you', - action_str='', + thought=message.content or "I am thinking about how to help you", + action_str="", action=None, observation=None, ) @@ -413,12 +393,9 @@ def _organize_historic_prompt_messages(self, current_session_messages: list[Prom try: current_scratchpad.action = AgentScratchpadUnit.Action( action_name=message.tool_calls[0].function.name, - action_input=json.loads( - message.tool_calls[0].function.arguments) - ) - current_scratchpad.action_str = json.dumps( - current_scratchpad.action.to_dict() + action_input=json.loads(message.tool_calls[0].function.arguments), ) + current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict()) except: pass elif isinstance(message, ToolPromptMessage): @@ -426,23 +403,19 @@ def _organize_historic_prompt_messages(self, current_session_messages: list[Prom current_scratchpad.observation = message.content elif isinstance(message, UserPromptMessage): if scratchpads: - result.append(AssistantPromptMessage( - content=self._format_assistant_message(scratchpads) - )) + result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) scratchpads = [] current_scratchpad = None result.append(message) if scratchpads: - result.append(AssistantPromptMessage( - content=self._format_assistant_message(scratchpads) - )) + result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) historic_prompts = AgentHistoryPromptTransform( model_config=self.model_config, prompt_messages=current_session_messages or [], history_messages=result, - memory=self.memory + memory=self.memory, ).get_prompt() return historic_prompts diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index 8debbe5c5ddf0f..d8d047fe91cdbd 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -1,13 +1,16 @@ import json from core.agent.cot_agent_runner import CotAgentRunner -from core.model_runtime.entities.message_entities import ( +from core.file import file_manager +from core.model_runtime.entities import ( AssistantPromptMessage, PromptMessage, + PromptMessageContent, SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, ) +from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.model_runtime.utils.encoders import jsonable_encoder @@ -19,21 +22,39 @@ def _organize_system_prompt(self) -> SystemPromptMessage: prompt_entity = self.app_config.agent.prompt first_prompt = prompt_entity.first_prompt - system_prompt = first_prompt \ - .replace("{{instruction}}", self._instruction) \ - .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \ - .replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools])) + system_prompt = ( + first_prompt.replace("{{instruction}}", self._instruction) + .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) + .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools])) + ) return SystemPromptMessage(content=system_prompt) - def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: + def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ Organize user query """ if self.files: - prompt_message_contents = [TextPromptMessageContent(data=query)] - for file_obj in self.files: - prompt_message_contents.append(file_obj.prompt_message_content) + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=query)) + + # get image detail config + image_detail_config = ( + self.application_generate_entity.file_upload_config.image_config.detail + if ( + self.application_generate_entity.file_upload_config + and self.application_generate_entity.file_upload_config.image_config + ) + else None + ) + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + for file in self.files: + prompt_message_contents.append( + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config, + ) + ) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: @@ -43,7 +64,7 @@ def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = No def _organize_prompt_messages(self) -> list[PromptMessage]: """ - Organize + Organize """ # organize system prompt system_message = self._organize_system_prompt() @@ -53,7 +74,7 @@ def _organize_prompt_messages(self) -> list[PromptMessage]: if not agent_scratchpad: assistant_messages = [] else: - assistant_message = AssistantPromptMessage(content='') + assistant_message = AssistantPromptMessage(content="") for unit in agent_scratchpad: if unit.is_final(): assistant_message.content += f"Final Answer: {unit.agent_response}" @@ -71,18 +92,15 @@ def _organize_prompt_messages(self) -> list[PromptMessage]: if assistant_messages: # organize historic prompt messages - historic_messages = self._organize_historic_prompt_messages([ - system_message, - *query_messages, - *assistant_messages, - UserPromptMessage(content='continue') - ]) + historic_messages = self._organize_historic_prompt_messages( + [system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")] + ) messages = [ system_message, *historic_messages, *query_messages, *assistant_messages, - UserPromptMessage(content='continue') + UserPromptMessage(content="continue"), ] else: # organize historic prompt messages diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 9e6eb54f4fe513..0563090537e62c 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -1,4 +1,5 @@ import json +from typing import Optional from core.agent.cot_agent_runner import CotAgentRunner from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage @@ -13,13 +14,15 @@ def _organize_instruction_prompt(self) -> str: prompt_entity = self.app_config.agent.prompt first_prompt = prompt_entity.first_prompt - system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \ - .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \ - .replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools])) - + system_prompt = ( + first_prompt.replace("{{instruction}}", self._instruction) + .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) + .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools])) + ) + return system_prompt - def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str: + def _organize_historic_prompt(self, current_session_messages: Optional[list[PromptMessage]] = None) -> str: """ Organize historic prompt """ @@ -46,7 +49,7 @@ def _organize_prompt_messages(self) -> list[PromptMessage]: # organize current assistant messages agent_scratchpad = self._agent_scratchpad - assistant_prompt = '' + assistant_prompt = "" for unit in agent_scratchpad: if unit.is_final(): assistant_prompt += f"Final Answer: {unit.agent_response}" @@ -61,9 +64,10 @@ def _organize_prompt_messages(self) -> list[PromptMessage]: query_prompt = f"Question: {self._query}" # join all messages - prompt = system_prompt \ - .replace("{{historic_messages}}", historic_prompt) \ - .replace("{{agent_scratchpad}}", assistant_prompt) \ + prompt = ( + system_prompt.replace("{{historic_messages}}", historic_prompt) + .replace("{{agent_scratchpad}}", assistant_prompt) .replace("{{query}}", query_prompt) + ) - return [UserPromptMessage(content=prompt)] \ No newline at end of file + return [UserPromptMessage(content=prompt)] diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 5274224de5772c..119a88fc7becbf 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -8,6 +8,7 @@ class AgentToolEntity(BaseModel): """ Agent Tool Entity. """ + provider_type: Literal["builtin", "api", "workflow"] provider_id: str tool_name: str @@ -18,6 +19,7 @@ class AgentPromptEntity(BaseModel): """ Agent Prompt Entity. """ + first_prompt: str next_iteration: str @@ -31,6 +33,7 @@ class Action(BaseModel): """ Action Entity. """ + action_name: str action_input: Union[dict, str] @@ -39,8 +42,8 @@ def to_dict(self) -> dict: Convert to dictionary. """ return { - 'action': self.action_name, - 'action_input': self.action_input, + "action": self.action_name, + "action_input": self.action_input, } agent_response: Optional[str] = None @@ -54,10 +57,10 @@ def is_final(self) -> bool: Check if the scratchpad unit is final. """ return self.action is None or ( - 'final' in self.action.action_name.lower() and - 'answer' in self.action.action_name.lower() + "final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower() ) + class AgentEntity(BaseModel): """ Agent Entity. @@ -67,8 +70,9 @@ class Strategy(Enum): """ Agent Strategy. """ - CHAIN_OF_THOUGHT = 'chain-of-thought' - FUNCTION_CALLING = 'function-calling' + + CHAIN_OF_THOUGHT = "chain-of-thought" + FUNCTION_CALLING = "function-calling" provider: str model: str diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 3ee6e47742a18f..cd546dee124147 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -2,21 +2,27 @@ import logging from collections.abc import Generator from copy import deepcopy -from typing import Any, Union +from typing import Any, Optional, Union from core.agent.base_agent_runner import BaseAgentRunner from core.app.apps.base_app_queue_manager import PublishFrom from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import ( +from core.file import file_manager +from core.model_runtime.entities import ( AssistantPromptMessage, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMUsage, PromptMessage, + PromptMessageContent, PromptMessageContentType, SystemPromptMessage, TextPromptMessageContent, ToolPromptMessage, UserPromptMessage, ) +from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine @@ -24,11 +30,9 @@ logger = logging.getLogger(__name__) -class FunctionCallAgentRunner(BaseAgentRunner): - def run(self, - message: Message, query: str, **kwargs: Any - ) -> Generator[LLMResultChunk, None, None]: +class FunctionCallAgentRunner(BaseAgentRunner): + def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]: """ Run FunctionCall agent application """ @@ -45,19 +49,17 @@ def run(self, # continue to run until there is not any tool call function_call_state = True - llm_usage = { - 'usage': None - } - final_answer = '' + llm_usage = {"usage": None} + final_answer = "" # get tracing instance trace_manager = app_generate_entity.trace_manager - + def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): - if not final_llm_usage_dict['usage']: - final_llm_usage_dict['usage'] = usage + if not final_llm_usage_dict["usage"]: + final_llm_usage_dict["usage"] = usage else: - llm_usage = final_llm_usage_dict['usage'] + llm_usage = final_llm_usage_dict["usage"] llm_usage.prompt_tokens += usage.prompt_tokens llm_usage.completion_tokens += usage.completion_tokens llm_usage.prompt_price += usage.prompt_price @@ -75,11 +77,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): message_file_ids = [] agent_thought = self.create_agent_thought( - message_id=message.id, - message='', - tool_name='', - tool_input='', - messages_ids=message_file_ids + message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids ) # recalc llm max tokens @@ -99,11 +97,11 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): tool_calls: list[tuple[str, str, dict[str, Any]]] = [] # save full response - response = '' + response = "" # save tool call names and inputs - tool_call_names = '' - tool_call_inputs = '' + tool_call_names = "" + tool_call_inputs = "" current_llm_usage = None @@ -111,24 +109,22 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): is_first_chunk = True for chunk in chunks: if is_first_chunk: - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) is_first_chunk = False # check if there is any tool call if self.check_tool_calls(chunk): function_call_state = True tool_calls.extend(self.extract_tool_calls(chunk)) - tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) + tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) try: - tool_call_inputs = json.dumps({ - tool_call[1]: tool_call[2] for tool_call in tool_calls - }, ensure_ascii=False) + tool_call_inputs = json.dumps( + {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False + ) except json.JSONDecodeError as e: # ensure ascii to avoid encoding error - tool_call_inputs = json.dumps({ - tool_call[1]: tool_call[2] for tool_call in tool_calls - }) + tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) if chunk.delta.message and chunk.delta.message.content: if isinstance(chunk.delta.message.content, list): @@ -148,16 +144,14 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): if self.check_blocking_tool_calls(result): function_call_state = True tool_calls.extend(self.extract_blocking_tool_calls(result)) - tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) + tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) try: - tool_call_inputs = json.dumps({ - tool_call[1]: tool_call[2] for tool_call in tool_calls - }, ensure_ascii=False) + tool_call_inputs = json.dumps( + {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False + ) except json.JSONDecodeError as e: # ensure ascii to avoid encoding error - tool_call_inputs = json.dumps({ - tool_call[1]: tool_call[2] for tool_call in tool_calls - }) + tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) if result.usage: increase_usage(llm_usage, result.usage) @@ -171,12 +165,12 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): response += result.message.content if not result.message.content: - result.message.content = '' + result.message.content = "" + + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) - yield LLMResultChunk( model=model_instance.model, prompt_messages=result.prompt_messages, @@ -185,32 +179,29 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): index=0, message=result.message, usage=result.usage, - ) + ), ) - assistant_message = AssistantPromptMessage( - content='', - tool_calls=[] - ) + assistant_message = AssistantPromptMessage(content="", tool_calls=[]) if tool_calls: - assistant_message.tool_calls=[ + assistant_message.tool_calls = [ AssistantPromptMessage.ToolCall( id=tool_call[0], - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool_call[1], - arguments=json.dumps(tool_call[2], ensure_ascii=False) - ) - ) for tool_call in tool_calls + name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False) + ), + ) + for tool_call in tool_calls ] else: assistant_message.content = response - + self._current_thoughts.append(assistant_message) # save thought self.save_agent_thought( - agent_thought=agent_thought, + agent_thought=agent_thought, tool_name=tool_call_names, tool_input=tool_call_inputs, thought=response, @@ -218,13 +209,13 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): observation=None, answer=response, messages_ids=[], - llm_usage=current_llm_usage + llm_usage=current_llm_usage, ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) - - final_answer += response + '\n' + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER + ) + + final_answer += response + "\n" # call tools tool_responses = [] @@ -235,7 +226,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): "tool_call_id": tool_call_id, "tool_call_name": tool_call_name, "tool_response": f"there is not a tool named {tool_call_name}", - "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict() + "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(), } else: # invoke tool @@ -255,50 +246,49 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) # publish message file - self.queue_manager.publish(QueueMessageFileEvent( - message_file_id=message_file_id - ), PublishFrom.APPLICATION_MANAGER) + self.queue_manager.publish( + QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER + ) # add message file ids message_file_ids.append(message_file_id) - + tool_response = { "tool_call_id": tool_call_id, "tool_call_name": tool_call_name, "tool_response": tool_invoke_response, - "meta": tool_invoke_meta.to_dict() + "meta": tool_invoke_meta.to_dict(), } - + tool_responses.append(tool_response) - if tool_response['tool_response'] is not None: + if tool_response["tool_response"] is not None: self._current_thoughts.append( ToolPromptMessage( - content=tool_response['tool_response'], + content=tool_response["tool_response"], tool_call_id=tool_call_id, name=tool_call_name, ) - ) + ) if len(tool_responses) > 0: # save agent thought self.save_agent_thought( - agent_thought=agent_thought, + agent_thought=agent_thought, tool_name=None, tool_input=None, - thought=None, + thought=None, tool_invoke_meta={ - tool_response['tool_call_name']: tool_response['meta'] - for tool_response in tool_responses + tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses }, observation={ - tool_response['tool_call_name']: tool_response['tool_response'] + tool_response["tool_call_name"]: tool_response["tool_response"] for tool_response in tool_responses }, answer=None, - messages_ids=message_file_ids + messages_ids=message_file_ids, + ) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER ) - self.queue_manager.publish(QueueAgentThoughtEvent( - agent_thought_id=agent_thought.id - ), PublishFrom.APPLICATION_MANAGER) # update prompt tool for prompt_tool in prompt_messages_tools: @@ -308,15 +298,18 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event - self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult( - model=model_instance.model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=final_answer + self.queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=model_instance.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=final_answer), + usage=llm_usage["usage"] or LLMUsage.empty_usage(), + system_fingerprint="", + ) ), - usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(), - system_fingerprint='' - )), PublishFrom.APPLICATION_MANAGER) + PublishFrom.APPLICATION_MANAGER, + ) def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool: """ @@ -325,7 +318,7 @@ def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool: if llm_result_chunk.delta.message.tool_calls: return True return False - + def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool: """ Check if there is any blocking tool call in llm result @@ -334,7 +327,9 @@ def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool: return True return False - def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: + def extract_tool_calls( + self, llm_result_chunk: LLMResultChunk + ) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: """ Extract tool calls from llm result chunk @@ -344,17 +339,19 @@ def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, li tool_calls = [] for prompt_message in llm_result_chunk.delta.message.tool_calls: args = {} - if prompt_message.function.arguments != '': + if prompt_message.function.arguments != "": args = json.loads(prompt_message.function.arguments) - tool_calls.append(( - prompt_message.id, - prompt_message.function.name, - args, - )) + tool_calls.append( + ( + prompt_message.id, + prompt_message.function.name, + args, + ) + ) return tool_calls - + def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: """ Extract blocking tool calls from llm result @@ -365,18 +362,22 @@ def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list tool_calls = [] for prompt_message in llm_result.message.tool_calls: args = {} - if prompt_message.function.arguments != '': + if prompt_message.function.arguments != "": args = json.loads(prompt_message.function.arguments) - tool_calls.append(( - prompt_message.id, - prompt_message.function.name, - args, - )) + tool_calls.append( + ( + prompt_message.id, + prompt_message.function.name, + args, + ) + ) return tool_calls - def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: + def _init_system_message( + self, prompt_template: str, prompt_messages: Optional[list[PromptMessage]] = None + ) -> list[PromptMessage]: """ Initialize system message """ @@ -384,27 +385,44 @@ def _init_system_message(self, prompt_template: str, prompt_messages: list[Promp return [ SystemPromptMessage(content=prompt_template), ] - + if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) return prompt_messages - def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: + def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ Organize user query """ if self.files: - prompt_message_contents = [TextPromptMessageContent(data=query)] - for file_obj in self.files: - prompt_message_contents.append(file_obj.prompt_message_content) + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=query)) + + # get image detail config + image_detail_config = ( + self.application_generate_entity.file_upload_config.image_config.detail + if ( + self.application_generate_entity.file_upload_config + and self.application_generate_entity.file_upload_config.image_config + ) + else None + ) + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + for file in self.files: + prompt_message_contents.append( + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config, + ) + ) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: prompt_messages.append(UserPromptMessage(content=query)) return prompt_messages - + def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ As for now, gpt supports both fc and vision at the first iteration. @@ -415,17 +433,21 @@ def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage] for prompt_message in prompt_messages: if isinstance(prompt_message, UserPromptMessage): if isinstance(prompt_message.content, list): - prompt_message.content = '\n'.join([ - content.data if content.type == PromptMessageContentType.TEXT else - '[image]' if content.type == PromptMessageContentType.IMAGE else - '[file]' - for content in prompt_message.content - ]) + prompt_message.content = "\n".join( + [ + content.data + if content.type == PromptMessageContentType.TEXT + else "[image]" + if content.type == PromptMessageContentType.IMAGE + else "[file]" + for content in prompt_message.content + ] + ) return prompt_messages def _organize_prompt_messages(self): - prompt_template = self.app_config.prompt_template.simple_prompt_template or '' + prompt_template = self.app_config.prompt_template.simple_prompt_template or "" self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) query_prompt_messages = self._organize_user_query(self.query, []) @@ -433,14 +455,10 @@ def _organize_prompt_messages(self): model_config=self.model_config, prompt_messages=[*query_prompt_messages, *self._current_thoughts], history_messages=self.history_prompt_messages, - memory=self.memory + memory=self.memory, ).get_prompt() - prompt_messages = [ - *self.history_prompt_messages, - *query_prompt_messages, - *self._current_thoughts - ] + prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts] if len(self._current_thoughts) != 0: # clear messages after the first iteration prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index c53fa5000e9958..085bac8601b2da 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -9,11 +9,12 @@ class CotAgentOutputParser: @classmethod - def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict) -> \ - Generator[Union[str, AgentScratchpadUnit.Action], None, None]: + def handle_react_stream_output( + cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict + ) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]: def parse_action(json_str): try: - action = json.loads(json_str) + action = json.loads(json_str, strict=False) action_name = None action_input = None @@ -22,7 +23,7 @@ def parse_action(json_str): action = action[0] for key, value in action.items(): - if 'input' in key.lower(): + if "input" in key.lower(): action_input = value else: action_name = value @@ -33,37 +34,39 @@ def parse_action(json_str): action_input=action_input, ) else: - return json_str or '' + return json_str or "" except: - return json_str or '' - + return json_str or "" + def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]: - code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL) + code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL) if not code_blocks: return for block in code_blocks: - json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE) + json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE) yield parse_action(json_text) - - code_block_cache = '' + + code_block_cache = "" code_block_delimiter_count = 0 in_code_block = False - json_cache = '' + json_cache = "" json_quote_count = 0 in_json = False got_json = False - action_cache = '' - action_str = 'action:' + action_cache = "" + action_str = "action:" action_idx = 0 - thought_cache = '' - thought_str = 'thought:' + thought_cache = "" + thought_str = "thought:" thought_idx = 0 + last_character = "" + for response in llm_response: if response.delta.usage: - usage_dict['usage'] = response.delta.usage + usage_dict["usage"] = response.delta.usage response = response.delta.message.content if not isinstance(response, str): continue @@ -72,91 +75,105 @@ def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, index = 0 while index < len(response): steps = 1 - delta = response[index:index+steps] - last_character = response[index-1] if index > 0 else '' + delta = response[index : index + steps] + yield_delta = False - if delta == '`': + if delta == "`": + last_character = delta code_block_cache += delta code_block_delimiter_count += 1 else: if not in_code_block: if code_block_delimiter_count > 0: + last_character = delta yield code_block_cache - code_block_cache = '' + code_block_cache = "" else: + last_character = delta code_block_cache += delta code_block_delimiter_count = 0 if not in_code_block and not in_json: if delta.lower() == action_str[action_idx] and action_idx == 0: - if last_character not in ['\n', ' ', '']: + if last_character not in {"\n", " ", ""}: + yield_delta = True + else: + last_character = delta + action_cache += delta + action_idx += 1 + if action_idx == len(action_str): + action_cache = "" + action_idx = 0 index += steps - yield delta continue - - action_cache += delta - action_idx += 1 - if action_idx == len(action_str): - action_cache = '' - action_idx = 0 - index += steps - continue elif delta.lower() == action_str[action_idx] and action_idx > 0: + last_character = delta action_cache += delta action_idx += 1 if action_idx == len(action_str): - action_cache = '' + action_cache = "" action_idx = 0 index += steps continue else: if action_cache: + last_character = delta yield action_cache - action_cache = '' + action_cache = "" action_idx = 0 - + if delta.lower() == thought_str[thought_idx] and thought_idx == 0: - if last_character not in ['\n', ' ', '']: + if last_character not in {"\n", " ", ""}: + yield_delta = True + else: + last_character = delta + thought_cache += delta + thought_idx += 1 + if thought_idx == len(thought_str): + thought_cache = "" + thought_idx = 0 index += steps - yield delta continue - - thought_cache += delta - thought_idx += 1 - if thought_idx == len(thought_str): - thought_cache = '' - thought_idx = 0 - index += steps - continue elif delta.lower() == thought_str[thought_idx] and thought_idx > 0: + last_character = delta thought_cache += delta thought_idx += 1 if thought_idx == len(thought_str): - thought_cache = '' + thought_cache = "" thought_idx = 0 index += steps continue else: if thought_cache: + last_character = delta yield thought_cache - thought_cache = '' + thought_cache = "" thought_idx = 0 + if yield_delta: + index += steps + last_character = delta + yield delta + continue + if code_block_delimiter_count == 3: if in_code_block: + last_character = delta yield from extra_json_from_code_block(code_block_cache) - code_block_cache = '' - + code_block_cache = "" + in_code_block = not in_code_block code_block_delimiter_count = 0 if not in_code_block: # handle single json - if delta == '{': + if delta == "{": json_quote_count += 1 in_json = True + last_character = delta json_cache += delta - elif delta == '}': + elif delta == "}": + last_character = delta json_cache += delta if json_quote_count > 0: json_quote_count -= 1 @@ -167,17 +184,20 @@ def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, continue else: if in_json: + last_character = delta json_cache += delta if got_json: got_json = False + last_character = delta yield parse_action(json_cache) - json_cache = '' + json_cache = "" json_quote_count = 0 in_json = False - + if not in_code_block and not in_json: - yield delta.replace('`', '') + last_character = delta + yield delta.replace("`", "") index += steps @@ -186,4 +206,3 @@ def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, if json_cache: yield parse_action(json_cache) - diff --git a/api/core/agent/prompt/template.py b/api/core/agent/prompt/template.py index b0cf1a77fb1772..ef64fd29fc3a76 100644 --- a/api/core/agent/prompt/template.py +++ b/api/core/agent/prompt/template.py @@ -41,7 +41,8 @@ {{historic_messages}} Question: {{query}} {{agent_scratchpad}} -Thought:""" +Thought:""" # noqa: E501 + ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}} Thought:""" @@ -86,19 +87,20 @@ ``` Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. -""" +""" # noqa: E501 + ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = "" REACT_PROMPT_TEMPLATES = { - 'english': { - 'chat': { - 'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES, - 'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES + "english": { + "chat": { + "prompt": ENGLISH_REACT_CHAT_PROMPT_TEMPLATES, + "agent_scratchpad": ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES, + }, + "completion": { + "prompt": ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES, + "agent_scratchpad": ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES, }, - 'completion': { - 'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES, - 'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES - } } -} \ No newline at end of file +} diff --git a/api/core/app/app_config/base_app_config_manager.py b/api/core/app/app_config/base_app_config_manager.py index 3dea305e984143..24d80f9cdd77f7 100644 --- a/api/core/app/app_config/base_app_config_manager.py +++ b/api/core/app/app_config/base_app_config_manager.py @@ -26,34 +26,24 @@ def convert_features(cls, config_dict: Mapping[str, Any], app_mode: AppMode) -> config_dict = dict(config_dict.items()) additional_features = AppAdditionalFeatures() - additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert( - config=config_dict - ) + additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict) additional_features.file_upload = FileUploadConfigManager.convert( - config=config_dict, - is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT] + config=config_dict, is_vision=app_mode in {AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT} ) - additional_features.opening_statement, additional_features.suggested_questions = \ - OpeningStatementConfigManager.convert( - config=config_dict - ) + additional_features.opening_statement, additional_features.suggested_questions = ( + OpeningStatementConfigManager.convert(config=config_dict) + ) additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert( config=config_dict ) - additional_features.more_like_this = MoreLikeThisConfigManager.convert( - config=config_dict - ) + additional_features.more_like_this = MoreLikeThisConfigManager.convert(config=config_dict) - additional_features.speech_to_text = SpeechToTextConfigManager.convert( - config=config_dict - ) + additional_features.speech_to_text = SpeechToTextConfigManager.convert(config=config_dict) - additional_features.text_to_speech = TextToSpeechConfigManager.convert( - config=config_dict - ) + additional_features.text_to_speech = TextToSpeechConfigManager.convert(config=config_dict) return additional_features diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py index 1ca8b1e3b8fed2..037037e6ca1cf0 100644 --- a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -7,25 +7,24 @@ class SensitiveWordAvoidanceConfigManager: @classmethod def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]: - sensitive_word_avoidance_dict = config.get('sensitive_word_avoidance') + sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance") if not sensitive_word_avoidance_dict: return None - if sensitive_word_avoidance_dict.get('enabled'): + if sensitive_word_avoidance_dict.get("enabled"): return SensitiveWordAvoidanceEntity( - type=sensitive_word_avoidance_dict.get('type'), - config=sensitive_word_avoidance_dict.get('config'), + type=sensitive_word_avoidance_dict.get("type"), + config=sensitive_word_avoidance_dict.get("config"), ) else: return None @classmethod - def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \ - -> tuple[dict, list[str]]: + def validate_and_set_defaults( + cls, tenant_id, config: dict, only_structure_validate: bool = False + ) -> tuple[dict, list[str]]: if not config.get("sensitive_word_avoidance"): - config["sensitive_word_avoidance"] = { - "enabled": False - } + config["sensitive_word_avoidance"] = {"enabled": False} if not isinstance(config["sensitive_word_avoidance"], dict): raise ValueError("sensitive_word_avoidance must be of dict type") @@ -41,10 +40,6 @@ def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_valid typ = config["sensitive_word_avoidance"]["type"] sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"] - ModerationFactory.validate_config( - name=typ, - tenant_id=tenant_id, - config=sensitive_word_avoidance_config - ) + ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config) return config, ["sensitive_word_avoidance"] diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index dc65d4439b6e03..f503543d7bd0f5 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -12,67 +12,70 @@ def convert(cls, config: dict) -> Optional[AgentEntity]: :param config: model config args """ - if 'agent_mode' in config and config['agent_mode'] \ - and 'enabled' in config['agent_mode']: + if "agent_mode" in config and config["agent_mode"] and "enabled" in config["agent_mode"]: + agent_dict = config.get("agent_mode", {}) + agent_strategy = agent_dict.get("strategy", "cot") - agent_dict = config.get('agent_mode', {}) - agent_strategy = agent_dict.get('strategy', 'cot') - - if agent_strategy == 'function_call': + if agent_strategy == "function_call": strategy = AgentEntity.Strategy.FUNCTION_CALLING - elif agent_strategy == 'cot' or agent_strategy == 'react': + elif agent_strategy in {"cot", "react"}: strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT else: # old configs, try to detect default strategy - if config['model']['provider'] == 'openai': + if config["model"]["provider"] == "openai": strategy = AgentEntity.Strategy.FUNCTION_CALLING else: strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT agent_tools = [] - for tool in agent_dict.get('tools', []): + for tool in agent_dict.get("tools", []): keys = tool.keys() if len(keys) >= 4: if "enabled" not in tool or not tool["enabled"]: continue agent_tool_properties = { - 'provider_type': tool['provider_type'], - 'provider_id': tool['provider_id'], - 'tool_name': tool['tool_name'], - 'tool_parameters': tool.get('tool_parameters', {}) + "provider_type": tool["provider_type"], + "provider_id": tool["provider_id"], + "tool_name": tool["tool_name"], + "tool_parameters": tool.get("tool_parameters", {}), } agent_tools.append(AgentToolEntity(**agent_tool_properties)) - if 'strategy' in config['agent_mode'] and \ - config['agent_mode']['strategy'] not in ['react_router', 'router']: - agent_prompt = agent_dict.get('prompt', None) or {} + if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in { + "react_router", + "router", + }: + agent_prompt = agent_dict.get("prompt", None) or {} # check model mode - model_mode = config.get('model', {}).get('mode', 'completion') - if model_mode == 'completion': + model_mode = config.get("model", {}).get("mode", "completion") + if model_mode == "completion": agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', - REACT_PROMPT_TEMPLATES['english']['completion']['prompt']), - next_iteration=agent_prompt.get('next_iteration', - REACT_PROMPT_TEMPLATES['english']['completion'][ - 'agent_scratchpad']), + first_prompt=agent_prompt.get( + "first_prompt", REACT_PROMPT_TEMPLATES["english"]["completion"]["prompt"] + ), + next_iteration=agent_prompt.get( + "next_iteration", REACT_PROMPT_TEMPLATES["english"]["completion"]["agent_scratchpad"] + ), ) else: agent_prompt_entity = AgentPromptEntity( - first_prompt=agent_prompt.get('first_prompt', - REACT_PROMPT_TEMPLATES['english']['chat']['prompt']), - next_iteration=agent_prompt.get('next_iteration', - REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']), + first_prompt=agent_prompt.get( + "first_prompt", REACT_PROMPT_TEMPLATES["english"]["chat"]["prompt"] + ), + next_iteration=agent_prompt.get( + "next_iteration", REACT_PROMPT_TEMPLATES["english"]["chat"]["agent_scratchpad"] + ), ) return AgentEntity( - provider=config['model']['provider'], - model=config['model']['name'], + provider=config["model"]["provider"], + model=config["model"]["name"], strategy=strategy, prompt=agent_prompt_entity, tools=agent_tools, - max_iteration=agent_dict.get('max_iteration', 5) + max_iteration=agent_dict.get("max_iteration", 5), ) return None diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index f4e6675bd44435..a22395b8e39a03 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -15,39 +15,38 @@ def convert(cls, config: dict) -> Optional[DatasetEntity]: :param config: model config args """ dataset_ids = [] - if 'datasets' in config.get('dataset_configs', {}): - datasets = config.get('dataset_configs', {}).get('datasets', { - 'strategy': 'router', - 'datasets': [] - }) + if "datasets" in config.get("dataset_configs", {}): + datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []}) - for dataset in datasets.get('datasets', []): + for dataset in datasets.get("datasets", []): keys = list(dataset.keys()) - if len(keys) == 0 or keys[0] != 'dataset': + if len(keys) == 0 or keys[0] != "dataset": continue - dataset = dataset['dataset'] + dataset = dataset["dataset"] - if 'enabled' not in dataset or not dataset['enabled']: + if "enabled" not in dataset or not dataset["enabled"]: continue - dataset_id = dataset.get('id', None) + dataset_id = dataset.get("id", None) if dataset_id: dataset_ids.append(dataset_id) - if 'agent_mode' in config and config['agent_mode'] \ - and 'enabled' in config['agent_mode'] \ - and config['agent_mode']['enabled']: + if ( + "agent_mode" in config + and config["agent_mode"] + and "enabled" in config["agent_mode"] + and config["agent_mode"]["enabled"] + ): + agent_dict = config.get("agent_mode", {}) - agent_dict = config.get('agent_mode', {}) - - for tool in agent_dict.get('tools', []): + for tool in agent_dict.get("tools", []): keys = tool.keys() if len(keys) == 1: # old standard key = list(tool.keys())[0] - if key != 'dataset': + if key != "dataset": continue tool_item = tool[key] @@ -55,30 +54,28 @@ def convert(cls, config: dict) -> Optional[DatasetEntity]: if "enabled" not in tool_item or not tool_item["enabled"]: continue - dataset_id = tool_item['id'] + dataset_id = tool_item["id"] dataset_ids.append(dataset_id) if len(dataset_ids) == 0: return None # dataset configs - if 'dataset_configs' in config and config.get('dataset_configs'): - dataset_configs = config.get('dataset_configs') + if "dataset_configs" in config and config.get("dataset_configs"): + dataset_configs = config.get("dataset_configs") else: - dataset_configs = { - 'retrieval_model': 'multiple' - } - query_variable = config.get('dataset_query_variable') + dataset_configs = {"retrieval_model": "multiple"} + query_variable = config.get("dataset_query_variable") - if dataset_configs['retrieval_model'] == 'single': + if dataset_configs["retrieval_model"] == "single": return DatasetEntity( dataset_ids=dataset_ids, retrieve_config=DatasetRetrieveConfigEntity( query_variable=query_variable, retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( - dataset_configs['retrieval_model'] - ) - ) + dataset_configs["retrieval_model"] + ), + ), ) else: return DatasetEntity( @@ -86,15 +83,15 @@ def convert(cls, config: dict) -> Optional[DatasetEntity]: retrieve_config=DatasetRetrieveConfigEntity( query_variable=query_variable, retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( - dataset_configs['retrieval_model'] + dataset_configs["retrieval_model"] ), - top_k=dataset_configs.get('top_k', 4), - score_threshold=dataset_configs.get('score_threshold'), - reranking_model=dataset_configs.get('reranking_model'), - weights=dataset_configs.get('weights'), - reranking_enabled=dataset_configs.get('reranking_enabled', True), - rerank_mode=dataset_configs["reranking_mode"], - ) + top_k=dataset_configs.get("top_k", 4), + score_threshold=dataset_configs.get("score_threshold"), + reranking_model=dataset_configs.get("reranking_model"), + weights=dataset_configs.get("weights"), + reranking_enabled=dataset_configs.get("reranking_enabled", True), + rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"), + ), ) @classmethod @@ -111,13 +108,10 @@ def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: di # dataset_configs if not config.get("dataset_configs"): - config["dataset_configs"] = {'retrieval_model': 'single'} + config["dataset_configs"] = {"retrieval_model": "single"} if not config["dataset_configs"].get("datasets"): - config["dataset_configs"]["datasets"] = { - "strategy": "router", - "datasets": [] - } + config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []} if not isinstance(config["dataset_configs"], dict): raise ValueError("dataset_configs must be of object type") @@ -125,8 +119,9 @@ def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: di if not isinstance(config["dataset_configs"], dict): raise ValueError("dataset_configs must be of object type") - need_manual_query_datasets = (config.get("dataset_configs") - and config["dataset_configs"].get("datasets", {}).get("datasets")) + need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get( + "datasets", {} + ).get("datasets") if need_manual_query_datasets and app_mode == AppMode.COMPLETION: # Only check when mode is completion @@ -148,10 +143,7 @@ def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mod """ # Extract dataset config for legacy compatibility if not config.get("agent_mode"): - config["agent_mode"] = { - "enabled": False, - "tools": [] - } + config["agent_mode"] = {"enabled": False, "tools": []} if not isinstance(config["agent_mode"], dict): raise ValueError("agent_mode must be of object type") @@ -175,7 +167,7 @@ def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mod config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value has_datasets = False - if config["agent_mode"]["strategy"] in [PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value]: + if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}: for tool in config["agent_mode"]["tools"]: key = list(tool.keys())[0] if key == "dataset": @@ -188,7 +180,7 @@ def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mod if not isinstance(tool_item["enabled"], bool): raise ValueError("enabled in agent_mode.tools must be of boolean type") - if 'id' not in tool_item: + if "id" not in tool_item: raise ValueError("id is required in dataset") try: diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index 5c9b2cfec7babf..a91b9f0f020073 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -11,9 +11,7 @@ class ModelConfigConverter: @classmethod - def convert(cls, app_config: EasyUIBasedAppConfig, - skip_check: bool = False) \ - -> ModelConfigWithCredentialsEntity: + def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity: """ Convert app model config dict to entity. :param app_config: app config @@ -25,9 +23,7 @@ def convert(cls, app_config: EasyUIBasedAppConfig, provider_manager = ProviderManager() provider_model_bundle = provider_manager.get_provider_model_bundle( - tenant_id=app_config.tenant_id, - provider=model_config.provider, - model_type=ModelType.LLM + tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM ) provider_name = provider_model_bundle.configuration.provider.provider @@ -38,8 +34,7 @@ def convert(cls, app_config: EasyUIBasedAppConfig, # check model credentials model_credentials = provider_model_bundle.configuration.get_current_credentials( - model_type=ModelType.LLM, - model=model_config.model + model_type=ModelType.LLM, model=model_config.model ) if model_credentials is None: @@ -51,8 +46,7 @@ def convert(cls, app_config: EasyUIBasedAppConfig, if not skip_check: # check model provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_config.model, - model_type=ModelType.LLM + model=model_config.model, model_type=ModelType.LLM ) if provider_model is None: @@ -69,24 +63,18 @@ def convert(cls, app_config: EasyUIBasedAppConfig, # model config completion_params = model_config.parameters stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] # get model mode model_mode = model_config.mode if not model_mode: - mode_enum = model_type_instance.get_model_mode( - model=model_config.model, - credentials=model_credentials - ) + mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials) model_mode = mode_enum.value - model_schema = model_type_instance.get_model_schema( - model_config.model, - model_credentials - ) + model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials) if not skip_check and not model_schema: raise ValueError(f"Model {model_name} not exist.") diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 730a9527cf7315..b5e4554181c06e 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -13,23 +13,23 @@ def convert(cls, config: dict) -> ModelConfigEntity: :param config: model config args """ # model config - model_config = config.get('model') + model_config = config.get("model") if not model_config: raise ValueError("model is required") - completion_params = model_config.get('completion_params') + completion_params = model_config.get("completion_params") stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] # get model mode - model_mode = model_config.get('mode') + model_mode = model_config.get("mode") return ModelConfigEntity( - provider=config['model']['provider'], - model=config['model']['name'], + provider=config["model"]["provider"], + model=config["model"]["name"], mode=model_mode, parameters=completion_params, stop=stop, @@ -43,7 +43,7 @@ def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, :param tenant_id: tenant id :param config: app model config args """ - if 'model' not in config: + if "model" not in config: raise ValueError("model is required") if not isinstance(config["model"], dict): @@ -52,17 +52,16 @@ def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, # model.provider provider_entities = model_provider_factory.get_providers() model_provider_names = [provider.provider for provider in provider_entities] - if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names: + if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names: raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") # model.name - if 'name' not in config["model"]: + if "name" not in config["model"]: raise ValueError("model.name is required") provider_manager = ProviderManager() models = provider_manager.get_configurations(tenant_id).get_models( - provider=config["model"]["provider"], - model_type=ModelType.LLM + provider=config["model"]["provider"], model_type=ModelType.LLM ) if not models: @@ -80,12 +79,12 @@ def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, # model.mode if model_mode: - config['model']["mode"] = model_mode + config["model"]["mode"] = model_mode else: - config['model']["mode"] = "completion" + config["model"]["mode"] = "completion" # model.completion_params - if 'completion_params' not in config["model"]: + if "completion_params" not in config["model"]: raise ValueError("model.completion_params is required") config["model"]["completion_params"] = cls.validate_model_completion_params( @@ -101,7 +100,7 @@ def validate_model_completion_params(cls, cp: dict) -> dict: raise ValueError("model.completion_params must be of object type") # stop - if 'stop' not in cp: + if "stop" not in cp: cp["stop"] = [] elif not isinstance(cp["stop"], list): raise ValueError("stop in model.completion_params must be of list type") diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 1f410758aa41da..82a0e56ce840e4 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -14,39 +14,33 @@ def convert(cls, config: dict) -> PromptTemplateEntity: if not config.get("prompt_type"): raise ValueError("prompt_type is required") - prompt_type = PromptTemplateEntity.PromptType.value_of(config['prompt_type']) + prompt_type = PromptTemplateEntity.PromptType.value_of(config["prompt_type"]) if prompt_type == PromptTemplateEntity.PromptType.SIMPLE: simple_prompt_template = config.get("pre_prompt", "") - return PromptTemplateEntity( - prompt_type=prompt_type, - simple_prompt_template=simple_prompt_template - ) + return PromptTemplateEntity(prompt_type=prompt_type, simple_prompt_template=simple_prompt_template) else: advanced_chat_prompt_template = None chat_prompt_config = config.get("chat_prompt_config", {}) if chat_prompt_config: chat_prompt_messages = [] for message in chat_prompt_config.get("prompt", []): - chat_prompt_messages.append({ - "text": message["text"], - "role": PromptMessageRole.value_of(message["role"]) - }) + chat_prompt_messages.append( + {"text": message["text"], "role": PromptMessageRole.value_of(message["role"])} + ) - advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity( - messages=chat_prompt_messages - ) + advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages) advanced_completion_prompt_template = None completion_prompt_config = config.get("completion_prompt_config", {}) if completion_prompt_config: completion_prompt_template_params = { - 'prompt': completion_prompt_config['prompt']['text'], + "prompt": completion_prompt_config["prompt"]["text"], } - if 'conversation_histories_role' in completion_prompt_config: - completion_prompt_template_params['role_prefix'] = { - 'user': completion_prompt_config['conversation_histories_role']['user_prefix'], - 'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix'] + if "conversation_histories_role" in completion_prompt_config: + completion_prompt_template_params["role_prefix"] = { + "user": completion_prompt_config["conversation_histories_role"]["user_prefix"], + "assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"], } advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity( @@ -56,7 +50,7 @@ def convert(cls, config: dict) -> PromptTemplateEntity: return PromptTemplateEntity( prompt_type=prompt_type, advanced_chat_prompt_template=advanced_chat_prompt_template, - advanced_completion_prompt_template=advanced_completion_prompt_template + advanced_completion_prompt_template=advanced_completion_prompt_template, ) @classmethod @@ -72,7 +66,7 @@ def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dic config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType] - if config['prompt_type'] not in prompt_type_vals: + if config["prompt_type"] not in prompt_type_vals: raise ValueError(f"prompt_type must be in {prompt_type_vals}") # chat_prompt_config @@ -89,27 +83,28 @@ def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dic if not isinstance(config["completion_prompt_config"], dict): raise ValueError("completion_prompt_config must be of object type") - if config['prompt_type'] == PromptTemplateEntity.PromptType.ADVANCED.value: - if not config['chat_prompt_config'] and not config['completion_prompt_config']: - raise ValueError("chat_prompt_config or completion_prompt_config is required " - "when prompt_type is advanced") + if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value: + if not config["chat_prompt_config"] and not config["completion_prompt_config"]: + raise ValueError( + "chat_prompt_config or completion_prompt_config is required when prompt_type is advanced" + ) model_mode_vals = [mode.value for mode in ModelMode] - if config['model']["mode"] not in model_mode_vals: + if config["model"]["mode"] not in model_mode_vals: raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced") - if app_mode == AppMode.CHAT and config['model']["mode"] == ModelMode.COMPLETION.value: - user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix'] - assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] + if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value: + user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] + assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] if not user_prefix: - config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human' + config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] = "Human" if not assistant_prefix: - config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant' + config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant" - if config['model']["mode"] == ModelMode.CHAT.value: - prompt_list = config['chat_prompt_config']['prompt'] + if config["model"]["mode"] == ModelMode.CHAT.value: + prompt_list = config["chat_prompt_config"]["prompt"] if len(prompt_list) > 10: raise ValueError("prompt messages must be less than 10") diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index 3eb006b46ec57b..2f2445a33639ed 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -1,6 +1,6 @@ import re -from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity +from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType from core.external_data_tool.factory import ExternalDataToolFactory @@ -13,67 +13,55 @@ def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataV :param config: model config args """ external_data_variables = [] - variables = [] + variable_entities = [] # old external_data_tools - external_data_tools = config.get('external_data_tools', []) + external_data_tools = config.get("external_data_tools", []) for external_data_tool in external_data_tools: - if 'enabled' not in external_data_tool or not external_data_tool['enabled']: + if "enabled" not in external_data_tool or not external_data_tool["enabled"]: continue external_data_variables.append( ExternalDataVariableEntity( - variable=external_data_tool['variable'], - type=external_data_tool['type'], - config=external_data_tool['config'] + variable=external_data_tool["variable"], + type=external_data_tool["type"], + config=external_data_tool["config"], ) ) # variables and external_data_tools - for variable in config.get('user_input_form', []): - typ = list(variable.keys())[0] - if typ == 'external_data_tool': - val = variable[typ] - if 'config' not in val: + for variables in config.get("user_input_form", []): + variable_type = list(variables.keys())[0] + if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL: + variable = variables[variable_type] + if "config" not in variable: continue external_data_variables.append( ExternalDataVariableEntity( - variable=val['variable'], - type=val['type'], - config=val['config'] + variable=variable["variable"], type=variable["type"], config=variable["config"] ) ) - elif typ in [ - VariableEntity.Type.TEXT_INPUT.value, - VariableEntity.Type.PARAGRAPH.value, - VariableEntity.Type.NUMBER.value, - ]: - variables.append( + elif variable_type in { + VariableEntityType.TEXT_INPUT, + VariableEntityType.PARAGRAPH, + VariableEntityType.NUMBER, + VariableEntityType.SELECT, + }: + variable = variables[variable_type] + variable_entities.append( VariableEntity( - type=VariableEntity.Type.value_of(typ), - variable=variable[typ].get('variable'), - description=variable[typ].get('description'), - label=variable[typ].get('label'), - required=variable[typ].get('required', False), - max_length=variable[typ].get('max_length'), - default=variable[typ].get('default'), - ) - ) - elif typ == VariableEntity.Type.SELECT.value: - variables.append( - VariableEntity( - type=VariableEntity.Type.SELECT, - variable=variable[typ].get('variable'), - description=variable[typ].get('description'), - label=variable[typ].get('label'), - required=variable[typ].get('required', False), - options=variable[typ].get('options'), - default=variable[typ].get('default'), + type=variable_type, + variable=variable.get("variable"), + description=variable.get("description") or "", + label=variable.get("label"), + required=variable.get("required", False), + max_length=variable.get("max_length"), + options=variable.get("options") or [], ) ) - return variables, external_data_variables + return variable_entities, external_data_variables @classmethod def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: @@ -108,17 +96,17 @@ def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[s variables = [] for item in config["user_input_form"]: key = list(item.keys())[0] - if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]: + if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}: raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") form_item = item[key] - if 'label' not in form_item: + if "label" not in form_item: raise ValueError("label is required in user_input_form") if not isinstance(form_item["label"], str): raise ValueError("label in user_input_form must be of string type") - if 'variable' not in form_item: + if "variable" not in form_item: raise ValueError("variable is required in user_input_form") if not isinstance(form_item["variable"], str): @@ -126,26 +114,24 @@ def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[s pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$") if pattern.match(form_item["variable"]) is None: - raise ValueError("variable in user_input_form must be a string, " - "and cannot start with a number") + raise ValueError("variable in user_input_form must be a string, and cannot start with a number") variables.append(form_item["variable"]) - if 'required' not in form_item or not form_item["required"]: + if "required" not in form_item or not form_item["required"]: form_item["required"] = False if not isinstance(form_item["required"], bool): raise ValueError("required in user_input_form must be of boolean type") if key == "select": - if 'options' not in form_item or not form_item["options"]: + if "options" not in form_item or not form_item["options"]: form_item["options"] = [] if not isinstance(form_item["options"], list): raise ValueError("options in user_input_form must be a list of strings") - if "default" in form_item and form_item['default'] \ - and form_item["default"] not in form_item["options"]: + if "default" in form_item and form_item["default"] and form_item["default"] not in form_item["options"]: raise ValueError("default value in user_input_form must be in the options list") return config, ["user_input_form"] @@ -177,10 +163,6 @@ def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: d typ = tool["type"] config = tool["config"] - ExternalDataToolFactory.validate_config( - name=typ, - tenant_id=tenant_id, - config=config - ) + ExternalDataToolFactory.validate_config(name=typ, tenant_id=tenant_id, config=config) - return config, ["external_data_tools"] \ No newline at end of file + return config, ["external_data_tools"] diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 05a42a898e4af7..9b72452d7a1a9e 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,17 +1,19 @@ +from collections.abc import Sequence from enum import Enum from typing import Any, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field, field_validator -from core.file.file_obj import FileExtraConfig +from core.file import FileTransferMethod, FileType, FileUploadConfig from core.model_runtime.entities.message_entities import PromptMessageRole -from models import AppMode +from models.model import AppMode class ModelConfigEntity(BaseModel): """ Model Config Entity. """ + provider: str model: str mode: Optional[str] = None @@ -23,6 +25,7 @@ class AdvancedChatMessageEntity(BaseModel): """ Advanced Chat Message Entity. """ + text: str role: PromptMessageRole @@ -31,6 +34,7 @@ class AdvancedChatPromptTemplateEntity(BaseModel): """ Advanced Chat Prompt Template Entity. """ + messages: list[AdvancedChatMessageEntity] @@ -43,6 +47,7 @@ class RolePrefixEntity(BaseModel): """ Role Prefix Entity. """ + user: str assistant: str @@ -60,11 +65,12 @@ class PromptType(Enum): Prompt Type. 'simple', 'advanced' """ - SIMPLE = 'simple' - ADVANCED = 'advanced' + + SIMPLE = "simple" + ADVANCED = "advanced" @classmethod - def value_of(cls, value: str) -> 'PromptType': + def value_of(cls, value: str): """ Get value of given mode. @@ -74,7 +80,7 @@ def value_of(cls, value: str) -> 'PromptType': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid prompt type value {value}') + raise ValueError(f"invalid prompt type value {value}") prompt_type: PromptType simple_prompt_template: Optional[str] = None @@ -82,48 +88,48 @@ def value_of(cls, value: str) -> 'PromptType': advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None +class VariableEntityType(str, Enum): + TEXT_INPUT = "text-input" + SELECT = "select" + PARAGRAPH = "paragraph" + NUMBER = "number" + EXTERNAL_DATA_TOOL = "external_data_tool" + FILE = "file" + FILE_LIST = "file-list" + + class VariableEntity(BaseModel): """ Variable Entity. """ - class Type(Enum): - TEXT_INPUT = 'text-input' - SELECT = 'select' - PARAGRAPH = 'paragraph' - NUMBER = 'number' - - @classmethod - def value_of(cls, value: str) -> 'VariableEntity.Type': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f'invalid variable type value {value}') variable: str label: str - description: Optional[str] = None - type: Type + description: str = "" + type: VariableEntityType required: bool = False max_length: Optional[int] = None - options: Optional[list[str]] = None - default: Optional[str] = None - hint: Optional[str] = None + options: Sequence[str] = Field(default_factory=list) + allowed_file_types: Sequence[FileType] = Field(default_factory=list) + allowed_file_extensions: Sequence[str] = Field(default_factory=list) + allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) + + @field_validator("description", mode="before") + @classmethod + def convert_none_description(cls, v: Any) -> str: + return v or "" - @property - def name(self) -> str: - return self.variable + @field_validator("options", mode="before") + @classmethod + def convert_none_options(cls, v: Any) -> Sequence[str]: + return v or [] class ExternalDataVariableEntity(BaseModel): """ External Data Variable Entity. """ + variable: str type: str config: dict[str, Any] = {} @@ -139,11 +145,12 @@ class RetrieveStrategy(Enum): Dataset Retrieve Strategy. 'single' or 'multiple' """ - SINGLE = 'single' - MULTIPLE = 'multiple' + + SINGLE = "single" + MULTIPLE = "multiple" @classmethod - def value_of(cls, value: str) -> 'RetrieveStrategy': + def value_of(cls, value: str): """ Get value of given mode. @@ -153,25 +160,24 @@ def value_of(cls, value: str) -> 'RetrieveStrategy': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid retrieve strategy value {value}') + raise ValueError(f"invalid retrieve strategy value {value}") query_variable: Optional[str] = None # Only when app mode is completion retrieve_strategy: RetrieveStrategy top_k: Optional[int] = None - score_threshold: Optional[float] = .0 - rerank_mode: Optional[str] = 'reranking_model' + score_threshold: Optional[float] = 0.0 + rerank_mode: Optional[str] = "reranking_model" reranking_model: Optional[dict] = None weights: Optional[dict] = None reranking_enabled: Optional[bool] = True - - class DatasetEntity(BaseModel): """ Dataset Config Entity. """ + dataset_ids: list[str] retrieve_config: DatasetRetrieveConfigEntity @@ -180,6 +186,7 @@ class SensitiveWordAvoidanceEntity(BaseModel): """ Sensitive Word Avoidance Entity. """ + type: str config: dict[str, Any] = {} @@ -188,6 +195,7 @@ class TextToSpeechEntity(BaseModel): """ Sensitive Word Avoidance Entity. """ + enabled: bool voice: Optional[str] = None language: Optional[str] = None @@ -197,14 +205,13 @@ class TracingConfigEntity(BaseModel): """ Tracing Config Entity. """ + enabled: bool tracing_provider: str - - class AppAdditionalFeatures(BaseModel): - file_upload: Optional[FileExtraConfig] = None + file_upload: Optional[FileUploadConfig] = None opening_statement: Optional[str] = None suggested_questions: list[str] = [] suggested_questions_after_answer: bool = False @@ -214,10 +221,12 @@ class AppAdditionalFeatures(BaseModel): text_to_speech: Optional[TextToSpeechEntity] = None trace_config: Optional[TracingConfigEntity] = None + class AppConfig(BaseModel): """ Application Config Entity. """ + tenant_id: str app_id: str app_mode: AppMode @@ -230,15 +239,17 @@ class EasyUIBasedAppModelConfigFrom(Enum): """ App Model Config From. """ - ARGS = 'args' - APP_LATEST_CONFIG = 'app-latest-config' - CONVERSATION_SPECIFIC_CONFIG = 'conversation-specific-config' + + ARGS = "args" + APP_LATEST_CONFIG = "app-latest-config" + CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config" class EasyUIBasedAppConfig(AppConfig): """ Easy UI Based App Config Entity. """ + app_model_config_from: EasyUIBasedAppModelConfigFrom app_model_config_id: str app_model_config_dict: dict @@ -252,4 +263,5 @@ class WorkflowUIBasedAppConfig(AppConfig): """ Workflow UI Based App Config Entity. """ - workflow_id: str \ No newline at end of file + + workflow_id: str diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 3da3c2eddb83f3..2043ea0e41795f 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,69 +1,46 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any -from core.file.file_obj import FileExtraConfig +from core.file import FileUploadConfig class FileUploadConfigManager: @classmethod - def convert(cls, config: Mapping[str, Any], is_vision: bool = True) -> Optional[FileExtraConfig]: + def convert(cls, config: Mapping[str, Any], is_vision: bool = True): """ Convert model config to model config :param config: model config args :param is_vision: if True, the feature is vision feature """ - file_upload_dict = config.get('file_upload') + file_upload_dict = config.get("file_upload") if file_upload_dict: - if file_upload_dict.get('image'): - if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']: - image_config = { - 'number_limits': file_upload_dict['image']['number_limits'], - 'transfer_methods': file_upload_dict['image']['transfer_methods'] + if file_upload_dict.get("enabled"): + transform_methods = file_upload_dict.get("allowed_file_upload_methods") or file_upload_dict.get( + "allowed_upload_methods", [] + ) + data = { + "image_config": { + "number_limits": file_upload_dict["number_limits"], + "transfer_methods": transform_methods, } + } - if is_vision: - image_config['detail'] = file_upload_dict['image']['detail'] + if is_vision: + data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low") - return FileExtraConfig( - image_config=image_config - ) - - return None + return FileUploadConfig.model_validate(data) @classmethod - def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for file upload feature :param config: app model config args - :param is_vision: if True, the feature is vision feature """ if not config.get("file_upload"): config["file_upload"] = {} - - if not isinstance(config["file_upload"], dict): - raise ValueError("file_upload must be of dict type") - - # check image config - if not config["file_upload"].get("image"): - config["file_upload"]["image"] = {"enabled": False} - - if config['file_upload']['image']['enabled']: - number_limits = config['file_upload']['image']['number_limits'] - if number_limits < 1 or number_limits > 6: - raise ValueError("number_limits must be in [1, 6]") - - if is_vision: - detail = config['file_upload']['image']['detail'] - if detail not in ['high', 'low']: - raise ValueError("detail must be in ['high', 'low']") - - transfer_methods = config['file_upload']['image']['transfer_methods'] - if not isinstance(transfer_methods, list): - raise ValueError("transfer_methods must be of list type") - for method in transfer_methods: - if method not in ['remote_url', 'local_file']: - raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") + else: + FileUploadConfig.model_validate(config["file_upload"]) return config, ["file_upload"] diff --git a/api/core/app/app_config/features/more_like_this/manager.py b/api/core/app/app_config/features/more_like_this/manager.py index 2ba99a5c40d5fc..496e1beeecfa0f 100644 --- a/api/core/app/app_config/features/more_like_this/manager.py +++ b/api/core/app/app_config/features/more_like_this/manager.py @@ -7,9 +7,9 @@ def convert(cls, config: dict) -> bool: :param config: model config args """ more_like_this = False - more_like_this_dict = config.get('more_like_this') + more_like_this_dict = config.get("more_like_this") if more_like_this_dict: - if more_like_this_dict.get('enabled'): + if more_like_this_dict.get("enabled"): more_like_this = True return more_like_this @@ -22,9 +22,7 @@ def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: :param config: app model config args """ if not config.get("more_like_this"): - config["more_like_this"] = { - "enabled": False - } + config["more_like_this"] = {"enabled": False} if not isinstance(config["more_like_this"], dict): raise ValueError("more_like_this must be of dict type") diff --git a/api/core/app/app_config/features/opening_statement/manager.py b/api/core/app/app_config/features/opening_statement/manager.py index 0d8a71bfcf4d8b..b4dacbc409044a 100644 --- a/api/core/app/app_config/features/opening_statement/manager.py +++ b/api/core/app/app_config/features/opening_statement/manager.py @@ -1,5 +1,3 @@ - - class OpeningStatementConfigManager: @classmethod def convert(cls, config: dict) -> tuple[str, list]: @@ -9,10 +7,10 @@ def convert(cls, config: dict) -> tuple[str, list]: :param config: model config args """ # opening statement - opening_statement = config.get('opening_statement') + opening_statement = config.get("opening_statement") # suggested questions - suggested_questions_list = config.get('suggested_questions') + suggested_questions_list = config.get("suggested_questions") return opening_statement, suggested_questions_list diff --git a/api/core/app/app_config/features/retrieval_resource/manager.py b/api/core/app/app_config/features/retrieval_resource/manager.py index fca58e12e883da..d098abac2fa2e7 100644 --- a/api/core/app/app_config/features/retrieval_resource/manager.py +++ b/api/core/app/app_config/features/retrieval_resource/manager.py @@ -2,9 +2,9 @@ class RetrievalResourceConfigManager: @classmethod def convert(cls, config: dict) -> bool: show_retrieve_source = False - retriever_resource_dict = config.get('retriever_resource') + retriever_resource_dict = config.get("retriever_resource") if retriever_resource_dict: - if retriever_resource_dict.get('enabled'): + if retriever_resource_dict.get("enabled"): show_retrieve_source = True return show_retrieve_source @@ -17,9 +17,7 @@ def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: :param config: app model config args """ if not config.get("retriever_resource"): - config["retriever_resource"] = { - "enabled": False - } + config["retriever_resource"] = {"enabled": False} if not isinstance(config["retriever_resource"], dict): raise ValueError("retriever_resource must be of dict type") diff --git a/api/core/app/app_config/features/speech_to_text/manager.py b/api/core/app/app_config/features/speech_to_text/manager.py index 88b4be25d3e216..e10ae03e043b78 100644 --- a/api/core/app/app_config/features/speech_to_text/manager.py +++ b/api/core/app/app_config/features/speech_to_text/manager.py @@ -7,9 +7,9 @@ def convert(cls, config: dict) -> bool: :param config: model config args """ speech_to_text = False - speech_to_text_dict = config.get('speech_to_text') + speech_to_text_dict = config.get("speech_to_text") if speech_to_text_dict: - if speech_to_text_dict.get('enabled'): + if speech_to_text_dict.get("enabled"): speech_to_text = True return speech_to_text @@ -22,9 +22,7 @@ def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: :param config: app model config args """ if not config.get("speech_to_text"): - config["speech_to_text"] = { - "enabled": False - } + config["speech_to_text"] = {"enabled": False} if not isinstance(config["speech_to_text"], dict): raise ValueError("speech_to_text must be of dict type") diff --git a/api/core/app/app_config/features/suggested_questions_after_answer/manager.py b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py index c6cab012207b9c..9ac5114d12dd44 100644 --- a/api/core/app/app_config/features/suggested_questions_after_answer/manager.py +++ b/api/core/app/app_config/features/suggested_questions_after_answer/manager.py @@ -7,9 +7,9 @@ def convert(cls, config: dict) -> bool: :param config: model config args """ suggested_questions_after_answer = False - suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer') + suggested_questions_after_answer_dict = config.get("suggested_questions_after_answer") if suggested_questions_after_answer_dict: - if suggested_questions_after_answer_dict.get('enabled'): + if suggested_questions_after_answer_dict.get("enabled"): suggested_questions_after_answer = True return suggested_questions_after_answer @@ -22,15 +22,15 @@ def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: :param config: app model config args """ if not config.get("suggested_questions_after_answer"): - config["suggested_questions_after_answer"] = { - "enabled": False - } + config["suggested_questions_after_answer"] = {"enabled": False} if not isinstance(config["suggested_questions_after_answer"], dict): raise ValueError("suggested_questions_after_answer must be of dict type") - if "enabled" not in config["suggested_questions_after_answer"] or not \ - config["suggested_questions_after_answer"]["enabled"]: + if ( + "enabled" not in config["suggested_questions_after_answer"] + or not config["suggested_questions_after_answer"]["enabled"] + ): config["suggested_questions_after_answer"]["enabled"] = False if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool): diff --git a/api/core/app/app_config/features/text_to_speech/manager.py b/api/core/app/app_config/features/text_to_speech/manager.py index f11e268e7380db..1c7598178527b4 100644 --- a/api/core/app/app_config/features/text_to_speech/manager.py +++ b/api/core/app/app_config/features/text_to_speech/manager.py @@ -10,13 +10,13 @@ def convert(cls, config: dict): :param config: model config args """ text_to_speech = None - text_to_speech_dict = config.get('text_to_speech') + text_to_speech_dict = config.get("text_to_speech") if text_to_speech_dict: - if text_to_speech_dict.get('enabled'): + if text_to_speech_dict.get("enabled"): text_to_speech = TextToSpeechEntity( - enabled=text_to_speech_dict.get('enabled'), - voice=text_to_speech_dict.get('voice'), - language=text_to_speech_dict.get('language'), + enabled=text_to_speech_dict.get("enabled"), + voice=text_to_speech_dict.get("voice"), + language=text_to_speech_dict.get("language"), ) return text_to_speech @@ -29,11 +29,7 @@ def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: :param config: app model config args """ if not config.get("text_to_speech"): - config["text_to_speech"] = { - "enabled": False, - "voice": "", - "language": "" - } + config["text_to_speech"] = {"enabled": False, "voice": "", "language": ""} if not isinstance(config["text_to_speech"], dict): raise ValueError("text_to_speech must be of dict type") diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index 4b117d87f8c157..2f1da3808231dd 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -17,6 +17,6 @@ def convert(cls, workflow: Workflow) -> list[VariableEntity]: # variables for variable in user_input_form: - variables.append(VariableEntity(**variable)) + variables.append(VariableEntity.model_validate(variable)) return variables diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py index c3d0e8ba037cae..cb606953cd7967 100644 --- a/api/core/app/apps/advanced_chat/app_config_manager.py +++ b/api/core/app/apps/advanced_chat/app_config_manager.py @@ -1,4 +1,3 @@ - from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.entities import WorkflowUIBasedAppConfig @@ -19,13 +18,13 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig): """ Advanced Chatbot App Config Entity. """ + pass class AdvancedChatAppConfigManager(BaseAppConfigManager): @classmethod - def get_app_config(cls, app_model: App, - workflow: Workflow) -> AdvancedChatAppConfig: + def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig: features_dict = workflow.features_dict app_mode = AppMode.value_of(app_model.mode) @@ -34,13 +33,9 @@ def get_app_config(cls, app_model: App, app_id=app_model.id, app_mode=app_mode, workflow_id=workflow.id, - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=features_dict - ), - variables=WorkflowVariablesConfigManager.convert( - workflow=workflow - ), - additional_features=cls.convert_features(features_dict, app_mode) + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict), + variables=WorkflowVariablesConfigManager.convert(workflow=workflow), + additional_features=cls.convert_features(features_dict, app_mode), ) return app_config @@ -57,10 +52,7 @@ def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: related_config_keys = [] # file upload validation - config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( - config=config, - is_vision=False - ) + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config) related_config_keys.extend(current_related_config_keys) # opening_statement @@ -69,7 +61,8 @@ def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: # suggested_questions_after_answer config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( - config) + config + ) related_config_keys.extend(current_related_config_keys) # speech_to_text @@ -86,9 +79,7 @@ def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: # moderation validation config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( - tenant_id=tenant_id, - config=config, - only_structure_validate=only_structure_validate + tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate ) related_config_keys.extend(current_related_config_keys) @@ -98,4 +89,3 @@ def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: filtered_config = {key: config.get(key) for key in related_config_keys} return filtered_config - diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 351eb05d8ad41c..0b883450610030 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -1,52 +1,69 @@ import contextvars import logging -import os import threading import uuid from collections.abc import Generator -from typing import Union +from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app from pydantic import ValidationError -from sqlalchemy import select -from sqlalchemy.orm import Session import contexts +from configs import dify_config +from constants import UUID_NIL from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager -from core.app.entities.app_invoke_entities import ( - AdvancedChatAppGenerateEntity, - InvokeFrom, -) +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse -from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable from extensions.ext_database import db +from factories import file_factory from models.account import Account from models.model import App, Conversation, EndUser, Message -from models.workflow import ConversationVariable, Workflow +from models.workflow import Workflow logger = logging.getLogger(__name__) class AdvancedChatAppGenerator(MessageBasedAppGenerator): + @overload def generate( - self, app_model: App, + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[True] = True, + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[False] = False, + ) -> dict: ... + + def generate( + self, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, stream: bool = True, - ) -> Union[dict, Generator[dict, None, None]]: + ) -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -57,65 +74,69 @@ def generate( :param invoke_from: invoke from source :param stream: is stream """ - if not args.get('query'): - raise ValueError('query is required') + if not args.get("query"): + raise ValueError("query is required") - query = args['query'] + query = args["query"] if not isinstance(query, str): - raise ValueError('query must be a string') + raise ValueError("query must be a string") - query = query.replace('\x00', '') - inputs = args['inputs'] + query = query.replace("\x00", "") + inputs = args["inputs"] - extras = { - "auto_generate_conversation_name": args.get('auto_generate_name', False) - } + extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)} # get conversation conversation = None - if args.get('conversation_id'): - conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + conversation_id = args.get("conversation_id") + if conversation_id: + conversation = self._get_conversation_by_user( + app_model=app_model, conversation_id=conversation_id, user=user + ) # parse files - files = args['files'] if args.get('files') else [] - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + files = args["files"] if args.get("files") else [] file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, ) else: file_objs = [] # convert to app config - app_config = AdvancedChatAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) # get tracing instance - user_id = user.id if isinstance(user, Account) else user.session_id - trace_manager = TraceQueueManager(app_model.id, user_id) + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id + ) if invoke_from == InvokeFrom.DEBUGGER: # always enable retriever resource in debugger mode app_config.additional_features.show_retrieve_source = True + workflow_run_id = str(uuid.uuid4()) # init application generate entity application_generate_entity = AdvancedChatAppGenerateEntity( task_id=str(uuid.uuid4()), app_config=app_config, + file_upload_config=file_extra_config, conversation_id=conversation.id if conversation else None, - inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), + inputs=conversation.inputs + if conversation + else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), query=query, files=file_objs, + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, stream=stream, invoke_from=invoke_from, extras=extras, - trace_manager=trace_manager + trace_manager=trace_manager, + workflow_run_id=workflow_run_id, ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) @@ -125,16 +146,12 @@ def generate( invoke_from=invoke_from, application_generate_entity=application_generate_entity, conversation=conversation, - stream=stream + stream=stream, ) - def single_iteration_generate(self, app_model: App, - workflow: Workflow, - node_id: str, - user: Account, - args: dict, - stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + def single_iteration_generate( + self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True + ) -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -146,42 +163,29 @@ def single_iteration_generate(self, app_model: App, :param stream: is stream """ if not node_id: - raise ValueError('node_id is required') - - if args.get('inputs') is None: - raise ValueError('inputs is required') + raise ValueError("node_id is required") - extras = { - "auto_generate_conversation_name": False - } - - # get conversation - conversation = None - if args.get('conversation_id'): - conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + if args.get("inputs") is None: + raise ValueError("inputs is required") # convert to app config - app_config = AdvancedChatAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) # init application generate entity application_generate_entity = AdvancedChatAppGenerateEntity( task_id=str(uuid.uuid4()), app_config=app_config, - conversation_id=conversation.id if conversation else None, + conversation_id=None, inputs={}, - query='', + query="", files=[], user_id=user.id, stream=stream, invoke_from=InvokeFrom.DEBUGGER, - extras=extras, + extras={"auto_generate_conversation_name": False}, single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity( - node_id=node_id, - inputs=args['inputs'] - ) + node_id=node_id, inputs=args["inputs"] + ), ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) @@ -190,33 +194,42 @@ def single_iteration_generate(self, app_model: App, user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, - conversation=conversation, - stream=stream + conversation=None, + stream=stream, ) - def _generate(self, *, - workflow: Workflow, - user: Union[Account, EndUser], - invoke_from: InvokeFrom, - application_generate_entity: AdvancedChatAppGenerateEntity, - conversation: Conversation | None = None, - stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + def _generate( + self, + *, + workflow: Workflow, + user: Union[Account, EndUser], + invoke_from: InvokeFrom, + application_generate_entity: AdvancedChatAppGenerateEntity, + conversation: Optional[Conversation] = None, + stream: bool = True, + ) -> dict[str, Any] | Generator[str, Any, None]: + """ + Generate App response. + + :param workflow: Workflow + :param user: account or end user + :param invoke_from: invoke from source + :param application_generate_entity: application generate entity + :param conversation: conversation + :param stream: is stream + """ is_first_conversation = False if not conversation: is_first_conversation = True # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity, conversation) + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) if is_first_conversation: # update conversation features conversation.override_model_configs = workflow.features db.session.commit() - # db.session.refresh(conversation) + db.session.refresh(conversation) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -225,73 +238,21 @@ def _generate(self, *, invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) - # Init conversation variables - stmt = select(ConversationVariable).where( - ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id - ) - with Session(db.engine) as session: - conversation_variables = session.scalars(stmt).all() - if not conversation_variables: - # Create conversation variables if they don't exist. - conversation_variables = [ - ConversationVariable.from_variable( - app_id=conversation.app_id, conversation_id=conversation.id, variable=variable - ) - for variable in workflow.conversation_variables - ] - session.add_all(conversation_variables) - # Convert database entities to variables. - conversation_variables = [item.to_variable() for item in conversation_variables] - - session.commit() - - # Increment dialogue count. - conversation.dialogue_count += 1 - - conversation_id = conversation.id - conversation_dialogue_count = conversation.dialogue_count - db.session.commit() - db.session.refresh(conversation) - - inputs = application_generate_entity.inputs - query = application_generate_entity.query - files = application_generate_entity.files - - user_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() - if end_user: - user_id = end_user.session_id - else: - user_id = application_generate_entity.user_id - - # Create a variable pool. - system_inputs = { - SystemVariable.QUERY: query, - SystemVariable.FILES: files, - SystemVariable.CONVERSATION_ID: conversation_id, - SystemVariable.USER_ID: user_id, - SystemVariable.DIALOGUE_COUNT: conversation_dialogue_count, - } - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=workflow.environment_variables, - conversation_variables=conversation_variables, - ) - contexts.workflow_variable_pool.set(variable_pool) - # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'message_id': message.id, - 'context': contextvars.copy_context(), - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + "context": contextvars.copy_context(), + }, + ) worker_thread.start() @@ -306,16 +267,17 @@ def _generate(self, *, stream=stream, ) - return AdvancedChatAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) - def _generate_worker(self, flask_app: Flask, - application_generate_entity: AdvancedChatAppGenerateEntity, - queue_manager: AppQueueManager, - message_id: str, - context: contextvars.Context) -> None: + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: AdvancedChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str, + context: contextvars.Context, + ) -> None: """ Generate worker in a new thread. :param flask_app: Flask app @@ -329,40 +291,30 @@ def _generate_worker(self, flask_app: Flask, var.set(val) with flask_app.app_context(): try: - runner = AdvancedChatAppRunner() - if application_generate_entity.single_iteration_run: - single_iteration_run = application_generate_entity.single_iteration_run - runner.single_iteration_run( - app_id=application_generate_entity.app_config.app_id, - workflow_id=application_generate_entity.app_config.workflow_id, - queue_manager=queue_manager, - inputs=single_iteration_run.inputs, - node_id=single_iteration_run.node_id, - user_id=application_generate_entity.user_id - ) - else: - # get message - message = self._get_message(message_id) - - # chatbot app - runner = AdvancedChatAppRunner() - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - message=message - ) - except GenerateTaskStoppedException: + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) + + # chatbot app + runner = AdvancedChatAppRunner( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + ) + + runner.run() + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if dify_config.DEBUG: logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: @@ -408,7 +360,7 @@ def _handle_advanced_chat_response( return generate_task_pipeline.process() except ValueError as e: if e.args[0] == "I/O operation on closed file.": # ignore this error - raise GenerateTaskStoppedException() + raise GenerateTaskStoppedError() else: logger.exception(e) raise e diff --git a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py index 0caff4a2e395d7..18b115dfe40d3c 100644 --- a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py +++ b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py @@ -21,14 +21,11 @@ def __init__(self, status: str, audio): self.status = status -def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str): +def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str): if not text_content or text_content.isspace(): return return model_instance.invoke_tts( - content_text=text_content.strip(), - user="responding_tts", - tenant_id=tenant_id, - voice=voice + content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice ) @@ -44,28 +41,26 @@ def _process_future(future_queue, audio_queue): except Exception as e: logging.getLogger(__name__).warning(e) break - audio_queue.put(AudioTrunk("finish", b'')) + audio_queue.put(AudioTrunk("finish", b"")) class AppGeneratorTTSPublisher: - def __init__(self, tenant_id: str, voice: str): self.logger = logging.getLogger(__name__) self.tenant_id = tenant_id - self.msg_text = '' + self.msg_text = "" self._audio_queue = queue.Queue() self._msg_queue = queue.Queue() - self.match = re.compile(r'[。.!?]') + self.match = re.compile(r"[。.!?]") self.model_manager = ModelManager() self.model_instance = self.model_manager.get_default_model_instance( - tenant_id=self.tenant_id, - model_type=ModelType.TTS + tenant_id=self.tenant_id, model_type=ModelType.TTS ) self.voices = self.model_instance.get_tts_voices() - values = [voice.get('value') for voice in self.voices] + values = [voice.get("value") for voice in self.voices] self.voice = voice if not voice or voice not in values: - self.voice = self.voices[0].get('value') + self.voice = self.voices[0].get("value") self.MAX_SENTENCE = 2 self._last_audio_event = None self._runtime_thread = threading.Thread(target=self._runtime).start() @@ -85,8 +80,9 @@ def _runtime(self): message = self._msg_queue.get() if message is None: if self.msg_text and len(self.msg_text.strip()) > 0: - futures_result = self.executor.submit(_invoiceTTS, self.msg_text, - self.model_instance, self.tenant_id, self.voice) + futures_result = self.executor.submit( + _invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice + ) future_queue.put(futures_result) break elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent): @@ -94,28 +90,27 @@ def _runtime(self): elif isinstance(message.event, QueueTextChunkEvent): self.msg_text += message.event.text elif isinstance(message.event, QueueNodeSucceededEvent): - self.msg_text += message.event.outputs.get('output', '') + self.msg_text += message.event.outputs.get("output", "") self.last_message = message sentence_arr, text_tmp = self._extract_sentence(self.msg_text) if len(sentence_arr) >= min(self.MAX_SENTENCE, 7): self.MAX_SENTENCE += 1 - text_content = ''.join(sentence_arr) - futures_result = self.executor.submit(_invoiceTTS, text_content, - self.model_instance, - self.tenant_id, - self.voice) + text_content = "".join(sentence_arr) + futures_result = self.executor.submit( + _invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice + ) future_queue.put(futures_result) if text_tmp: self.msg_text = text_tmp else: - self.msg_text = '' + self.msg_text = "" except Exception as e: self.logger.warning(e) break future_queue.put(None) - def checkAndGetAudio(self) -> AudioTrunk | None: + def check_and_get_audio(self) -> AudioTrunk | None: try: if self._last_audio_event and self._last_audio_event.status == "finish": if self.executor: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 5dc03979cf3b4b..65d744eddff5c8 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -1,161 +1,192 @@ import logging -import os -import time from collections.abc import Mapping -from typing import Any, Optional, cast +from typing import Any, cast +from sqlalchemy import select +from sqlalchemy.orm import Session + +from configs import dify_config from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig -from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback -from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.apps.base_app_runner import AppRunner -from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback -from core.app.entities.app_invoke_entities import ( - AdvancedChatAppGenerateEntity, - InvokeFrom, +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.app.entities.queue_entities import ( + QueueAnnotationReplyEvent, + QueueStopEvent, + QueueTextChunkEvent, ) -from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent -from core.moderation.base import ModerationException -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.nodes.base_node import UserFrom -from core.workflow.workflow_engine_manager import WorkflowEngineManager +from core.moderation.base import ModerationError +from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db -from models import App, Message, Workflow +from models.enums import UserFrom +from models.model import App, Conversation, EndUser, Message +from models.workflow import ConversationVariable, WorkflowType logger = logging.getLogger(__name__) -class AdvancedChatAppRunner(AppRunner): +class AdvancedChatAppRunner(WorkflowBasedAppRunner): """ AdvancedChat Application Runner """ - def run( + def __init__( self, application_generate_entity: AdvancedChatAppGenerateEntity, queue_manager: AppQueueManager, + conversation: Conversation, message: Message, ) -> None: - """ - Run application - :param application_generate_entity: application generate entity - :param queue_manager: application queue manager - :param conversation: conversation - :param message: message - :return: - """ - app_config = application_generate_entity.app_config + super().__init__(queue_manager) + + self.application_generate_entity = application_generate_entity + self.conversation = conversation + self.message = message + + def run(self) -> None: + app_config = self.application_generate_entity.app_config app_config = cast(AdvancedChatAppConfig, app_config) app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: - raise ValueError('App not found') + raise ValueError("App not found") workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) if not workflow: - raise ValueError('Workflow not initialized') - - inputs = application_generate_entity.inputs - query = application_generate_entity.query - - # moderation - if self.handle_input_moderation( - queue_manager=queue_manager, - app_record=app_record, - app_generate_entity=application_generate_entity, - inputs=inputs, - query=query, - message_id=message.id, - ): - return + raise ValueError("Workflow not initialized") - # annotation reply - if self.handle_annotation_reply( - app_record=app_record, - message=message, - query=query, - queue_manager=queue_manager, - app_generate_entity=application_generate_entity, - ): - return - - db.session.close() - - workflow_callbacks: list[WorkflowCallback] = [ - WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow) - ] + user_id = None + if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: + end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() + if end_user: + user_id = end_user.session_id + else: + user_id = self.application_generate_entity.user_id - if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): + workflow_callbacks: list[WorkflowCallback] = [] + if dify_config.DEBUG: workflow_callbacks.append(WorkflowLoggingCallback()) - # RUN WORKFLOW - workflow_engine_manager = WorkflowEngineManager() - workflow_engine_manager.run_workflow( - workflow=workflow, - user_id=application_generate_entity.user_id, - user_from=UserFrom.ACCOUNT - if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] - else UserFrom.END_USER, - invoke_from=application_generate_entity.invoke_from, - callbacks=workflow_callbacks, - call_depth=application_generate_entity.call_depth, - ) - - def single_iteration_run( - self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str - ) -> None: - """ - Single iteration run - """ - app_record = db.session.query(App).filter(App.id == app_id).first() - if not app_record: - raise ValueError('App not found') + if self.application_generate_entity.single_iteration_run: + # if only single iteration run is requested + graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( + workflow=workflow, + node_id=self.application_generate_entity.single_iteration_run.node_id, + user_inputs=self.application_generate_entity.single_iteration_run.inputs, + ) + else: + inputs = self.application_generate_entity.inputs + query = self.application_generate_entity.query + files = self.application_generate_entity.files + + # moderation + if self.handle_input_moderation( + app_record=app_record, + app_generate_entity=self.application_generate_entity, + inputs=inputs, + query=query, + message_id=self.message.id, + ): + return + + # annotation reply + if self.handle_annotation_reply( + app_record=app_record, + message=self.message, + query=query, + app_generate_entity=self.application_generate_entity, + ): + return + + # Init conversation variables + stmt = select(ConversationVariable).where( + ConversationVariable.app_id == self.conversation.app_id, + ConversationVariable.conversation_id == self.conversation.id, + ) + with Session(db.engine) as session: + conversation_variables = session.scalars(stmt).all() + if not conversation_variables: + # Create conversation variables if they don't exist. + conversation_variables = [ + ConversationVariable.from_variable( + app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable + ) + for variable in workflow.conversation_variables + ] + session.add_all(conversation_variables) + # Convert database entities to variables. + conversation_variables = [item.to_variable() for item in conversation_variables] + + session.commit() + + # Increment dialogue count. + self.conversation.dialogue_count += 1 + + conversation_dialogue_count = self.conversation.dialogue_count + db.session.commit() + + # Create a variable pool. + system_inputs = { + SystemVariableKey.QUERY: query, + SystemVariableKey.FILES: files, + SystemVariableKey.CONVERSATION_ID: self.conversation.id, + SystemVariableKey.USER_ID: user_id, + SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count, + SystemVariableKey.APP_ID: app_config.app_id, + SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, + SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id, + } + + # init variable pool + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=inputs, + environment_variables=workflow.environment_variables, + conversation_variables=conversation_variables, + ) - workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id) - if not workflow: - raise ValueError('Workflow not initialized') + # init graph + graph = self._init_graph(graph_config=workflow.graph_dict) - workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)] + db.session.close() - workflow_engine_manager = WorkflowEngineManager() - workflow_engine_manager.single_step_run_iteration_workflow_node( - workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks + # RUN WORKFLOW + workflow_entry = WorkflowEntry( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + workflow_type=WorkflowType.value_of(workflow.type), + graph=graph, + graph_config=workflow.graph_dict, + user_id=self.application_generate_entity.user_id, + user_from=( + UserFrom.ACCOUNT + if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else UserFrom.END_USER + ), + invoke_from=self.application_generate_entity.invoke_from, + call_depth=self.application_generate_entity.call_depth, + variable_pool=variable_pool, ) - def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: - """ - Get workflow - """ - # fetch workflow by workflow_id - workflow = ( - db.session.query(Workflow) - .filter( - Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id - ) - .first() + generator = workflow_entry.run( + callbacks=workflow_callbacks, ) - # return workflow - return workflow + for event in generator: + self._handle_event(workflow_entry, event) def handle_input_moderation( self, - queue_manager: AppQueueManager, app_record: App, app_generate_entity: AdvancedChatAppGenerateEntity, inputs: Mapping[str, Any], query: str, message_id: str, ) -> bool: - """ - Handle input moderation - :param queue_manager: application queue manager - :param app_record: app record - :param app_generate_entity: application generate entity - :param inputs: inputs - :param query: query - :param message_id: message id - :return: - """ try: # process sensitive_word_avoidance _, inputs, query = self.moderation_for_inputs( @@ -166,34 +197,15 @@ def handle_input_moderation( query=query, message_id=message_id, ) - except ModerationException as e: - self._stream_output( - queue_manager=queue_manager, - text=str(e), - stream=app_generate_entity.stream, - stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION, - ) + except ModerationError as e: + self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION) return True return False def handle_annotation_reply( - self, - app_record: App, - message: Message, - query: str, - queue_manager: AppQueueManager, - app_generate_entity: AdvancedChatAppGenerateEntity, + self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity ) -> bool: - """ - Handle annotation reply - :param app_record: app record - :param message: message - :param query: query - :param queue_manager: application queue manager - :param app_generate_entity: application generate entity - """ - # annotation reply annotation_reply = self.query_app_annotations_to_reply( app_record=app_record, message=message, @@ -203,37 +215,19 @@ def handle_annotation_reply( ) if annotation_reply: - queue_manager.publish( - QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER - ) + self._publish_event(QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id)) - self._stream_output( - queue_manager=queue_manager, - text=annotation_reply.content, - stream=app_generate_entity.stream, - stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY, + self._complete_with_stream_output( + text=annotation_reply.content, stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY ) return True return False - def _stream_output( - self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy - ) -> None: + def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None: """ Direct output - :param queue_manager: application queue manager - :param text: text - :param stream: stream - :return: """ - if stream: - index = 0 - for token in text: - queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER) - index += 1 - time.sleep(0.01) - else: - queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER) + self._publish_event(QueueTextChunkEvent(text=text)) - queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER) + self._publish_event(QueueStopEvent(stopped_by=stopped_by)) diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py index ef579827b47c7e..5fbd3e9a94906f 100644 --- a/api/core/app/apps/advanced_chat/generate_response_converter.py +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -28,15 +28,15 @@ def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) """ blocking_response = cast(ChatbotAppBlockingResponse, blocking_response) response = { - 'event': 'message', - 'task_id': blocking_response.task_id, - 'id': blocking_response.data.id, - 'message_id': blocking_response.data.message_id, - 'conversation_id': blocking_response.data.conversation_id, - 'mode': blocking_response.data.mode, - 'answer': blocking_response.data.answer, - 'metadata': blocking_response.data.metadata, - 'created_at': blocking_response.data.created_at + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "conversation_id": blocking_response.data.conversation_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, } return response @@ -50,13 +50,15 @@ def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse """ response = cls.convert_blocking_full_response(blocking_response) - metadata = response.get('metadata', {}) - response['metadata'] = cls._get_simple_metadata(metadata) + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) return response @classmethod - def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]: + def convert_stream_full_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[str, Any, None]: """ Convert stream full response. :param stream_response: stream response @@ -67,14 +69,14 @@ def convert_stream_full_response(cls, stream_response: Generator[AppStreamRespon sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -85,7 +87,9 @@ def convert_stream_full_response(cls, stream_response: Generator[AppStreamRespon yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[str, Any, None]: """ Convert stream simple response. :param stream_response: stream response @@ -96,20 +100,20 @@ def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResp sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, MessageEndStreamResponse): sub_stream_response_dict = sub_stream_response.to_dict() - metadata = sub_stream_response_dict.get('metadata', {}) - sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index ac51a4e840bd54..1d4c0ea0fa6f4d 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -1,15 +1,15 @@ import json import logging import time -from collections.abc import Generator -from typing import Any, Optional, Union, cast +from collections.abc import Generator, Mapping +from typing import Any, Optional, Union -import contexts from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, + InvokeFrom, ) from core.app.entities.queue_entities import ( QueueAdvancedChatMessageEndEvent, @@ -20,8 +20,12 @@ QueueIterationStartEvent, QueueMessageReplaceEvent, QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueParallelBranchRunSucceededEvent, QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent, @@ -31,31 +35,29 @@ QueueWorkflowSucceededEvent, ) from core.app.entities.task_entities import ( - AdvancedChatTaskState, ChatbotAppBlockingResponse, ChatbotAppStreamResponse, - ChatflowStreamGenerateRoute, ErrorStreamResponse, MessageAudioEndStreamResponse, MessageAudioStreamResponse, MessageEndStreamResponse, StreamResponse, + WorkflowTaskState, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.message_cycle_manage import MessageCycleManage from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage -from core.file.file_obj import FileVar from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.node_entities import NodeType -from core.workflow.enums import SystemVariable -from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes import NodeType from events.message_event import message_was_created from extensions.ext_database import db +from models import Conversation, EndUser, Message, MessageFile from models.account import Account -from models.model import Conversation, EndUser, Message +from models.enums import CreatedByRole from models.workflow import ( Workflow, WorkflowNodeExecution, @@ -69,22 +71,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ - _task_state: AdvancedChatTaskState + + _task_state: WorkflowTaskState _application_generate_entity: AdvancedChatAppGenerateEntity _workflow: Workflow _user: Union[Account, EndUser] - # Deprecated - _workflow_system_variables: dict[SystemVariable, Any] - _iteration_nested_relations: dict[str, list[str]] + _workflow_system_variables: dict[SystemVariableKey, Any] + _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] def __init__( - self, application_generate_entity: AdvancedChatAppGenerateEntity, - workflow: Workflow, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - user: Union[Account, EndUser], - stream: bool, + self, + application_generate_entity: AdvancedChatAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool, ) -> None: """ Initialize AdvancedChatAppGenerateTaskPipeline. @@ -106,21 +109,22 @@ def __init__( self._workflow = workflow self._conversation = conversation self._message = message - # Deprecated self._workflow_system_variables = { - SystemVariable.QUERY: message.query, - SystemVariable.FILES: application_generate_entity.files, - SystemVariable.CONVERSATION_ID: conversation.id, - SystemVariable.USER_ID: user_id, + SystemVariableKey.QUERY: message.query, + SystemVariableKey.FILES: application_generate_entity.files, + SystemVariableKey.CONVERSATION_ID: conversation.id, + SystemVariableKey.USER_ID: user_id, + SystemVariableKey.DIALOGUE_COUNT: conversation.dialogue_count, + SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, + SystemVariableKey.WORKFLOW_ID: workflow.id, + SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, } - self._task_state = AdvancedChatTaskState( - usage=LLMUsage.empty_usage() - ) + self._task_state = WorkflowTaskState() + self._wip_workflow_node_executions = {} - self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict) - self._stream_generate_routes = self._get_stream_generate_routes() self._conversation_name_generate_thread = None + self._recorded_files: list[Mapping[str, Any]] = [] def process(self): """ @@ -133,13 +137,11 @@ def process(self): # start generate conversation name thread self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, - self._application_generate_entity.query + self._conversation, self._application_generate_entity.query ) - generator = self._wrapper_process_stream_response( - trace_manager=self._application_generate_entity.trace_manager - ) + generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) + if self._stream: return self._to_stream_response(generator) else: @@ -156,7 +158,7 @@ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None] elif isinstance(stream_response, MessageEndStreamResponse): extras = {} if stream_response.metadata: - extras['metadata'] = stream_response.metadata + extras["metadata"] = stream_response.metadata return ChatbotAppBlockingResponse( task_id=stream_response.task_id, @@ -167,15 +169,17 @@ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None] message_id=self._message.id, answer=self._task_state.answer, created_at=int(self._message.created_at.timestamp()), - **extras - ) + **extras, + ), ) else: continue - raise Exception('Queue listening stopped unexpectedly.') + raise Exception("Queue listening stopped unexpectedly.") - def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]: + def _to_stream_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Generator[ChatbotAppStreamResponse, Any, None]: """ To stream response. :return: @@ -185,31 +189,35 @@ def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) conversation_id=self._conversation.id, message_id=self._message.id, created_at=int(self._message.created_at.timestamp()), - stream_response=stream_response + stream_response=stream_response, ) - def _listenAudioMsg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher, task_id: str): if not publisher: return None - audio_msg: AudioTrunk = publisher.checkAndGetAudio() + audio_msg: AudioTrunk = publisher.check_and_get_audio() if audio_msg and audio_msg.status != "finish": return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None - def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ - Generator[StreamResponse, None, None]: - - publisher = None + def _wrapper_process_stream_response( + self, trace_manager: Optional[TraceQueueManager] = None + ) -> Generator[StreamResponse, None, None]: + tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id features_dict = self._workflow.features_dict - if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ - 'text_to_speech'].get('autoPlay') == 'enabled': - publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) - for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): + if ( + features_dict.get("text_to_speech") + and features_dict["text_to_speech"].get("enabled") + and features_dict["text_to_speech"].get("autoPlay") == "enabled" + ): + tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice")) + + for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listenAudioMsg(publisher, task_id=task_id) + audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -220,9 +228,9 @@ def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueMan # timeout while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: try: - if not publisher: + if not tts_publisher: break - audio_trunk = publisher.checkAndGetAudio() + audio_trunk = tts_publisher.check_and_get_audio() if audio_trunk is None: # release cpu # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) @@ -234,40 +242,41 @@ def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueMan start_listener_time = time.time() yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) except Exception as e: - logger.error(e) + logger.exception(e) break - yield MessageAudioEndStreamResponse(audio='', task_id=task_id) + if tts_publisher: + yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( - self, - publisher: AppGeneratorTTSPublisher, - trace_manager: Optional[TraceQueueManager] = None + self, + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, + trace_manager: Optional[TraceQueueManager] = None, ) -> Generator[StreamResponse, None, None]: """ Process stream response. :return: """ - for message in self._queue_manager.listen(): - if (message.event - and getattr(message.event, 'metadata', None) - and message.event.metadata.get('is_answer_previous_node', False) - and publisher): - publisher.publish(message=message) - elif (hasattr(message.event, 'execution_metadata') - and message.event.execution_metadata - and message.event.execution_metadata.get('is_answer_previous_node', False) - and publisher): - publisher.publish(message=message) - event = message.event - - if isinstance(event, QueueErrorEvent): + # init fake graph runtime state + graph_runtime_state = None + workflow_run = None + + for queue_message in self._queue_manager.listen(): + event = queue_message.event + + if isinstance(event, QueuePingEvent): + yield self._ping_stream_response() + elif isinstance(event, QueueErrorEvent): err = self._handle_error(event, self._message) yield self._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): - workflow_run = self._handle_workflow_start() + # override graph runtime state + graph_runtime_state = event.graph_runtime_state + + # init workflow run + workflow_run = self._handle_workflow_run_start() - self._message = db.session.query(Message).filter(Message.id == self._message.id).first() + self._refetch_message() self._message.workflow_run_id = workflow_run.id db.session.commit() @@ -275,137 +284,247 @@ def _process_stream_response( db.session.close() yield self._workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) elif isinstance(event, QueueNodeStartedEvent): - workflow_node_execution = self._handle_node_start(event) + if not workflow_run: + raise Exception("Workflow run not initialized.") + + workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) + + response = self._workflow_node_start_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) - # search stream_generate_routes if node id is answer start at node - if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes: - self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id] - # reset current route position to 0 - self._task_state.current_stream_generate_state.current_route_position = 0 + if response: + yield response + elif isinstance(event, QueueNodeSucceededEvent): + workflow_node_execution = self._handle_workflow_node_execution_success(event) - # generate stream outputs when node started - yield from self._generate_stream_outputs_when_node_started() + # Record files if it's an answer node or end node + if event.node_type in [NodeType.ANSWER, NodeType.END]: + self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) - yield self._workflow_node_start_to_stream_response( + response = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) - elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): - workflow_node_execution = self._handle_node_finished(event) - # stream outputs when node finished - generator = self._generate_stream_outputs_when_node_finished() - if generator: - yield from generator + if response: + yield response + elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent): + workflow_node_execution = self._handle_workflow_node_execution_failed(event) - yield self._workflow_node_finish_to_stream_response( + response = self._workflow_node_finish_to_stream_response( + event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) - if isinstance(event, QueueNodeFailedEvent): - yield from self._handle_iteration_exception( - task_id=self._application_generate_entity.task_id, - error=f'Child node failed: {event.error}' - ) - elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent): - if isinstance(event, QueueIterationNextEvent): - # clear ran node execution infos of current iteration - iteration_relations = self._iteration_nested_relations.get(event.node_id) - if iteration_relations: - for node_id in iteration_relations: - self._task_state.ran_node_execution_infos.pop(node_id, None) - - yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event) - self._handle_iteration_operation(event) - elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - workflow_run = self._handle_workflow_finished( - event, conversation_id=self._conversation.id, trace_manager=trace_manager + if response: + yield response + elif isinstance(event, QueueParallelBranchRunStartedEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + yield self._workflow_parallel_branch_start_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) - if workflow_run: - yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run - ) + elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") - if workflow_run.status == WorkflowRunStatus.FAILED.value: - err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}')) - yield self._error_to_stream_response(self._handle_error(err_event, self._message)) - break + yield self._workflow_parallel_branch_finished_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event + ) + elif isinstance(event, QueueIterationStartEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") - if isinstance(event, QueueStopEvent): - # Save message - self._save_message() + yield self._workflow_iteration_start_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event + ) + elif isinstance(event, QueueIterationNextEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") - yield self._message_end_to_stream_response() - break - else: - self._queue_manager.publish( - QueueAdvancedChatMessageEndEvent(), - PublishFrom.TASK_PIPELINE + yield self._workflow_iteration_next_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event + ) + elif isinstance(event, QueueIterationCompletedEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + yield self._workflow_iteration_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event + ) + elif isinstance(event, QueueWorkflowSucceededEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + if not graph_runtime_state: + raise Exception("Graph runtime state not initialized.") + + workflow_run = self._handle_workflow_run_success( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + conversation_id=self._conversation.id, + trace_manager=trace_manager, + ) + + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + + self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) + elif isinstance(event, QueueWorkflowFailedEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + if not graph_runtime_state: + raise Exception("Graph runtime state not initialized.") + + workflow_run = self._handle_workflow_run_failed( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.FAILED, + error=event.error, + conversation_id=self._conversation.id, + trace_manager=trace_manager, + ) + + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + + err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) + yield self._error_to_stream_response(self._handle_error(err_event, self._message)) + break + elif isinstance(event, QueueStopEvent): + if workflow_run and graph_runtime_state: + workflow_run = self._handle_workflow_run_failed( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.STOPPED, + error=event.get_stop_reason(), + conversation_id=self._conversation.id, + trace_manager=trace_manager, + ) + + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) - elif isinstance(event, QueueAdvancedChatMessageEndEvent): - output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) - if output_moderation_answer: - self._task_state.answer = output_moderation_answer - yield self._message_replace_to_stream_response(answer=output_moderation_answer) # Save message - self._save_message() + self._save_message(graph_runtime_state=graph_runtime_state) yield self._message_end_to_stream_response() + break elif isinstance(event, QueueRetrieverResourcesEvent): self._handle_retriever_resources(event) + + self._refetch_message() + + self._message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) + + db.session.commit() + db.session.refresh(self._message) + db.session.close() elif isinstance(event, QueueAnnotationReplyEvent): self._handle_annotation_reply(event) + + self._refetch_message() + + self._message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) + + db.session.commit() + db.session.refresh(self._message) + db.session.close() elif isinstance(event, QueueTextChunkEvent): delta_text = event.text if delta_text is None: continue - if not self._is_stream_out_support( - event=event - ): - continue - # handle output moderation chunk should_direct_answer = self._handle_output_moderation_chunk(delta_text) if should_direct_answer: continue + # only publish tts message at text chunk streaming + if tts_publisher: + tts_publisher.publish(message=queue_message) + self._task_state.answer += delta_text - yield self._message_to_stream_response(delta_text, self._message.id) + yield self._message_to_stream_response( + answer=delta_text, message_id=self._message.id, from_variable_selector=event.from_variable_selector + ) elif isinstance(event, QueueMessageReplaceEvent): + # published by moderation yield self._message_replace_to_stream_response(answer=event.text) - elif isinstance(event, QueuePingEvent): - yield self._ping_stream_response() + elif isinstance(event, QueueAdvancedChatMessageEndEvent): + if not graph_runtime_state: + raise Exception("Graph runtime state not initialized.") + + output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer) + if output_moderation_answer: + self._task_state.answer = output_moderation_answer + yield self._message_replace_to_stream_response(answer=output_moderation_answer) + + # Save message + self._save_message(graph_runtime_state=graph_runtime_state) + + yield self._message_end_to_stream_response() else: continue - if publisher: - publisher.publish(None) + + # publish None when task finished + if tts_publisher: + tts_publisher.publish(None) + if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message(self) -> None: - """ - Save message. - :return: - """ - self._message = db.session.query(Message).filter(Message.id == self._message.id).first() + def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: + self._refetch_message() self._message.answer = self._task_state.answer self._message.provider_response_latency = time.perf_counter() - self._start_at - self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ - if self._task_state.metadata else None - - if self._task_state.metadata and self._task_state.metadata.get('usage'): - usage = LLMUsage(**self._task_state.metadata['usage']) + self._message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) + message_files = [ + MessageFile( + message_id=self._message.id, + type=file["type"], + transfer_method=file["transfer_method"], + url=file["remote_url"], + belongs_to="assistant", + upload_file_id=file["related_id"], + created_by_role=CreatedByRole.ACCOUNT + if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else CreatedByRole.END_USER, + created_by=self._message.from_account_id or self._message.from_end_user_id or "", + ) + for file in self._recorded_files + ] + db.session.add_all(message_files) + if graph_runtime_state and graph_runtime_state.llm_usage: + usage = graph_runtime_state.llm_usage self._message.message_tokens = usage.prompt_tokens self._message.message_unit_price = usage.prompt_unit_price self._message.message_price_unit = usage.prompt_price_unit @@ -415,6 +534,10 @@ def _save_message(self) -> None: self._message.total_price = usage.total_price self._message.currency = usage.currency + self._task_state.metadata["usage"] = jsonable_encoder(usage) + else: + self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage()) + db.session.commit() message_was_created.send( @@ -422,7 +545,7 @@ def _save_message(self) -> None: application_generate_entity=self._application_generate_entity, conversation=self._conversation, is_first_message=self._application_generate_entity.conversation_id is None, - extras=self._application_generate_entity.extras + extras=self._application_generate_entity.extras, ) def _message_end_to_stream_response(self) -> MessageEndStreamResponse: @@ -432,331 +555,15 @@ def _message_end_to_stream_response(self) -> MessageEndStreamResponse: """ extras = {} if self._task_state.metadata: - extras['metadata'] = self._task_state.metadata + extras["metadata"] = self._task_state.metadata.copy() + + if "annotation_reply" in extras["metadata"]: + del extras["metadata"]["annotation_reply"] return MessageEndStreamResponse( - task_id=self._application_generate_entity.task_id, - id=self._message.id, - **extras + task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras ) - def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]: - """ - Get stream generate routes. - :return: - """ - # find all answer nodes - graph = self._workflow.graph_dict - answer_node_configs = [ - node for node in graph['nodes'] - if node.get('data', {}).get('type') == NodeType.ANSWER.value - ] - - # parse stream output node value selectors of answer nodes - stream_generate_routes = {} - for node_config in answer_node_configs: - # get generate route for stream output - answer_node_id = node_config['id'] - generate_route = AnswerNode.extract_generate_route_selectors(node_config) - start_node_ids = self._get_answer_start_at_node_ids(graph, answer_node_id) - if not start_node_ids: - continue - - for start_node_id in start_node_ids: - stream_generate_routes[start_node_id] = ChatflowStreamGenerateRoute( - answer_node_id=answer_node_id, - generate_route=generate_route - ) - - return stream_generate_routes - - def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \ - -> list[str]: - """ - Get answer start at node id. - :param graph: graph - :param target_node_id: target node ID - :return: - """ - nodes = graph.get('nodes') - edges = graph.get('edges') - - # fetch all ingoing edges from source node - ingoing_edges = [] - for edge in edges: - if edge.get('target') == target_node_id: - ingoing_edges.append(edge) - - if not ingoing_edges: - # check if it's the first node in the iteration - target_node = next((node for node in nodes if node.get('id') == target_node_id), None) - if not target_node: - return [] - - node_iteration_id = target_node.get('data', {}).get('iteration_id') - # get iteration start node id - for node in nodes: - if node.get('id') == node_iteration_id: - if node.get('data', {}).get('start_node_id') == target_node_id: - return [target_node_id] - - return [] - - start_node_ids = [] - for ingoing_edge in ingoing_edges: - source_node_id = ingoing_edge.get('source') - source_node = next((node for node in nodes if node.get('id') == source_node_id), None) - if not source_node: - continue - - node_type = source_node.get('data', {}).get('type') - node_iteration_id = source_node.get('data', {}).get('iteration_id') - iteration_start_node_id = None - if node_iteration_id: - iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None) - iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id') - - if node_type in [ - NodeType.ANSWER.value, - NodeType.IF_ELSE.value, - NodeType.QUESTION_CLASSIFIER.value, - NodeType.ITERATION.value, - NodeType.LOOP.value - ]: - start_node_id = target_node_id - start_node_ids.append(start_node_id) - elif node_type == NodeType.START.value or \ - node_iteration_id is not None and iteration_start_node_id == source_node.get('id'): - start_node_id = source_node_id - start_node_ids.append(start_node_id) - else: - sub_start_node_ids = self._get_answer_start_at_node_ids(graph, source_node_id) - if sub_start_node_ids: - start_node_ids.extend(sub_start_node_ids) - - return start_node_ids - - def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: - """ - Get iteration nested relations. - :param graph: graph - :return: - """ - nodes = graph.get('nodes') - - iteration_ids = [node.get('id') for node in nodes - if node.get('data', {}).get('type') in [ - NodeType.ITERATION.value, - NodeType.LOOP.value, - ]] - - return { - iteration_id: [ - node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id - ] for iteration_id in iteration_ids - } - - def _generate_stream_outputs_when_node_started(self) -> Generator: - """ - Generate stream outputs. - :return: - """ - if self._task_state.current_stream_generate_state: - route_chunks = self._task_state.current_stream_generate_state.generate_route[ - self._task_state.current_stream_generate_state.current_route_position: - ] - - for route_chunk in route_chunks: - if route_chunk.type == 'text': - route_chunk = cast(TextGenerateRouteChunk, route_chunk) - - # handle output moderation chunk - should_direct_answer = self._handle_output_moderation_chunk(route_chunk.text) - if should_direct_answer: - continue - - self._task_state.answer += route_chunk.text - yield self._message_to_stream_response(route_chunk.text, self._message.id) - else: - break - - self._task_state.current_stream_generate_state.current_route_position += 1 - - # all route chunks are generated - if self._task_state.current_stream_generate_state.current_route_position == len( - self._task_state.current_stream_generate_state.generate_route - ): - self._task_state.current_stream_generate_state = None - - def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]: - """ - Generate stream outputs. - :return: - """ - if not self._task_state.current_stream_generate_state: - return - - route_chunks = self._task_state.current_stream_generate_state.generate_route[ - self._task_state.current_stream_generate_state.current_route_position:] - - for route_chunk in route_chunks: - if route_chunk.type == 'text': - route_chunk = cast(TextGenerateRouteChunk, route_chunk) - self._task_state.answer += route_chunk.text - yield self._message_to_stream_response(route_chunk.text, self._message.id) - else: - value = None - route_chunk = cast(VarGenerateRouteChunk, route_chunk) - value_selector = route_chunk.value_selector - if not value_selector: - self._task_state.current_stream_generate_state.current_route_position += 1 - continue - - route_chunk_node_id = value_selector[0] - - if route_chunk_node_id == 'sys': - # system variable - value = contexts.workflow_variable_pool.get().get(value_selector) - if value: - value = value.text - elif route_chunk_node_id in self._iteration_nested_relations: - # it's a iteration variable - if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations: - continue - iteration_state = self._iteration_state.current_iterations[route_chunk_node_id] - iterator = iteration_state.inputs - if not iterator: - continue - iterator_selector = iterator.get('iterator_selector', []) - if value_selector[1] == 'index': - value = iteration_state.current_index - elif value_selector[1] == 'item': - value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len( - iterator_selector - ) else None - else: - # check chunk node id is before current node id or equal to current node id - if route_chunk_node_id not in self._task_state.ran_node_execution_infos: - break - - latest_node_execution_info = self._task_state.latest_node_execution_info - - # get route chunk node execution info - route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id] - if (route_chunk_node_execution_info.node_type == NodeType.LLM - and latest_node_execution_info.node_type == NodeType.LLM): - # only LLM support chunk stream output - self._task_state.current_stream_generate_state.current_route_position += 1 - continue - - # get route chunk node execution - route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id - ).first() - - outputs = route_chunk_node_execution.outputs_dict - - # get value from outputs - value = None - for key in value_selector[1:]: - if not value: - value = outputs.get(key) if outputs else None - else: - value = value.get(key) - - if value is not None: - text = '' - if isinstance(value, str | int | float): - text = str(value) - elif isinstance(value, FileVar): - # convert file to markdown - text = value.to_markdown() - elif isinstance(value, dict): - # handle files - file_vars = self._fetch_files_from_variable_value(value) - if file_vars: - file_var = file_vars[0] - try: - file_var_obj = FileVar(**file_var) - - # convert file to markdown - text = file_var_obj.to_markdown() - except Exception as e: - logger.error(f'Error creating file var: {e}') - - if not text: - # other types - text = json.dumps(value, ensure_ascii=False) - elif isinstance(value, list): - # handle files - file_vars = self._fetch_files_from_variable_value(value) - for file_var in file_vars: - try: - file_var_obj = FileVar(**file_var) - except Exception as e: - logger.error(f'Error creating file var: {e}') - continue - - # convert file to markdown - text = file_var_obj.to_markdown() + ' ' - - text = text.strip() - - if not text and value: - # other types - text = json.dumps(value, ensure_ascii=False) - - if text: - self._task_state.answer += text - yield self._message_to_stream_response(text, self._message.id) - - self._task_state.current_stream_generate_state.current_route_position += 1 - - # all route chunks are generated - if self._task_state.current_stream_generate_state.current_route_position == len( - self._task_state.current_stream_generate_state.generate_route - ): - self._task_state.current_stream_generate_state = None - - def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool: - """ - Is stream out support - :param event: queue text chunk event - :return: - """ - if not event.metadata: - return True - - if 'node_id' not in event.metadata: - return True - - node_type = event.metadata.get('node_type') - stream_output_value_selector = event.metadata.get('value_selector') - if not stream_output_value_selector: - return False - - if not self._task_state.current_stream_generate_state: - return False - - route_chunk = self._task_state.current_stream_generate_state.generate_route[ - self._task_state.current_stream_generate_state.current_route_position] - - if route_chunk.type != 'var': - return False - - if node_type != NodeType.LLM: - # only LLM support chunk stream output - return False - - route_chunk = cast(VarGenerateRouteChunk, route_chunk) - value_selector = route_chunk.value_selector - - # check chunk node id is before current node id or equal to current node id - if value_selector != stream_output_value_selector: - return False - - return True - def _handle_output_moderation_chunk(self, text: str) -> bool: """ Handle output moderation chunk. @@ -768,17 +575,23 @@ def _handle_output_moderation_chunk(self, text: str) -> bool: # stop subscribe new token when output moderation should direct output self._task_state.answer = self._output_moderation_handler.get_final_output() self._queue_manager.publish( - QueueTextChunkEvent( - text=self._task_state.answer - ), PublishFrom.TASK_PIPELINE + QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE ) self._queue_manager.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), - PublishFrom.TASK_PIPELINE + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE ) return True else: self._output_moderation_handler.append_new_token(text) return False + + def _refetch_message(self) -> None: + """ + Refetch message. + :return: + """ + message = db.session.query(Message).filter(Message.id == self._message.id).first() + if message: + self._message = message diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py deleted file mode 100644 index 8d43155a0886bf..00000000000000 --- a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py +++ /dev/null @@ -1,203 +0,0 @@ -from typing import Any, Optional - -from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.entities.queue_entities import ( - AppQueueEvent, - QueueIterationCompletedEvent, - QueueIterationNextEvent, - QueueIterationStartEvent, - QueueNodeFailedEvent, - QueueNodeStartedEvent, - QueueNodeSucceededEvent, - QueueTextChunkEvent, - QueueWorkflowFailedEvent, - QueueWorkflowStartedEvent, - QueueWorkflowSucceededEvent, -) -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType -from models.workflow import Workflow - - -class WorkflowEventTriggerCallback(WorkflowCallback): - - def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): - self._queue_manager = queue_manager - - def on_workflow_run_started(self) -> None: - """ - Workflow run started - """ - self._queue_manager.publish( - QueueWorkflowStartedEvent(), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_run_succeeded(self) -> None: - """ - Workflow run succeeded - """ - self._queue_manager.publish( - QueueWorkflowSucceededEvent(), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_run_failed(self, error: str) -> None: - """ - Workflow run failed - """ - self._queue_manager.publish( - QueueWorkflowFailedEvent( - error=error - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_started(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> None: - """ - Workflow node execute started - """ - self._queue_manager.publish( - QueueNodeStartedEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - node_run_index=node_run_index, - predecessor_node_id=predecessor_node_id - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_succeeded(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> None: - """ - Workflow node execute succeeded - """ - self._queue_manager.publish( - QueueNodeSucceededEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - inputs=inputs, - process_data=process_data, - outputs=outputs, - execution_metadata=execution_metadata - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_failed(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - error: str, - inputs: Optional[dict] = None, - outputs: Optional[dict] = None, - process_data: Optional[dict] = None) -> None: - """ - Workflow node execute failed - """ - self._queue_manager.publish( - QueueNodeFailedEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - inputs=inputs, - outputs=outputs, - process_data=process_data, - error=error - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: - """ - Publish text chunk - """ - self._queue_manager.publish( - QueueTextChunkEvent( - text=text, - metadata={ - "node_id": node_id, - **metadata - } - ), PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_started(self, - node_id: str, - node_type: NodeType, - node_run_index: int = 1, - node_data: Optional[BaseNodeData] = None, - inputs: dict = None, - predecessor_node_id: Optional[str] = None, - metadata: Optional[dict] = None) -> None: - """ - Publish iteration started - """ - self._queue_manager.publish( - QueueIterationStartEvent( - node_id=node_id, - node_type=node_type, - node_run_index=node_run_index, - node_data=node_data, - inputs=inputs, - predecessor_node_id=predecessor_node_id, - metadata=metadata - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_next(self, node_id: str, - node_type: NodeType, - index: int, - node_run_index: int, - output: Optional[Any]) -> None: - """ - Publish iteration next - """ - self._queue_manager._publish( - QueueIterationNextEvent( - node_id=node_id, - node_type=node_type, - index=index, - node_run_index=node_run_index, - output=output - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_completed(self, node_id: str, - node_type: NodeType, - node_run_index: int, - outputs: dict) -> None: - """ - Publish iteration completed - """ - self._queue_manager._publish( - QueueIterationCompletedEvent( - node_id=node_id, - node_type=node_type, - node_run_index=node_run_index, - outputs=outputs - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_event(self, event: AppQueueEvent) -> None: - """ - Publish event - """ - self._queue_manager.publish( - event, - PublishFrom.APPLICATION_MANAGER - ) diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index f495ebbf35fe40..9040f18bfd71d3 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -28,15 +28,19 @@ class AgentChatAppConfig(EasyUIBasedAppConfig): """ Agent Chatbot App Config Entity. """ + agent: Optional[AgentEntity] = None class AgentChatAppConfigManager(BaseAppConfigManager): @classmethod - def get_app_config(cls, app_model: App, - app_model_config: AppModelConfig, - conversation: Optional[Conversation] = None, - override_config_dict: Optional[dict] = None) -> AgentChatAppConfig: + def get_app_config( + cls, + app_model: App, + app_model_config: AppModelConfig, + conversation: Optional[Conversation] = None, + override_config_dict: Optional[dict] = None, + ) -> AgentChatAppConfig: """ Convert app model config to agent chat app config :param app_model: app model @@ -66,22 +70,12 @@ def get_app_config(cls, app_model: App, app_model_config_from=config_from, app_model_config_id=app_model_config.id, app_model_config_dict=config_dict, - model=ModelConfigManager.convert( - config=config_dict - ), - prompt_template=PromptTemplateConfigManager.convert( - config=config_dict - ), - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=config_dict - ), - dataset=DatasetConfigManager.convert( - config=config_dict - ), - agent=AgentConfigManager.convert( - config=config_dict - ), - additional_features=cls.convert_features(config_dict, app_mode) + model=ModelConfigManager.convert(config=config_dict), + prompt_template=PromptTemplateConfigManager.convert(config=config_dict), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), + dataset=DatasetConfigManager.convert(config=config_dict), + agent=AgentConfigManager.convert(config=config_dict), + additional_features=cls.convert_features(config_dict, app_mode), ) app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( @@ -128,7 +122,8 @@ def config_validate(cls, tenant_id: str, config: dict) -> dict: # suggested_questions_after_answer config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( - config) + config + ) related_config_keys.extend(current_related_config_keys) # speech_to_text @@ -145,13 +140,15 @@ def config_validate(cls, tenant_id: str, config: dict) -> dict: # dataset configs # dataset_query_variable - config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, - config) + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults( + tenant_id, app_mode, config + ) related_config_keys.extend(current_related_config_keys) # moderation validation - config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, - config) + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id, config + ) related_config_keys.extend(current_related_config_keys) related_config_keys = list(set(related_config_keys)) @@ -170,10 +167,7 @@ def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> t :param config: app model config args """ if not config.get("agent_mode"): - config["agent_mode"] = { - "enabled": False, - "tools": [] - } + config["agent_mode"] = {"enabled": False, "tools": []} if not isinstance(config["agent_mode"], dict): raise ValueError("agent_mode must be of object type") @@ -187,8 +181,9 @@ def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> t if not config["agent_mode"].get("strategy"): config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value - if config["agent_mode"]["strategy"] not in [member.value for member in - list(PlanningStrategy.__members__.values())]: + if config["agent_mode"]["strategy"] not in [ + member.value for member in list(PlanningStrategy.__members__.values()) + ]: raise ValueError("strategy in agent_mode must be in the specified strategy list") if not config["agent_mode"].get("tools"): @@ -210,7 +205,7 @@ def validate_agent_mode_and_set_defaults(cls, tenant_id: str, config: dict) -> t raise ValueError("enabled in agent_mode.tools must be of boolean type") if key == "dataset": - if 'id' not in tool_item: + if "id" not in tool_item: raise ValueError("id is required in dataset") try: diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 53780bdfb003b2..d1564a260e2a6d 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -1,39 +1,61 @@ import logging -import os import threading import uuid from collections.abc import Generator -from typing import Any, Union +from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError +from configs import dify_config +from constants import UUID_NIL from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom -from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db -from models.account import Account -from models.model import App, EndUser +from factories import file_factory +from models import Account, App, EndUser logger = logging.getLogger(__name__) class AgentChatAppGenerator(MessageBasedAppGenerator): - def generate(self, app_model: App, - user: Union[Account, EndUser], - args: Any, - invoke_from: InvokeFrom, - stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + @overload + def generate( + self, + app_model: App, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[True] = True, + ) -> Generator[dict, None, None]: ... + + @overload + def generate( + self, + app_model: App, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[False] = False, + ) -> dict: ... + + def generate( + self, + app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + stream: bool = True, + ) -> Union[dict, Generator[dict, None, None]]: """ Generate App response. @@ -44,59 +66,50 @@ def generate(self, app_model: App, :param stream: is stream """ if not stream: - raise ValueError('Agent Chat App does not support blocking mode') + raise ValueError("Agent Chat App does not support blocking mode") - if not args.get('query'): - raise ValueError('query is required') + if not args.get("query"): + raise ValueError("query is required") - query = args['query'] + query = args["query"] if not isinstance(query, str): - raise ValueError('query must be a string') + raise ValueError("query must be a string") - query = query.replace('\x00', '') - inputs = args['inputs'] + query = query.replace("\x00", "") + inputs = args["inputs"] - extras = { - "auto_generate_conversation_name": args.get('auto_generate_name', True) - } + extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)} # get conversation conversation = None - if args.get('conversation_id'): - conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + if args.get("conversation_id"): + conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user) # get app model config - app_model_config = self._get_app_model_config( - app_model=app_model, - conversation=conversation - ) + app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) # validate override model config override_model_config_dict = None - if args.get('model_config'): + if args.get("model_config"): if invoke_from != InvokeFrom.DEBUGGER: - raise ValueError('Only in App debug mode can override model config') + raise ValueError("Only in App debug mode can override model config") # validate config override_model_config_dict = AgentChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=args.get('model_config') + tenant_id=app_model.tenant_id, config=args.get("model_config") ) # always enable retriever resource in debugger mode - override_model_config_dict["retriever_resource"] = { - "enabled": True - } + override_model_config_dict["retriever_resource"] = {"enabled": True} # parse files - files = args['files'] if args.get('files') else [] - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + files = args.get("files") or [] file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, ) else: file_objs = [] @@ -106,35 +119,35 @@ def generate(self, app_model: App, app_model=app_model, app_model_config=app_model_config, conversation=conversation, - override_config_dict=override_model_config_dict + override_config_dict=override_model_config_dict, ) # get tracing instance - user_id = user.id if isinstance(user, Account) else user.session_id - trace_manager = TraceQueueManager(app_model.id, user_id) + trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id) # init application generate entity application_generate_entity = AgentChatAppGenerateEntity( task_id=str(uuid.uuid4()), app_config=app_config, model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, conversation_id=conversation.id if conversation else None, - inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), + inputs=conversation.inputs + if conversation + else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), query=query, files=file_objs, + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, stream=stream, invoke_from=invoke_from, extras=extras, call_depth=0, - trace_manager=trace_manager + trace_manager=trace_manager, ) # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity, conversation) + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -143,17 +156,20 @@ def generate(self, app_model: App, invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'conversation_id': conversation.id, - 'message_id': message.id, - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + }, + ) worker_thread.start() @@ -167,13 +183,11 @@ def generate(self, app_model: App, stream=stream, ) - return AgentChatAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( - self, flask_app: Flask, + self, + flask_app: Flask, application_generate_entity: AgentChatAppGenerateEntity, queue_manager: AppQueueManager, conversation_id: str, @@ -202,18 +216,17 @@ def _generate_worker( conversation=conversation, message=message, ) - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if dify_config.DEBUG: logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index d1bbf679c567fd..45b1bf00934d35 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -15,7 +15,7 @@ from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.moderation.base import ModerationException +from core.moderation.base import ModerationError from core.tools.entities.tool_entities import ToolRuntimeVariablePool from extensions.ext_database import db from models.model import App, Conversation, Message, MessageAgentThought @@ -30,7 +30,8 @@ class AgentChatAppRunner(AppRunner): """ def run( - self, application_generate_entity: AgentChatAppGenerateEntity, + self, + application_generate_entity: AgentChatAppGenerateEntity, queue_manager: AppQueueManager, conversation: Conversation, message: Message, @@ -65,7 +66,7 @@ def run( prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, - query=query + query=query, ) memory = None @@ -73,13 +74,10 @@ def run( # get memory of conversation (read-only) model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) # organize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) @@ -91,7 +89,7 @@ def run( inputs=inputs, files=files, query=query, - memory=memory + memory=memory, ) # moderation @@ -103,15 +101,15 @@ def run( app_generate_entity=application_generate_entity, inputs=inputs, query=query, - message_id=message.id + message_id=message.id, ) - except ModerationException as e: + except ModerationError as e: self.direct_output( queue_manager=queue_manager, app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=str(e), - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -122,13 +120,13 @@ def run( message=message, query=query, user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from + invoke_from=application_generate_entity.invoke_from, ) if annotation_reply: queue_manager.publish( QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), - PublishFrom.APPLICATION_MANAGER + PublishFrom.APPLICATION_MANAGER, ) self.direct_output( @@ -136,7 +134,7 @@ def run( app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=annotation_reply.content, - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -148,7 +146,7 @@ def run( app_id=app_record.id, external_data_tools=external_data_tools, inputs=inputs, - query=query + query=query, ) # reorganize all inputs and template to prompt messages @@ -161,14 +159,14 @@ def run( inputs=inputs, files=files, query=query, - memory=memory + memory=memory, ) # check hosting moderation hosting_moderation_result = self.check_hosting_moderation( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - prompt_messages=prompt_messages + prompt_messages=prompt_messages, ) if hosting_moderation_result: @@ -177,9 +175,9 @@ def run( agent_entity = app_config.agent # load tool variables - tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id, - user_id=application_generate_entity.user_id, - tenant_id=app_config.tenant_id) + tool_conversation_variables = self._load_tool_variables( + conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id + ) # convert db variables to tool variables tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) @@ -187,7 +185,7 @@ def run( # init model instance model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) prompt_message, _ = self.organize_prompt_messages( app_record=app_record, @@ -238,7 +236,7 @@ def run( prompt_messages=prompt_message, variables_pool=tool_variables, db_variables=tool_conversation_variables, - model_instance=model_instance + model_instance=model_instance, ) invoke_result = runner.run( @@ -252,17 +250,21 @@ def run( invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream, - agent=True + agent=True, ) def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables: """ load tool variables from database """ - tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter( - ToolConversationVariables.conversation_id == conversation_id, - ToolConversationVariables.tenant_id == tenant_id - ).first() + tool_variables: ToolConversationVariables = ( + db.session.query(ToolConversationVariables) + .filter( + ToolConversationVariables.conversation_id == conversation_id, + ToolConversationVariables.tenant_id == tenant_id, + ) + .first() + ) if tool_variables: # save tool variables to session, so that we can update it later @@ -273,34 +275,40 @@ def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: st conversation_id=conversation_id, user_id=user_id, tenant_id=tenant_id, - variables_str='[]', + variables_str="[]", ) db.session.add(tool_variables) db.session.commit() return tool_variables - - def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversationVariables) -> ToolRuntimeVariablePool: + + def _convert_db_variables_to_tool_variables( + self, db_variables: ToolConversationVariables + ) -> ToolRuntimeVariablePool: """ convert db variables to tool variables """ - return ToolRuntimeVariablePool(**{ - 'conversation_id': db_variables.conversation_id, - 'user_id': db_variables.user_id, - 'tenant_id': db_variables.tenant_id, - 'pool': db_variables.variables - }) - - def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigWithCredentialsEntity, - message: Message) -> LLMUsage: + return ToolRuntimeVariablePool( + **{ + "conversation_id": db_variables.conversation_id, + "user_id": db_variables.user_id, + "tenant_id": db_variables.tenant_id, + "pool": db_variables.variables, + } + ) + + def _get_usage_of_all_agent_thoughts( + self, model_config: ModelConfigWithCredentialsEntity, message: Message + ) -> LLMUsage: """ Get usage of all agent thoughts :param model_config: model config :param message: message :return: """ - agent_thoughts = (db.session.query(MessageAgentThought) - .filter(MessageAgentThought.message_id == message.id).all()) + agent_thoughts = ( + db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all() + ) all_message_tokens = 0 all_answer_tokens = 0 @@ -312,8 +320,5 @@ def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigWithCredenti model_type_instance = cast(LargeLanguageModel, model_type_instance) return model_type_instance._calc_response_usage( - model_config.model, - model_config.credentials, - all_message_tokens, - all_answer_tokens + model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens ) diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index 118d82c495f1fe..629c309c065458 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -23,15 +23,15 @@ def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingRes :return: """ response = { - 'event': 'message', - 'task_id': blocking_response.task_id, - 'id': blocking_response.data.id, - 'message_id': blocking_response.data.message_id, - 'conversation_id': blocking_response.data.conversation_id, - 'mode': blocking_response.data.mode, - 'answer': blocking_response.data.answer, - 'metadata': blocking_response.data.metadata, - 'created_at': blocking_response.data.created_at + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "conversation_id": blocking_response.data.conversation_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, } return response @@ -45,14 +45,15 @@ def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingR """ response = cls.convert_blocking_full_response(blocking_response) - metadata = response.get('metadata', {}) - response['metadata'] = cls._get_simple_metadata(metadata) + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) return response @classmethod - def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_full_response( + cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -63,14 +64,14 @@ def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStrea sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -81,8 +82,9 @@ def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStrea yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream simple response. :param stream_response: stream response @@ -93,20 +95,20 @@ def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStr sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, MessageEndStreamResponse): sub_stream_response_dict = sub_stream_response.to_dict() - metadata = sub_stream_response_dict.get('metadata', {}) - sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 1165314a7f2bd8..62e79ec444a48a 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -13,32 +13,33 @@ class AppGenerateResponseConverter(ABC): _blocking_response_type: type[AppBlockingResponse] @classmethod - def convert(cls, response: Union[ - AppBlockingResponse, - Generator[AppStreamResponse, Any, None] - ], invoke_from: InvokeFrom): - if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: + def convert( + cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom + ) -> dict[str, Any] | Generator[str, Any, None]: + if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}: if isinstance(response, AppBlockingResponse): return cls.convert_blocking_full_response(response) else: + def _generate_full_response() -> Generator[str, Any, None]: for chunk in cls.convert_stream_full_response(response): - if chunk == 'ping': - yield f'event: {chunk}\n\n' + if chunk == "ping": + yield f"event: {chunk}\n\n" else: - yield f'data: {chunk}\n\n' + yield f"data: {chunk}\n\n" return _generate_full_response() else: if isinstance(response, AppBlockingResponse): return cls.convert_blocking_simple_response(response) else: + def _generate_simple_response() -> Generator[str, Any, None]: for chunk in cls.convert_stream_simple_response(response): - if chunk == 'ping': - yield f'event: {chunk}\n\n' + if chunk == "ping": + yield f"event: {chunk}\n\n" else: - yield f'data: {chunk}\n\n' + yield f"data: {chunk}\n\n" return _generate_simple_response() @@ -54,14 +55,16 @@ def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse @classmethod @abstractmethod - def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_full_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[str, None, None]: raise NotImplementedError @classmethod @abstractmethod - def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[AppStreamResponse, None, None] + ) -> Generator[str, None, None]: raise NotImplementedError @classmethod @@ -72,24 +75,27 @@ def _get_simple_metadata(cls, metadata: dict[str, Any]): :return: """ # show_retrieve_source - if 'retriever_resources' in metadata: - metadata['retriever_resources'] = [] - for resource in metadata['retriever_resources']: - metadata['retriever_resources'].append({ - 'segment_id': resource['segment_id'], - 'position': resource['position'], - 'document_name': resource['document_name'], - 'score': resource['score'], - 'content': resource['content'], - }) + updated_resources = [] + if "retriever_resources" in metadata: + for resource in metadata["retriever_resources"]: + updated_resources.append( + { + "segment_id": resource["segment_id"], + "position": resource["position"], + "document_name": resource["document_name"], + "score": resource["score"], + "content": resource["content"], + } + ) + metadata["retriever_resources"] = updated_resources # show annotation reply - if 'annotation_reply' in metadata: - del metadata['annotation_reply'] + if "annotation_reply" in metadata: + del metadata["annotation_reply"] # show usage - if 'usage' in metadata: - del metadata['usage'] + if "usage" in metadata: + del metadata["usage"] return metadata @@ -101,16 +107,16 @@ def _error_to_stream_response(cls, e: Exception) -> dict: :return: """ error_responses = { - ValueError: {'code': 'invalid_param', 'status': 400}, - ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400}, + ValueError: {"code": "invalid_param", "status": 400}, + ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400}, QuotaExceededError: { - 'code': 'provider_quota_exceeded', - 'message': "Your quota for Dify Hosted Model Provider has been exhausted. " - "Please go to Settings -> Model Provider to complete your own provider credentials.", - 'status': 400 + "code": "provider_quota_exceeded", + "message": "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials.", + "status": 400, }, - ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400}, - InvokeError: {'code': 'completion_request_error', 'status': 400} + ModelCurrentlyNotSupportError: {"code": "model_currently_not_support", "status": 400}, + InvokeError: {"code": "completion_request_error", "status": 400}, } # Determine the response based on the type of exception @@ -120,13 +126,13 @@ def _error_to_stream_response(cls, e: Exception) -> dict: data = v if data: - data.setdefault('message', getattr(e, 'description', str(e))) + data.setdefault("message", getattr(e, "description", str(e))) else: logging.error(e) data = { - 'code': 'internal_server_error', - 'message': 'Internal Server Error, please contact support.', - 'status': 500 + "code": "internal_server_error", + "message": "Internal Server Error, please contact support.", + "status": 500, } return data diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 6f48aa23637f7a..6e6da954014958 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,56 +1,137 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional -from core.app.app_config.entities import AppConfig, VariableEntity +from core.app.app_config.entities import VariableEntityType +from core.file import File, FileUploadConfig +from factories import file_factory + +if TYPE_CHECKING: + from core.app.app_config.entities import AppConfig, VariableEntity class BaseAppGenerator: - def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_config: AppConfig) -> Mapping[str, Any]: + def _prepare_user_inputs( + self, + *, + user_inputs: Optional[Mapping[str, Any]], + app_config: "AppConfig", + ) -> Mapping[str, Any]: user_inputs = user_inputs or {} # Filter input variables from form configuration, handle required fields, default values, and option values variables = app_config.variables - filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables} - filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()} - return filtered_inputs - - def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): - user_input_value = inputs.get(var.name) - if var.required and not user_input_value: - raise ValueError(f'{var.name} is required in input form') - if not var.required and not user_input_value: - # TODO: should we return None here if the default value is None? - return var.default or '' - if ( - var.type - in ( - VariableEntity.Type.TEXT_INPUT, - VariableEntity.Type.SELECT, - VariableEntity.Type.PARAGRAPH, + user_inputs = { + var.variable: self._validate_inputs(value=user_inputs.get(var.variable), variable_entity=var) + for var in variables + } + user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()} + # Convert files in inputs to File + entity_dictionary = {item.variable: item for item in app_config.variables} + # Convert single file to File + files_inputs = { + k: file_factory.build_from_mapping( + mapping=v, + tenant_id=app_config.tenant_id, + config=FileUploadConfig( + allowed_file_types=entity_dictionary[k].allowed_file_types, + allowed_extensions=entity_dictionary[k].allowed_file_extensions, + allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods, + ), + ) + for k, v in user_inputs.items() + if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE + } + # Convert list of files to File + file_list_inputs = { + k: file_factory.build_from_mappings( + mappings=v, + tenant_id=app_config.tenant_id, + config=FileUploadConfig( + allowed_file_types=entity_dictionary[k].allowed_file_types, + allowed_extensions=entity_dictionary[k].allowed_file_extensions, + allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods, + ), ) - and user_input_value - and not isinstance(user_input_value, str) + for k, v in user_inputs.items() + if isinstance(v, list) + # Ensure skip List + and all(isinstance(item, dict) for item in v) + and entity_dictionary[k].type == VariableEntityType.FILE_LIST + } + # Merge all inputs + user_inputs = {**user_inputs, **files_inputs, **file_list_inputs} + + # Check if all files are converted to File + if any(filter(lambda v: isinstance(v, dict), user_inputs.values())): + raise ValueError("Invalid input type") + if any( + filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values())) ): - raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string") - if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str): + raise ValueError("Invalid input type") + + return user_inputs + + def _validate_inputs( + self, + *, + variable_entity: "VariableEntity", + value: Any, + ): + if value is None: + if variable_entity.required: + raise ValueError(f"{variable_entity.variable} is required in input form") + return value + + if variable_entity.type in { + VariableEntityType.TEXT_INPUT, + VariableEntityType.SELECT, + VariableEntityType.PARAGRAPH, + } and not isinstance(value, str): + raise ValueError( + f"(type '{variable_entity.type}') {variable_entity.variable} in input form must be a string" + ) + + if variable_entity.type == VariableEntityType.NUMBER and isinstance(value, str): # may raise ValueError if user_input_value is not a valid number try: - if '.' in user_input_value: - return float(user_input_value) + if "." in value: + return float(value) else: - return int(user_input_value) + return int(value) except ValueError: - raise ValueError(f"{var.name} in input form must be a valid number") - if var.type == VariableEntity.Type.SELECT: - options = var.options or [] - if user_input_value not in options: - raise ValueError(f'{var.name} in input form must be one of the following: {options}') - elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH): - if var.max_length and user_input_value and len(user_input_value) > var.max_length: - raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters') + raise ValueError(f"{variable_entity.variable} in input form must be a valid number") + + match variable_entity.type: + case VariableEntityType.SELECT: + if value not in variable_entity.options: + raise ValueError( + f"{variable_entity.variable} in input form must be one of the following: " + f"{variable_entity.options}" + ) + case VariableEntityType.TEXT_INPUT | VariableEntityType.PARAGRAPH: + if variable_entity.max_length and len(value) > variable_entity.max_length: + raise ValueError( + f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} " + "characters" + ) + case VariableEntityType.FILE: + if not isinstance(value, dict) and not isinstance(value, File): + raise ValueError(f"{variable_entity.variable} in input form must be a file") + case VariableEntityType.FILE_LIST: + # if number of files exceeds the limit, raise ValueError + if not ( + isinstance(value, list) + and (all(isinstance(item, dict) for item in value) or all(isinstance(item, File) for item in value)) + ): + raise ValueError(f"{variable_entity.variable} in input form must be a list of files") - return user_input_value + if variable_entity.max_length and len(value) > variable_entity.max_length: + raise ValueError( + f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files" + ) + + return value def _sanitize_value(self, value: Any) -> Any: if isinstance(value, str): - return value.replace('\x00', '') + return value.replace("\x00", "") return value diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index f929a979f129de..4c4d282e99b6ae 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -24,9 +24,7 @@ class PublishFrom(Enum): class AppQueueManager: - def __init__(self, task_id: str, - user_id: str, - invoke_from: InvokeFrom) -> None: + def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None: if not user_id: raise ValueError("user is required") @@ -34,9 +32,10 @@ def __init__(self, task_id: str, self._user_id = user_id self._invoke_from = invoke_from - user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' - redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, - f"{user_prefix}-{self._user_id}") + user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user" + redis_client.setex( + AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}" + ) q = queue.Queue() @@ -66,8 +65,7 @@ def listen(self) -> Generator: # publish two messages to make sure the client can receive the stop signal # and stop listening after the stop signal processed self.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), - PublishFrom.TASK_PIPELINE + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE ) if elapsed_time // 10 > last_ping_time: @@ -88,9 +86,7 @@ def publish_error(self, e, pub_from: PublishFrom) -> None: :param pub_from: publish from :return: """ - self.publish(QueueErrorEvent( - error=e - ), pub_from) + self.publish(QueueErrorEvent(error=e), pub_from) def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: """ @@ -122,8 +118,8 @@ def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> N if result is None: return - user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' - if result.decode('utf-8') != f"{user_prefix}-{user_id}": + user_prefix = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user" + if result.decode("utf-8") != f"{user_prefix}-{user_id}": return stopped_cache_key = cls._generate_stopped_cache_key(task_id) @@ -168,10 +164,12 @@ def _check_for_sqlalchemy_models(self, data: Any): for item in data: self._check_for_sqlalchemy_models(item) else: - if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'): - raise TypeError("Critical Error: Passing SQLAlchemy Model instances " - "that cause thread safety issues is not allowed.") + if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"): + raise TypeError( + "Critical Error: Passing SQLAlchemy Model instances " + "that cause thread safety issues is not allowed." + ) -class GenerateTaskStoppedException(Exception): +class GenerateTaskStoppedError(Exception): pass diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 2c5feaaaafb153..609fd03f229da8 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -1,6 +1,6 @@ import time -from collections.abc import Generator -from typing import TYPE_CHECKING, Optional, Union +from collections.abc import Generator, Mapping +from typing import TYPE_CHECKING, Any, Optional, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -27,16 +27,19 @@ from models.model import App, AppMode, Message, MessageAnnotation if TYPE_CHECKING: - from core.file.file_obj import FileVar + from core.file.models import File class AppRunner: - def get_pre_calculate_rest_tokens(self, app_record: App, - model_config: ModelConfigWithCredentialsEntity, - prompt_template_entity: PromptTemplateEntity, - inputs: dict[str, str], - files: list["FileVar"], - query: Optional[str] = None) -> int: + def get_pre_calculate_rest_tokens( + self, + app_record: App, + model_config: ModelConfigWithCredentialsEntity, + prompt_template_entity: PromptTemplateEntity, + inputs: dict[str, str], + files: list["File"], + query: Optional[str] = None, + ) -> int: """ Get pre calculate rest tokens :param app_record: app record @@ -49,18 +52,20 @@ def get_pre_calculate_rest_tokens(self, app_record: App, """ # Invoke model model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template) + ) or 0 if model_context_tokens is None: return -1 @@ -75,36 +80,39 @@ def get_pre_calculate_rest_tokens(self, app_record: App, prompt_template_entity=prompt_template_entity, inputs=inputs, files=files, - query=query + query=query, ) - prompt_tokens = model_instance.get_llm_num_tokens( - prompt_messages - ) + prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) rest_tokens = model_context_tokens - max_tokens - prompt_tokens if rest_tokens < 0: - raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, " - "or shrink the max token, or switch to a llm with a larger token limit size.") + raise InvokeBadRequestError( + "Query or prefix prompt is too long, you can reduce the prefix prompt, " + "or shrink the max token, or switch to a llm with a larger token limit size." + ) return rest_tokens - def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity, - prompt_messages: list[PromptMessage]): + def recalc_llm_max_tokens( + self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage] + ): # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template) + ) or 0 if model_context_tokens is None: return -1 @@ -112,27 +120,28 @@ def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity, if max_tokens is None: max_tokens = 0 - prompt_tokens = model_instance.get_llm_num_tokens( - prompt_messages - ) + prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) if prompt_tokens + max_tokens > model_context_tokens: max_tokens = max(model_context_tokens - prompt_tokens, 16) for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): model_config.parameters[parameter_rule.name] = max_tokens - def organize_prompt_messages(self, app_record: App, - model_config: ModelConfigWithCredentialsEntity, - prompt_template_entity: PromptTemplateEntity, - inputs: dict[str, str], - files: list["FileVar"], - query: Optional[str] = None, - context: Optional[str] = None, - memory: Optional[TokenBufferMemory] = None) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: + def organize_prompt_messages( + self, + app_record: App, + model_config: ModelConfigWithCredentialsEntity, + prompt_template_entity: PromptTemplateEntity, + inputs: dict[str, str], + files: list["File"], + query: Optional[str] = None, + context: Optional[str] = None, + memory: Optional[TokenBufferMemory] = None, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: """ Organize prompt messages :param context: @@ -152,60 +161,54 @@ def organize_prompt_messages(self, app_record: App, app_mode=AppMode.value_of(app_record.mode), prompt_template_entity=prompt_template_entity, inputs=inputs, - query=query if query else '', + query=query or "", files=files, context=context, memory=memory, - model_config=model_config + model_config=model_config, ) else: - memory_config = MemoryConfig( - window=MemoryConfig.WindowConfig( - enabled=False - ) - ) + memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) model_mode = ModelMode.value_of(model_config.mode) if model_mode == ModelMode.COMPLETION: advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template - prompt_template = CompletionModelPromptTemplate( - text=advanced_completion_prompt_template.prompt - ) + prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt) if advanced_completion_prompt_template.role_prefix: memory_config.role_prefix = MemoryConfig.RolePrefix( user=advanced_completion_prompt_template.role_prefix.user, - assistant=advanced_completion_prompt_template.role_prefix.assistant + assistant=advanced_completion_prompt_template.role_prefix.assistant, ) else: prompt_template = [] for message in prompt_template_entity.advanced_chat_prompt_template.messages: - prompt_template.append(ChatModelMessage( - text=message.text, - role=message.role - )) + prompt_template.append(ChatModelMessage(text=message.text, role=message.role)) prompt_transform = AdvancedPromptTransform() prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs=inputs, - query=query if query else '', + query=query or "", files=files, context=context, memory_config=memory_config, memory=memory, - model_config=model_config + model_config=model_config, ) stop = model_config.stop return prompt_messages, stop - def direct_output(self, queue_manager: AppQueueManager, - app_generate_entity: EasyUIBasedAppGenerateEntity, - prompt_messages: list, - text: str, - stream: bool, - usage: Optional[LLMUsage] = None) -> None: + def direct_output( + self, + queue_manager: AppQueueManager, + app_generate_entity: EasyUIBasedAppGenerateEntity, + prompt_messages: list, + text: str, + stream: bool, + usage: Optional[LLMUsage] = None, + ) -> None: """ Direct output :param queue_manager: application queue manager @@ -222,17 +225,10 @@ def direct_output(self, queue_manager: AppQueueManager, chunk = LLMResultChunk( model=app_generate_entity.model_conf.model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=AssistantPromptMessage(content=token) - ) + delta=LLMResultChunkDelta(index=index, message=AssistantPromptMessage(content=token)), ) - queue_manager.publish( - QueueLLMChunkEvent( - chunk=chunk - ), PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER) index += 1 time.sleep(0.01) @@ -242,15 +238,19 @@ def direct_output(self, queue_manager: AppQueueManager, model=app_generate_entity.model_conf.model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), - usage=usage if usage else LLMUsage.empty_usage() + usage=usage or LLMUsage.empty_usage(), ), - ), PublishFrom.APPLICATION_MANAGER + ), + PublishFrom.APPLICATION_MANAGER, ) - def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], - queue_manager: AppQueueManager, - stream: bool, - agent: bool = False) -> None: + def _handle_invoke_result( + self, + invoke_result: Union[LLMResult, Generator], + queue_manager: AppQueueManager, + stream: bool, + agent: bool = False, + ) -> None: """ Handle invoke result :param invoke_result: invoke result @@ -260,21 +260,13 @@ def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], :return: """ if not stream: - self._handle_invoke_result_direct( - invoke_result=invoke_result, - queue_manager=queue_manager, - agent=agent - ) + self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) else: - self._handle_invoke_result_stream( - invoke_result=invoke_result, - queue_manager=queue_manager, - agent=agent - ) + self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent) - def _handle_invoke_result_direct(self, invoke_result: LLMResult, - queue_manager: AppQueueManager, - agent: bool) -> None: + def _handle_invoke_result_direct( + self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool + ) -> None: """ Handle invoke result direct :param invoke_result: invoke result @@ -285,12 +277,13 @@ def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager.publish( QueueMessageEndEvent( llm_result=invoke_result, - ), PublishFrom.APPLICATION_MANAGER + ), + PublishFrom.APPLICATION_MANAGER, ) - def _handle_invoke_result_stream(self, invoke_result: Generator, - queue_manager: AppQueueManager, - agent: bool) -> None: + def _handle_invoke_result_stream( + self, invoke_result: Generator, queue_manager: AppQueueManager, agent: bool + ) -> None: """ Handle invoke result :param invoke_result: invoke result @@ -300,21 +293,13 @@ def _handle_invoke_result_stream(self, invoke_result: Generator, """ model = None prompt_messages = [] - text = '' + text = "" usage = None for result in invoke_result: if not agent: - queue_manager.publish( - QueueLLMChunkEvent( - chunk=result - ), PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER) else: - queue_manager.publish( - QueueAgentMessageEvent( - chunk=result - ), PublishFrom.APPLICATION_MANAGER - ) + queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER) text += result.delta.message.content @@ -324,32 +309,31 @@ def _handle_invoke_result_stream(self, invoke_result: Generator, if not prompt_messages: prompt_messages = result.prompt_messages - if not usage and result.delta.usage: + if result.delta.usage: usage = result.delta.usage if not usage: usage = LLMUsage.empty_usage() llm_result = LLMResult( - model=model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage(content=text), - usage=usage + model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage ) queue_manager.publish( QueueMessageEndEvent( llm_result=llm_result, - ), PublishFrom.APPLICATION_MANAGER + ), + PublishFrom.APPLICATION_MANAGER, ) def moderation_for_inputs( - self, app_id: str, - tenant_id: str, - app_generate_entity: AppGenerateEntity, - inputs: dict, - query: str, - message_id: str, + self, + app_id: str, + tenant_id: str, + app_generate_entity: AppGenerateEntity, + inputs: Mapping[str, Any], + query: str, + message_id: str, ) -> tuple[bool, dict, str]: """ Process sensitive_word_avoidance. @@ -367,14 +351,17 @@ def moderation_for_inputs( tenant_id=tenant_id, app_config=app_generate_entity.app_config, inputs=inputs, - query=query if query else '', + query=query or "", message_id=message_id, - trace_manager=app_generate_entity.trace_manager + trace_manager=app_generate_entity.trace_manager, ) - def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity, - queue_manager: AppQueueManager, - prompt_messages: list[PromptMessage]) -> bool: + def check_hosting_moderation( + self, + application_generate_entity: EasyUIBasedAppGenerateEntity, + queue_manager: AppQueueManager, + prompt_messages: list[PromptMessage], + ) -> bool: """ Check hosting moderation :param application_generate_entity: application generate entity @@ -384,8 +371,7 @@ def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGe """ hosting_moderation_feature = HostingModerationFeature() moderation_result = hosting_moderation_feature.check( - application_generate_entity=application_generate_entity, - prompt_messages=prompt_messages + application_generate_entity=application_generate_entity, prompt_messages=prompt_messages ) if moderation_result: @@ -393,18 +379,20 @@ def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGe queue_manager=queue_manager, app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, - text="I apologize for any confusion, " \ - "but I'm an AI assistant to be helpful, harmless, and honest.", - stream=application_generate_entity.stream + text="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.", + stream=application_generate_entity.stream, ) return moderation_result - def fill_in_inputs_from_external_data_tools(self, tenant_id: str, - app_id: str, - external_data_tools: list[ExternalDataVariableEntity], - inputs: dict, - query: str) -> dict: + def fill_in_inputs_from_external_data_tools( + self, + tenant_id: str, + app_id: str, + external_data_tools: list[ExternalDataVariableEntity], + inputs: dict, + query: str, + ) -> dict: """ Fill in variable inputs from external data tools if exists. @@ -417,18 +405,12 @@ def fill_in_inputs_from_external_data_tools(self, tenant_id: str, """ external_data_fetch_feature = ExternalDataFetch() return external_data_fetch_feature.fetch( - tenant_id=tenant_id, - app_id=app_id, - external_data_tools=external_data_tools, - inputs=inputs, - query=query + tenant_id=tenant_id, app_id=app_id, external_data_tools=external_data_tools, inputs=inputs, query=query ) - def query_app_annotations_to_reply(self, app_record: App, - message: Message, - query: str, - user_id: str, - invoke_from: InvokeFrom) -> Optional[MessageAnnotation]: + def query_app_annotations_to_reply( + self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom + ) -> Optional[MessageAnnotation]: """ Query app annotations to reply :param app_record: app record @@ -440,9 +422,5 @@ def query_app_annotations_to_reply(self, app_record: App, """ annotation_reply_feature = AnnotationReplyFeature() return annotation_reply_feature.query( - app_record=app_record, - message=message, - query=query, - user_id=user_id, - invoke_from=invoke_from + app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from ) diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py index a286c349b2715b..96dc7dda79af6d 100644 --- a/api/core/app/apps/chat/app_config_manager.py +++ b/api/core/app/apps/chat/app_config_manager.py @@ -22,15 +22,19 @@ class ChatAppConfig(EasyUIBasedAppConfig): """ Chatbot App Config Entity. """ + pass class ChatAppConfigManager(BaseAppConfigManager): @classmethod - def get_app_config(cls, app_model: App, - app_model_config: AppModelConfig, - conversation: Optional[Conversation] = None, - override_config_dict: Optional[dict] = None) -> ChatAppConfig: + def get_app_config( + cls, + app_model: App, + app_model_config: AppModelConfig, + conversation: Optional[Conversation] = None, + override_config_dict: Optional[dict] = None, + ) -> ChatAppConfig: """ Convert app model config to chat app config :param app_model: app model @@ -51,7 +55,7 @@ def get_app_config(cls, app_model: App, config_dict = app_model_config_dict.copy() else: if not override_config_dict: - raise Exception('override_config_dict is required when config_from is ARGS') + raise Exception("override_config_dict is required when config_from is ARGS") config_dict = override_config_dict @@ -63,19 +67,11 @@ def get_app_config(cls, app_model: App, app_model_config_from=config_from, app_model_config_id=app_model_config.id, app_model_config_dict=config_dict, - model=ModelConfigManager.convert( - config=config_dict - ), - prompt_template=PromptTemplateConfigManager.convert( - config=config_dict - ), - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=config_dict - ), - dataset=DatasetConfigManager.convert( - config=config_dict - ), - additional_features=cls.convert_features(config_dict, app_mode) + model=ModelConfigManager.convert(config=config_dict), + prompt_template=PromptTemplateConfigManager.convert(config=config_dict), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), + dataset=DatasetConfigManager.convert(config=config_dict), + additional_features=cls.convert_features(config_dict, app_mode), ) app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( @@ -113,8 +109,9 @@ def config_validate(cls, tenant_id: str, config: dict) -> dict: related_config_keys.extend(current_related_config_keys) # dataset_query_variable - config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, - config) + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults( + tenant_id, app_mode, config + ) related_config_keys.extend(current_related_config_keys) # opening_statement @@ -123,7 +120,8 @@ def config_validate(cls, tenant_id: str, config: dict) -> dict: # suggested_questions_after_answer config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( - config) + config + ) related_config_keys.extend(current_related_config_keys) # speech_to_text @@ -139,8 +137,9 @@ def config_validate(cls, tenant_id: str, config: dict) -> dict: related_config_keys.extend(current_related_config_keys) # moderation validation - config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, - config) + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id, config + ) related_config_keys.extend(current_related_config_keys) related_config_keys = list(set(related_config_keys)) diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 5b896e28455340..e683dfef3f7f78 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -1,26 +1,27 @@ import logging -import os import threading import uuid from collections.abc import Generator -from typing import Any, Union +from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError +from configs import dify_config +from constants import UUID_NIL from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.chat.app_runner import ChatAppRunner from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom -from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db +from factories import file_factory from models.account import Account from models.model import App, EndUser @@ -28,13 +29,34 @@ class ChatAppGenerator(MessageBasedAppGenerator): + @overload def generate( - self, app_model: App, + self, + app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + stream: Literal[True] = True, + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, + app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + stream: Literal[False] = False, + ) -> dict: ... + + def generate( + self, + app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True, - ) -> Union[dict, Generator[dict, None, None]]: + ) -> Union[dict, Generator[str, None, None]]: """ Generate App response. @@ -44,57 +66,48 @@ def generate( :param invoke_from: invoke from source :param stream: is stream """ - if not args.get('query'): - raise ValueError('query is required') + if not args.get("query"): + raise ValueError("query is required") - query = args['query'] + query = args["query"] if not isinstance(query, str): - raise ValueError('query must be a string') + raise ValueError("query must be a string") - query = query.replace('\x00', '') - inputs = args['inputs'] + query = query.replace("\x00", "") + inputs = args["inputs"] - extras = { - "auto_generate_conversation_name": args.get('auto_generate_name', True) - } + extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)} # get conversation conversation = None - if args.get('conversation_id'): - conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user) + if args.get("conversation_id"): + conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user) # get app model config - app_model_config = self._get_app_model_config( - app_model=app_model, - conversation=conversation - ) + app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) # validate override model config override_model_config_dict = None - if args.get('model_config'): + if args.get("model_config"): if invoke_from != InvokeFrom.DEBUGGER: - raise ValueError('Only in App debug mode can override model config') + raise ValueError("Only in App debug mode can override model config") # validate config override_model_config_dict = ChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=args.get('model_config') + tenant_id=app_model.tenant_id, config=args.get("model_config") ) # always enable retriever resource in debugger mode - override_model_config_dict["retriever_resource"] = { - "enabled": True - } + override_model_config_dict["retriever_resource"] = {"enabled": True} # parse files - files = args['files'] if args.get('files') else [] - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + files = args["files"] if args.get("files") else [] file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, ) else: file_objs = [] @@ -104,33 +117,34 @@ def generate( app_model=app_model, app_model_config=app_model_config, conversation=conversation, - override_config_dict=override_model_config_dict + override_config_dict=override_model_config_dict, ) # get tracing instance - trace_manager = TraceQueueManager(app_model.id) + trace_manager = TraceQueueManager(app_id=app_model.id) # init application generate entity application_generate_entity = ChatAppGenerateEntity( task_id=str(uuid.uuid4()), app_config=app_config, model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, conversation_id=conversation.id if conversation else None, - inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config), + inputs=conversation.inputs + if conversation + else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), query=query, files=file_objs, + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, - stream=stream, invoke_from=invoke_from, extras=extras, - trace_manager=trace_manager + trace_manager=trace_manager, + stream=stream, ) # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity, conversation) + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -139,17 +153,20 @@ def generate( invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'conversation_id': conversation.id, - 'message_id': message.id, - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + }, + ) worker_thread.start() @@ -163,16 +180,16 @@ def generate( stream=stream, ) - return ChatAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) - def _generate_worker(self, flask_app: Flask, - application_generate_entity: ChatAppGenerateEntity, - queue_manager: AppQueueManager, - conversation_id: str, - message_id: str) -> None: + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: ChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation_id: str, + message_id: str, + ) -> None: """ Generate worker in a new thread. :param flask_app: Flask app @@ -194,20 +211,19 @@ def _generate_worker(self, flask_app: Flask, application_generate_entity=application_generate_entity, queue_manager=queue_manager, conversation=conversation, - message=message + message=message, ) - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if dify_config.DEBUG: logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 89a498eb3607f9..425f1ab7ef4cc6 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -11,7 +11,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.moderation.base import ModerationException +from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db from models.model import App, Conversation, Message @@ -24,10 +24,13 @@ class ChatAppRunner(AppRunner): Chat Application Runner """ - def run(self, application_generate_entity: ChatAppGenerateEntity, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message) -> None: + def run( + self, + application_generate_entity: ChatAppGenerateEntity, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + ) -> None: """ Run application :param application_generate_entity: application generate entity @@ -58,7 +61,7 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, - query=query + query=query, ) memory = None @@ -66,13 +69,10 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, # get memory of conversation (read-only) model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) # organize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) @@ -84,7 +84,7 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, inputs=inputs, files=files, query=query, - memory=memory + memory=memory, ) # moderation @@ -96,15 +96,15 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, app_generate_entity=application_generate_entity, inputs=inputs, query=query, - message_id=message.id + message_id=message.id, ) - except ModerationException as e: + except ModerationError as e: self.direct_output( queue_manager=queue_manager, app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=str(e), - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -115,13 +115,13 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, message=message, query=query, user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from + invoke_from=application_generate_entity.invoke_from, ) if annotation_reply: queue_manager.publish( QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), - PublishFrom.APPLICATION_MANAGER + PublishFrom.APPLICATION_MANAGER, ) self.direct_output( @@ -129,7 +129,7 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=annotation_reply.content, - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -141,7 +141,7 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, app_id=app_record.id, external_data_tools=external_data_tools, inputs=inputs, - query=query + query=query, ) # get context from datasets @@ -152,7 +152,7 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, app_record.id, message.id, application_generate_entity.user_id, - application_generate_entity.invoke_from + application_generate_entity.invoke_from, ) dataset_retrieval = DatasetRetrieval(application_generate_entity) @@ -181,29 +181,26 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, files=files, query=query, context=context, - memory=memory + memory=memory, ) # check hosting moderation hosting_moderation_result = self.check_hosting_moderation( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - prompt_messages=prompt_messages + prompt_messages=prompt_messages, ) if hosting_moderation_result: return # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit - self.recalc_llm_max_tokens( - model_config=application_generate_entity.model_conf, - prompt_messages=prompt_messages - ) + self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages) # Invoke model model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) db.session.close() @@ -218,7 +215,5 @@ def run(self, application_generate_entity: ChatAppGenerateEntity, # handle invoke result self._handle_invoke_result( - invoke_result=invoke_result, - queue_manager=queue_manager, - stream=application_generate_entity.stream + invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream ) diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index 625e14c9c39712..0fa7af0a7fa36d 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -23,15 +23,15 @@ def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingRes :return: """ response = { - 'event': 'message', - 'task_id': blocking_response.task_id, - 'id': blocking_response.data.id, - 'message_id': blocking_response.data.message_id, - 'conversation_id': blocking_response.data.conversation_id, - 'mode': blocking_response.data.mode, - 'answer': blocking_response.data.answer, - 'metadata': blocking_response.data.metadata, - 'created_at': blocking_response.data.created_at + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "conversation_id": blocking_response.data.conversation_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, } return response @@ -45,14 +45,15 @@ def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingR """ response = cls.convert_blocking_full_response(blocking_response) - metadata = response.get('metadata', {}) - response['metadata'] = cls._get_simple_metadata(metadata) + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) return response @classmethod - def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_full_response( + cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -63,14 +64,14 @@ def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStrea sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -81,8 +82,9 @@ def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStrea yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream simple response. :param stream_response: stream response @@ -93,20 +95,20 @@ def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStr sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'conversation_id': chunk.conversation_id, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "conversation_id": chunk.conversation_id, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, MessageEndStreamResponse): sub_stream_response_dict = sub_stream_response.to_dict() - metadata = sub_stream_response_dict.get('metadata', {}) - sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index a7711983249a33..1193c4b7a43632 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -17,14 +17,15 @@ class CompletionAppConfig(EasyUIBasedAppConfig): """ Completion App Config Entity. """ + pass class CompletionAppConfigManager(BaseAppConfigManager): @classmethod - def get_app_config(cls, app_model: App, - app_model_config: AppModelConfig, - override_config_dict: Optional[dict] = None) -> CompletionAppConfig: + def get_app_config( + cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: Optional[dict] = None + ) -> CompletionAppConfig: """ Convert app model config to completion app config :param app_model: app model @@ -51,19 +52,11 @@ def get_app_config(cls, app_model: App, app_model_config_from=config_from, app_model_config_id=app_model_config.id, app_model_config_dict=config_dict, - model=ModelConfigManager.convert( - config=config_dict - ), - prompt_template=PromptTemplateConfigManager.convert( - config=config_dict - ), - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=config_dict - ), - dataset=DatasetConfigManager.convert( - config=config_dict - ), - additional_features=cls.convert_features(config_dict, app_mode) + model=ModelConfigManager.convert(config=config_dict), + prompt_template=PromptTemplateConfigManager.convert(config=config_dict), + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), + dataset=DatasetConfigManager.convert(config=config_dict), + additional_features=cls.convert_features(config_dict, app_mode), ) app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( @@ -101,8 +94,9 @@ def config_validate(cls, tenant_id: str, config: dict) -> dict: related_config_keys.extend(current_related_config_keys) # dataset_query_variable - config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode, - config) + config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults( + tenant_id, app_mode, config + ) related_config_keys.extend(current_related_config_keys) # text_to_speech @@ -114,8 +108,9 @@ def config_validate(cls, tenant_id: str, config: dict) -> dict: related_config_keys.extend(current_related_config_keys) # moderation validation - config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id, - config) + config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( + tenant_id, config + ) related_config_keys.extend(current_related_config_keys) related_config_keys = list(set(related_config_keys)) diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index c4e1caf65a9679..22ee8b096762d3 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -1,28 +1,27 @@ import logging -import os import threading import uuid from collections.abc import Generator -from typing import Any, Union +from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError +from configs import dify_config from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.app.apps.completion.app_runner import CompletionAppRunner from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom -from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db -from models.account import Account -from models.model import App, EndUser, Message +from factories import file_factory +from models import Account, App, EndUser, Message from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError @@ -30,12 +29,29 @@ class CompletionAppGenerator(MessageBasedAppGenerator): - def generate(self, app_model: App, - user: Union[Account, EndUser], - args: Any, - invoke_from: InvokeFrom, - stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + @overload + def generate( + self, + app_model: App, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[True] = True, + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, + app_model: App, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[False] = False, + ) -> dict: ... + + def generate( + self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True + ) -> Union[dict, Generator[str, None, None]]: """ Generate App response. @@ -45,12 +61,12 @@ def generate(self, app_model: App, :param invoke_from: invoke from source :param stream: is stream """ - query = args['query'] + query = args["query"] if not isinstance(query, str): - raise ValueError('query must be a string') + raise ValueError("query must be a string") - query = query.replace('\x00', '') - inputs = args['inputs'] + query = query.replace("\x00", "") + inputs = args["inputs"] extras = {} @@ -58,41 +74,34 @@ def generate(self, app_model: App, conversation = None # get app model config - app_model_config = self._get_app_model_config( - app_model=app_model, - conversation=conversation - ) + app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) # validate override model config override_model_config_dict = None - if args.get('model_config'): + if args.get("model_config"): if invoke_from != InvokeFrom.DEBUGGER: - raise ValueError('Only in App debug mode can override model config') + raise ValueError("Only in App debug mode can override model config") # validate config override_model_config_dict = CompletionAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=args.get('model_config') + tenant_id=app_model.tenant_id, config=args.get("model_config") ) # parse files - files = args['files'] if args.get('files') else [] - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) + files = args["files"] if args.get("files") else [] file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, ) else: file_objs = [] # convert to app config app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config, - override_config_dict=override_model_config_dict + app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict ) # get tracing instance @@ -103,21 +112,19 @@ def generate(self, app_model: App, task_id=str(uuid.uuid4()), app_config=app_config, model_conf=ModelConfigConverter.convert(app_config), - inputs=self._get_cleaned_inputs(inputs, app_config), + file_upload_config=file_extra_config, + inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), query=query, files=file_objs, user_id=user.id, stream=stream, invoke_from=invoke_from, extras=extras, - trace_manager=trace_manager + trace_manager=trace_manager, ) # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity) + (conversation, message) = self._init_generate_records(application_generate_entity) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -126,16 +133,19 @@ def generate(self, app_model: App, invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'message_id': message.id, - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "message_id": message.id, + }, + ) worker_thread.start() @@ -149,15 +159,15 @@ def generate(self, app_model: App, stream=stream, ) - return CompletionAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) - def _generate_worker(self, flask_app: Flask, - application_generate_entity: CompletionAppGenerateEntity, - queue_manager: AppQueueManager, - message_id: str) -> None: + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: CompletionAppGenerateEntity, + queue_manager: AppQueueManager, + message_id: str, + ) -> None: """ Generate worker in a new thread. :param flask_app: Flask app @@ -176,20 +186,19 @@ def _generate_worker(self, flask_app: Flask, runner.run( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - message=message + message=message, ) - except GenerateTaskStoppedException: + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if dify_config.DEBUG: logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: @@ -198,12 +207,14 @@ def _generate_worker(self, flask_app: Flask, finally: db.session.close() - def generate_more_like_this(self, app_model: App, - message_id: str, - user: Union[Account, EndUser], - invoke_from: InvokeFrom, - stream: bool = True) \ - -> Union[dict, Generator[dict, None, None]]: + def generate_more_like_this( + self, + app_model: App, + message_id: str, + user: Union[Account, EndUser], + invoke_from: InvokeFrom, + stream: bool = True, + ) -> Union[dict, Generator[str, None, None]]: """ Generate App response. @@ -213,13 +224,17 @@ def generate_more_like_this(self, app_model: App, :param invoke_from: invoke from source :param stream: is stream """ - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id, - Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), - Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), - Message.from_account_id == (user.id if isinstance(user, Account) else None), - ).first() + message = ( + db.session.query(Message) + .filter( + Message.id == message_id, + Message.app_id == app_model.id, + Message.from_source == ("api" if isinstance(user, EndUser) else "console"), + Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Message.from_account_id == (user.id if isinstance(user, Account) else None), + ) + .first() + ) if not message: raise MessageNotExistsError() @@ -232,29 +247,26 @@ def generate_more_like_this(self, app_model: App, app_model_config = message.app_model_config override_model_config_dict = app_model_config.to_dict() - model_dict = override_model_config_dict['model'] - completion_params = model_dict.get('completion_params') - completion_params['temperature'] = 0.9 - model_dict['completion_params'] = completion_params - override_model_config_dict['model'] = model_dict + model_dict = override_model_config_dict["model"] + completion_params = model_dict.get("completion_params") + completion_params["temperature"] = 0.9 + model_dict["completion_params"] = completion_params + override_model_config_dict["model"] = model_dict # parse files - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) - file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) + file_extra_config = FileUploadConfigManager.convert(override_model_config_dict) if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - message.files, - file_extra_config, - user + file_objs = file_factory.build_from_mappings( + mappings=message.message_files, + tenant_id=app_model.tenant_id, + config=file_extra_config, ) else: file_objs = [] # convert to app config app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config, - override_config_dict=override_model_config_dict + app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict ) # init application generate entity @@ -268,14 +280,11 @@ def generate_more_like_this(self, app_model: App, user_id=user.id, stream=stream, invoke_from=invoke_from, - extras={} + extras={}, ) # init generate records - ( - conversation, - message - ) = self._init_generate_records(application_generate_entity) + (conversation, message) = self._init_generate_records(application_generate_entity) # init queue manager queue_manager = MessageBasedAppQueueManager( @@ -284,16 +293,19 @@ def generate_more_like_this(self, app_model: App, invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, app_mode=conversation.mode, - message_id=message.id + message_id=message.id, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'message_id': message.id, - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "message_id": message.id, + }, + ) worker_thread.start() @@ -307,7 +319,4 @@ def generate_more_like_this(self, app_model: App, stream=stream, ) - return CompletionAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index f0e5f9ae173c39..908d74ff539a5a 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -9,7 +9,7 @@ ) from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.model_manager import ModelInstance -from core.moderation.base import ModerationException +from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from extensions.ext_database import db from models.model import App, Message @@ -22,9 +22,9 @@ class CompletionAppRunner(AppRunner): Completion Application Runner """ - def run(self, application_generate_entity: CompletionAppGenerateEntity, - queue_manager: AppQueueManager, - message: Message) -> None: + def run( + self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message + ) -> None: """ Run application :param application_generate_entity: application generate entity @@ -54,7 +54,7 @@ def run(self, application_generate_entity: CompletionAppGenerateEntity, prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, - query=query + query=query, ) # organize all inputs and template to prompt messages @@ -65,7 +65,7 @@ def run(self, application_generate_entity: CompletionAppGenerateEntity, prompt_template_entity=app_config.prompt_template, inputs=inputs, files=files, - query=query + query=query, ) # moderation @@ -77,15 +77,15 @@ def run(self, application_generate_entity: CompletionAppGenerateEntity, app_generate_entity=application_generate_entity, inputs=inputs, query=query, - message_id=message.id + message_id=message.id, ) - except ModerationException as e: + except ModerationError as e: self.direct_output( queue_manager=queue_manager, app_generate_entity=application_generate_entity, prompt_messages=prompt_messages, text=str(e), - stream=application_generate_entity.stream + stream=application_generate_entity.stream, ) return @@ -97,7 +97,7 @@ def run(self, application_generate_entity: CompletionAppGenerateEntity, app_id=app_record.id, external_data_tools=external_data_tools, inputs=inputs, - query=query + query=query, ) # get context from datasets @@ -108,7 +108,7 @@ def run(self, application_generate_entity: CompletionAppGenerateEntity, app_record.id, message.id, application_generate_entity.user_id, - application_generate_entity.invoke_from + application_generate_entity.invoke_from, ) dataset_config = app_config.dataset @@ -126,7 +126,7 @@ def run(self, application_generate_entity: CompletionAppGenerateEntity, invoke_from=application_generate_entity.invoke_from, show_retrieve_source=app_config.additional_features.show_retrieve_source, hit_callback=hit_callback, - message_id=message.id + message_id=message.id, ) # reorganize all inputs and template to prompt messages @@ -139,29 +139,26 @@ def run(self, application_generate_entity: CompletionAppGenerateEntity, inputs=inputs, files=files, query=query, - context=context + context=context, ) # check hosting moderation hosting_moderation_result = self.check_hosting_moderation( application_generate_entity=application_generate_entity, queue_manager=queue_manager, - prompt_messages=prompt_messages + prompt_messages=prompt_messages, ) if hosting_moderation_result: return # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit - self.recalc_llm_max_tokens( - model_config=application_generate_entity.model_conf, - prompt_messages=prompt_messages - ) + self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages) # Invoke model model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, - model=application_generate_entity.model_conf.model + model=application_generate_entity.model_conf.model, ) db.session.close() @@ -176,8 +173,5 @@ def run(self, application_generate_entity: CompletionAppGenerateEntity, # handle invoke result self._handle_invoke_result( - invoke_result=invoke_result, - queue_manager=queue_manager, - stream=application_generate_entity.stream + invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream ) - \ No newline at end of file diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index 14db74dbd04b95..697f0273a5673e 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -23,14 +23,14 @@ def convert_blocking_full_response(cls, blocking_response: CompletionAppBlocking :return: """ response = { - 'event': 'message', - 'task_id': blocking_response.task_id, - 'id': blocking_response.data.id, - 'message_id': blocking_response.data.message_id, - 'mode': blocking_response.data.mode, - 'answer': blocking_response.data.answer, - 'metadata': blocking_response.data.metadata, - 'created_at': blocking_response.data.created_at + "event": "message", + "task_id": blocking_response.task_id, + "id": blocking_response.data.id, + "message_id": blocking_response.data.message_id, + "mode": blocking_response.data.mode, + "answer": blocking_response.data.answer, + "metadata": blocking_response.data.metadata, + "created_at": blocking_response.data.created_at, } return response @@ -44,14 +44,15 @@ def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlocki """ response = cls.convert_blocking_full_response(blocking_response) - metadata = response.get('metadata', {}) - response['metadata'] = cls._get_simple_metadata(metadata) + metadata = response.get("metadata", {}) + response["metadata"] = cls._get_simple_metadata(metadata) return response @classmethod - def convert_stream_full_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_full_response( + cls, stream_response: Generator[CompletionAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -62,13 +63,13 @@ def convert_stream_full_response(cls, stream_response: Generator[CompletionAppSt sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -79,8 +80,9 @@ def convert_stream_full_response(cls, stream_response: Generator[CompletionAppSt yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[CompletionAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream simple response. :param stream_response: stream response @@ -91,19 +93,19 @@ def convert_stream_simple_response(cls, stream_response: Generator[CompletionApp sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'message_id': chunk.message_id, - 'created_at': chunk.created_at + "event": sub_stream_response.event.value, + "message_id": chunk.message_id, + "created_at": chunk.created_at, } if isinstance(sub_stream_response, MessageEndStreamResponse): sub_stream_response_dict = sub_stream_response.to_dict() - metadata = sub_stream_response_dict.get('metadata', {}) - sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata) + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 49ef7b7b4010eb..bae64368e3f384 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -1,13 +1,14 @@ import json import logging from collections.abc import Generator +from datetime import datetime, timezone from typing import Optional, Union from sqlalchemy import and_ from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, AgentChatAppGenerateEntity, @@ -25,7 +26,8 @@ from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db -from models.account import Account +from models import Account +from models.enums import CreatedByRole from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError @@ -34,13 +36,13 @@ class MessageBasedAppGenerator(BaseAppGenerator): - def _handle_response( - self, application_generate_entity: Union[ + self, + application_generate_entity: Union[ ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity + AdvancedChatAppGenerateEntity, ], queue_manager: AppQueueManager, conversation: Conversation, @@ -50,7 +52,7 @@ def _handle_response( ) -> Union[ ChatbotAppBlockingResponse, CompletionAppBlockingResponse, - Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None] + Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None], ]: """ Handle response. @@ -69,24 +71,25 @@ def _handle_response( conversation=conversation, message=message, user=user, - stream=stream + stream=stream, ) try: return generate_task_pipeline.process() except ValueError as e: if e.args[0] == "I/O operation on closed file.": # ignore this error - raise GenerateTaskStoppedException() + raise GenerateTaskStoppedError() else: logger.exception(e) raise e - def _get_conversation_by_user(self, app_model: App, conversation_id: str, - user: Union[Account, EndUser]) -> Conversation: + def _get_conversation_by_user( + self, app_model: App, conversation_id: str, user: Union[Account, EndUser] + ) -> Conversation: conversation_filter = [ Conversation.id == conversation_id, Conversation.app_id == app_model.id, - Conversation.status == 'normal' + Conversation.status == "normal", ] if isinstance(user, Account): @@ -99,19 +102,18 @@ def _get_conversation_by_user(self, app_model: App, conversation_id: str, if not conversation: raise ConversationNotExistsError() - if conversation.status != 'normal': + if conversation.status != "normal": raise ConversationCompletedError() return conversation - def _get_app_model_config(self, app_model: App, - conversation: Optional[Conversation] = None) \ - -> AppModelConfig: + def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig: if conversation: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation.app_model_config_id, - AppModelConfig.app_id == app_model.id - ).first() + app_model_config = ( + db.session.query(AppModelConfig) + .filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id) + .first() + ) if not app_model_config: raise AppModelConfigBrokenError() @@ -126,15 +128,16 @@ def _get_app_model_config(self, app_model: App, return app_model_config - def _init_generate_records(self, - application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity - ], - conversation: Optional[Conversation] = None) \ - -> tuple[Conversation, Message]: + def _init_generate_records( + self, + application_generate_entity: Union[ + ChatAppGenerateEntity, + CompletionAppGenerateEntity, + AgentChatAppGenerateEntity, + AdvancedChatAppGenerateEntity, + ], + conversation: Optional[Conversation] = None, + ) -> tuple[Conversation, Message]: """ Initialize generate records :param application_generate_entity: application generate entity @@ -146,11 +149,11 @@ def _init_generate_records(self, # get from source end_user_id = None account_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - from_source = 'api' + if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: + from_source = "api" end_user_id = application_generate_entity.user_id else: - from_source = 'console' + from_source = "console" account_id = application_generate_entity.user_id if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity): @@ -163,8 +166,11 @@ def _init_generate_records(self, model_provider = application_generate_entity.model_conf.provider model_id = application_generate_entity.model_conf.model override_model_configs = None - if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS \ - and app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]: + if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in { + AppMode.AGENT_CHAT, + AppMode.CHAT, + AppMode.COMPLETION, + }: override_model_configs = app_config.app_model_config_dict # get conversation introduction @@ -178,12 +184,12 @@ def _init_generate_records(self, model_id=model_id, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, mode=app_config.app_mode.value, - name='New conversation', + name="New conversation", inputs=application_generate_entity.inputs, introduction=introduction, system_instruction="", system_instruction_tokens=0, - status='normal', + status="normal", invoke_from=application_generate_entity.invoke_from.value, from_source=from_source, from_end_user_id=end_user_id, @@ -193,6 +199,9 @@ def _init_generate_records(self, db.session.add(conversation) db.session.commit() db.session.refresh(conversation) + else: + conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + db.session.commit() message = Message( app_id=app_config.app_id, @@ -210,13 +219,14 @@ def _init_generate_records(self, answer_tokens=0, answer_unit_price=0, answer_price_unit=0, + parent_message_id=getattr(application_generate_entity, "parent_message_id", None), provider_response_latency=0, total_price=0, - currency='USD', + currency="USD", invoke_from=application_generate_entity.invoke_from.value, from_source=from_source, from_end_user_id=end_user_id, - from_account_id=account_id + from_account_id=account_id, ) db.session.add(message) @@ -226,13 +236,13 @@ def _init_generate_records(self, for file in application_generate_entity.files: message_file = MessageFile( message_id=message.id, - type=file.type.value, - transfer_method=file.transfer_method.value, - belongs_to='user', - url=file.url, + type=file.type, + transfer_method=file.transfer_method, + belongs_to="user", + url=file.remote_url, upload_file_id=file.related_id, - created_by_role=('account' if account_id else 'end_user'), - created_by=account_id or end_user_id, + created_by_role=(CreatedByRole.ACCOUNT if account_id else CreatedByRole.END_USER), + created_by=account_id or end_user_id or "", ) db.session.add(message_file) db.session.commit() @@ -265,11 +275,7 @@ def _get_conversation(self, conversation_id: str): :param conversation_id: conversation id :return: conversation """ - conversation = ( - db.session.query(Conversation) - .filter(Conversation.id == conversation_id) - .first() - ) + conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() if not conversation: raise ConversationNotExistsError() @@ -282,10 +288,6 @@ def _get_message(self, message_id: str) -> Message: :param message_id: message id :return: message """ - message = ( - db.session.query(Message) - .filter(Message.id == message_id) - .first() - ) + message = db.session.query(Message).filter(Message.id == message_id).first() return message diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py index f4ff44dddac9ef..363c3c82bbc24e 100644 --- a/api/core/app/apps/message_based_app_queue_manager.py +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -1,4 +1,4 @@ -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, @@ -12,12 +12,9 @@ class MessageBasedAppQueueManager(AppQueueManager): - def __init__(self, task_id: str, - user_id: str, - invoke_from: InvokeFrom, - conversation_id: str, - app_mode: str, - message_id: str) -> None: + def __init__( + self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str + ) -> None: super().__init__(task_id, user_id, invoke_from) self._conversation_id = str(conversation_id) @@ -30,7 +27,7 @@ def construct_queue_message(self, event: AppQueueEvent) -> QueueMessage: message_id=self._message_id, conversation_id=self._conversation_id, app_mode=self._app_mode, - event=event + event=event, ) def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: @@ -45,17 +42,15 @@ def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: message_id=self._message_id, conversation_id=self._conversation_id, app_mode=self._app_mode, - event=event + event=event, ) self._q.put(message) - if isinstance(event, QueueStopEvent - | QueueErrorEvent - | QueueMessageEndEvent - | QueueAdvancedChatMessageEndEvent): + if isinstance( + event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent | QueueAdvancedChatMessageEndEvent + ): self.stop_listen() if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): - raise GenerateTaskStoppedException() - + raise GenerateTaskStoppedError() diff --git a/api/core/app/apps/workflow/app_config_manager.py b/api/core/app/apps/workflow/app_config_manager.py index 36d3696d601da4..b0aa21c7317b65 100644 --- a/api/core/app/apps/workflow/app_config_manager.py +++ b/api/core/app/apps/workflow/app_config_manager.py @@ -12,6 +12,7 @@ class WorkflowAppConfig(WorkflowUIBasedAppConfig): """ Workflow App Config Entity. """ + pass @@ -26,13 +27,9 @@ def get_app_config(cls, app_model: App, workflow: Workflow) -> WorkflowAppConfig app_id=app_model.id, app_mode=app_mode, workflow_id=workflow.id, - sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert( - config=features_dict - ), - variables=WorkflowVariablesConfigManager.convert( - workflow=workflow - ), - additional_features=cls.convert_features(features_dict, app_mode) + sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict), + variables=WorkflowVariablesConfigManager.convert(workflow=workflow), + additional_features=cls.convert_features(features_dict, app_mode), ) return app_config @@ -49,10 +46,7 @@ def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: related_config_keys = [] # file upload validation - config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( - config=config, - is_vision=False - ) + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config) related_config_keys.extend(current_related_config_keys) # text_to_speech @@ -61,9 +55,7 @@ def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: # moderation validation config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( - tenant_id=tenant_id, - config=config, - only_structure_validate=only_structure_validate + tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate ) related_config_keys.extend(current_related_config_keys) diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index df40aec154a856..a0080ece2016a5 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -1,18 +1,18 @@ import contextvars import logging -import os import threading import uuid -from collections.abc import Generator -from typing import Union +from collections.abc import Generator, Mapping, Sequence +from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app from pydantic import ValidationError import contexts +from configs import dify_config from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner @@ -20,74 +20,90 @@ from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse -from core.file.message_file_parser import MessageFileParser from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db -from models.account import Account -from models.model import App, EndUser -from models.workflow import Workflow +from factories import file_factory +from models import Account, App, EndUser, Workflow logger = logging.getLogger(__name__) class WorkflowAppGenerator(BaseAppGenerator): + @overload def generate( - self, app_model: App, + self, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], args: dict, invoke_from: InvokeFrom, + stream: Literal[True] = True, + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None, + ) -> Generator[str, None, None]: ... + + @overload + def generate( + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[False] = False, + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None, + ) -> dict: ... + + def generate( + self, + app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, stream: bool = True, call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None, ): - """ - Generate App response. - - :param app_model: App - :param workflow: Workflow - :param user: account or end user - :param args: request args - :param invoke_from: invoke from source - :param stream: is stream - :param call_depth: call depth - """ - inputs = args['inputs'] + files: Sequence[Mapping[str, Any]] = args.get("files") or [] # parse files - files = args['files'] if args.get('files') else [] - message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) - if file_extra_config: - file_objs = message_file_parser.validate_and_transform_files_arg( - files, - file_extra_config, - user - ) - else: - file_objs = [] + system_files = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + ) # convert to app config app_config = WorkflowAppConfigManager.get_app_config( app_model=app_model, - workflow=workflow + workflow=workflow, ) # get tracing instance - user_id = user.id if isinstance(user, Account) else user.session_id - trace_manager = TraceQueueManager(app_model.id, user_id) + trace_manager = TraceQueueManager( + app_id=app_model.id, + user_id=user.id if isinstance(user, Account) else user.session_id, + ) + inputs: Mapping[str, Any] = args["inputs"] + workflow_run_id = str(uuid.uuid4()) # init application generate entity application_generate_entity = WorkflowAppGenerateEntity( task_id=str(uuid.uuid4()), app_config=app_config, - inputs=self._get_cleaned_inputs(inputs, app_config), - files=file_objs, + file_upload_config=file_extra_config, + inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), + files=system_files, user_id=user.id, stream=stream, invoke_from=invoke_from, call_depth=call_depth, - trace_manager=trace_manager + trace_manager=trace_manager, + workflow_run_id=workflow_run_id, ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) @@ -98,16 +114,20 @@ def generate( application_generate_entity=application_generate_entity, invoke_from=invoke_from, stream=stream, + workflow_thread_pool_id=workflow_thread_pool_id, ) def _generate( - self, app_model: App, + self, + *, + app_model: App, workflow: Workflow, user: Union[Account, EndUser], application_generate_entity: WorkflowAppGenerateEntity, invoke_from: InvokeFrom, stream: bool = True, - ) -> Union[dict, Generator[dict, None, None]]: + workflow_thread_pool_id: Optional[str] = None, + ) -> dict[str, Any] | Generator[str, None, None]: """ Generate App response. @@ -117,22 +137,27 @@ def _generate( :param application_generate_entity: application generate entity :param invoke_from: invoke from source :param stream: is stream + :param workflow_thread_pool_id: workflow thread pool id """ # init queue manager queue_manager = WorkflowAppQueueManager( task_id=application_generate_entity.task_id, user_id=application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from, - app_mode=app_model.mode + app_mode=app_model.mode, ) # new thread - worker_thread = threading.Thread(target=self._generate_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'application_generate_entity': application_generate_entity, - 'queue_manager': queue_manager, - 'context': contextvars.copy_context() - }) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "context": contextvars.copy_context(), + "workflow_thread_pool_id": workflow_thread_pool_id, + }, + ) worker_thread.start() @@ -145,17 +170,11 @@ def _generate( stream=stream, ) - return WorkflowAppGenerateResponseConverter.convert( - response=response, - invoke_from=invoke_from - ) + return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) - def single_iteration_generate(self, app_model: App, - workflow: Workflow, - node_id: str, - user: Account, - args: dict, - stream: bool = True): + def single_iteration_generate( + self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True + ) -> dict[str, Any] | Generator[str, Any, None]: """ Generate App response. @@ -167,20 +186,13 @@ def single_iteration_generate(self, app_model: App, :param stream: is stream """ if not node_id: - raise ValueError('node_id is required') + raise ValueError("node_id is required") - if args.get('inputs') is None: - raise ValueError('inputs is required') - - extras = { - "auto_generate_conversation_name": False - } + if args.get("inputs") is None: + raise ValueError("inputs is required") # convert to app config - app_config = WorkflowAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) # init application generate entity application_generate_entity = WorkflowAppGenerateEntity( @@ -191,11 +203,10 @@ def single_iteration_generate(self, app_model: App, user_id=user.id, stream=stream, invoke_from=InvokeFrom.DEBUGGER, - extras=extras, + extras={"auto_generate_conversation_name": False}, single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( - node_id=node_id, - inputs=args['inputs'] - ) + node_id=node_id, inputs=args["inputs"] + ), ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) @@ -205,18 +216,23 @@ def single_iteration_generate(self, app_model: App, user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, - stream=stream + stream=stream, ) - def _generate_worker(self, flask_app: Flask, - application_generate_entity: WorkflowAppGenerateEntity, - queue_manager: AppQueueManager, - context: contextvars.Context) -> None: + def _generate_worker( + self, + flask_app: Flask, + application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager, + context: contextvars.Context, + workflow_thread_pool_id: Optional[str] = None, + ) -> None: """ Generate worker in a new thread. :param flask_app: Flask app :param application_generate_entity: application generate entity :param queue_manager: queue manager + :param workflow_thread_pool_id: workflow thread pool id :return: """ for var, val in context.items(): @@ -224,50 +240,40 @@ def _generate_worker(self, flask_app: Flask, with flask_app.app_context(): try: # workflow app - runner = WorkflowAppRunner() - if application_generate_entity.single_iteration_run: - single_iteration_run = application_generate_entity.single_iteration_run - runner.single_iteration_run( - app_id=application_generate_entity.app_config.app_id, - workflow_id=application_generate_entity.app_config.workflow_id, - queue_manager=queue_manager, - inputs=single_iteration_run.inputs, - node_id=single_iteration_run.node_id, - user_id=application_generate_entity.user_id - ) - else: - runner.run( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager - ) - except GenerateTaskStoppedException: + runner = WorkflowAppRunner( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + workflow_thread_pool_id=workflow_thread_pool_id, + ) + + runner.run() + except GenerateTaskStoppedError: pass except InvokeAuthorizationError: queue_manager.publish_error( - InvokeAuthorizationError('Incorrect API key provided'), - PublishFrom.APPLICATION_MANAGER + InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER ) except ValidationError as e: logger.exception("Validation Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except (ValueError, InvokeError) as e: - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + if dify_config.DEBUG: logger.exception("Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) except Exception as e: logger.exception("Unknown Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) finally: - db.session.remove() + db.session.close() - def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity, - workflow: Workflow, - queue_manager: AppQueueManager, - user: Union[Account, EndUser], - stream: bool = False) -> Union[ - WorkflowAppBlockingResponse, - Generator[WorkflowAppStreamResponse, None, None] - ]: + def _handle_response( + self, + application_generate_entity: WorkflowAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + stream: bool = False, + ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ Handle response. :param application_generate_entity: application generate entity @@ -283,14 +289,14 @@ def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntit workflow=workflow, queue_manager=queue_manager, user=user, - stream=stream + stream=stream, ) try: return generate_task_pipeline.process() except ValueError as e: if e.args[0] == "I/O operation on closed file.": # ignore this error - raise GenerateTaskStoppedException() + raise GenerateTaskStoppedError() else: logger.exception(e) raise e diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py index f448138b53c0c2..76371f800ba1e5 100644 --- a/api/core/app/apps/workflow/app_queue_manager.py +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -1,4 +1,4 @@ -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, @@ -12,10 +12,7 @@ class WorkflowAppQueueManager(AppQueueManager): - def __init__(self, task_id: str, - user_id: str, - invoke_from: InvokeFrom, - app_mode: str) -> None: + def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None: super().__init__(task_id, user_id, invoke_from) self._app_mode = app_mode @@ -27,20 +24,19 @@ def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: :param pub_from: :return: """ - message = WorkflowQueueMessage( - task_id=self._task_id, - app_mode=self._app_mode, - event=event - ) + message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event) self._q.put(message) - if isinstance(event, QueueStopEvent - | QueueErrorEvent - | QueueMessageEndEvent - | QueueWorkflowSucceededEvent - | QueueWorkflowFailedEvent): + if isinstance( + event, + QueueStopEvent + | QueueErrorEvent + | QueueMessageEndEvent + | QueueWorkflowSucceededEvent + | QueueWorkflowFailedEvent, + ): self.stop_listen() if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): - raise GenerateTaskStoppedException() + raise GenerateTaskStoppedError() diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 994919391e7ed5..faefcb0ed50629 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -1,132 +1,130 @@ import logging -import os from typing import Optional, cast +from configs import dify_config from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfig -from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback -from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import ( InvokeFrom, WorkflowAppGenerateEntity, ) -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback +from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable -from core.workflow.nodes.base_node import UserFrom -from core.workflow.workflow_engine_manager import WorkflowEngineManager +from core.workflow.enums import SystemVariableKey +from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db +from models.enums import UserFrom from models.model import App, EndUser -from models.workflow import Workflow +from models.workflow import WorkflowType logger = logging.getLogger(__name__) -class WorkflowAppRunner: +class WorkflowAppRunner(WorkflowBasedAppRunner): """ Workflow Application Runner """ - def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None: + def __init__( + self, + application_generate_entity: WorkflowAppGenerateEntity, + queue_manager: AppQueueManager, + workflow_thread_pool_id: Optional[str] = None, + ) -> None: + """ + :param application_generate_entity: application generate entity + :param queue_manager: application queue manager + :param workflow_thread_pool_id: workflow thread pool id + """ + self.application_generate_entity = application_generate_entity + self.queue_manager = queue_manager + self.workflow_thread_pool_id = workflow_thread_pool_id + + def run(self) -> None: """ Run application :param application_generate_entity: application generate entity :param queue_manager: application queue manager :return: """ - app_config = application_generate_entity.app_config + app_config = self.application_generate_entity.app_config app_config = cast(WorkflowAppConfig, app_config) user_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: - end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first() + if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: + end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() if end_user: user_id = end_user.session_id else: - user_id = application_generate_entity.user_id + user_id = self.application_generate_entity.user_id app_record = db.session.query(App).filter(App.id == app_config.app_id).first() if not app_record: - raise ValueError('App not found') + raise ValueError("App not found") workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) if not workflow: - raise ValueError('Workflow not initialized') - - inputs = application_generate_entity.inputs - files = application_generate_entity.files + raise ValueError("Workflow not initialized") db.session.close() - workflow_callbacks: list[WorkflowCallback] = [ - WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow) - ] - - if bool(os.environ.get('DEBUG', 'False').lower() == 'true'): + workflow_callbacks: list[WorkflowCallback] = [] + if dify_config.DEBUG: workflow_callbacks.append(WorkflowLoggingCallback()) - # Create a variable pool. - system_inputs = { - SystemVariable.FILES: files, - SystemVariable.USER_ID: user_id, - } - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=workflow.environment_variables, - conversation_variables=[], - ) + # if only single iteration run is requested + if self.application_generate_entity.single_iteration_run: + # if only single iteration run is requested + graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( + workflow=workflow, + node_id=self.application_generate_entity.single_iteration_run.node_id, + user_inputs=self.application_generate_entity.single_iteration_run.inputs, + ) + else: + inputs = self.application_generate_entity.inputs + files = self.application_generate_entity.files + + # Create a variable pool. + system_inputs = { + SystemVariableKey.FILES: files, + SystemVariableKey.USER_ID: user_id, + SystemVariableKey.APP_ID: app_config.app_id, + SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, + SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id, + } + + variable_pool = VariablePool( + system_variables=system_inputs, + user_inputs=inputs, + environment_variables=workflow.environment_variables, + conversation_variables=[], + ) + + # init graph + graph = self._init_graph(graph_config=workflow.graph_dict) # RUN WORKFLOW - workflow_engine_manager = WorkflowEngineManager() - workflow_engine_manager.run_workflow( - workflow=workflow, - user_id=application_generate_entity.user_id, - user_from=UserFrom.ACCOUNT - if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] - else UserFrom.END_USER, - invoke_from=application_generate_entity.invoke_from, - callbacks=workflow_callbacks, - call_depth=application_generate_entity.call_depth, + workflow_entry = WorkflowEntry( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + workflow_type=WorkflowType.value_of(workflow.type), + graph=graph, + graph_config=workflow.graph_dict, + user_id=self.application_generate_entity.user_id, + user_from=( + UserFrom.ACCOUNT + if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else UserFrom.END_USER + ), + invoke_from=self.application_generate_entity.invoke_from, + call_depth=self.application_generate_entity.call_depth, variable_pool=variable_pool, + thread_pool_id=self.workflow_thread_pool_id, ) - def single_iteration_run( - self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str - ) -> None: - """ - Single iteration run - """ - app_record = db.session.query(App).filter(App.id == app_id).first() - if not app_record: - raise ValueError('App not found') - - if not app_record.workflow_id: - raise ValueError('Workflow not initialized') - - workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id) - if not workflow: - raise ValueError('Workflow not initialized') - - workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)] - - workflow_engine_manager = WorkflowEngineManager() - workflow_engine_manager.single_step_run_iteration_workflow_node( - workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks - ) - - def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: - """ - Get workflow - """ - # fetch workflow by workflow_id - workflow = ( - db.session.query(Workflow) - .filter( - Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id - ) - .first() - ) + generator = workflow_entry.run(callbacks=workflow_callbacks) - # return workflow - return workflow + for event in generator: + self._handle_event(workflow_entry, event) diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index 88bde58ba049ba..08d00ee1805aa2 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -35,8 +35,9 @@ def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlocking return cls.convert_blocking_full_response(blocking_response) @classmethod - def convert_stream_full_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_full_response( + cls, stream_response: Generator[WorkflowAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream full response. :param stream_response: stream response @@ -47,12 +48,12 @@ def convert_stream_full_response(cls, stream_response: Generator[WorkflowAppStre sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'workflow_run_id': chunk.workflow_run_id, + "event": sub_stream_response.event.value, + "workflow_run_id": chunk.workflow_run_id, } if isinstance(sub_stream_response, ErrorStreamResponse): @@ -63,8 +64,9 @@ def convert_stream_full_response(cls, stream_response: Generator[WorkflowAppStre yield json.dumps(response_chunk) @classmethod - def convert_stream_simple_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \ - -> Generator[str, None, None]: + def convert_stream_simple_response( + cls, stream_response: Generator[WorkflowAppStreamResponse, None, None] + ) -> Generator[str, None, None]: """ Convert stream simple response. :param stream_response: stream response @@ -75,12 +77,12 @@ def convert_stream_simple_response(cls, stream_response: Generator[WorkflowAppSt sub_stream_response = chunk.stream_response if isinstance(sub_stream_response, PingStreamResponse): - yield 'ping' + yield "ping" continue response_chunk = { - 'event': sub_stream_response.event.value, - 'workflow_run_id': chunk.workflow_run_id, + "event": sub_stream_response.event.value, + "workflow_run_id": chunk.workflow_run_id, } if isinstance(sub_stream_response, ErrorStreamResponse): diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 5022eb0438d13b..aaa4824fe8421c 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -15,10 +15,13 @@ QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, - QueueMessageReplaceEvent, QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueParallelBranchRunSucceededEvent, QueuePingEvent, QueueStopEvent, QueueTextChunkEvent, @@ -32,19 +35,16 @@ MessageAudioStreamResponse, StreamResponse, TextChunkStreamResponse, - TextReplaceStreamResponse, WorkflowAppBlockingResponse, WorkflowAppStreamResponse, WorkflowFinishStreamResponse, - WorkflowStreamGenerateNodes, + WorkflowStartStreamResponse, WorkflowTaskState, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.node_entities import NodeType -from core.workflow.enums import SystemVariable -from core.workflow.nodes.end.end_node import EndNode +from core.workflow.enums import SystemVariableKey from extensions.ext_database import db from models.account import Account from models.model import EndUser @@ -54,6 +54,7 @@ WorkflowAppLogCreatedFrom, WorkflowNodeExecution, WorkflowRun, + WorkflowRunStatus, ) logger = logging.getLogger(__name__) @@ -63,18 +64,22 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa """ WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. """ + _workflow: Workflow _user: Union[Account, EndUser] _task_state: WorkflowTaskState _application_generate_entity: WorkflowAppGenerateEntity - _workflow_system_variables: dict[SystemVariable, Any] - _iteration_nested_relations: dict[str, list[str]] - - def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, - workflow: Workflow, - queue_manager: AppQueueManager, - user: Union[Account, EndUser], - stream: bool) -> None: + _workflow_system_variables: dict[SystemVariableKey, Any] + _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] + + def __init__( + self, + application_generate_entity: WorkflowAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + stream: bool, + ) -> None: """ Initialize GenerateTaskPipeline. :param application_generate_entity: application generate entity @@ -92,15 +97,15 @@ def __init__(self, application_generate_entity: WorkflowAppGenerateEntity, self._workflow = workflow self._workflow_system_variables = { - SystemVariable.FILES: application_generate_entity.files, - SystemVariable.USER_ID: user_id + SystemVariableKey.FILES: application_generate_entity.files, + SystemVariableKey.USER_ID: user_id, + SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, + SystemVariableKey.WORKFLOW_ID: workflow.id, + SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, } - self._task_state = WorkflowTaskState( - iteration_nested_node_ids=[] - ) - self._stream_generate_nodes = self._get_stream_generate_nodes() - self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict) + self._task_state = WorkflowTaskState() + self._wip_workflow_node_executions = {} def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ @@ -111,16 +116,13 @@ def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStr db.session.refresh(self._user) db.session.close() - generator = self._wrapper_process_stream_response( - trace_manager=self._application_generate_entity.trace_manager - ) + generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) if self._stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) - def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \ - -> WorkflowAppBlockingResponse: + def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse: """ To blocking response. :return: @@ -129,66 +131,69 @@ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None] if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err elif isinstance(stream_response, WorkflowFinishStreamResponse): - workflow_run = db.session.query(WorkflowRun).filter( - WorkflowRun.id == self._task_state.workflow_run_id).first() - response = WorkflowAppBlockingResponse( task_id=self._application_generate_entity.task_id, - workflow_run_id=workflow_run.id, + workflow_run_id=stream_response.data.id, data=WorkflowAppBlockingResponse.Data( - id=workflow_run.id, - workflow_id=workflow_run.workflow_id, - status=workflow_run.status, - outputs=workflow_run.outputs_dict, - error=workflow_run.error, - elapsed_time=workflow_run.elapsed_time, - total_tokens=workflow_run.total_tokens, - total_steps=workflow_run.total_steps, - created_at=int(workflow_run.created_at.timestamp()), - finished_at=int(workflow_run.finished_at.timestamp()) - ) + id=stream_response.data.id, + workflow_id=stream_response.data.workflow_id, + status=stream_response.data.status, + outputs=stream_response.data.outputs, + error=stream_response.data.error, + elapsed_time=stream_response.data.elapsed_time, + total_tokens=stream_response.data.total_tokens, + total_steps=stream_response.data.total_steps, + created_at=int(stream_response.data.created_at), + finished_at=int(stream_response.data.finished_at), + ), ) return response else: continue - raise Exception('Queue listening stopped unexpectedly.') + raise Exception("Queue listening stopped unexpectedly.") - def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \ - -> Generator[WorkflowAppStreamResponse, None, None]: + def _to_stream_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Generator[WorkflowAppStreamResponse, None, None]: """ To stream response. :return: """ + workflow_run_id = None for stream_response in generator: - yield WorkflowAppStreamResponse( - workflow_run_id=self._task_state.workflow_run_id, - stream_response=stream_response - ) + if isinstance(stream_response, WorkflowStartStreamResponse): + workflow_run_id = stream_response.workflow_run_id + + yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response) - def _listenAudioMsg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher, task_id: str): if not publisher: return None - audio_msg: AudioTrunk = publisher.checkAndGetAudio() + audio_msg: AudioTrunk = publisher.check_and_get_audio() if audio_msg and audio_msg.status != "finish": return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None - def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ - Generator[StreamResponse, None, None]: - - publisher = None + def _wrapper_process_stream_response( + self, trace_manager: Optional[TraceQueueManager] = None + ) -> Generator[StreamResponse, None, None]: + tts_publisher = None task_id = self._application_generate_entity.task_id tenant_id = self._application_generate_entity.app_config.tenant_id features_dict = self._workflow.features_dict - if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ - 'text_to_speech'].get('autoPlay') == 'enabled': - publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) - for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): + if ( + features_dict.get("text_to_speech") + and features_dict["text_to_speech"].get("enabled") + and features_dict["text_to_speech"].get("autoPlay") == "enabled" + ): + tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice")) + + for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listenAudioMsg(publisher, task_id=task_id) + audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -198,9 +203,9 @@ def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueMan start_listener_time = time.time() while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: try: - if not publisher: + if not tts_publisher: break - audio_trunk = publisher.checkAndGetAudio() + audio_trunk = tts_publisher.check_and_get_audio() if audio_trunk is None: # release cpu # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) @@ -211,107 +216,179 @@ def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueMan else: yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) except Exception as e: - logger.error(e) + logger.exception(e) break - yield MessageAudioEndStreamResponse(audio='', task_id=task_id) - + if tts_publisher: + yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( self, - publisher: AppGeneratorTTSPublisher, - trace_manager: Optional[TraceQueueManager] = None + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, + trace_manager: Optional[TraceQueueManager] = None, ) -> Generator[StreamResponse, None, None]: """ Process stream response. :return: """ - for message in self._queue_manager.listen(): - if publisher: - publisher.publish(message=message) - event = message.event + graph_runtime_state = None + workflow_run = None + + for queue_message in self._queue_manager.listen(): + event = queue_message.event - if isinstance(event, QueueErrorEvent): + if isinstance(event, QueuePingEvent): + yield self._ping_stream_response() + elif isinstance(event, QueueErrorEvent): err = self._handle_error(event) yield self._error_to_stream_response(err) break elif isinstance(event, QueueWorkflowStartedEvent): - workflow_run = self._handle_workflow_start() + # override graph runtime state + graph_runtime_state = event.graph_runtime_state + + # init workflow run + workflow_run = self._handle_workflow_run_start() yield self._workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) elif isinstance(event, QueueNodeStartedEvent): - workflow_node_execution = self._handle_node_start(event) + if not workflow_run: + raise Exception("Workflow run not initialized.") + + workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) - # search stream_generate_routes if node id is answer start at node - if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_nodes: - self._task_state.current_stream_generate_state = self._stream_generate_nodes[event.node_id] + response = self._workflow_node_start_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) - # generate stream outputs when node started - yield from self._generate_stream_outputs_when_node_started() + if response: + yield response + elif isinstance(event, QueueNodeSucceededEvent): + workflow_node_execution = self._handle_workflow_node_execution_success(event) - yield self._workflow_node_start_to_stream_response( + response = self._workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, ) - elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent): - workflow_node_execution = self._handle_node_finished(event) - yield self._workflow_node_finish_to_stream_response( + if response: + yield response + elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent): + workflow_node_execution = self._handle_workflow_node_execution_failed(event) + + response = self._workflow_node_finish_to_stream_response( + event=event, task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution + workflow_node_execution=workflow_node_execution, + ) + + if response: + yield response + elif isinstance(event, QueueParallelBranchRunStartedEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + yield self._workflow_parallel_branch_start_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event + ) + elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + yield self._workflow_parallel_branch_finished_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event + ) + elif isinstance(event, QueueIterationStartEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + yield self._workflow_iteration_start_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) + elif isinstance(event, QueueIterationNextEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") - if isinstance(event, QueueNodeFailedEvent): - yield from self._handle_iteration_exception( - task_id=self._application_generate_entity.task_id, - error=f'Child node failed: {event.error}' - ) - elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent): - if isinstance(event, QueueIterationNextEvent): - # clear ran node execution infos of current iteration - iteration_relations = self._iteration_nested_relations.get(event.node_id) - if iteration_relations: - for node_id in iteration_relations: - self._task_state.ran_node_execution_infos.pop(node_id, None) - - yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event) - self._handle_iteration_operation(event) - elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent): - workflow_run = self._handle_workflow_finished( - event, trace_manager=trace_manager + yield self._workflow_iteration_next_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event + ) + elif isinstance(event, QueueIterationCompletedEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + yield self._workflow_iteration_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event + ) + elif isinstance(event, QueueWorkflowSucceededEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + if not graph_runtime_state: + raise Exception("Graph runtime state not initialized.") + + workflow_run = self._handle_workflow_run_success( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + conversation_id=None, + trace_manager=trace_manager, ) # save workflow app log self._save_workflow_app_log(workflow_run) yield self._workflow_finish_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + if not graph_runtime_state: + raise Exception("Graph runtime state not initialized.") + + workflow_run = self._handle_workflow_run_failed( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowRunStatus.FAILED + if isinstance(event, QueueWorkflowFailedEvent) + else WorkflowRunStatus.STOPPED, + error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), + conversation_id=None, + trace_manager=trace_manager, + ) + + # save workflow app log + self._save_workflow_app_log(workflow_run) + + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) elif isinstance(event, QueueTextChunkEvent): delta_text = event.text if delta_text is None: continue - if not self._is_stream_out_support( - event=event - ): - continue + # only publish tts message at text chunk streaming + if tts_publisher: + tts_publisher.publish(message=queue_message) self._task_state.answer += delta_text - yield self._text_chunk_to_stream_response(delta_text) - elif isinstance(event, QueueMessageReplaceEvent): - yield self._text_replace_to_stream_response(event.text) - elif isinstance(event, QueuePingEvent): - yield self._ping_stream_response() + yield self._text_chunk_to_stream_response( + delta_text, from_variable_selector=event.from_variable_selector + ) else: continue - if publisher: - publisher.publish(None) - + if tts_publisher: + tts_publisher.publish(None) def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: """ @@ -329,20 +406,22 @@ def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: # not save log for debugging return - workflow_app_log = WorkflowAppLog( - tenant_id=workflow_run.tenant_id, - app_id=workflow_run.app_id, - workflow_id=workflow_run.workflow_id, - workflow_run_id=workflow_run.id, - created_from=created_from.value, - created_by_role=('account' if isinstance(self._user, Account) else 'end_user'), - created_by=self._user.id, - ) + workflow_app_log = WorkflowAppLog() + workflow_app_log.tenant_id = workflow_run.tenant_id + workflow_app_log.app_id = workflow_run.app_id + workflow_app_log.workflow_id = workflow_run.workflow_id + workflow_app_log.workflow_run_id = workflow_run.id + workflow_app_log.created_from = created_from.value + workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user" + workflow_app_log.created_by = self._user.id + db.session.add(workflow_app_log) db.session.commit() db.session.close() - def _text_chunk_to_stream_response(self, text: str) -> TextChunkStreamResponse: + def _text_chunk_to_stream_response( + self, text: str, from_variable_selector: Optional[list[str]] = None + ) -> TextChunkStreamResponse: """ Handle completed event. :param text: text @@ -350,184 +429,7 @@ def _text_chunk_to_stream_response(self, text: str) -> TextChunkStreamResponse: """ response = TextChunkStreamResponse( task_id=self._application_generate_entity.task_id, - data=TextChunkStreamResponse.Data(text=text) + data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector), ) return response - - def _text_replace_to_stream_response(self, text: str) -> TextReplaceStreamResponse: - """ - Text replace to stream response. - :param text: text - :return: - """ - return TextReplaceStreamResponse( - task_id=self._application_generate_entity.task_id, - text=TextReplaceStreamResponse.Data(text=text) - ) - - def _get_stream_generate_nodes(self) -> dict[str, WorkflowStreamGenerateNodes]: - """ - Get stream generate nodes. - :return: - """ - # find all answer nodes - graph = self._workflow.graph_dict - end_node_configs = [ - node for node in graph['nodes'] - if node.get('data', {}).get('type') == NodeType.END.value - ] - - # parse stream output node value selectors of end nodes - stream_generate_routes = {} - for node_config in end_node_configs: - # get generate route for stream output - end_node_id = node_config['id'] - generate_nodes = EndNode.extract_generate_nodes(graph, node_config) - start_node_ids = self._get_end_start_at_node_ids(graph, end_node_id) - if not start_node_ids: - continue - - for start_node_id in start_node_ids: - stream_generate_routes[start_node_id] = WorkflowStreamGenerateNodes( - end_node_id=end_node_id, - stream_node_ids=generate_nodes - ) - - return stream_generate_routes - - def _get_end_start_at_node_ids(self, graph: dict, target_node_id: str) \ - -> list[str]: - """ - Get end start at node id. - :param graph: graph - :param target_node_id: target node ID - :return: - """ - nodes = graph.get('nodes') - edges = graph.get('edges') - - # fetch all ingoing edges from source node - ingoing_edges = [] - for edge in edges: - if edge.get('target') == target_node_id: - ingoing_edges.append(edge) - - if not ingoing_edges: - return [] - - start_node_ids = [] - for ingoing_edge in ingoing_edges: - source_node_id = ingoing_edge.get('source') - source_node = next((node for node in nodes if node.get('id') == source_node_id), None) - if not source_node: - continue - - node_type = source_node.get('data', {}).get('type') - node_iteration_id = source_node.get('data', {}).get('iteration_id') - iteration_start_node_id = None - if node_iteration_id: - iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None) - iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id') - - if node_type in [ - NodeType.IF_ELSE.value, - NodeType.QUESTION_CLASSIFIER.value - ]: - start_node_id = target_node_id - start_node_ids.append(start_node_id) - elif node_type == NodeType.START.value or \ - node_iteration_id is not None and iteration_start_node_id == source_node.get('id'): - start_node_id = source_node_id - start_node_ids.append(start_node_id) - else: - sub_start_node_ids = self._get_end_start_at_node_ids(graph, source_node_id) - if sub_start_node_ids: - start_node_ids.extend(sub_start_node_ids) - - return start_node_ids - - def _generate_stream_outputs_when_node_started(self) -> Generator: - """ - Generate stream outputs. - :return: - """ - if self._task_state.current_stream_generate_state: - stream_node_ids = self._task_state.current_stream_generate_state.stream_node_ids - - for node_id, node_execution_info in self._task_state.ran_node_execution_infos.items(): - if node_id not in stream_node_ids: - continue - - node_execution_info = self._task_state.ran_node_execution_infos[node_id] - - # get chunk node execution - route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == node_execution_info.workflow_node_execution_id).first() - - if not route_chunk_node_execution: - continue - - outputs = route_chunk_node_execution.outputs_dict - - if not outputs: - continue - - # get value from outputs - text = outputs.get('text') - - if text: - self._task_state.answer += text - yield self._text_chunk_to_stream_response(text) - - db.session.close() - - def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool: - """ - Is stream out support - :param event: queue text chunk event - :return: - """ - if not event.metadata: - return False - - if 'node_id' not in event.metadata: - return False - - node_id = event.metadata.get('node_id') - node_type = event.metadata.get('node_type') - stream_output_value_selector = event.metadata.get('value_selector') - if not stream_output_value_selector: - return False - - if not self._task_state.current_stream_generate_state: - return False - - if node_id not in self._task_state.current_stream_generate_state.stream_node_ids: - return False - - if node_type != NodeType.LLM: - # only LLM support chunk stream output - return False - - return True - - def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]: - """ - Get iteration nested relations. - :param graph: graph - :return: - """ - nodes = graph.get('nodes') - - iteration_ids = [node.get('id') for node in nodes - if node.get('data', {}).get('type') in [ - NodeType.ITERATION.value, - NodeType.LOOP.value, - ]] - - return { - iteration_id: [ - node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id - ] for iteration_id in iteration_ids - } diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py deleted file mode 100644 index 4472a7e9b5a85c..00000000000000 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ /dev/null @@ -1,200 +0,0 @@ -from typing import Any, Optional - -from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.entities.queue_entities import ( - AppQueueEvent, - QueueIterationCompletedEvent, - QueueIterationNextEvent, - QueueIterationStartEvent, - QueueNodeFailedEvent, - QueueNodeStartedEvent, - QueueNodeSucceededEvent, - QueueTextChunkEvent, - QueueWorkflowFailedEvent, - QueueWorkflowStartedEvent, - QueueWorkflowSucceededEvent, -) -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType -from models.workflow import Workflow - - -class WorkflowEventTriggerCallback(WorkflowCallback): - - def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): - self._queue_manager = queue_manager - - def on_workflow_run_started(self) -> None: - """ - Workflow run started - """ - self._queue_manager.publish( - QueueWorkflowStartedEvent(), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_run_succeeded(self) -> None: - """ - Workflow run succeeded - """ - self._queue_manager.publish( - QueueWorkflowSucceededEvent(), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_run_failed(self, error: str) -> None: - """ - Workflow run failed - """ - self._queue_manager.publish( - QueueWorkflowFailedEvent( - error=error - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_started(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> None: - """ - Workflow node execute started - """ - self._queue_manager.publish( - QueueNodeStartedEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - node_run_index=node_run_index, - predecessor_node_id=predecessor_node_id - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_succeeded(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> None: - """ - Workflow node execute succeeded - """ - self._queue_manager.publish( - QueueNodeSucceededEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - inputs=inputs, - process_data=process_data, - outputs=outputs, - execution_metadata=execution_metadata - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_node_execute_failed(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - error: str, - inputs: Optional[dict] = None, - outputs: Optional[dict] = None, - process_data: Optional[dict] = None) -> None: - """ - Workflow node execute failed - """ - self._queue_manager.publish( - QueueNodeFailedEvent( - node_id=node_id, - node_type=node_type, - node_data=node_data, - inputs=inputs, - outputs=outputs, - process_data=process_data, - error=error - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: - """ - Publish text chunk - """ - self._queue_manager.publish( - QueueTextChunkEvent( - text=text, - metadata={ - "node_id": node_id, - **metadata - } - ), PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_started(self, - node_id: str, - node_type: NodeType, - node_run_index: int = 1, - node_data: Optional[BaseNodeData] = None, - inputs: dict = None, - predecessor_node_id: Optional[str] = None, - metadata: Optional[dict] = None) -> None: - """ - Publish iteration started - """ - self._queue_manager.publish( - QueueIterationStartEvent( - node_id=node_id, - node_type=node_type, - node_run_index=node_run_index, - node_data=node_data, - inputs=inputs, - predecessor_node_id=predecessor_node_id, - metadata=metadata - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_next(self, node_id: str, - node_type: NodeType, - index: int, - node_run_index: int, - output: Optional[Any]) -> None: - """ - Publish iteration next - """ - self._queue_manager.publish( - QueueIterationNextEvent( - node_id=node_id, - node_type=node_type, - index=index, - node_run_index=node_run_index, - output=output - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_workflow_iteration_completed(self, node_id: str, - node_type: NodeType, - node_run_index: int, - outputs: dict) -> None: - """ - Publish iteration completed - """ - self._queue_manager.publish( - QueueIterationCompletedEvent( - node_id=node_id, - node_type=node_type, - node_run_index=node_run_index, - outputs=outputs - ), - PublishFrom.APPLICATION_MANAGER - ) - - def on_event(self, event: AppQueueEvent) -> None: - """ - Publish event - """ - pass diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py new file mode 100644 index 00000000000000..9a01e8a253f97b --- /dev/null +++ b/api/core/app/apps/workflow_app_runner.py @@ -0,0 +1,404 @@ +from collections.abc import Mapping +from typing import Any, Optional, cast + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.base_app_runner import AppRunner +from core.app.entities.queue_entities import ( + AppQueueEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, + QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, + QueueNodeStartedEvent, + QueueNodeSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueParallelBranchRunSucceededEvent, + QueueRetrieverResourcesEvent, + QueueTextChunkEvent, + QueueWorkflowFailedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + GraphRunFailedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + IterationRunFailedEvent, + IterationRunNextEvent, + IterationRunStartedEvent, + IterationRunSucceededEvent, + NodeInIterationFailedEvent, + NodeRunFailedEvent, + NodeRunRetrieverResourceEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, + ParallelBranchRunFailedEvent, + ParallelBranchRunStartedEvent, + ParallelBranchRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.nodes import NodeType +from core.workflow.nodes.iteration import IterationNodeData +from core.workflow.nodes.node_mapping import node_type_classes_mapping +from core.workflow.workflow_entry import WorkflowEntry +from extensions.ext_database import db +from models.model import App +from models.workflow import Workflow + + +class WorkflowBasedAppRunner(AppRunner): + def __init__(self, queue_manager: AppQueueManager): + self.queue_manager = queue_manager + + def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph: + """ + Init graph + """ + if "nodes" not in graph_config or "edges" not in graph_config: + raise ValueError("nodes or edges not found in workflow graph") + + if not isinstance(graph_config.get("nodes"), list): + raise ValueError("nodes in workflow graph must be a list") + + if not isinstance(graph_config.get("edges"), list): + raise ValueError("edges in workflow graph must be a list") + # init graph + graph = Graph.init(graph_config=graph_config) + + if not graph: + raise ValueError("graph not found in workflow") + + return graph + + def _get_graph_and_variable_pool_of_single_iteration( + self, + workflow: Workflow, + node_id: str, + user_inputs: dict, + ) -> tuple[Graph, VariablePool]: + """ + Get variable pool of single iteration + """ + # fetch workflow graph + graph_config = workflow.graph_dict + if not graph_config: + raise ValueError("workflow graph not found") + + graph_config = cast(dict[str, Any], graph_config) + + if "nodes" not in graph_config or "edges" not in graph_config: + raise ValueError("nodes or edges not found in workflow graph") + + if not isinstance(graph_config.get("nodes"), list): + raise ValueError("nodes in workflow graph must be a list") + + if not isinstance(graph_config.get("edges"), list): + raise ValueError("edges in workflow graph must be a list") + + # filter nodes only in iteration + node_configs = [ + node + for node in graph_config.get("nodes", []) + if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id + ] + + graph_config["nodes"] = node_configs + + node_ids = [node.get("id") for node in node_configs] + + # filter edges only in iteration + edge_configs = [ + edge + for edge in graph_config.get("edges", []) + if (edge.get("source") is None or edge.get("source") in node_ids) + and (edge.get("target") is None or edge.get("target") in node_ids) + ] + + graph_config["edges"] = edge_configs + + # init graph + graph = Graph.init(graph_config=graph_config, root_node_id=node_id) + + if not graph: + raise ValueError("graph not found in workflow") + + # fetch node config from node id + iteration_node_config = None + for node in node_configs: + if node.get("id") == node_id: + iteration_node_config = node + break + + if not iteration_node_config: + raise ValueError("iteration node id not found in workflow graph") + + # Get node class + node_type = NodeType(iteration_node_config.get("data", {}).get("type")) + node_cls = node_type_classes_mapping[node_type] + + # init variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + environment_variables=workflow.environment_variables, + ) + + try: + variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=workflow.graph_dict, config=iteration_node_config + ) + except NotImplementedError: + variable_mapping = {} + + WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id=workflow.tenant_id, + node_type=node_type, + node_data=IterationNodeData(**iteration_node_config.get("data", {})), + ) + + return graph, variable_pool + + def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None: + """ + Handle event + :param workflow_entry: workflow entry + :param event: event + """ + if isinstance(event, GraphRunStartedEvent): + self._publish_event( + QueueWorkflowStartedEvent(graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state) + ) + elif isinstance(event, GraphRunSucceededEvent): + self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs)) + elif isinstance(event, GraphRunFailedEvent): + self._publish_event(QueueWorkflowFailedEvent(error=event.error)) + elif isinstance(event, NodeRunStartedEvent): + self._publish_event( + QueueNodeStartedEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + node_run_index=event.route_node_state.index, + predecessor_node_id=event.predecessor_node_id, + in_iteration_id=event.in_iteration_id, + parallel_mode_run_id=event.parallel_mode_run_id, + ) + ) + elif isinstance(event, NodeRunSucceededEvent): + self._publish_event( + QueueNodeSucceededEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result + else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result + else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result + else {}, + execution_metadata=event.route_node_state.node_run_result.metadata + if event.route_node_state.node_run_result + else {}, + in_iteration_id=event.in_iteration_id, + ) + ) + elif isinstance(event, NodeRunFailedEvent): + self._publish_event( + QueueNodeFailedEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result + else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result + else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result + else {}, + error=event.route_node_state.node_run_result.error + if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error + else "Unknown error", + execution_metadata=event.route_node_state.node_run_result.metadata + if event.route_node_state.node_run_result + else {}, + in_iteration_id=event.in_iteration_id, + ) + ) + elif isinstance(event, NodeInIterationFailedEvent): + self._publish_event( + QueueNodeInIterationFailedEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result + else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result + else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result + else {}, + execution_metadata=event.route_node_state.node_run_result.metadata + if event.route_node_state.node_run_result + else {}, + in_iteration_id=event.in_iteration_id, + error=event.error, + ) + ) + elif isinstance(event, NodeRunStreamChunkEvent): + self._publish_event( + QueueTextChunkEvent( + text=event.chunk_content, + from_variable_selector=event.from_variable_selector, + in_iteration_id=event.in_iteration_id, + ) + ) + elif isinstance(event, NodeRunRetrieverResourceEvent): + self._publish_event( + QueueRetrieverResourcesEvent( + retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id + ) + ) + elif isinstance(event, ParallelBranchRunStartedEvent): + self._publish_event( + QueueParallelBranchRunStartedEvent( + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + in_iteration_id=event.in_iteration_id, + ) + ) + elif isinstance(event, ParallelBranchRunSucceededEvent): + self._publish_event( + QueueParallelBranchRunSucceededEvent( + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + in_iteration_id=event.in_iteration_id, + ) + ) + elif isinstance(event, ParallelBranchRunFailedEvent): + self._publish_event( + QueueParallelBranchRunFailedEvent( + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + in_iteration_id=event.in_iteration_id, + error=event.error, + ) + ) + elif isinstance(event, IterationRunStartedEvent): + self._publish_event( + QueueIterationStartEvent( + node_execution_id=event.iteration_id, + node_id=event.iteration_node_id, + node_type=event.iteration_node_type, + node_data=event.iteration_node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.start_at, + node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + inputs=event.inputs, + predecessor_node_id=event.predecessor_node_id, + metadata=event.metadata, + ) + ) + elif isinstance(event, IterationRunNextEvent): + self._publish_event( + QueueIterationNextEvent( + node_execution_id=event.iteration_id, + node_id=event.iteration_node_id, + node_type=event.iteration_node_type, + node_data=event.iteration_node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + index=event.index, + node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + output=event.pre_iteration_output, + parallel_mode_run_id=event.parallel_mode_run_id, + ) + ) + elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)): + self._publish_event( + QueueIterationCompletedEvent( + node_execution_id=event.iteration_id, + node_id=event.iteration_node_id, + node_type=event.iteration_node_type, + node_data=event.iteration_node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.start_at, + node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, + inputs=event.inputs, + outputs=event.outputs, + metadata=event.metadata, + steps=event.steps, + error=event.error if isinstance(event, IterationRunFailedEvent) else None, + ) + ) + + def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: + """ + Get workflow + """ + # fetch workflow by workflow_id + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id + ) + .first() + ) + + # return workflow + return workflow + + def _publish_event(self, event: AppQueueEvent) -> None: + self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) diff --git a/api/core/app/apps/workflow_logging_callback.py b/api/core/app/apps/workflow_logging_callback.py deleted file mode 100644 index 2e6431d6d05f4d..00000000000000 --- a/api/core/app/apps/workflow_logging_callback.py +++ /dev/null @@ -1,155 +0,0 @@ -from typing import Optional - -from core.app.entities.queue_entities import AppQueueEvent -from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType - -_TEXT_COLOR_MAPPING = { - "blue": "36;1", - "yellow": "33;1", - "pink": "38;5;200", - "green": "32;1", - "red": "31;1", -} - - -class WorkflowLoggingCallback(WorkflowCallback): - - def __init__(self) -> None: - self.current_node_id = None - - def on_workflow_run_started(self) -> None: - """ - Workflow run started - """ - self.print_text("\n[on_workflow_run_started]", color='pink') - - def on_workflow_run_succeeded(self) -> None: - """ - Workflow run succeeded - """ - self.print_text("\n[on_workflow_run_succeeded]", color='green') - - def on_workflow_run_failed(self, error: str) -> None: - """ - Workflow run failed - """ - self.print_text("\n[on_workflow_run_failed]", color='red') - - def on_workflow_node_execute_started(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> None: - """ - Workflow node execute started - """ - self.print_text("\n[on_workflow_node_execute_started]", color='yellow') - self.print_text(f"Node ID: {node_id}", color='yellow') - self.print_text(f"Type: {node_type.value}", color='yellow') - self.print_text(f"Index: {node_run_index}", color='yellow') - if predecessor_node_id: - self.print_text(f"Predecessor Node ID: {predecessor_node_id}", color='yellow') - - def on_workflow_node_execute_succeeded(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> None: - """ - Workflow node execute succeeded - """ - self.print_text("\n[on_workflow_node_execute_succeeded]", color='green') - self.print_text(f"Node ID: {node_id}", color='green') - self.print_text(f"Type: {node_type.value}", color='green') - self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='green') - self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='green') - self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='green') - self.print_text(f"Metadata: {jsonable_encoder(execution_metadata) if execution_metadata else ''}", - color='green') - - def on_workflow_node_execute_failed(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - error: str, - inputs: Optional[dict] = None, - outputs: Optional[dict] = None, - process_data: Optional[dict] = None) -> None: - """ - Workflow node execute failed - """ - self.print_text("\n[on_workflow_node_execute_failed]", color='red') - self.print_text(f"Node ID: {node_id}", color='red') - self.print_text(f"Type: {node_type.value}", color='red') - self.print_text(f"Error: {error}", color='red') - self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='red') - self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='red') - self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='red') - - def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: - """ - Publish text chunk - """ - if not self.current_node_id or self.current_node_id != node_id: - self.current_node_id = node_id - self.print_text('\n[on_node_text_chunk]') - self.print_text(f"Node ID: {node_id}") - self.print_text(f"Metadata: {jsonable_encoder(metadata) if metadata else ''}") - - self.print_text(text, color="pink", end="") - - def on_workflow_iteration_started(self, - node_id: str, - node_type: NodeType, - node_run_index: int = 1, - node_data: Optional[BaseNodeData] = None, - inputs: dict = None, - predecessor_node_id: Optional[str] = None, - metadata: Optional[dict] = None) -> None: - """ - Publish iteration started - """ - self.print_text("\n[on_workflow_iteration_started]", color='blue') - self.print_text(f"Node ID: {node_id}", color='blue') - - def on_workflow_iteration_next(self, node_id: str, - node_type: NodeType, - index: int, - node_run_index: int, - output: Optional[dict]) -> None: - """ - Publish iteration next - """ - self.print_text("\n[on_workflow_iteration_next]", color='blue') - - def on_workflow_iteration_completed(self, node_id: str, - node_type: NodeType, - node_run_index: int, - outputs: dict) -> None: - """ - Publish iteration completed - """ - self.print_text("\n[on_workflow_iteration_completed]", color='blue') - - def on_event(self, event: AppQueueEvent) -> None: - """ - Publish event - """ - self.print_text("\n[on_workflow_event]", color='blue') - self.print_text(f"Event: {jsonable_encoder(event)}", color='blue') - - def print_text( - self, text: str, color: Optional[str] = None, end: str = "\n" - ) -> None: - """Print text with highlighting and no end characters.""" - text_to_print = self._get_colored_text(text, color) if color else text - print(f'{text_to_print}', end=end) - - def _get_colored_text(self, text: str, color: str) -> str: - """Get colored text.""" - color_str = _TEXT_COLOR_MAPPING[color] - return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 6a1ab230416d0c..31c3a996e19286 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -1,12 +1,13 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from enum import Enum from typing import Any, Optional -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator +from constants import UUID_NIL from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle -from core.file.file_obj import FileVar +from core.file import File, FileUploadConfig from core.model_runtime.entities.model_entities import AIModelEntity from core.ops.ops_trace_manager import TraceQueueManager @@ -15,13 +16,14 @@ class InvokeFrom(Enum): """ Invoke From. """ - SERVICE_API = 'service-api' - WEB_APP = 'web-app' - EXPLORE = 'explore' - DEBUGGER = 'debugger' + + SERVICE_API = "service-api" + WEB_APP = "web-app" + EXPLORE = "explore" + DEBUGGER = "debugger" @classmethod - def value_of(cls, value: str) -> 'InvokeFrom': + def value_of(cls, value: str): """ Get value of given mode. @@ -31,7 +33,7 @@ def value_of(cls, value: str) -> 'InvokeFrom': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid invoke from value {value}') + raise ValueError(f"invalid invoke from value {value}") def to_source(self) -> str: """ @@ -40,21 +42,22 @@ def to_source(self) -> str: :return: source """ if self == InvokeFrom.WEB_APP: - return 'web_app' + return "web_app" elif self == InvokeFrom.DEBUGGER: - return 'dev' + return "dev" elif self == InvokeFrom.EXPLORE: - return 'explore_app' + return "explore_app" elif self == InvokeFrom.SERVICE_API: - return 'api' + return "api" - return 'dev' + return "dev" class ModelConfigWithCredentialsEntity(BaseModel): """ Model Config With Credentials Entity. """ + provider: str model: str model_schema: AIModelEntity @@ -72,13 +75,15 @@ class AppGenerateEntity(BaseModel): """ App Generate Entity. """ + task_id: str # app config app_config: AppConfig + file_upload_config: Optional[FileUploadConfig] = None inputs: Mapping[str, Any] - files: list[FileVar] = [] + files: Sequence[File] user_id: str # extras @@ -102,6 +107,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity): """ Chat Application Generate Entity. """ + # app config app_config: EasyUIBasedAppConfig model_conf: ModelConfigWithCredentialsEntity @@ -112,57 +118,90 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity): model_config = ConfigDict(protected_namespaces=()) -class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity): +class ConversationAppGenerateEntity(AppGenerateEntity): """ - Chat Application Generate Entity. + Base entity for conversation-based app generation. """ + conversation_id: Optional[str] = None + parent_message_id: Optional[str] = Field( + default=None, + description=( + "Starting from v0.9.0, parent_message_id is used to support message regeneration for internal chat API." + "For service API, we need to ensure its forward compatibility, " + "so passing in the parent_message_id as request arg is not supported for now. " + "It needs to be set to UUID_NIL so that the subsequent processing will treat it as legacy messages." + ), + ) + + @field_validator("parent_message_id") + @classmethod + def validate_parent_message_id(cls, v, info: ValidationInfo): + if info.data.get("invoke_from") == InvokeFrom.SERVICE_API and v != UUID_NIL: + raise ValueError("parent_message_id should be UUID_NIL for service API") + return v + + +class ChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity): + """ + Chat Application Generate Entity. + """ + + pass class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity): """ Completion Application Generate Entity. """ + pass -class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity): +class AgentChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity): """ Agent Chat Application Generate Entity. """ - conversation_id: Optional[str] = None + + pass -class AdvancedChatAppGenerateEntity(AppGenerateEntity): +class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity): """ Advanced Chat Application Generate Entity. """ + # app config app_config: WorkflowUIBasedAppConfig - conversation_id: Optional[str] = None + workflow_run_id: Optional[str] = None query: str class SingleIterationRunEntity(BaseModel): """ Single Iteration Run Entity. """ + node_id: str inputs: dict single_iteration_run: Optional[SingleIterationRunEntity] = None + class WorkflowAppGenerateEntity(AppGenerateEntity): """ Workflow Application Generate Entity. """ + # app config app_config: WorkflowUIBasedAppConfig + workflow_run_id: Optional[str] = None class SingleIterationRunEntity(BaseModel): """ Single Iteration Run Entity. """ + node_id: str inputs: dict diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 15348251f2de35..f1542ec5d8c578 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -1,17 +1,21 @@ +from datetime import datetime from enum import Enum from typing import Any, Optional from pydantic import BaseModel, field_validator from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes import NodeType +from core.workflow.nodes.base import BaseNodeData class QueueEvent(str, Enum): """ QueueEvent enum """ + LLM_CHUNK = "llm_chunk" TEXT_CHUNK = "text_chunk" AGENT_MESSAGE = "agent_message" @@ -31,6 +35,9 @@ class QueueEvent(str, Enum): ANNOTATION_REPLY = "annotation_reply" AGENT_THOUGHT = "agent_thought" MESSAGE_FILE = "message_file" + PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started" + PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded" + PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed" ERROR = "error" PING = "ping" STOP = "stop" @@ -38,46 +45,74 @@ class QueueEvent(str, Enum): class AppQueueEvent(BaseModel): """ - QueueEvent entity + QueueEvent abstract entity """ + event: QueueEvent class QueueLLMChunkEvent(AppQueueEvent): """ QueueLLMChunkEvent entity + Only for basic mode apps """ + event: QueueEvent = QueueEvent.LLM_CHUNK chunk: LLMResultChunk + class QueueIterationStartEvent(AppQueueEvent): """ QueueIterationStartEvent entity """ + event: QueueEvent = QueueEvent.ITERATION_START + node_execution_id: str node_id: str node_type: NodeType node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + start_at: datetime node_run_index: int - inputs: dict = None + inputs: Optional[dict[str, Any]] = None predecessor_node_id: Optional[str] = None - metadata: Optional[dict] = None + metadata: Optional[dict[str, Any]] = None + class QueueIterationNextEvent(AppQueueEvent): """ QueueIterationNextEvent entity """ + event: QueueEvent = QueueEvent.ITERATION_NEXT index: int + node_execution_id: str node_id: str node_type: NodeType - + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + parallel_mode_run_id: Optional[str] = None + """iteratoin run in parallel mode run id""" node_run_index: int - output: Optional[Any] = None # output for the current iteration + output: Optional[Any] = None # output for the current iteration - @field_validator('output', mode='before') + @field_validator("output", mode="before") @classmethod def set_output(cls, v): """ @@ -87,41 +122,66 @@ def set_output(cls, v): return None if isinstance(v, int | float | str | bool | dict | list): return v - raise ValueError('output must be a valid type') + raise ValueError("output must be a valid type") + class QueueIterationCompletedEvent(AppQueueEvent): """ QueueIterationCompletedEvent entity """ - event:QueueEvent = QueueEvent.ITERATION_COMPLETED + event: QueueEvent = QueueEvent.ITERATION_COMPLETED + + node_execution_id: str node_id: str node_type: NodeType - + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + start_at: datetime + node_run_index: int - outputs: dict + inputs: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + metadata: Optional[dict[str, Any]] = None + steps: int = 0 + + error: Optional[str] = None + class QueueTextChunkEvent(AppQueueEvent): """ QueueTextChunkEvent entity """ + event: QueueEvent = QueueEvent.TEXT_CHUNK text: str - metadata: Optional[dict] = None + from_variable_selector: Optional[list[str]] = None + """from variable selector""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" class QueueAgentMessageEvent(AppQueueEvent): """ QueueMessageEvent entity """ + event: QueueEvent = QueueEvent.AGENT_MESSAGE chunk: LLMResultChunk - + class QueueMessageReplaceEvent(AppQueueEvent): """ QueueMessageReplaceEvent entity """ + event: QueueEvent = QueueEvent.MESSAGE_REPLACE text: str @@ -130,14 +190,18 @@ class QueueRetrieverResourcesEvent(AppQueueEvent): """ QueueRetrieverResourcesEvent entity """ + event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES retriever_resources: list[dict] + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" class QueueAnnotationReplyEvent(AppQueueEvent): """ QueueAnnotationReplyEvent entity """ + event: QueueEvent = QueueEvent.ANNOTATION_REPLY message_annotation_id: str @@ -146,6 +210,7 @@ class QueueMessageEndEvent(AppQueueEvent): """ QueueMessageEndEvent entity """ + event: QueueEvent = QueueEvent.MESSAGE_END llm_result: Optional[LLMResult] = None @@ -154,6 +219,7 @@ class QueueAdvancedChatMessageEndEvent(AppQueueEvent): """ QueueAdvancedChatMessageEndEvent entity """ + event: QueueEvent = QueueEvent.ADVANCED_CHAT_MESSAGE_END @@ -161,20 +227,25 @@ class QueueWorkflowStartedEvent(AppQueueEvent): """ QueueWorkflowStartedEvent entity """ + event: QueueEvent = QueueEvent.WORKFLOW_STARTED + graph_runtime_state: GraphRuntimeState class QueueWorkflowSucceededEvent(AppQueueEvent): """ QueueWorkflowSucceededEvent entity """ + event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED + outputs: Optional[dict[str, Any]] = None class QueueWorkflowFailedEvent(AppQueueEvent): """ QueueWorkflowFailedEvent entity """ + event: QueueEvent = QueueEvent.WORKFLOW_FAILED error: str @@ -183,46 +254,119 @@ class QueueNodeStartedEvent(AppQueueEvent): """ QueueNodeStartedEvent entity """ + event: QueueEvent = QueueEvent.NODE_STARTED + node_execution_id: str node_id: str node_type: NodeType node_data: BaseNodeData node_run_index: int = 1 predecessor_node_id: Optional[str] = None + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime + parallel_mode_run_id: Optional[str] = None + """iteratoin run in parallel mode run id""" class QueueNodeSucceededEvent(AppQueueEvent): """ QueueNodeSucceededEvent entity """ + event: QueueEvent = QueueEvent.NODE_SUCCEEDED + node_execution_id: str node_id: str node_type: NodeType node_data: BaseNodeData - - inputs: Optional[dict] = None - process_data: Optional[dict] = None - outputs: Optional[dict] = None - execution_metadata: Optional[dict] = None + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime + + inputs: Optional[dict[str, Any]] = None + process_data: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None error: Optional[str] = None +class QueueNodeInIterationFailedEvent(AppQueueEvent): + """ + QueueNodeInIterationFailedEvent entity + """ + + event: QueueEvent = QueueEvent.NODE_FAILED + + node_execution_id: str + node_id: str + node_type: NodeType + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime + + inputs: Optional[dict[str, Any]] = None + process_data: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None + + error: str + + class QueueNodeFailedEvent(AppQueueEvent): """ QueueNodeFailedEvent entity """ + event: QueueEvent = QueueEvent.NODE_FAILED + node_execution_id: str node_id: str node_type: NodeType node_data: BaseNodeData - - inputs: Optional[dict] = None - outputs: Optional[dict] = None - process_data: Optional[dict] = None + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime + + inputs: Optional[dict[str, Any]] = None + process_data: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None error: str @@ -231,6 +375,7 @@ class QueueAgentThoughtEvent(AppQueueEvent): """ QueueAgentThoughtEvent entity """ + event: QueueEvent = QueueEvent.AGENT_THOUGHT agent_thought_id: str @@ -239,6 +384,7 @@ class QueueMessageFileEvent(AppQueueEvent): """ QueueAgentThoughtEvent entity """ + event: QueueEvent = QueueEvent.MESSAGE_FILE message_file_id: str @@ -247,6 +393,7 @@ class QueueErrorEvent(AppQueueEvent): """ QueueErrorEvent entity """ + event: QueueEvent = QueueEvent.ERROR error: Any = None @@ -255,6 +402,7 @@ class QueuePingEvent(AppQueueEvent): """ QueuePingEvent entity """ + event: QueueEvent = QueueEvent.PING @@ -262,10 +410,12 @@ class QueueStopEvent(AppQueueEvent): """ QueueStopEvent entity """ + class StopBy(Enum): """ Stop by enum """ + USER_MANUAL = "user-manual" ANNOTATION_REPLY = "annotation-reply" OUTPUT_MODERATION = "output-moderation" @@ -274,11 +424,25 @@ class StopBy(Enum): event: QueueEvent = QueueEvent.STOP stopped_by: StopBy + def get_stop_reason(self) -> str: + """ + To stop reason + """ + reason_mapping = { + QueueStopEvent.StopBy.USER_MANUAL: "Stopped by user.", + QueueStopEvent.StopBy.ANNOTATION_REPLY: "Stopped by annotation reply.", + QueueStopEvent.StopBy.OUTPUT_MODERATION: "Stopped by output moderation.", + QueueStopEvent.StopBy.INPUT_MODERATION: "Stopped by input moderation.", + } + + return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.") + class QueueMessage(BaseModel): """ - QueueMessage entity + QueueMessage abstract entity """ + task_id: str app_mode: str event: AppQueueEvent @@ -288,6 +452,7 @@ class MessageQueueMessage(QueueMessage): """ MessageQueueMessage entity """ + message_id: str conversation_id: str @@ -296,4 +461,57 @@ class WorkflowQueueMessage(QueueMessage): """ WorkflowQueueMessage entity """ + pass + + +class QueueParallelBranchRunStartedEvent(AppQueueEvent): + """ + QueueParallelBranchRunStartedEvent entity + """ + + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED + + parallel_id: str + parallel_start_node_id: str + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + + +class QueueParallelBranchRunSucceededEvent(AppQueueEvent): + """ + QueueParallelBranchRunSucceededEvent entity + """ + + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED + + parallel_id: str + parallel_start_node_id: str + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + + +class QueueParallelBranchRunFailedEvent(AppQueueEvent): + """ + QueueParallelBranchRunFailedEvent entity + """ + + event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED + + parallel_id: str + parallel_start_node_id: str + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + error: str diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 7bc55989843305..7e9aad54be57e4 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -1,46 +1,19 @@ +from collections.abc import Mapping, Sequence from enum import Enum from typing import Any, Optional from pydantic import BaseModel, ConfigDict -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType -from core.workflow.nodes.answer.entities import GenerateRouteChunk from models.workflow import WorkflowNodeExecutionStatus -class WorkflowStreamGenerateNodes(BaseModel): - """ - WorkflowStreamGenerateNodes entity - """ - end_node_id: str - stream_node_ids: list[str] - - -class ChatflowStreamGenerateRoute(BaseModel): - """ - ChatflowStreamGenerateRoute entity - """ - answer_node_id: str - generate_route: list[GenerateRouteChunk] - current_route_position: int = 0 - - -class NodeExecutionInfo(BaseModel): - """ - NodeExecutionInfo entity - """ - workflow_node_execution_id: str - node_type: NodeType - start_at: float - - class TaskState(BaseModel): """ TaskState entity """ + metadata: dict = {} @@ -48,6 +21,7 @@ class EasyUITaskState(TaskState): """ EasyUITaskState entity """ + llm_result: LLMResult @@ -55,34 +29,15 @@ class WorkflowTaskState(TaskState): """ WorkflowTaskState entity """ - answer: str = "" - - workflow_run_id: Optional[str] = None - start_at: Optional[float] = None - total_tokens: int = 0 - total_steps: int = 0 - - ran_node_execution_infos: dict[str, NodeExecutionInfo] = {} - latest_node_execution_info: Optional[NodeExecutionInfo] = None - - current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None - - iteration_nested_node_ids: list[str] = None - - -class AdvancedChatTaskState(WorkflowTaskState): - """ - AdvancedChatTaskState entity - """ - usage: LLMUsage - current_stream_generate_state: Optional[ChatflowStreamGenerateRoute] = None + answer: str = "" class StreamEvent(Enum): """ Stream event """ + PING = "ping" ERROR = "error" MESSAGE = "message" @@ -97,6 +52,8 @@ class StreamEvent(Enum): WORKFLOW_FINISHED = "workflow_finished" NODE_STARTED = "node_started" NODE_FINISHED = "node_finished" + PARALLEL_BRANCH_STARTED = "parallel_branch_started" + PARALLEL_BRANCH_FINISHED = "parallel_branch_finished" ITERATION_STARTED = "iteration_started" ITERATION_NEXT = "iteration_next" ITERATION_COMPLETED = "iteration_completed" @@ -108,6 +65,7 @@ class StreamResponse(BaseModel): """ StreamResponse entity """ + event: StreamEvent task_id: str @@ -119,6 +77,7 @@ class ErrorStreamResponse(StreamResponse): """ ErrorStreamResponse entity """ + event: StreamEvent = StreamEvent.ERROR err: Exception model_config = ConfigDict(arbitrary_types_allowed=True) @@ -128,15 +87,18 @@ class MessageStreamResponse(StreamResponse): """ MessageStreamResponse entity """ + event: StreamEvent = StreamEvent.MESSAGE id: str answer: str + from_variable_selector: Optional[list[str]] = None class MessageAudioStreamResponse(StreamResponse): """ MessageStreamResponse entity """ + event: StreamEvent = StreamEvent.TTS_MESSAGE audio: str @@ -145,6 +107,7 @@ class MessageAudioEndStreamResponse(StreamResponse): """ MessageStreamResponse entity """ + event: StreamEvent = StreamEvent.TTS_MESSAGE_END audio: str @@ -153,15 +116,18 @@ class MessageEndStreamResponse(StreamResponse): """ MessageEndStreamResponse entity """ + event: StreamEvent = StreamEvent.MESSAGE_END id: str metadata: dict = {} + files: Optional[Sequence[Mapping[str, Any]]] = None class MessageFileStreamResponse(StreamResponse): """ MessageFileStreamResponse entity """ + event: StreamEvent = StreamEvent.MESSAGE_FILE id: str type: str @@ -173,6 +139,7 @@ class MessageReplaceStreamResponse(StreamResponse): """ MessageReplaceStreamResponse entity """ + event: StreamEvent = StreamEvent.MESSAGE_REPLACE answer: str @@ -181,6 +148,7 @@ class AgentThoughtStreamResponse(StreamResponse): """ AgentThoughtStreamResponse entity """ + event: StreamEvent = StreamEvent.AGENT_THOUGHT id: str position: int @@ -196,6 +164,7 @@ class AgentMessageStreamResponse(StreamResponse): """ AgentMessageStreamResponse entity """ + event: StreamEvent = StreamEvent.AGENT_MESSAGE id: str answer: str @@ -210,6 +179,7 @@ class Data(BaseModel): """ Data entity """ + id: str workflow_id: str sequence_number: int @@ -230,6 +200,7 @@ class Data(BaseModel): """ Data entity """ + id: str workflow_id: str sequence_number: int @@ -242,7 +213,7 @@ class Data(BaseModel): created_by: Optional[dict] = None created_at: int finished_at: int - files: Optional[list[dict]] = [] + files: Optional[Sequence[Mapping[str, Any]]] = [] event: StreamEvent = StreamEvent.WORKFLOW_FINISHED workflow_run_id: str @@ -258,6 +229,7 @@ class Data(BaseModel): """ Data entity """ + id: str node_id: str node_type: str @@ -267,6 +239,12 @@ class Data(BaseModel): inputs: Optional[dict] = None created_at: int extras: dict = {} + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None + parallel_run_id: Optional[str] = None event: StreamEvent = StreamEvent.NODE_STARTED workflow_run_id: str @@ -286,8 +264,13 @@ def to_ignore_detail_dict(self): "predecessor_node_id": self.data.predecessor_node_id, "inputs": None, "created_at": self.data.created_at, - "extras": {} - } + "extras": {}, + "parallel_id": self.data.parallel_id, + "parallel_start_node_id": self.data.parallel_start_node_id, + "parent_parallel_id": self.data.parent_parallel_id, + "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, + "iteration_id": self.data.iteration_id, + }, } @@ -300,6 +283,7 @@ class Data(BaseModel): """ Data entity """ + id: str node_id: str node_type: str @@ -315,7 +299,12 @@ class Data(BaseModel): execution_metadata: Optional[dict] = None created_at: int finished_at: int - files: Optional[list[dict]] = [] + files: Optional[Sequence[Mapping[str, Any]]] = [] + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None event: StreamEvent = StreamEvent.NODE_FINISHED workflow_run_id: str @@ -342,11 +331,62 @@ def to_ignore_detail_dict(self): "execution_metadata": None, "created_at": self.data.created_at, "finished_at": self.data.finished_at, - "files": [] - } + "files": [], + "parallel_id": self.data.parallel_id, + "parallel_start_node_id": self.data.parallel_start_node_id, + "parent_parallel_id": self.data.parent_parallel_id, + "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, + "iteration_id": self.data.iteration_id, + }, } +class ParallelBranchStartStreamResponse(StreamResponse): + """ + ParallelBranchStartStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + parallel_id: str + parallel_branch_id: str + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None + created_at: int + + event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED + workflow_run_id: str + data: Data + + +class ParallelBranchFinishedStreamResponse(StreamResponse): + """ + ParallelBranchFinishedStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + parallel_id: str + parallel_branch_id: str + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None + status: str + error: Optional[str] = None + created_at: int + + event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED + workflow_run_id: str + data: Data + + class IterationNodeStartStreamResponse(StreamResponse): """ NodeStartStreamResponse entity @@ -356,6 +396,7 @@ class Data(BaseModel): """ Data entity """ + id: str node_id: str node_type: str @@ -364,6 +405,8 @@ class Data(BaseModel): extras: dict = {} metadata: dict = {} inputs: dict = {} + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None event: StreamEvent = StreamEvent.ITERATION_STARTED workflow_run_id: str @@ -379,6 +422,7 @@ class Data(BaseModel): """ Data entity """ + id: str node_id: str node_type: str @@ -387,6 +431,9 @@ class Data(BaseModel): created_at: int pre_iteration_output: Optional[Any] = None extras: dict = {} + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None + parallel_mode_run_id: Optional[str] = None event: StreamEvent = StreamEvent.ITERATION_NEXT workflow_run_id: str @@ -402,14 +449,15 @@ class Data(BaseModel): """ Data entity """ + id: str node_id: str node_type: str title: str outputs: Optional[dict] = None created_at: int - extras: dict = None - inputs: dict = None + extras: Optional[dict] = None + inputs: Optional[dict] = None status: WorkflowNodeExecutionStatus error: Optional[str] = None elapsed_time: float @@ -417,6 +465,8 @@ class Data(BaseModel): execution_metadata: Optional[dict] = None finished_at: int steps: int + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None event: StreamEvent = StreamEvent.ITERATION_COMPLETED workflow_run_id: str @@ -432,7 +482,9 @@ class Data(BaseModel): """ Data entity """ + text: str + from_variable_selector: Optional[list[str]] = None event: StreamEvent = StreamEvent.TEXT_CHUNK data: Data @@ -447,6 +499,7 @@ class Data(BaseModel): """ Data entity """ + text: str event: StreamEvent = StreamEvent.TEXT_REPLACE @@ -457,6 +510,7 @@ class PingStreamResponse(StreamResponse): """ PingStreamResponse entity """ + event: StreamEvent = StreamEvent.PING @@ -464,6 +518,7 @@ class AppStreamResponse(BaseModel): """ AppStreamResponse entity """ + stream_response: StreamResponse @@ -471,6 +526,7 @@ class ChatbotAppStreamResponse(AppStreamResponse): """ ChatbotAppStreamResponse entity """ + conversation_id: str message_id: str created_at: int @@ -480,6 +536,7 @@ class CompletionAppStreamResponse(AppStreamResponse): """ CompletionAppStreamResponse entity """ + message_id: str created_at: int @@ -488,13 +545,15 @@ class WorkflowAppStreamResponse(AppStreamResponse): """ WorkflowAppStreamResponse entity """ - workflow_run_id: str + + workflow_run_id: Optional[str] = None class AppBlockingResponse(BaseModel): """ AppBlockingResponse entity """ + task_id: str def to_dict(self) -> dict: @@ -510,6 +569,7 @@ class Data(BaseModel): """ Data entity """ + id: str mode: str conversation_id: str @@ -530,6 +590,7 @@ class Data(BaseModel): """ Data entity """ + id: str mode: str message_id: str @@ -549,6 +610,7 @@ class Data(BaseModel): """ Data entity """ + id: str workflow_id: str status: str @@ -562,25 +624,3 @@ class Data(BaseModel): workflow_run_id: str data: Data - - -class WorkflowIterationState(BaseModel): - """ - WorkflowIterationState entity - """ - - class Data(BaseModel): - """ - Data entity - """ - parent_iteration_id: Optional[str] = None - iteration_id: str - current_index: int - iteration_steps_boundary: list[int] = None - node_execution_id: str - started_at: float - inputs: dict = None - total_tokens: int = 0 - node_data: BaseNodeData - - current_iterations: dict[str, Data] = None diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 19ff94de5e8d58..77b6bb554c65ec 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -13,11 +13,9 @@ class AnnotationReplyFeature: - def query(self, app_record: App, - message: Message, - query: str, - user_id: str, - invoke_from: InvokeFrom) -> Optional[MessageAnnotation]: + def query( + self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom + ) -> Optional[MessageAnnotation]: """ Query app annotations to reply :param app_record: app record @@ -27,8 +25,9 @@ def query(self, app_record: App, :param invoke_from: invoke from :return: """ - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_record.id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_record.id).first() + ) if not annotation_setting: return None @@ -41,55 +40,50 @@ def query(self, app_record: App, embedding_model_name = collection_binding_detail.model_name dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_provider_name, - embedding_model_name, - 'annotation' + embedding_provider_name, embedding_model_name, "annotation" ) dataset = Dataset( id=app_record.id, tenant_id=app_record.tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=embedding_provider_name, embedding_model=embedding_model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) documents = vector.search_by_vector( - query=query, - top_k=1, - score_threshold=score_threshold, - filter={ - 'group_id': [dataset.id] - } + query=query, top_k=1, score_threshold=score_threshold, filter={"group_id": [dataset.id]} ) if documents: - annotation_id = documents[0].metadata['annotation_id'] - score = documents[0].metadata['score'] + annotation_id = documents[0].metadata["annotation_id"] + score = documents[0].metadata["score"] annotation = AppAnnotationService.get_annotation_by_id(annotation_id) if annotation: - if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]: - from_source = 'api' + if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP}: + from_source = "api" else: - from_source = 'console' + from_source = "console" # insert annotation history - AppAnnotationService.add_annotation_history(annotation.id, - app_record.id, - annotation.question, - annotation.content, - query, - user_id, - message.id, - from_source, - score) + AppAnnotationService.add_annotation_history( + annotation.id, + app_record.id, + annotation.question, + annotation.content, + query, + user_id, + message.id, + from_source, + score, + ) return annotation except Exception as e: - logger.warning(f'Query annotation failed, exception: {str(e)}.') + logger.warning(f"Query annotation failed, exception: {str(e)}.") return None return None diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index b8f3e0e1f65b74..ba14b61201e72f 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -8,8 +8,9 @@ class HostingModerationFeature: - def check(self, application_generate_entity: EasyUIBasedAppGenerateEntity, - prompt_messages: list[PromptMessage]) -> bool: + def check( + self, application_generate_entity: EasyUIBasedAppGenerateEntity, prompt_messages: list[PromptMessage] + ) -> bool: """ Check hosting moderation :param application_generate_entity: application generate entity @@ -23,9 +24,6 @@ def check(self, application_generate_entity: EasyUIBasedAppGenerateEntity, if isinstance(prompt_message.content, str): text += prompt_message.content + "\n" - moderation_result = moderation.check_moderation( - model_config, - text - ) + moderation_result = moderation.check_moderation(model_config, text) return moderation_result diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index f11e8021f0b1cc..227182f5ab0923 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -19,7 +19,7 @@ class RateLimit: _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes _instance_dict = {} - def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int): + def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int): if client_id not in cls._instance_dict: instance = super().__new__(cls) cls._instance_dict[client_id] = instance @@ -27,13 +27,13 @@ def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int): def __init__(self, client_id: str, max_active_requests: int): self.max_active_requests = max_active_requests - if hasattr(self, 'initialized'): + if hasattr(self, "initialized"): return self.initialized = True self.client_id = client_id self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id) self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id) - self.last_recalculate_time = float('-inf') + self.last_recalculate_time = float("-inf") self.flush_cache(use_local_value=True) def flush_cache(self, use_local_value=False): @@ -46,7 +46,7 @@ def flush_cache(self, use_local_value=False): pipe.execute() else: with redis_client.pipeline() as pipe: - self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode('utf-8')) + self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8")) redis_client.expire(self.max_active_requests_key, timedelta(days=1)) # flush max active requests (in-transit request list) @@ -54,8 +54,11 @@ def flush_cache(self, use_local_value=False): return request_details = redis_client.hgetall(self.active_requests_key) redis_client.expire(self.active_requests_key, timedelta(days=1)) - timeout_requests = [k for k, v in request_details.items() if - time.time() - float(v.decode('utf-8')) > RateLimit._REQUEST_MAX_ALIVE_TIME] + timeout_requests = [ + k + for k, v in request_details.items() + if time.time() - float(v.decode("utf-8")) > RateLimit._REQUEST_MAX_ALIVE_TIME + ] if timeout_requests: redis_client.hdel(self.active_requests_key, *timeout_requests) @@ -69,8 +72,10 @@ def enter(self, request_id: Optional[str] = None) -> str: active_requests_count = redis_client.hlen(self.active_requests_key) if active_requests_count >= self.max_active_requests: - raise AppInvokeQuotaExceededError("Too many requests. Please try again later. The current maximum " - "concurrent requests allowed is {}.".format(self.max_active_requests)) + raise AppInvokeQuotaExceededError( + "Too many requests. Please try again later. The current maximum " + "concurrent requests allowed is {}.".format(self.max_active_requests) + ) redis_client.hset(self.active_requests_key, request_id, str(time.time())) return request_id @@ -116,5 +121,5 @@ def close(self): if not self.closed: self.closed = True self.rate_limit.exit(self.request_id) - if self.generator is not None and hasattr(self.generator, 'close'): + if self.generator is not None and hasattr(self.generator, "close"): self.generator.close() diff --git a/api/core/app/segments/__init__.py b/api/core/app/segments/__init__.py deleted file mode 100644 index 7de06dfb9639fd..00000000000000 --- a/api/core/app/segments/__init__.py +++ /dev/null @@ -1,49 +0,0 @@ -from .segment_group import SegmentGroup -from .segments import ( - ArrayAnySegment, - ArraySegment, - FloatSegment, - IntegerSegment, - NoneSegment, - ObjectSegment, - Segment, - StringSegment, -) -from .types import SegmentType -from .variables import ( - ArrayAnyVariable, - ArrayNumberVariable, - ArrayObjectVariable, - ArrayStringVariable, - FloatVariable, - IntegerVariable, - NoneVariable, - ObjectVariable, - SecretVariable, - StringVariable, - Variable, -) - -__all__ = [ - 'IntegerVariable', - 'FloatVariable', - 'ObjectVariable', - 'SecretVariable', - 'StringVariable', - 'ArrayAnyVariable', - 'Variable', - 'SegmentType', - 'SegmentGroup', - 'Segment', - 'NoneSegment', - 'NoneVariable', - 'IntegerSegment', - 'FloatSegment', - 'ObjectSegment', - 'ArrayAnySegment', - 'StringSegment', - 'ArrayStringVariable', - 'ArrayNumberVariable', - 'ArrayObjectVariable', - 'ArraySegment', -] diff --git a/api/core/app/segments/exc.py b/api/core/app/segments/exc.py deleted file mode 100644 index d15d6d500ffa4a..00000000000000 --- a/api/core/app/segments/exc.py +++ /dev/null @@ -1,2 +0,0 @@ -class VariableError(Exception): - pass diff --git a/api/core/app/segments/factory.py b/api/core/app/segments/factory.py deleted file mode 100644 index e6e9ce97747ce1..00000000000000 --- a/api/core/app/segments/factory.py +++ /dev/null @@ -1,76 +0,0 @@ -from collections.abc import Mapping -from typing import Any - -from configs import dify_config - -from .exc import VariableError -from .segments import ( - ArrayAnySegment, - FloatSegment, - IntegerSegment, - NoneSegment, - ObjectSegment, - Segment, - StringSegment, -) -from .types import SegmentType -from .variables import ( - ArrayNumberVariable, - ArrayObjectVariable, - ArrayStringVariable, - FloatVariable, - IntegerVariable, - ObjectVariable, - SecretVariable, - StringVariable, - Variable, -) - - -def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: - if (value_type := mapping.get('value_type')) is None: - raise VariableError('missing value type') - if not mapping.get('name'): - raise VariableError('missing name') - if (value := mapping.get('value')) is None: - raise VariableError('missing value') - match value_type: - case SegmentType.STRING: - result = StringVariable.model_validate(mapping) - case SegmentType.SECRET: - result = SecretVariable.model_validate(mapping) - case SegmentType.NUMBER if isinstance(value, int): - result = IntegerVariable.model_validate(mapping) - case SegmentType.NUMBER if isinstance(value, float): - result = FloatVariable.model_validate(mapping) - case SegmentType.NUMBER if not isinstance(value, float | int): - raise VariableError(f'invalid number value {value}') - case SegmentType.OBJECT if isinstance(value, dict): - result = ObjectVariable.model_validate(mapping) - case SegmentType.ARRAY_STRING if isinstance(value, list): - result = ArrayStringVariable.model_validate(mapping) - case SegmentType.ARRAY_NUMBER if isinstance(value, list): - result = ArrayNumberVariable.model_validate(mapping) - case SegmentType.ARRAY_OBJECT if isinstance(value, list): - result = ArrayObjectVariable.model_validate(mapping) - case _: - raise VariableError(f'not supported value type {value_type}') - if result.size > dify_config.MAX_VARIABLE_SIZE: - raise VariableError(f'variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}') - return result - - -def build_segment(value: Any, /) -> Segment: - if value is None: - return NoneSegment() - if isinstance(value, str): - return StringSegment(value=value) - if isinstance(value, int): - return IntegerSegment(value=value) - if isinstance(value, float): - return FloatSegment(value=value) - if isinstance(value, dict): - return ObjectSegment(value=value) - if isinstance(value, list): - return ArrayAnySegment(value=value) - raise ValueError(f'not supported value {value}') diff --git a/api/core/app/segments/parser.py b/api/core/app/segments/parser.py deleted file mode 100644 index de6c7966525c06..00000000000000 --- a/api/core/app/segments/parser.py +++ /dev/null @@ -1,18 +0,0 @@ -import re - -from core.workflow.entities.variable_pool import VariablePool - -from . import SegmentGroup, factory - -VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}') - - -def convert_template(*, template: str, variable_pool: VariablePool): - parts = re.split(VARIABLE_PATTERN, template) - segments = [] - for part in filter(lambda x: x, parts): - if '.' in part and (value := variable_pool.get(part.split('.'))): - segments.append(value) - else: - segments.append(factory.build_segment(part)) - return SegmentGroup(value=segments) diff --git a/api/core/app/segments/segment_group.py b/api/core/app/segments/segment_group.py deleted file mode 100644 index b4ff09b6d39ac9..00000000000000 --- a/api/core/app/segments/segment_group.py +++ /dev/null @@ -1,22 +0,0 @@ -from .segments import Segment -from .types import SegmentType - - -class SegmentGroup(Segment): - value_type: SegmentType = SegmentType.GROUP - value: list[Segment] - - @property - def text(self): - return ''.join([segment.text for segment in self.value]) - - @property - def log(self): - return ''.join([segment.log for segment in self.value]) - - @property - def markdown(self): - return ''.join([segment.markdown for segment in self.value]) - - def to_object(self): - return [segment.to_object() for segment in self.value] diff --git a/api/core/app/segments/segments.py b/api/core/app/segments/segments.py deleted file mode 100644 index 5c713cac6747f9..00000000000000 --- a/api/core/app/segments/segments.py +++ /dev/null @@ -1,129 +0,0 @@ -import json -import sys -from collections.abc import Mapping, Sequence -from typing import Any - -from pydantic import BaseModel, ConfigDict, field_validator - -from .types import SegmentType - - -class Segment(BaseModel): - model_config = ConfigDict(frozen=True) - - value_type: SegmentType - value: Any - - @field_validator('value_type') - def validate_value_type(cls, value): - """ - This validator checks if the provided value is equal to the default value of the 'value_type' field. - If the value is different, a ValueError is raised. - """ - if value != cls.model_fields['value_type'].default: - raise ValueError("Cannot modify 'value_type'") - return value - - @property - def text(self) -> str: - return str(self.value) - - @property - def log(self) -> str: - return str(self.value) - - @property - def markdown(self) -> str: - return str(self.value) - - @property - def size(self) -> int: - return sys.getsizeof(self.value) - - def to_object(self) -> Any: - return self.value - - -class NoneSegment(Segment): - value_type: SegmentType = SegmentType.NONE - value: None = None - - @property - def text(self) -> str: - return 'null' - - @property - def log(self) -> str: - return 'null' - - @property - def markdown(self) -> str: - return 'null' - - -class StringSegment(Segment): - value_type: SegmentType = SegmentType.STRING - value: str - - -class FloatSegment(Segment): - value_type: SegmentType = SegmentType.NUMBER - value: float - - -class IntegerSegment(Segment): - value_type: SegmentType = SegmentType.NUMBER - value: int - - - - - -class ObjectSegment(Segment): - value_type: SegmentType = SegmentType.OBJECT - value: Mapping[str, Any] - - @property - def text(self) -> str: - return json.dumps(self.model_dump()['value'], ensure_ascii=False) - - @property - def log(self) -> str: - return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) - - @property - def markdown(self) -> str: - return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) - - -class ArraySegment(Segment): - @property - def markdown(self) -> str: - items = [] - for item in self.value: - if hasattr(item, 'to_markdown'): - items.append(item.to_markdown()) - else: - items.append(str(item)) - return '\n'.join(items) - - -class ArrayAnySegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_ANY - value: Sequence[Any] - - -class ArrayStringSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_STRING - value: Sequence[str] - - -class ArrayNumberSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_NUMBER - value: Sequence[float | int] - - -class ArrayObjectSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_OBJECT - value: Sequence[Mapping[str, Any]] - diff --git a/api/core/app/segments/types.py b/api/core/app/segments/types.py deleted file mode 100644 index cdd2b0b4b09191..00000000000000 --- a/api/core/app/segments/types.py +++ /dev/null @@ -1,15 +0,0 @@ -from enum import Enum - - -class SegmentType(str, Enum): - NONE = 'none' - NUMBER = 'number' - STRING = 'string' - SECRET = 'secret' - ARRAY_ANY = 'array[any]' - ARRAY_STRING = 'array[string]' - ARRAY_NUMBER = 'array[number]' - ARRAY_OBJECT = 'array[object]' - OBJECT = 'object' - - GROUP = 'group' diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index a3c1fb58245a11..51d610e2cbedc6 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -32,10 +32,13 @@ class BasedGenerateTaskPipeline: _task_state: TaskState _application_generate_entity: AppGenerateEntity - def __init__(self, application_generate_entity: AppGenerateEntity, - queue_manager: AppQueueManager, - user: Union[Account, EndUser], - stream: bool) -> None: + def __init__( + self, + application_generate_entity: AppGenerateEntity, + queue_manager: AppQueueManager, + user: Union[Account, EndUser], + stream: bool, + ) -> None: """ Initialize GenerateTaskPipeline. :param application_generate_entity: application generate entity @@ -50,7 +53,7 @@ def __init__(self, application_generate_entity: AppGenerateEntity, self._output_moderation_handler = self._init_output_moderation() self._stream = stream - def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None) -> Exception: + def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None): """ Handle error event. :param event: event @@ -61,48 +64,49 @@ def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = Non e = event.error if isinstance(e, InvokeAuthorizationError): - err = InvokeAuthorizationError('Incorrect API key provided') - elif isinstance(e, InvokeError) or isinstance(e, ValueError): + err = InvokeAuthorizationError("Incorrect API key provided") + elif isinstance(e, InvokeError | ValueError): err = e else: - err = Exception(e.description if getattr(e, 'description', None) is not None else str(e)) + err = Exception(e.description if getattr(e, "description", None) is not None else str(e)) if message: - message = db.session.query(Message).filter(Message.id == message.id).first() - err_desc = self._error_to_desc(err) - message.status = 'error' - message.error = err_desc + refetch_message = db.session.query(Message).filter(Message.id == message.id).first() - db.session.commit() + if refetch_message: + err_desc = self._error_to_desc(err) + refetch_message.status = "error" + refetch_message.error = err_desc + + db.session.commit() return err - def _error_to_desc(cls, e: Exception) -> str: + def _error_to_desc(self, e: Exception) -> str: """ Error to desc. :param e: exception :return: """ if isinstance(e, QuotaExceededError): - return ("Your quota for Dify Hosted Model Provider has been exhausted. " - "Please go to Settings -> Model Provider to complete your own provider credentials.") + return ( + "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials." + ) - message = getattr(e, 'description', str(e)) + message = getattr(e, "description", str(e)) if not message: - message = 'Internal Server Error, please contact support.' + message = "Internal Server Error, please contact support." return message - def _error_to_stream_response(self, e: Exception) -> ErrorStreamResponse: + def _error_to_stream_response(self, e: Exception): """ Error to stream response. :param e: exception :return: """ - return ErrorStreamResponse( - task_id=self._application_generate_entity.task_id, - err=e - ) + return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e) def _ping_stream_response(self) -> PingStreamResponse: """ @@ -123,11 +127,8 @@ def _init_output_moderation(self) -> Optional[OutputModeration]: return OutputModeration( tenant_id=app_config.tenant_id, app_id=app_config.app_id, - rule=ModerationRule( - type=sensitive_word_avoidance.type, - config=sensitive_word_avoidance.config - ), - queue_manager=self._queue_manager + rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config), + queue_manager=self._queue_manager, ) def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]: @@ -141,8 +142,7 @@ def _handle_output_moderation_when_task_finished(self, completion: str) -> Optio self._output_moderation_handler.stop_thread() completion = self._output_moderation_handler.moderation_completion( - completion=completion, - public_event=False + completion=completion, public_event=False ) self._output_moderation_handler = None diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 8d91a507a9e8ee..917649f34e769c 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -64,23 +64,21 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan """ EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. """ + _task_state: EasyUITaskState - _application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity - ] - - def __init__(self, application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity - ], - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - user: Union[Account, EndUser], - stream: bool) -> None: + _application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity] + + def __init__( + self, + application_generate_entity: Union[ + ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity + ], + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool, + ) -> None: """ Initialize GenerateTaskPipeline. :param application_generate_entity: application generate entity @@ -101,18 +99,18 @@ def __init__(self, application_generate_entity: Union[ model=self._model_config.model, prompt_messages=[], message=AssistantPromptMessage(content=""), - usage=LLMUsage.empty_usage() + usage=LLMUsage.empty_usage(), ) ) self._conversation_name_generate_thread = None def process( - self, + self, ) -> Union[ ChatbotAppBlockingResponse, CompletionAppBlockingResponse, - Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None] + Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None], ]: """ Process generate task pipeline. @@ -125,22 +123,18 @@ def process( if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: # start generate conversation name thread self._conversation_name_generate_thread = self._generate_conversation_name( - self._conversation, - self._application_generate_entity.query + self._conversation, self._application_generate_entity.query ) - generator = self._wrapper_process_stream_response( - trace_manager=self._application_generate_entity.trace_manager - ) + generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) if self._stream: return self._to_stream_response(generator) else: return self._to_blocking_response(generator) - def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> Union[ - ChatbotAppBlockingResponse, - CompletionAppBlockingResponse - ]: + def _to_blocking_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]: """ Process blocking response. :return: @@ -149,11 +143,9 @@ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None] if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err elif isinstance(stream_response, MessageEndStreamResponse): - extras = { - 'usage': jsonable_encoder(self._task_state.llm_result.usage) - } + extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)} if self._task_state.metadata: - extras['metadata'] = self._task_state.metadata + extras["metadata"] = self._task_state.metadata if self._conversation.mode == AppMode.COMPLETION.value: response = CompletionAppBlockingResponse( @@ -164,8 +156,8 @@ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None] message_id=self._message.id, answer=self._task_state.llm_result.message.content, created_at=int(self._message.created_at.timestamp()), - **extras - ) + **extras, + ), ) else: response = ChatbotAppBlockingResponse( @@ -177,18 +169,19 @@ def _to_blocking_response(self, generator: Generator[StreamResponse, None, None] message_id=self._message.id, answer=self._task_state.llm_result.message.content, created_at=int(self._message.created_at.timestamp()), - **extras - ) + **extras, + ), ) return response else: continue - raise Exception('Queue listening stopped unexpectedly.') + raise Exception("Queue listening stopped unexpectedly.") - def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \ - -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]: + def _to_stream_response( + self, generator: Generator[StreamResponse, None, None] + ) -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]: """ To stream response. :return: @@ -198,37 +191,41 @@ def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) yield CompletionAppStreamResponse( message_id=self._message.id, created_at=int(self._message.created_at.timestamp()), - stream_response=stream_response + stream_response=stream_response, ) else: yield ChatbotAppStreamResponse( conversation_id=self._conversation.id, message_id=self._message.id, created_at=int(self._message.created_at.timestamp()), - stream_response=stream_response + stream_response=stream_response, ) - def _listenAudioMsg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher, task_id: str): if publisher is None: return None - audio_msg: AudioTrunk = publisher.checkAndGetAudio() + audio_msg: AudioTrunk = publisher.check_and_get_audio() if audio_msg and audio_msg.status != "finish": # audio_str = audio_msg.audio.decode('utf-8', errors='ignore') return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None - def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ - Generator[StreamResponse, None, None]: - + def _wrapper_process_stream_response( + self, trace_manager: Optional[TraceQueueManager] = None + ) -> Generator[StreamResponse, None, None]: tenant_id = self._application_generate_entity.app_config.tenant_id task_id = self._application_generate_entity.task_id publisher = None - text_to_speech_dict = self._app_config.app_model_config_dict.get('text_to_speech') - if text_to_speech_dict and text_to_speech_dict.get('autoPlay') == 'enabled' and text_to_speech_dict.get('enabled'): - publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get('voice', None)) + text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech") + if ( + text_to_speech_dict + and text_to_speech_dict.get("autoPlay") == "enabled" + and text_to_speech_dict.get("enabled") + ): + publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None)) for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): while True: - audio_response = self._listenAudioMsg(publisher, task_id) + audio_response = self._listen_audio_msg(publisher, task_id) if audio_response: yield audio_response else: @@ -240,7 +237,7 @@ def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueMan while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: if publisher is None: break - audio = publisher.checkAndGetAudio() + audio = publisher.check_and_get_audio() if audio is None: # release cpu # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) @@ -250,14 +247,12 @@ def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueMan break else: start_listener_time = time.time() - yield MessageAudioStreamResponse(audio=audio.audio, - task_id=task_id) - yield MessageAudioEndStreamResponse(audio='', task_id=task_id) + yield MessageAudioStreamResponse(audio=audio.audio, task_id=task_id) + if publisher: + yield MessageAudioEndStreamResponse(audio="", task_id=task_id) def _process_stream_response( - self, - publisher: AppGeneratorTTSPublisher, - trace_manager: Optional[TraceQueueManager] = None + self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None ) -> Generator[StreamResponse, None, None]: """ Process stream response. @@ -333,9 +328,7 @@ def _process_stream_response( if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() - def _save_message( - self, trace_manager: Optional[TraceQueueManager] = None - ) -> None: + def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None: """ Save message. :return: @@ -347,31 +340,32 @@ def _save_message( self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( - self._model_config.mode, - self._task_state.llm_result.prompt_messages + self._model_config.mode, self._task_state.llm_result.prompt_messages ) self._message.message_tokens = usage.prompt_tokens self._message.message_unit_price = usage.prompt_unit_price self._message.message_price_unit = usage.prompt_price_unit - self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \ - if llm_result.message.content else '' + self._message.answer = ( + PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) + if llm_result.message.content + else "" + ) self._message.answer_tokens = usage.completion_tokens self._message.answer_unit_price = usage.completion_unit_price self._message.answer_price_unit = usage.completion_price_unit self._message.provider_response_latency = time.perf_counter() - self._start_at self._message.total_price = usage.total_price self._message.currency = usage.currency - self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \ - if self._task_state.metadata else None + self._message.message_metadata = ( + json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None + ) db.session.commit() if trace_manager: trace_manager.add_trace_task( TraceTask( - TraceTaskName.MESSAGE_TRACE, - conversation_id=self._conversation.id, - message_id=self._message.id + TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id ) ) @@ -379,11 +373,9 @@ def _save_message( self._message, application_generate_entity=self._application_generate_entity, conversation=self._conversation, - is_first_message=self._application_generate_entity.app_config.app_mode in [ - AppMode.AGENT_CHAT, - AppMode.CHAT - ] and self._application_generate_entity.conversation_id is None, - extras=self._application_generate_entity.extras + is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT} + and self._application_generate_entity.conversation_id is None, + extras=self._application_generate_entity.extras, ) def _handle_stop(self, event: QueueStopEvent) -> None: @@ -395,22 +387,17 @@ def _handle_stop(self, event: QueueStopEvent) -> None: model = model_config.model model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) # calculate num tokens prompt_tokens = 0 if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY: - prompt_tokens = model_instance.get_llm_num_tokens( - self._task_state.llm_result.prompt_messages - ) + prompt_tokens = model_instance.get_llm_num_tokens(self._task_state.llm_result.prompt_messages) completion_tokens = 0 if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL: - completion_tokens = model_instance.get_llm_num_tokens( - [self._task_state.llm_result.message] - ) + completion_tokens = model_instance.get_llm_num_tokens([self._task_state.llm_result.message]) credentials = model_config.credentials @@ -418,10 +405,7 @@ def _handle_stop(self, event: QueueStopEvent) -> None: model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) self._task_state.llm_result.usage = model_type_instance._calc_response_usage( - model, - credentials, - prompt_tokens, - completion_tokens + model, credentials, prompt_tokens, completion_tokens ) def _message_end_to_stream_response(self) -> MessageEndStreamResponse: @@ -429,16 +413,14 @@ def _message_end_to_stream_response(self) -> MessageEndStreamResponse: Message end to stream response. :return: """ - self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage) + self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage) extras = {} if self._task_state.metadata: - extras['metadata'] = self._task_state.metadata + extras["metadata"] = self._task_state.metadata return MessageEndStreamResponse( - task_id=self._application_generate_entity.task_id, - id=self._message.id, - **extras + task_id=self._application_generate_entity.task_id, id=self._message.id, **extras ) def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: @@ -449,9 +431,7 @@ def _agent_message_to_stream_response(self, answer: str, message_id: str) -> Age :return: """ return AgentMessageStreamResponse( - task_id=self._application_generate_entity.task_id, - id=message_id, - answer=answer + task_id=self._application_generate_entity.task_id, id=message_id, answer=answer ) def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]: @@ -461,9 +441,7 @@ def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Op :return: """ agent_thought: MessageAgentThought = ( - db.session.query(MessageAgentThought) - .filter(MessageAgentThought.id == event.agent_thought_id) - .first() + db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first() ) db.session.refresh(agent_thought) db.session.close() @@ -478,7 +456,7 @@ def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Op tool=agent_thought.tool, tool_labels=agent_thought.tool_labels, tool_input=agent_thought.tool_input, - message_files=agent_thought.files + message_files=agent_thought.files, ) return None @@ -500,15 +478,15 @@ def _handle_output_moderation_chunk(self, text: str) -> bool: prompt_messages=self._task_state.llm_result.prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage(content=self._task_state.llm_result.message.content) - ) + message=AssistantPromptMessage(content=self._task_state.llm_result.message.content), + ), ) - ), PublishFrom.TASK_PIPELINE + ), + PublishFrom.TASK_PIPELINE, ) self._queue_manager.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), - PublishFrom.TASK_PIPELINE + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE ) return True else: diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index 76c50809cf340c..236eebf0b85ff6 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -1,14 +1,15 @@ +import logging from threading import Thread from typing import Optional, Union from flask import Flask, current_app +from configs import dify_config from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, AgentChatAppGenerateEntity, ChatAppGenerateEntity, CompletionAppGenerateEntity, - InvokeFrom, ) from core.app.entities.queue_entities import ( QueueAnnotationReplyEvent, @@ -16,11 +17,11 @@ QueueRetrieverResourcesEvent, ) from core.app.entities.task_entities import ( - AdvancedChatTaskState, EasyUITaskState, MessageFileStreamResponse, MessageReplaceStreamResponse, MessageStreamResponse, + WorkflowTaskState, ) from core.llm_generator.llm_generator import LLMGenerator from core.tools.tool_file_manager import ToolFileManager @@ -31,12 +32,9 @@ class MessageCycleManage: _application_generate_entity: Union[ - ChatAppGenerateEntity, - CompletionAppGenerateEntity, - AgentChatAppGenerateEntity, - AdvancedChatAppGenerateEntity + ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity ] - _task_state: Union[EasyUITaskState, AdvancedChatTaskState] + _task_state: Union[EasyUITaskState, WorkflowTaskState] def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]: """ @@ -45,17 +43,23 @@ def _generate_conversation_name(self, conversation: Conversation, query: str) -> :param query: query :return: thread """ + if isinstance(self._application_generate_entity, CompletionAppGenerateEntity): + return None + is_first_message = self._application_generate_entity.conversation_id is None extras = self._application_generate_entity.extras - auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True) + auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True) if auto_generate_conversation_name and is_first_message: # start generate thread - thread = Thread(target=self._generate_conversation_name_worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'conversation_id': conversation.id, - 'query': query - }) + thread = Thread( + target=self._generate_conversation_name_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "conversation_id": conversation.id, + "query": query, + }, + ) thread.start() @@ -63,17 +67,13 @@ def _generate_conversation_name(self, conversation: Conversation, query: str) -> return None - def _generate_conversation_name_worker(self, - flask_app: Flask, - conversation_id: str, - query: str): + def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str): with flask_app.app_context(): # get conversation and message - conversation = ( - db.session.query(Conversation) - .filter(Conversation.id == conversation_id) - .first() - ) + conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() + + if not conversation: + return if conversation.mode != AppMode.COMPLETION.value: app_model = conversation.app @@ -84,7 +84,9 @@ def _generate_conversation_name_worker(self, try: name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query) conversation.name = name - except: + except Exception as e: + if dify_config.DEBUG: + logging.exception(f"generate conversation name failed: {e}") pass db.session.merge(conversation) @@ -100,12 +102,9 @@ def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) if annotation: account = annotation.account - self._task_state.metadata['annotation_reply'] = { - 'id': annotation.id, - 'account': { - 'id': annotation.account_id, - 'name': account.name if account else 'Dify user' - } + self._task_state.metadata["annotation_reply"] = { + "id": annotation.id, + "account": {"id": annotation.account_id, "name": account.name if account else "Dify user"}, } return annotation @@ -119,28 +118,7 @@ def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> No :return: """ if self._application_generate_entity.app_config.additional_features.show_retrieve_source: - self._task_state.metadata['retriever_resources'] = event.retriever_resources - - def _get_response_metadata(self) -> dict: - """ - Get response metadata by invoke from. - :return: - """ - metadata = {} - - # show_retrieve_source - if 'retriever_resources' in self._task_state.metadata: - metadata['retriever_resources'] = self._task_state.metadata['retriever_resources'] - - # show annotation reply - if 'annotation_reply' in self._task_state.metadata: - metadata['annotation_reply'] = self._task_state.metadata['annotation_reply'] - - # show usage - if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: - metadata['usage'] = self._task_state.metadata['usage'] - - return metadata + self._task_state.metadata["retriever_resources"] = event.retriever_resources def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: """ @@ -148,27 +126,23 @@ def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Opti :param event: event :return: """ - message_file: MessageFile = ( - db.session.query(MessageFile) - .filter(MessageFile.id == event.message_file_id) - .first() - ) + message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first() if message_file: # get tool file id - tool_file_id = message_file.url.split('/')[-1] + tool_file_id = message_file.url.split("/")[-1] # trim extension - tool_file_id = tool_file_id.split('.')[0] + tool_file_id = tool_file_id.split(".")[0] # get extension - if '.' in message_file.url: + if "." in message_file.url: extension = f'.{message_file.url.split(".")[-1]}' if len(extension) > 10: - extension = '.bin' + extension = ".bin" else: - extension = '.bin' + extension = ".bin" # add sign url to local file - if message_file.url.startswith('http'): + if message_file.url.startswith("http"): url = message_file.url else: url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension) @@ -177,13 +151,15 @@ def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Opti task_id=self._application_generate_entity.task_id, id=message_file.id, type=message_file.type, - belongs_to=message_file.belongs_to or 'user', - url=url + belongs_to=message_file.belongs_to or "user", + url=url, ) return None - def _message_to_stream_response(self, answer: str, message_id: str) -> MessageStreamResponse: + def _message_to_stream_response( + self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None + ) -> MessageStreamResponse: """ Message to stream response. :param answer: answer @@ -193,7 +169,8 @@ def _message_to_stream_response(self, answer: str, message_id: str) -> MessageSt return MessageStreamResponse( task_id=self._application_generate_entity.task_id, id=message_id, - answer=answer + answer=answer, + from_variable_selector=from_variable_selector, ) def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse: @@ -202,7 +179,4 @@ def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStre :param answer: answer :return: """ - return MessageReplaceStreamResponse( - task_id=self._application_generate_entity.task_id, - answer=answer - ) + return MessageReplaceStreamResponse(task_id=self._application_generate_entity.task_id, answer=answer) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 4935c43ac437e4..b89edf9079f043 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -1,94 +1,112 @@ import json import time +from collections.abc import Mapping, Sequence from datetime import datetime, timezone -from typing import Optional, Union, cast +from typing import Any, Optional, Union, cast -from core.app.entities.app_invoke_entities import InvokeFrom +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, - QueueStopEvent, - QueueWorkflowFailedEvent, - QueueWorkflowSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueParallelBranchRunSucceededEvent, ) from core.app.entities.task_entities import ( - NodeExecutionInfo, + IterationNodeCompletedStreamResponse, + IterationNodeNextStreamResponse, + IterationNodeStartStreamResponse, NodeFinishStreamResponse, NodeStartStreamResponse, + ParallelBranchFinishedStreamResponse, + ParallelBranchStartStreamResponse, WorkflowFinishStreamResponse, WorkflowStartStreamResponse, + WorkflowTaskState, ) -from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage -from core.file.file_obj import FileVar +from core.file import FILE_MODEL_IDENTITY, File from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.tools.tool_manager import ToolManager -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.enums import SystemVariableKey +from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData -from core.workflow.workflow_engine_manager import WorkflowEngineManager +from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.account import Account +from models.enums import CreatedByRole, WorkflowRunTriggeredFrom from models.model import EndUser from models.workflow import ( - CreatedByRole, Workflow, WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom, WorkflowRun, WorkflowRunStatus, - WorkflowRunTriggeredFrom, ) -from services.workflow_service import WorkflowService - - -class WorkflowCycleManage(WorkflowIterationCycleManage): - def _init_workflow_run(self, workflow: Workflow, - triggered_from: WorkflowRunTriggeredFrom, - user: Union[Account, EndUser], - user_inputs: dict, - system_inputs: Optional[dict] = None) -> WorkflowRun: - """ - Init workflow run - :param workflow: Workflow instance - :param triggered_from: triggered from - :param user: account or end user - :param user_inputs: user variables inputs - :param system_inputs: system inputs, like: query, files - :return: - """ - max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \ - .filter(WorkflowRun.tenant_id == workflow.tenant_id) \ - .filter(WorkflowRun.app_id == workflow.app_id) \ - .scalar() or 0 + + +class WorkflowCycleManage: + _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] + _workflow: Workflow + _user: Union[Account, EndUser] + _task_state: WorkflowTaskState + _workflow_system_variables: dict[SystemVariableKey, Any] + _wip_workflow_node_executions: dict[str, WorkflowNodeExecution] + + def _handle_workflow_run_start(self) -> WorkflowRun: + max_sequence = ( + db.session.query(db.func.max(WorkflowRun.sequence_number)) + .filter(WorkflowRun.tenant_id == self._workflow.tenant_id) + .filter(WorkflowRun.app_id == self._workflow.app_id) + .scalar() + or 0 + ) new_sequence_number = max_sequence + 1 - inputs = {**user_inputs} - for key, value in (system_inputs or {}).items(): - if key.value == 'conversation': + inputs = {**self._application_generate_entity.inputs} + for key, value in (self._workflow_system_variables or {}).items(): + if key.value == "conversation": continue - inputs[f'sys.{key.value}'] = value - inputs = WorkflowEngineManager.handle_special_values(inputs) + inputs[f"sys.{key.value}"] = value + + inputs = WorkflowEntry.handle_special_values(inputs) + + triggered_from = ( + WorkflowRunTriggeredFrom.DEBUGGING + if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER + else WorkflowRunTriggeredFrom.APP_RUN + ) # init workflow run - workflow_run = WorkflowRun( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - sequence_number=new_sequence_number, - workflow_id=workflow.id, - type=workflow.type, - triggered_from=triggered_from.value, - version=workflow.version, - graph=workflow.graph, - inputs=json.dumps(inputs), - status=WorkflowRunStatus.RUNNING.value, - created_by_role=(CreatedByRole.ACCOUNT.value - if isinstance(user, Account) else CreatedByRole.END_USER.value), - created_by=user.id + workflow_run = WorkflowRun() + workflow_run_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID] + if workflow_run_id: + workflow_run.id = workflow_run_id + workflow_run.tenant_id = self._workflow.tenant_id + workflow_run.app_id = self._workflow.app_id + workflow_run.sequence_number = new_sequence_number + workflow_run.workflow_id = self._workflow.id + workflow_run.type = self._workflow.type + workflow_run.triggered_from = triggered_from.value + workflow_run.version = self._workflow.version + workflow_run.graph = self._workflow.graph + workflow_run.inputs = json.dumps(inputs) + workflow_run.status = WorkflowRunStatus.RUNNING.value + workflow_run.created_by_role = ( + CreatedByRole.ACCOUNT.value if isinstance(self._user, Account) else CreatedByRole.END_USER.value ) + workflow_run.created_by = self._user.id db.session.add(workflow_run) db.session.commit() @@ -97,33 +115,39 @@ def _init_workflow_run(self, workflow: Workflow, return workflow_run - def _workflow_run_success( - self, workflow_run: WorkflowRun, + def _handle_workflow_run_success( + self, + workflow_run: WorkflowRun, + start_at: float, total_tokens: int, total_steps: int, - outputs: Optional[str] = None, + outputs: Mapping[str, Any] | None = None, conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> WorkflowRun: """ Workflow run success :param workflow_run: workflow run + :param start_at: start time :param total_tokens: total tokens :param total_steps: total steps :param outputs: outputs :param conversation_id: conversation id :return: """ + workflow_run = self._refetch_workflow_run(workflow_run.id) + + outputs = WorkflowEntry.handle_special_values(outputs) + workflow_run.status = WorkflowRunStatus.SUCCEEDED.value - workflow_run.outputs = outputs - workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id) + workflow_run.outputs = json.dumps(outputs or {}) + workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.total_tokens = total_tokens workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() db.session.refresh(workflow_run) - db.session.close() if trace_manager: trace_manager.add_trace_task( @@ -135,34 +159,64 @@ def _workflow_run_success( ) ) + db.session.close() + return workflow_run - def _workflow_run_failed( - self, workflow_run: WorkflowRun, + def _handle_workflow_run_failed( + self, + workflow_run: WorkflowRun, + start_at: float, total_tokens: int, total_steps: int, status: WorkflowRunStatus, error: str, conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> WorkflowRun: """ Workflow run failed :param workflow_run: workflow run + :param start_at: start time :param total_tokens: total tokens :param total_steps: total steps :param status: status :param error: error message :return: """ + workflow_run = self._refetch_workflow_run(workflow_run.id) + workflow_run.status = status.value workflow_run.error = error - workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id) + workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.total_tokens = total_tokens workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() + + running_workflow_node_executions = ( + db.session.query(WorkflowNodeExecution) + .filter( + WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, + WorkflowNodeExecution.app_id == workflow_run.app_id, + WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.workflow_run_id == workflow_run.id, + WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, + ) + .all() + ) + + for workflow_node_execution in running_workflow_node_executions: + workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.error = error + workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.elapsed_time = ( + workflow_node_execution.finished_at - workflow_node_execution.created_at + ).total_seconds() + db.session.commit() + db.session.refresh(workflow_run) db.session.close() @@ -178,117 +232,142 @@ def _workflow_run_failed( return workflow_run - def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun, - node_id: str, - node_type: NodeType, - node_title: str, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> WorkflowNodeExecution: - """ - Init workflow node execution from workflow run - :param workflow_run: workflow run - :param node_id: node id - :param node_type: node type - :param node_title: node title - :param node_run_index: run index - :param predecessor_node_id: predecessor node id if exists - :return: - """ + def _handle_node_execution_start( + self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent + ) -> WorkflowNodeExecution: # init workflow node execution - workflow_node_execution = WorkflowNodeExecution( - tenant_id=workflow_run.tenant_id, - app_id=workflow_run.app_id, - workflow_id=workflow_run.workflow_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - workflow_run_id=workflow_run.id, - predecessor_node_id=predecessor_node_id, - index=node_run_index, - node_id=node_id, - node_type=node_type.value, - title=node_title, - status=WorkflowNodeExecutionStatus.RUNNING.value, - created_by_role=workflow_run.created_by_role, - created_by=workflow_run.created_by, - created_at=datetime.now(timezone.utc).replace(tzinfo=None) - ) - db.session.add(workflow_node_execution) - db.session.commit() - db.session.refresh(workflow_node_execution) - db.session.close() + with Session(db.engine, expire_on_commit=False) as session: + workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.tenant_id = workflow_run.tenant_id + workflow_node_execution.app_id = workflow_run.app_id + workflow_node_execution.workflow_id = workflow_run.workflow_id + workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value + workflow_node_execution.workflow_run_id = workflow_run.id + workflow_node_execution.predecessor_node_id = event.predecessor_node_id + workflow_node_execution.index = event.node_run_index + workflow_node_execution.node_execution_id = event.node_execution_id + workflow_node_execution.node_id = event.node_id + workflow_node_execution.node_type = event.node_type.value + workflow_node_execution.title = event.node_data.title + workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value + workflow_node_execution.created_by_role = workflow_run.created_by_role + workflow_node_execution.created_by = workflow_run.created_by + workflow_node_execution.execution_metadata = json.dumps( + { + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, + NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, + } + ) + workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) + session.add(workflow_node_execution) + session.commit() + session.refresh(workflow_node_execution) + + self._wip_workflow_node_executions[workflow_node_execution.node_execution_id] = workflow_node_execution return workflow_node_execution - def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution, - start_at: float, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> WorkflowNodeExecution: + def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: """ Workflow node execution success - :param workflow_node_execution: workflow node execution - :param start_at: start time - :param inputs: inputs - :param process_data: process data - :param outputs: outputs - :param execution_metadata: execution metadata + :param event: queue node succeeded event :return: """ - inputs = WorkflowEngineManager.handle_special_values(inputs) - outputs = WorkflowEngineManager.handle_special_values(outputs) + workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) + + inputs = WorkflowEntry.handle_special_values(event.inputs) + process_data = WorkflowEntry.handle_special_values(event.process_data) + outputs = WorkflowEntry.handle_special_values(event.outputs) + execution_metadata = ( + json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None + ) + finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + elapsed_time = (finished_at - event.start_at).total_seconds() + + db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( + { + WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.SUCCEEDED.value, + WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, + WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None, + WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, + WorkflowNodeExecution.execution_metadata: execution_metadata, + WorkflowNodeExecution.finished_at: finished_at, + WorkflowNodeExecution.elapsed_time: elapsed_time, + } + ) + + db.session.commit() + db.session.close() + process_data = WorkflowEntry.handle_special_values(event.process_data) workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value - workflow_node_execution.elapsed_time = time.perf_counter() - start_at workflow_node_execution.inputs = json.dumps(inputs) if inputs else None workflow_node_execution.process_data = json.dumps(process_data) if process_data else None workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \ - if execution_metadata else None - workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.execution_metadata = execution_metadata + workflow_node_execution.finished_at = finished_at + workflow_node_execution.elapsed_time = elapsed_time - db.session.commit() - db.session.refresh(workflow_node_execution) - db.session.close() + self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id) return workflow_node_execution - def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution, - start_at: float, - error: str, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None - ) -> WorkflowNodeExecution: + def _handle_workflow_node_execution_failed( + self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent + ) -> WorkflowNodeExecution: """ Workflow node execution failed - :param workflow_node_execution: workflow node execution - :param start_at: start time - :param error: error message + :param event: queue node failed event :return: """ - inputs = WorkflowEngineManager.handle_special_values(inputs) - outputs = WorkflowEngineManager.handle_special_values(outputs) + workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) + + inputs = WorkflowEntry.handle_special_values(event.inputs) + process_data = WorkflowEntry.handle_special_values(event.process_data) + outputs = WorkflowEntry.handle_special_values(event.outputs) + finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + elapsed_time = (finished_at - event.start_at).total_seconds() + execution_metadata = ( + json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None + ) + db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( + { + WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value, + WorkflowNodeExecution.error: event.error, + WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, + WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None, + WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, + WorkflowNodeExecution.finished_at: finished_at, + WorkflowNodeExecution.elapsed_time: elapsed_time, + WorkflowNodeExecution.execution_metadata: execution_metadata, + } + ) + + db.session.commit() + db.session.close() + process_data = WorkflowEntry.handle_special_values(event.process_data) workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value - workflow_node_execution.error = error - workflow_node_execution.elapsed_time = time.perf_counter() - start_at - workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.error = event.error workflow_node_execution.inputs = json.dumps(inputs) if inputs else None workflow_node_execution.process_data = json.dumps(process_data) if process_data else None workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \ - if execution_metadata else None + workflow_node_execution.finished_at = finished_at + workflow_node_execution.elapsed_time = elapsed_time + workflow_node_execution.execution_metadata = execution_metadata - db.session.commit() - db.session.refresh(workflow_node_execution) - db.session.close() + self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id) return workflow_node_execution - def _workflow_start_to_stream_response(self, task_id: str, - workflow_run: WorkflowRun) -> WorkflowStartStreamResponse: + ################################################# + # to stream responses # + ################################################# + + def _workflow_start_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun + ) -> WorkflowStartStreamResponse: """ Workflow start to stream response. :param task_id: task id @@ -302,13 +381,14 @@ def _workflow_start_to_stream_response(self, task_id: str, id=workflow_run.id, workflow_id=workflow_run.workflow_id, sequence_number=workflow_run.sequence_number, - inputs=workflow_run.inputs_dict, - created_at=int(workflow_run.created_at.timestamp()) - ) + inputs=workflow_run.inputs_dict or {}, + created_at=int(workflow_run.created_at.timestamp()), + ), ) - def _workflow_finish_to_stream_response(self, task_id: str, - workflow_run: WorkflowRun) -> WorkflowFinishStreamResponse: + def _workflow_finish_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun + ) -> WorkflowFinishStreamResponse: """ Workflow finish to stream response. :param task_id: task id @@ -348,14 +428,13 @@ def _workflow_finish_to_stream_response(self, task_id: str, created_by=created_by, created_at=int(workflow_run.created_at.timestamp()), finished_at=int(workflow_run.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict) - ) + files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict or {}), + ), ) - def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent, - task_id: str, - workflow_node_execution: WorkflowNodeExecution) \ - -> NodeStartStreamResponse: + def _workflow_node_start_to_stream_response( + self, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution + ) -> Optional[NodeStartStreamResponse]: """ Workflow node start to stream response. :param event: queue node started event @@ -363,6 +442,9 @@ def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent, :param workflow_node_execution: workflow node execution :return: """ + if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: + return None + response = NodeStartStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_run_id, @@ -374,29 +456,43 @@ def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent, index=workflow_node_execution.index, predecessor_node_id=workflow_node_execution.predecessor_node_id, inputs=workflow_node_execution.inputs_dict, - created_at=int(workflow_node_execution.created_at.timestamp()) - ) + created_at=int(workflow_node_execution.created_at.timestamp()), + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + parallel_run_id=event.parallel_mode_run_id, + ), ) # extras logic if event.node_type == NodeType.TOOL: node_data = cast(ToolNodeData, event.node_data) - response.data.extras['icon'] = ToolManager.get_tool_icon( + response.data.extras["icon"] = ToolManager.get_tool_icon( tenant_id=self._application_generate_entity.app_config.tenant_id, provider_type=node_data.provider_type, - provider_id=node_data.provider_id + provider_id=node_data.provider_id, ) return response - def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \ - -> NodeFinishStreamResponse: + def _workflow_node_finish_to_stream_response( + self, + event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent, + task_id: str, + workflow_node_execution: WorkflowNodeExecution, + ) -> Optional[NodeFinishStreamResponse]: """ Workflow node finish to stream response. + :param event: queue node succeeded or failed event :param task_id: task id :param workflow_node_execution: workflow node execution :return: """ + if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: + return None + return NodeFinishStreamResponse( task_id=task_id, workflow_run_id=workflow_node_execution.workflow_run_id, @@ -416,183 +512,158 @@ def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_e execution_metadata=workflow_node_execution.execution_metadata_dict, created_at=int(workflow_node_execution.created_at.timestamp()), finished_at=int(workflow_node_execution.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict) - ) + files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + ), ) - def _handle_workflow_start(self) -> WorkflowRun: - self._task_state.start_at = time.perf_counter() - - workflow_run = self._init_workflow_run( - workflow=self._workflow, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING - if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER - else WorkflowRunTriggeredFrom.APP_RUN, - user=self._user, - user_inputs=self._application_generate_entity.inputs, - system_inputs=self._workflow_system_variables + def _workflow_parallel_branch_start_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent + ) -> ParallelBranchStartStreamResponse: + """ + Workflow parallel branch start to stream response + :param task_id: task id + :param workflow_run: workflow run + :param event: parallel branch run started event + :return: + """ + return ParallelBranchStartStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=ParallelBranchStartStreamResponse.Data( + parallel_id=event.parallel_id, + parallel_branch_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + created_at=int(time.time()), + ), ) - self._task_state.workflow_run_id = workflow_run.id - - db.session.close() - - return workflow_run - - def _handle_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() - workflow_node_execution = self._init_node_execution_from_workflow_run( - workflow_run=workflow_run, - node_id=event.node_id, - node_type=event.node_type, - node_title=event.node_data.title, - node_run_index=event.node_run_index, - predecessor_node_id=event.predecessor_node_id + def _workflow_parallel_branch_finished_to_stream_response( + self, + task_id: str, + workflow_run: WorkflowRun, + event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent, + ) -> ParallelBranchFinishedStreamResponse: + """ + Workflow parallel branch finished to stream response + :param task_id: task id + :param workflow_run: workflow run + :param event: parallel branch run succeeded or failed event + :return: + """ + return ParallelBranchFinishedStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=ParallelBranchFinishedStreamResponse.Data( + parallel_id=event.parallel_id, + parallel_branch_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed", + error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None, + created_at=int(time.time()), + ), ) - latest_node_execution_info = NodeExecutionInfo( - workflow_node_execution_id=workflow_node_execution.id, - node_type=event.node_type, - start_at=time.perf_counter() + def _workflow_iteration_start_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent + ) -> IterationNodeStartStreamResponse: + """ + Workflow iteration start to stream response + :param task_id: task id + :param workflow_run: workflow run + :param event: iteration start event + :return: + """ + return IterationNodeStartStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=IterationNodeStartStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, + created_at=int(time.time()), + extras={}, + inputs=event.inputs or {}, + metadata=event.metadata or {}, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ), ) - self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info - self._task_state.latest_node_execution_info = latest_node_execution_info - - self._task_state.total_steps += 1 - - db.session.close() - - return workflow_node_execution - - def _handle_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution: - current_node_execution = self._task_state.ran_node_execution_infos[event.node_id] - workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first() - - execution_metadata = event.execution_metadata if isinstance(event, QueueNodeSucceededEvent) else None - - if self._iteration_state and self._iteration_state.current_iterations: - if not execution_metadata: - execution_metadata = {} - current_iteration_data = None - for iteration_node_id in self._iteration_state.current_iterations: - data = self._iteration_state.current_iterations[iteration_node_id] - if data.parent_iteration_id == None: - current_iteration_data = data - break - - if current_iteration_data: - execution_metadata[NodeRunMetadataKey.ITERATION_ID] = current_iteration_data.iteration_id - execution_metadata[NodeRunMetadataKey.ITERATION_INDEX] = current_iteration_data.current_index - - if isinstance(event, QueueNodeSucceededEvent): - workflow_node_execution = self._workflow_node_execution_success( - workflow_node_execution=workflow_node_execution, - start_at=current_node_execution.start_at, - inputs=event.inputs, - process_data=event.process_data, - outputs=event.outputs, - execution_metadata=execution_metadata - ) - - if execution_metadata and execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): - self._task_state.total_tokens += ( - int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))) - - if self._iteration_state: - for iteration_node_id in self._iteration_state.current_iterations: - data = self._iteration_state.current_iterations[iteration_node_id] - if execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): - data.total_tokens += int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)) + def _workflow_iteration_next_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent + ) -> IterationNodeNextStreamResponse: + """ + Workflow iteration next to stream response + :param task_id: task id + :param workflow_run: workflow run + :param event: iteration next event + :return: + """ + return IterationNodeNextStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=IterationNodeNextStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, + index=event.index, + pre_iteration_output=event.output, + created_at=int(time.time()), + extras={}, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parallel_mode_run_id=event.parallel_mode_run_id, + ), + ) - if workflow_node_execution.node_type == NodeType.LLM.value: - outputs = workflow_node_execution.outputs_dict - usage_dict = outputs.get('usage', {}) - self._task_state.metadata['usage'] = usage_dict - else: - workflow_node_execution = self._workflow_node_execution_failed( - workflow_node_execution=workflow_node_execution, - start_at=current_node_execution.start_at, - error=event.error, - inputs=event.inputs, - process_data=event.process_data, + def _workflow_iteration_completed_to_stream_response( + self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent + ) -> IterationNodeCompletedStreamResponse: + """ + Workflow iteration completed to stream response + :param task_id: task id + :param workflow_run: workflow run + :param event: iteration completed event + :return: + """ + return IterationNodeCompletedStreamResponse( + task_id=task_id, + workflow_run_id=workflow_run.id, + data=IterationNodeCompletedStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, outputs=event.outputs, - execution_metadata=execution_metadata - ) - - db.session.close() - - return workflow_node_execution - - def _handle_workflow_finished( - self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent, - conversation_id: Optional[str] = None, - trace_manager: Optional[TraceQueueManager] = None - ) -> Optional[WorkflowRun]: - workflow_run = db.session.query(WorkflowRun).filter( - WorkflowRun.id == self._task_state.workflow_run_id).first() - if not workflow_run: - return None - - if conversation_id is None: - conversation_id = self._application_generate_entity.inputs.get('sys.conversation_id') - if isinstance(event, QueueStopEvent): - workflow_run = self._workflow_run_failed( - workflow_run=workflow_run, - total_tokens=self._task_state.total_tokens, - total_steps=self._task_state.total_steps, - status=WorkflowRunStatus.STOPPED, - error='Workflow stopped.', - conversation_id=conversation_id, - trace_manager=trace_manager - ) - - latest_node_execution_info = self._task_state.latest_node_execution_info - if latest_node_execution_info: - workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == latest_node_execution_info.workflow_node_execution_id).first() - if (workflow_node_execution - and workflow_node_execution.status == WorkflowNodeExecutionStatus.RUNNING.value): - self._workflow_node_execution_failed( - workflow_node_execution=workflow_node_execution, - start_at=latest_node_execution_info.start_at, - error='Workflow stopped.' - ) - elif isinstance(event, QueueWorkflowFailedEvent): - workflow_run = self._workflow_run_failed( - workflow_run=workflow_run, - total_tokens=self._task_state.total_tokens, - total_steps=self._task_state.total_steps, - status=WorkflowRunStatus.FAILED, - error=event.error, - conversation_id=conversation_id, - trace_manager=trace_manager - ) - else: - if self._task_state.latest_node_execution_info: - workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first() - outputs = workflow_node_execution.outputs - else: - outputs = None - - workflow_run = self._workflow_run_success( - workflow_run=workflow_run, - total_tokens=self._task_state.total_tokens, - total_steps=self._task_state.total_steps, - outputs=outputs, - conversation_id=conversation_id, - trace_manager=trace_manager - ) - - self._task_state.workflow_run_id = workflow_run.id - - db.session.close() - - return workflow_run + created_at=int(time.time()), + extras={}, + inputs=event.inputs or {}, + status=WorkflowNodeExecutionStatus.SUCCEEDED + if event.error is None + else WorkflowNodeExecutionStatus.FAILED, + error=None, + elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(), + total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, + execution_metadata=event.metadata, + finished_at=int(time.time()), + steps=event.steps, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ), + ) - def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]: + def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]: """ Fetch files from node outputs :param outputs_dict: node outputs dict @@ -601,15 +672,15 @@ def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]: if not outputs_dict: return [] - files = [] - for output_var, output_value in outputs_dict.items(): - file_vars = self._fetch_files_from_variable_value(output_value) - if file_vars: - files.extend(file_vars) + files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()] + # Remove None + files = [file for file in files if file] + # Flatten list + files = [file for sublist in files for file in sublist] return files - def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> list[dict]: + def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]: """ Fetch files from variable value :param value: variable value @@ -621,17 +692,17 @@ def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> list[dic files = [] if isinstance(value, list): for item in value: - file_var = self._get_file_var_from_value(item) - if file_var: - files.append(file_var) + file = self._get_file_var_from_value(item) + if file: + files.append(file) elif isinstance(value, dict): - file_var = self._get_file_var_from_value(value) - if file_var: - files.append(file_var) + file = self._get_file_var_from_value(value) + if file: + files.append(file) return files - def _get_file_var_from_value(self, value: Union[dict, list]) -> Optional[dict]: + def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None: """ Get file var from value :param value: variable value @@ -640,10 +711,33 @@ def _get_file_var_from_value(self, value: Union[dict, list]) -> Optional[dict]: if not value: return None - if isinstance(value, dict): - if '__variant' in value and value['__variant'] == FileVar.__name__: - return value - elif isinstance(value, FileVar): + if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: + return value + elif isinstance(value, File): return value.to_dict() - return None + def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: + """ + Refetch workflow run + :param workflow_run_id: workflow run id + :return: + """ + workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() + + if not workflow_run: + raise Exception(f"Workflow run not found: {workflow_run_id}") + + return workflow_run + + def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution: + """ + Refetch workflow node execution + :param node_execution_id: workflow node execution id + :return: + """ + workflow_node_execution = self._wip_workflow_node_executions.get(node_execution_id) + + if not workflow_node_execution: + raise Exception(f"Workflow node execution not found: {node_execution_id}") + + return workflow_node_execution diff --git a/api/core/app/task_pipeline/workflow_cycle_state_manager.py b/api/core/app/task_pipeline/workflow_cycle_state_manager.py index 8baa8ba09e4b00..e69de29bb2d1d6 100644 --- a/api/core/app/task_pipeline/workflow_cycle_state_manager.py +++ b/api/core/app/task_pipeline/workflow_cycle_state_manager.py @@ -1,16 +0,0 @@ -from typing import Any, Union - -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState -from core.workflow.enums import SystemVariable -from models.account import Account -from models.model import EndUser -from models.workflow import Workflow - - -class WorkflowCycleStateManager: - _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] - _workflow: Workflow - _user: Union[Account, EndUser] - _task_state: Union[AdvancedChatTaskState, WorkflowTaskState] - _workflow_system_variables: dict[SystemVariable, Any] diff --git a/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py b/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py deleted file mode 100644 index aff187071417c7..00000000000000 --- a/api/core/app/task_pipeline/workflow_iteration_cycle_manage.py +++ /dev/null @@ -1,290 +0,0 @@ -import json -import time -from collections.abc import Generator -from datetime import datetime, timezone -from typing import Optional, Union - -from core.app.entities.queue_entities import ( - QueueIterationCompletedEvent, - QueueIterationNextEvent, - QueueIterationStartEvent, -) -from core.app.entities.task_entities import ( - IterationNodeCompletedStreamResponse, - IterationNodeNextStreamResponse, - IterationNodeStartStreamResponse, - NodeExecutionInfo, - WorkflowIterationState, -) -from core.app.task_pipeline.workflow_cycle_state_manager import WorkflowCycleStateManager -from core.workflow.entities.node_entities import NodeType -from core.workflow.workflow_engine_manager import WorkflowEngineManager -from extensions.ext_database import db -from models.workflow import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, - WorkflowNodeExecutionTriggeredFrom, - WorkflowRun, -) - - -class WorkflowIterationCycleManage(WorkflowCycleStateManager): - _iteration_state: WorkflowIterationState = None - - def _init_iteration_state(self) -> WorkflowIterationState: - if not self._iteration_state: - self._iteration_state = WorkflowIterationState( - current_iterations={} - ) - - def _handle_iteration_to_stream_response(self, task_id: str, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) \ - -> Union[IterationNodeStartStreamResponse, IterationNodeNextStreamResponse, IterationNodeCompletedStreamResponse]: - """ - Handle iteration to stream response - :param task_id: task id - :param event: iteration event - :return: - """ - if isinstance(event, QueueIterationStartEvent): - return IterationNodeStartStreamResponse( - task_id=task_id, - workflow_run_id=self._task_state.workflow_run_id, - data=IterationNodeStartStreamResponse.Data( - id=event.node_id, - node_id=event.node_id, - node_type=event.node_type.value, - title=event.node_data.title, - created_at=int(time.time()), - extras={}, - inputs=event.inputs, - metadata=event.metadata - ) - ) - elif isinstance(event, QueueIterationNextEvent): - current_iteration = self._iteration_state.current_iterations[event.node_id] - - return IterationNodeNextStreamResponse( - task_id=task_id, - workflow_run_id=self._task_state.workflow_run_id, - data=IterationNodeNextStreamResponse.Data( - id=event.node_id, - node_id=event.node_id, - node_type=event.node_type.value, - title=current_iteration.node_data.title, - index=event.index, - pre_iteration_output=event.output, - created_at=int(time.time()), - extras={} - ) - ) - elif isinstance(event, QueueIterationCompletedEvent): - current_iteration = self._iteration_state.current_iterations[event.node_id] - - return IterationNodeCompletedStreamResponse( - task_id=task_id, - workflow_run_id=self._task_state.workflow_run_id, - data=IterationNodeCompletedStreamResponse.Data( - id=event.node_id, - node_id=event.node_id, - node_type=event.node_type.value, - title=current_iteration.node_data.title, - outputs=event.outputs, - created_at=int(time.time()), - extras={}, - inputs=current_iteration.inputs, - status=WorkflowNodeExecutionStatus.SUCCEEDED, - error=None, - elapsed_time=time.perf_counter() - current_iteration.started_at, - total_tokens=current_iteration.total_tokens, - execution_metadata={ - 'total_tokens': current_iteration.total_tokens, - }, - finished_at=int(time.time()), - steps=current_iteration.current_index - ) - ) - - def _init_iteration_execution_from_workflow_run(self, - workflow_run: WorkflowRun, - node_id: str, - node_type: NodeType, - node_title: str, - node_run_index: int = 1, - inputs: Optional[dict] = None, - predecessor_node_id: Optional[str] = None - ) -> WorkflowNodeExecution: - workflow_node_execution = WorkflowNodeExecution( - tenant_id=workflow_run.tenant_id, - app_id=workflow_run.app_id, - workflow_id=workflow_run.workflow_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - workflow_run_id=workflow_run.id, - predecessor_node_id=predecessor_node_id, - index=node_run_index, - node_id=node_id, - node_type=node_type.value, - inputs=json.dumps(inputs) if inputs else None, - title=node_title, - status=WorkflowNodeExecutionStatus.RUNNING.value, - created_by_role=workflow_run.created_by_role, - created_by=workflow_run.created_by, - execution_metadata=json.dumps({ - 'started_run_index': node_run_index + 1, - 'current_index': 0, - 'steps_boundary': [], - }), - created_at=datetime.now(timezone.utc).replace(tzinfo=None) - ) - - db.session.add(workflow_node_execution) - db.session.commit() - db.session.refresh(workflow_node_execution) - db.session.close() - - return workflow_node_execution - - def _handle_iteration_operation(self, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) -> WorkflowNodeExecution: - if isinstance(event, QueueIterationStartEvent): - return self._handle_iteration_started(event) - elif isinstance(event, QueueIterationNextEvent): - return self._handle_iteration_next(event) - elif isinstance(event, QueueIterationCompletedEvent): - return self._handle_iteration_completed(event) - - def _handle_iteration_started(self, event: QueueIterationStartEvent) -> WorkflowNodeExecution: - self._init_iteration_state() - - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() - workflow_node_execution = self._init_iteration_execution_from_workflow_run( - workflow_run=workflow_run, - node_id=event.node_id, - node_type=NodeType.ITERATION, - node_title=event.node_data.title, - node_run_index=event.node_run_index, - inputs=event.inputs, - predecessor_node_id=event.predecessor_node_id - ) - - latest_node_execution_info = NodeExecutionInfo( - workflow_node_execution_id=workflow_node_execution.id, - node_type=NodeType.ITERATION, - start_at=time.perf_counter() - ) - - self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info - self._task_state.latest_node_execution_info = latest_node_execution_info - - self._iteration_state.current_iterations[event.node_id] = WorkflowIterationState.Data( - parent_iteration_id=None, - iteration_id=event.node_id, - current_index=0, - iteration_steps_boundary=[], - node_execution_id=workflow_node_execution.id, - started_at=time.perf_counter(), - inputs=event.inputs, - total_tokens=0, - node_data=event.node_data - ) - - db.session.close() - - return workflow_node_execution - - def _handle_iteration_next(self, event: QueueIterationNextEvent) -> WorkflowNodeExecution: - if event.node_id not in self._iteration_state.current_iterations: - return - current_iteration = self._iteration_state.current_iterations[event.node_id] - current_iteration.current_index = event.index - current_iteration.iteration_steps_boundary.append(event.node_run_index) - workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == current_iteration.node_execution_id - ).first() - - original_node_execution_metadata = workflow_node_execution.execution_metadata_dict - if original_node_execution_metadata: - original_node_execution_metadata['current_index'] = event.index - original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary - original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens - workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata) - - db.session.commit() - - db.session.close() - - def _handle_iteration_completed(self, event: QueueIterationCompletedEvent): - if event.node_id not in self._iteration_state.current_iterations: - return - - current_iteration = self._iteration_state.current_iterations[event.node_id] - workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == current_iteration.node_execution_id - ).first() - - workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value - workflow_node_execution.outputs = json.dumps(WorkflowEngineManager.handle_special_values(event.outputs)) if event.outputs else None - workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at - - original_node_execution_metadata = workflow_node_execution.execution_metadata_dict - if original_node_execution_metadata: - original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary - original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens - workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata) - - db.session.commit() - - # remove current iteration - self._iteration_state.current_iterations.pop(event.node_id, None) - - # set latest node execution info - latest_node_execution_info = NodeExecutionInfo( - workflow_node_execution_id=workflow_node_execution.id, - node_type=NodeType.ITERATION, - start_at=time.perf_counter() - ) - - self._task_state.latest_node_execution_info = latest_node_execution_info - - db.session.close() - - def _handle_iteration_exception(self, task_id: str, error: str) -> Generator[IterationNodeCompletedStreamResponse, None, None]: - """ - Handle iteration exception - """ - if not self._iteration_state or not self._iteration_state.current_iterations: - return - - for node_id, current_iteration in self._iteration_state.current_iterations.items(): - workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == current_iteration.node_execution_id - ).first() - - workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value - workflow_node_execution.error = error - workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at - - db.session.commit() - db.session.close() - - yield IterationNodeCompletedStreamResponse( - task_id=task_id, - workflow_run_id=self._task_state.workflow_run_id, - data=IterationNodeCompletedStreamResponse.Data( - id=node_id, - node_id=node_id, - node_type=NodeType.ITERATION.value, - title=current_iteration.node_data.title, - outputs={}, - created_at=int(time.time()), - extras={}, - inputs=current_iteration.inputs, - status=WorkflowNodeExecutionStatus.FAILED, - error=error, - elapsed_time=time.perf_counter() - current_iteration.started_at, - total_tokens=current_iteration.total_tokens, - execution_metadata={ - 'total_tokens': current_iteration.total_tokens, - }, - finished_at=int(time.time()), - steps=current_iteration.current_index - ) - ) diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index 578996574739a8..d826edf6a0fc19 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -1,9 +1,9 @@ -import os from collections.abc import Mapping, Sequence from typing import Any, Optional, TextIO, Union from pydantic import BaseModel +from configs import dify_config from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.tools.entities.tool_entities import ToolInvokeMessage @@ -16,31 +16,32 @@ "red": "31;1", } + def get_colored_text(text: str, color: str) -> str: """Get colored text.""" color_str = _TEXT_COLOR_MAPPING[color] return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" -def print_text( - text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None -) -> None: +def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None) -> None: """Print text with highlighting and no end characters.""" text_to_print = get_colored_text(text, color) if color else text print(text_to_print, end=end, file=file) if file: file.flush() # ensure all printed content are written to file + class DifyAgentCallbackHandler(BaseModel): """Callback Handler that prints to std out.""" - color: Optional[str] = '' + + color: Optional[str] = "" current_loop: int = 1 def __init__(self, color: Optional[str] = None) -> None: super().__init__() """Initialize callback handler.""" # use a specific color is not specified - self.color = color or 'green' + self.color = color or "green" self.current_loop = 1 def on_tool_start( @@ -49,7 +50,8 @@ def on_tool_start( tool_inputs: Mapping[str, Any], ) -> None: """Do nothing.""" - print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color) + if dify_config.DEBUG: + print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color) def on_tool_end( self, @@ -58,14 +60,15 @@ def on_tool_end( tool_outputs: Sequence[ToolInvokeMessage], message_id: Optional[str] = None, timer: Optional[Any] = None, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> None: """If not the final action, print out observation.""" - print_text("\n[on_tool_end]\n", color=self.color) - print_text("Tool: " + tool_name + "\n", color=self.color) - print_text("Inputs: " + str(tool_inputs) + "\n", color=self.color) - print_text("Outputs: " + str(tool_outputs)[:1000] + "\n", color=self.color) - print_text("\n") + if dify_config.DEBUG: + print_text("\n[on_tool_end]\n", color=self.color) + print_text("Tool: " + tool_name + "\n", color=self.color) + print_text("Inputs: " + str(tool_inputs) + "\n", color=self.color) + print_text("Outputs: " + str(tool_outputs)[:1000] + "\n", color=self.color) + print_text("\n") if trace_manager: trace_manager.add_trace_task( @@ -79,37 +82,35 @@ def on_tool_end( ) ) - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None: """Do nothing.""" - print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='red') + if dify_config.DEBUG: + print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red") - def on_agent_start( - self, thought: str - ) -> None: + def on_agent_start(self, thought: str) -> None: """Run on agent start.""" - if thought: - print_text("\n[on_agent_start] \nCurrent Loop: " + \ - str(self.current_loop) + \ - "\nThought: " + thought + "\n", color=self.color) - else: - print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color) - - def on_agent_finish( - self, color: Optional[str] = None, **kwargs: Any - ) -> None: + if dify_config.DEBUG: + if thought: + print_text( + "\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\nThought: " + thought + "\n", + color=self.color, + ) + else: + print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color) + + def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any) -> None: """Run on agent end.""" - print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color) + if dify_config.DEBUG: + print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color) self.current_loop += 1 @property def ignore_agent(self) -> bool: """Whether to ignore agent callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' + return not dify_config.DEBUG @property def ignore_chat_model(self) -> bool: """Whether to ignore chat model callbacks.""" - return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true' + return not dify_config.DEBUG diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 8e1f496b226c14..1481578630f63b 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,4 +1,3 @@ - from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueRetrieverResourcesEvent @@ -11,11 +10,9 @@ class DatasetIndexToolCallbackHandler: """Callback handler for dataset tool.""" - def __init__(self, queue_manager: AppQueueManager, - app_id: str, - message_id: str, - user_id: str, - invoke_from: InvokeFrom) -> None: + def __init__( + self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom + ) -> None: self._queue_manager = queue_manager self._app_id = app_id self._message_id = message_id @@ -29,11 +26,12 @@ def on_query(self, query: str, dataset_id: str) -> None: dataset_query = DatasetQuery( dataset_id=dataset_id, content=query, - source='app', + source="app", source_app_id=self._app_id, - created_by_role=('account' - if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'), - created_by=self._user_id + created_by_role=( + "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" + ), + created_by=self._user_id, ) db.session.add(dataset_query) @@ -43,18 +41,14 @@ def on_tool_end(self, documents: list[Document]) -> None: """Handle tool end.""" for document in documents: query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata['doc_id'] + DocumentSegment.index_node_id == document.metadata["doc_id"] ) - # if 'dataset_id' in document.metadata: - if 'dataset_id' in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id']) + if "dataset_id" in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) # add hit count to document segment - query.update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, - synchronize_session=False - ) + query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) db.session.commit() @@ -64,26 +58,25 @@ def return_retriever_resource_info(self, resource: list): for item in resource: dataset_retriever_resource = DatasetRetrieverResource( message_id=self._message_id, - position=item.get('position'), - dataset_id=item.get('dataset_id'), - dataset_name=item.get('dataset_name'), - document_id=item.get('document_id'), - document_name=item.get('document_name'), - data_source_type=item.get('data_source_type'), - segment_id=item.get('segment_id'), - score=item.get('score') if 'score' in item else None, - hit_count=item.get('hit_count') if 'hit_count' else None, - word_count=item.get('word_count') if 'word_count' in item else None, - segment_position=item.get('segment_position') if 'segment_position' in item else None, - index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None, - content=item.get('content'), - retriever_from=item.get('retriever_from'), - created_by=self._user_id + position=item.get("position") or 0, + dataset_id=item.get("dataset_id"), + dataset_name=item.get("dataset_name"), + document_id=item.get("document_id"), + document_name=item.get("document_name"), + data_source_type=item.get("data_source_type"), + segment_id=item.get("segment_id"), + score=item.get("score") if "score" in item else None, + hit_count=item.get("hit_count") if "hit_count" in item else None, + word_count=item.get("word_count") if "word_count" in item else None, + segment_position=item.get("segment_position") if "segment_position" in item else None, + index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None, + content=item.get("content"), + retriever_from=item.get("retriever_from"), + created_by=self._user_id, ) db.session.add(dataset_retriever_resource) db.session.commit() self._queue_manager.publish( - QueueRetrieverResourcesEvent(retriever_resources=resource), - PublishFrom.APPLICATION_MANAGER + QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER ) diff --git a/api/core/callback_handler/workflow_tool_callback_handler.py b/api/core/callback_handler/workflow_tool_callback_handler.py index 84bab7e1a3d22f..8ac12f72f29d6c 100644 --- a/api/core/callback_handler/workflow_tool_callback_handler.py +++ b/api/core/callback_handler/workflow_tool_callback_handler.py @@ -2,4 +2,4 @@ class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler): - """Callback Handler that prints to std out.""" \ No newline at end of file + """Callback Handler that prints to std out.""" diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py deleted file mode 100644 index b7e0cc0c2b2ae6..00000000000000 --- a/api/core/embedding/cached_embedding.py +++ /dev/null @@ -1,121 +0,0 @@ -import base64 -import logging -from typing import Optional, cast - -import numpy as np -from sqlalchemy.exc import IntegrityError - -from core.model_manager import ModelInstance -from core.model_runtime.entities.model_entities import ModelPropertyKey -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.rag.datasource.entity.embedding import Embeddings -from extensions.ext_database import db -from extensions.ext_redis import redis_client -from libs import helper -from models.dataset import Embedding - -logger = logging.getLogger(__name__) - - -class CacheEmbedding(Embeddings): - def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> None: - self._model_instance = model_instance - self._user = user - - def embed_documents(self, texts: list[str]) -> list[list[float]]: - """Embed search docs in batches of 10.""" - # use doc embedding cache or store if not exists - text_embeddings = [None for _ in range(len(texts))] - embedding_queue_indices = [] - for i, text in enumerate(texts): - hash = helper.generate_text_hash(text) - embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, - hash=hash, - provider_name=self._model_instance.provider).first() - if embedding: - text_embeddings[i] = embedding.get_embedding() - else: - embedding_queue_indices.append(i) - if embedding_queue_indices: - embedding_queue_texts = [texts[i] for i in embedding_queue_indices] - embedding_queue_embeddings = [] - try: - model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) - model_schema = model_type_instance.get_model_schema(self._model_instance.model, - self._model_instance.credentials) - max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \ - if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1 - for i in range(0, len(embedding_queue_texts), max_chunks): - batch_texts = embedding_queue_texts[i:i + max_chunks] - - embedding_result = self._model_instance.invoke_text_embedding( - texts=batch_texts, - user=self._user - ) - - for vector in embedding_result.embeddings: - try: - normalized_embedding = (vector / np.linalg.norm(vector)).tolist() - embedding_queue_embeddings.append(normalized_embedding) - except IntegrityError: - db.session.rollback() - except Exception as e: - logging.exception('Failed transform embedding: ', e) - cache_embeddings = [] - try: - for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): - text_embeddings[i] = embedding - hash = helper.generate_text_hash(texts[i]) - if hash not in cache_embeddings: - embedding_cache = Embedding(model_name=self._model_instance.model, - hash=hash, - provider_name=self._model_instance.provider) - embedding_cache.set_embedding(embedding) - db.session.add(embedding_cache) - cache_embeddings.append(hash) - db.session.commit() - except IntegrityError: - db.session.rollback() - except Exception as ex: - db.session.rollback() - logger.error('Failed to embed documents: ', ex) - raise ex - - return text_embeddings - - def embed_query(self, text: str) -> list[float]: - """Embed query text.""" - # use doc embedding cache or store if not exists - hash = helper.generate_text_hash(text) - embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}' - embedding = redis_client.get(embedding_cache_key) - if embedding: - redis_client.expire(embedding_cache_key, 600) - return list(np.frombuffer(base64.b64decode(embedding), dtype="float")) - try: - embedding_result = self._model_instance.invoke_text_embedding( - texts=[text], - user=self._user - ) - - embedding_results = embedding_result.embeddings[0] - embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() - except Exception as ex: - raise ex - - try: - # encode embedding to base64 - embedding_vector = np.array(embedding_results) - vector_bytes = embedding_vector.tobytes() - # Transform to Base64 - encoded_vector = base64.b64encode(vector_bytes) - # Transform to string - encoded_str = encoded_vector.decode("utf-8") - redis_client.setex(embedding_cache_key, 600, encoded_str) - - except IntegrityError: - db.session.rollback() - except: - logging.exception('Failed to add embedding to redis') - - return embedding_results diff --git a/api/core/entities/agent_entities.py b/api/core/entities/agent_entities.py index 0cdf8670c492c2..656bf4aa724893 100644 --- a/api/core/entities/agent_entities.py +++ b/api/core/entities/agent_entities.py @@ -2,7 +2,7 @@ class PlanningStrategy(Enum): - ROUTER = 'router' - REACT_ROUTER = 'react_router' - REACT = 'react' - FUNCTION_CALL = 'function_call' + ROUTER = "router" + REACT_ROUTER = "react_router" + REACT = "react" + FUNCTION_CALL = "function_call" diff --git a/api/core/entities/embedding_type.py b/api/core/entities/embedding_type.py new file mode 100644 index 00000000000000..9b4934646bc0e8 --- /dev/null +++ b/api/core/entities/embedding_type.py @@ -0,0 +1,10 @@ +from enum import Enum + + +class EmbeddingInputType(Enum): + """ + Enum for embedding input type. + """ + + DOCUMENT = "document" + QUERY = "query" diff --git a/api/core/entities/message_entities.py b/api/core/entities/message_entities.py deleted file mode 100644 index 370aeee4633550..00000000000000 --- a/api/core/entities/message_entities.py +++ /dev/null @@ -1,29 +0,0 @@ -import enum -from typing import Any - -from pydantic import BaseModel - - -class PromptMessageFileType(enum.Enum): - IMAGE = 'image' - - @staticmethod - def value_of(value): - for member in PromptMessageFileType: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class PromptMessageFile(BaseModel): - type: PromptMessageFileType - data: Any = None - - -class ImagePromptMessageFile(PromptMessageFile): - class DETAIL(enum.Enum): - LOW = 'low' - HIGH = 'high' - - type: PromptMessageFileType = PromptMessageFileType.IMAGE - detail: DETAIL = DETAIL.LOW diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 22a21ecf9331ea..9ed5528e43b9b8 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -12,6 +12,7 @@ class ModelStatus(Enum): """ Enum class for model status. """ + ACTIVE = "active" NO_CONFIGURE = "no-configure" QUOTA_EXCEEDED = "quota-exceeded" @@ -23,6 +24,7 @@ class SimpleModelProviderEntity(BaseModel): """ Simple provider. """ + provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -40,7 +42,7 @@ def __init__(self, provider_entity: ProviderEntity) -> None: label=provider_entity.label, icon_small=provider_entity.icon_small, icon_large=provider_entity.icon_large, - supported_model_types=provider_entity.supported_model_types + supported_model_types=provider_entity.supported_model_types, ) @@ -48,6 +50,7 @@ class ProviderModelWithStatusEntity(ProviderModel): """ Model class for model response. """ + status: ModelStatus load_balancing_enabled: bool = False @@ -56,6 +59,7 @@ class ModelWithProviderEntity(ProviderModelWithStatusEntity): """ Model with provider entity. """ + provider: SimpleModelProviderEntity @@ -63,6 +67,7 @@ class DefaultModelProviderEntity(BaseModel): """ Default model provider entity. """ + provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -74,6 +79,7 @@ class DefaultModelEntity(BaseModel): """ Default model entity. """ + model: str model_type: ModelType provider: DefaultModelProviderEntity diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 778ef2e1ac42ad..807f09598c7607 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -47,6 +47,7 @@ class ProviderConfiguration(BaseModel): """ Model class for provider configuration. """ + tenant_id: str provider: ProviderEntity preferred_provider_type: ProviderType @@ -67,9 +68,13 @@ def __init__(self, **data): original_provider_configurate_methods[self.provider.provider].append(configurate_method) if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: - if (any(len(quota_configuration.restrict_models) > 0 - for quota_configuration in self.system_configuration.quota_configurations) - and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods): + if ( + any( + len(quota_configuration.restrict_models) > 0 + for quota_configuration in self.system_configuration.quota_configurations + ) + and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods + ): self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: @@ -83,10 +88,9 @@ def get_current_credentials(self, model_type: ModelType, model: str) -> Optional if self.model_settings: # check if model is disabled by admin for model_setting in self.model_settings: - if (model_setting.model_type == model_type - and model_setting.model == model): + if model_setting.model_type == model_type and model_setting.model == model: if not model_setting.enabled: - raise ValueError(f'Model {model} is disabled.') + raise ValueError(f"Model {model} is disabled.") if self.using_provider_type == ProviderType.SYSTEM: restrict_models = [] @@ -99,10 +103,12 @@ def get_current_credentials(self, model_type: ModelType, model: str) -> Optional copy_credentials = self.system_configuration.credentials.copy() if restrict_models: for restrict_model in restrict_models: - if (restrict_model.model_type == model_type - and restrict_model.model == model - and restrict_model.base_model_name): - copy_credentials['base_model_name'] = restrict_model.base_model_name + if ( + restrict_model.model_type == model_type + and restrict_model.model == model + and restrict_model.base_model_name + ): + copy_credentials["base_model_name"] = restrict_model.base_model_name return copy_credentials else: @@ -113,7 +119,7 @@ def get_current_credentials(self, model_type: ModelType, model: str) -> Optional credentials = model_configuration.credentials break - if self.custom_configuration.provider: + if not credentials and self.custom_configuration.provider: credentials = self.custom_configuration.provider.credentials return credentials @@ -128,20 +134,21 @@ def get_system_configuration_status(self) -> SystemConfigurationStatus: current_quota_type = self.system_configuration.current_quota_type current_quota_configuration = next( - (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), - None + (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None ) - return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \ - SystemConfigurationStatus.QUOTA_EXCEEDED + return ( + SystemConfigurationStatus.ACTIVE + if current_quota_configuration.is_valid + else SystemConfigurationStatus.QUOTA_EXCEEDED + ) def is_custom_configuration_available(self) -> bool: """ Check custom configuration available. :return: """ - return (self.custom_configuration.provider is not None - or len(self.custom_configuration.models) > 0) + return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: """ @@ -161,7 +168,8 @@ def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: return self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas - if self.provider.provider_credential_schema else [] + if self.provider.provider_credential_schema + else [], ) def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]: @@ -171,17 +179,21 @@ def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict :return: """ # get provider - provider_record = db.session.query(Provider) \ + provider_record = ( + db.session.query(Provider) .filter( - Provider.tenant_id == self.tenant_id, - Provider.provider_name == self.provider.provider, - Provider.provider_type == ProviderType.CUSTOM.value - ).first() + Provider.tenant_id == self.tenant_id, + Provider.provider_name == self.provider.provider, + Provider.provider_type == ProviderType.CUSTOM.value, + ) + .first() + ) # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( self.provider.provider_credential_schema.credential_form_schemas - if self.provider.provider_credential_schema else [] + if self.provider.provider_credential_schema + else [] ) if provider_record: @@ -189,9 +201,7 @@ def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict # fix origin data if provider_record.encrypted_config: if not provider_record.encrypted_config.startswith("{"): - original_credentials = { - "openai_api_key": provider_record.encrypted_config - } + original_credentials = {"openai_api_key": provider_record.encrypted_config} else: original_credentials = json.loads(provider_record.encrypted_config) else: @@ -207,8 +217,7 @@ def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials = model_provider_factory.provider_credentials_validate( - provider=self.provider.provider, - credentials=credentials + provider=self.provider.provider, credentials=credentials ) for key, value in credentials.items(): @@ -239,15 +248,13 @@ def add_or_update_custom_credentials(self, credentials: dict) -> None: provider_name=self.provider.provider, provider_type=ProviderType.CUSTOM.value, encrypted_config=json.dumps(credentials), - is_valid=True + is_valid=True, ) db.session.add(provider_record) db.session.commit() provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER + tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER ) provider_model_credentials_cache.delete() @@ -260,12 +267,15 @@ def delete_custom_credentials(self) -> None: :return: """ # get provider - provider_record = db.session.query(Provider) \ + provider_record = ( + db.session.query(Provider) .filter( - Provider.tenant_id == self.tenant_id, - Provider.provider_name == self.provider.provider, - Provider.provider_type == ProviderType.CUSTOM.value - ).first() + Provider.tenant_id == self.tenant_id, + Provider.provider_name == self.provider.provider, + Provider.provider_type == ProviderType.CUSTOM.value, + ) + .first() + ) # delete provider if provider_record: @@ -277,13 +287,14 @@ def delete_custom_credentials(self) -> None: provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER + cache_type=ProviderCredentialsCacheType.PROVIDER, ) provider_model_credentials_cache.delete() - def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \ - -> Optional[dict]: + def get_custom_model_credentials( + self, model_type: ModelType, model: str, obfuscated: bool = False + ) -> Optional[dict]: """ Get custom model credentials. @@ -305,13 +316,15 @@ def get_custom_model_credentials(self, model_type: ModelType, model: str, obfusc return self.obfuscated_credentials( credentials=credentials, credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas - if self.provider.model_credential_schema else [] + if self.provider.model_credential_schema + else [], ) return None - def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \ - -> tuple[ProviderModel, dict]: + def custom_model_credentials_validate( + self, model_type: ModelType, model: str, credentials: dict + ) -> tuple[ProviderModel, dict]: """ Validate custom model credentials. @@ -321,24 +334,29 @@ def custom_model_credentials_validate(self, model_type: ModelType, model: str, c :return: """ # get provider model - provider_model_record = db.session.query(ProviderModel) \ + provider_model_record = ( + db.session.query(ProviderModel) .filter( - ProviderModel.tenant_id == self.tenant_id, - ProviderModel.provider_name == self.provider.provider, - ProviderModel.model_name == model, - ProviderModel.model_type == model_type.to_origin_model_type() - ).first() + ProviderModel.tenant_id == self.tenant_id, + ProviderModel.provider_name == self.provider.provider, + ProviderModel.model_name == model, + ProviderModel.model_type == model_type.to_origin_model_type(), + ) + .first() + ) # Get provider credential secret variables provider_credential_secret_variables = self.extract_secret_variables( self.provider.model_credential_schema.credential_form_schemas - if self.provider.model_credential_schema else [] + if self.provider.model_credential_schema + else [] ) if provider_model_record: try: - original_credentials = json.loads( - provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} + original_credentials = ( + json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} + ) except JSONDecodeError: original_credentials = {} @@ -350,10 +368,7 @@ def custom_model_credentials_validate(self, model_type: ModelType, model: str, c credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) credentials = model_provider_factory.model_credentials_validate( - provider=self.provider.provider, - model_type=model_type, - model=model, - credentials=credentials + provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) for key, value in credentials.items(): @@ -388,7 +403,7 @@ def add_or_update_custom_model_credentials(self, model_type: ModelType, model: s model_name=model, model_type=model_type.to_origin_model_type(), encrypted_config=json.dumps(credentials), - is_valid=True + is_valid=True, ) db.session.add(provider_model_record) db.session.commit() @@ -396,7 +411,7 @@ def add_or_update_custom_model_credentials(self, model_type: ModelType, model: s provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL + cache_type=ProviderCredentialsCacheType.MODEL, ) provider_model_credentials_cache.delete() @@ -409,13 +424,16 @@ def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> :return: """ # get provider model - provider_model_record = db.session.query(ProviderModel) \ + provider_model_record = ( + db.session.query(ProviderModel) .filter( - ProviderModel.tenant_id == self.tenant_id, - ProviderModel.provider_name == self.provider.provider, - ProviderModel.model_name == model, - ProviderModel.model_type == model_type.to_origin_model_type() - ).first() + ProviderModel.tenant_id == self.tenant_id, + ProviderModel.provider_name == self.provider.provider, + ProviderModel.model_name == model, + ProviderModel.model_type == model_type.to_origin_model_type(), + ) + .first() + ) # delete provider model if provider_model_record: @@ -425,7 +443,7 @@ def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=self.tenant_id, identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL + cache_type=ProviderCredentialsCacheType.MODEL, ) provider_model_credentials_cache.delete() @@ -437,13 +455,16 @@ def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSettin :param model: model name :return: """ - model_setting = db.session.query(ProviderModelSetting) \ + model_setting = ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) if model_setting: model_setting.enabled = True @@ -455,7 +476,7 @@ def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSettin provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, - enabled=True + enabled=True, ) db.session.add(model_setting) db.session.commit() @@ -469,13 +490,16 @@ def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetti :param model: model name :return: """ - model_setting = db.session.query(ProviderModelSetting) \ + model_setting = ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) if model_setting: model_setting.enabled = False @@ -487,7 +511,7 @@ def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetti provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, - enabled=False + enabled=False, ) db.session.add(model_setting) db.session.commit() @@ -501,13 +525,16 @@ def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optio :param model: model name :return: """ - return db.session.query(ProviderModelSetting) \ + return ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: """ @@ -516,24 +543,30 @@ def enable_model_load_balancing(self, model_type: ModelType, model: str) -> Prov :param model: model name :return: """ - load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \ + load_balancing_config_count = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == self.tenant_id, - LoadBalancingModelConfig.provider_name == self.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model - ).count() + LoadBalancingModelConfig.tenant_id == self.tenant_id, + LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + ) + .count() + ) if load_balancing_config_count <= 1: - raise ValueError('Model load balancing configuration must be more than 1.') + raise ValueError("Model load balancing configuration must be more than 1.") - model_setting = db.session.query(ProviderModelSetting) \ + model_setting = ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) if model_setting: model_setting.load_balancing_enabled = True @@ -545,7 +578,7 @@ def enable_model_load_balancing(self, model_type: ModelType, model: str) -> Prov provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, - load_balancing_enabled=True + load_balancing_enabled=True, ) db.session.add(model_setting) db.session.commit() @@ -559,13 +592,16 @@ def disable_model_load_balancing(self, model_type: ModelType, model: str) -> Pro :param model: model name :return: """ - model_setting = db.session.query(ProviderModelSetting) \ + model_setting = ( + db.session.query(ProviderModelSetting) .filter( - ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name == self.provider.provider, - ProviderModelSetting.model_type == model_type.to_origin_model_type(), - ProviderModelSetting.model_name == model - ).first() + ProviderModelSetting.tenant_id == self.tenant_id, + ProviderModelSetting.provider_name == self.provider.provider, + ProviderModelSetting.model_type == model_type.to_origin_model_type(), + ProviderModelSetting.model_name == model, + ) + .first() + ) if model_setting: model_setting.load_balancing_enabled = False @@ -577,7 +613,7 @@ def disable_model_load_balancing(self, model_type: ModelType, model: str) -> Pro provider_name=self.provider.provider, model_type=model_type.to_origin_model_type(), model_name=model, - load_balancing_enabled=False + load_balancing_enabled=False, ) db.session.add(model_setting) db.session.commit() @@ -617,11 +653,14 @@ def switch_preferred_provider_type(self, provider_type: ProviderType) -> None: return # get preferred provider - preferred_model_provider = db.session.query(TenantPreferredModelProvider) \ + preferred_model_provider = ( + db.session.query(TenantPreferredModelProvider) .filter( - TenantPreferredModelProvider.tenant_id == self.tenant_id, - TenantPreferredModelProvider.provider_name == self.provider.provider - ).first() + TenantPreferredModelProvider.tenant_id == self.tenant_id, + TenantPreferredModelProvider.provider_name == self.provider.provider, + ) + .first() + ) if preferred_model_provider: preferred_model_provider.preferred_provider_type = provider_type.value @@ -629,7 +668,7 @@ def switch_preferred_provider_type(self, provider_type: ProviderType) -> None: preferred_model_provider = TenantPreferredModelProvider( tenant_id=self.tenant_id, provider_name=self.provider.provider, - preferred_provider_type=provider_type.value + preferred_provider_type=provider_type.value, ) db.session.add(preferred_model_provider) @@ -658,9 +697,7 @@ def obfuscated_credentials(self, credentials: dict, credential_form_schemas: lis :return: """ # Get provider credential secret variables - credential_secret_variables = self.extract_secret_variables( - credential_form_schemas - ) + credential_secret_variables = self.extract_secret_variables(credential_form_schemas) # Obfuscate provider credentials copy_credentials = credentials.copy() @@ -670,9 +707,9 @@ def obfuscated_credentials(self, credentials: dict, credential_form_schemas: lis return copy_credentials - def get_provider_model(self, model_type: ModelType, - model: str, - only_active: bool = False) -> Optional[ModelWithProviderEntity]: + def get_provider_model( + self, model_type: ModelType, model: str, only_active: bool = False + ) -> Optional[ModelWithProviderEntity]: """ Get provider model. :param model_type: model type @@ -688,8 +725,9 @@ def get_provider_model(self, model_type: ModelType, return None - def get_provider_models(self, model_type: Optional[ModelType] = None, - only_active: bool = False) -> list[ModelWithProviderEntity]: + def get_provider_models( + self, model_type: Optional[ModelType] = None, only_active: bool = False + ) -> list[ModelWithProviderEntity]: """ Get provider models. :param model_type: model type @@ -711,15 +749,11 @@ def get_provider_models(self, model_type: Optional[ModelType] = None, if self.using_provider_type == ProviderType.SYSTEM: provider_models = self._get_system_provider_models( - model_types=model_types, - provider_instance=provider_instance, - model_setting_map=model_setting_map + model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map ) else: provider_models = self._get_custom_provider_models( - model_types=model_types, - provider_instance=provider_instance, - model_setting_map=model_setting_map + model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map ) if only_active: @@ -728,11 +762,12 @@ def get_provider_models(self, model_type: Optional[ModelType] = None, # resort provider_models return sorted(provider_models, key=lambda x: x.model_type.value) - def _get_system_provider_models(self, - model_types: list[ModelType], - provider_instance: ModelProvider, - model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \ - -> list[ModelWithProviderEntity]: + def _get_system_provider_models( + self, + model_types: list[ModelType], + provider_instance: ModelProvider, + model_setting_map: dict[ModelType, dict[str, ModelSettings]], + ) -> list[ModelWithProviderEntity]: """ Get system provider models. @@ -760,7 +795,7 @@ def _get_system_provider_models(self, model_properties=m.model_properties, deprecated=m.deprecated, provider=SimpleModelProviderEntity(self.provider), - status=status + status=status, ) ) @@ -783,23 +818,20 @@ def _get_system_provider_models(self, if should_use_custom_model: if original_provider_configurate_methods[self.provider.provider] == [ - ConfigurateMethod.CUSTOMIZABLE_MODEL]: + ConfigurateMethod.CUSTOMIZABLE_MODEL + ]: # only customizable model for restrict_model in restrict_models: copy_credentials = self.system_configuration.credentials.copy() if restrict_model.base_model_name: - copy_credentials['base_model_name'] = restrict_model.base_model_name + copy_credentials["base_model_name"] = restrict_model.base_model_name try: - custom_model_schema = ( - provider_instance.get_model_instance(restrict_model.model_type) - .get_customizable_model_schema_from_credentials( - restrict_model.model, - copy_credentials - ) - ) + custom_model_schema = provider_instance.get_model_instance( + restrict_model.model_type + ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials) except Exception as ex: - logger.warning(f'get custom model schema failed, {ex}') + logger.warning(f"get custom model schema failed, {ex}") continue if not custom_model_schema: @@ -809,8 +841,10 @@ def _get_system_provider_models(self, continue status = ModelStatus.ACTIVE - if (custom_model_schema.model_type in model_setting_map - and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]): + if ( + custom_model_schema.model_type in model_setting_map + and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] + ): model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] if model_setting.enabled is False: status = ModelStatus.DISABLED @@ -825,7 +859,7 @@ def _get_system_provider_models(self, model_properties=custom_model_schema.model_properties, deprecated=custom_model_schema.deprecated, provider=SimpleModelProviderEntity(self.provider), - status=status + status=status, ) ) @@ -839,11 +873,12 @@ def _get_system_provider_models(self, return provider_models - def _get_custom_provider_models(self, - model_types: list[ModelType], - provider_instance: ModelProvider, - model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \ - -> list[ModelWithProviderEntity]: + def _get_custom_provider_models( + self, + model_types: list[ModelType], + provider_instance: ModelProvider, + model_setting_map: dict[ModelType, dict[str, ModelSettings]], + ) -> list[ModelWithProviderEntity]: """ Get custom provider models. @@ -885,7 +920,7 @@ def _get_custom_provider_models(self, deprecated=m.deprecated, provider=SimpleModelProviderEntity(self.provider), status=status, - load_balancing_enabled=load_balancing_enabled + load_balancing_enabled=load_balancing_enabled, ) ) @@ -895,15 +930,13 @@ def _get_custom_provider_models(self, continue try: - custom_model_schema = ( - provider_instance.get_model_instance(model_configuration.model_type) - .get_customizable_model_schema_from_credentials( - model_configuration.model, - model_configuration.credentials - ) + custom_model_schema = provider_instance.get_model_instance( + model_configuration.model_type + ).get_customizable_model_schema_from_credentials( + model_configuration.model, model_configuration.credentials ) except Exception as ex: - logger.warning(f'get custom model schema failed, {ex}') + logger.warning(f"get custom model schema failed, {ex}") continue if not custom_model_schema: @@ -911,8 +944,10 @@ def _get_custom_provider_models(self, status = ModelStatus.ACTIVE load_balancing_enabled = False - if (custom_model_schema.model_type in model_setting_map - and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]): + if ( + custom_model_schema.model_type in model_setting_map + and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] + ): model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] if model_setting.enabled is False: status = ModelStatus.DISABLED @@ -931,7 +966,7 @@ def _get_custom_provider_models(self, deprecated=custom_model_schema.deprecated, provider=SimpleModelProviderEntity(self.provider), status=status, - load_balancing_enabled=load_balancing_enabled + load_balancing_enabled=load_balancing_enabled, ) ) @@ -942,17 +977,16 @@ class ProviderConfigurations(BaseModel): """ Model class for provider configuration dict. """ + tenant_id: str configurations: dict[str, ProviderConfiguration] = {} def __init__(self, tenant_id: str): super().__init__(tenant_id=tenant_id) - def get_models(self, - provider: Optional[str] = None, - model_type: Optional[ModelType] = None, - only_active: bool = False) \ - -> list[ModelWithProviderEntity]: + def get_models( + self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False + ) -> list[ModelWithProviderEntity]: """ Get available models. @@ -1019,10 +1053,10 @@ class ProviderModelBundle(BaseModel): """ Provider model bundle. """ + configuration: ProviderConfiguration provider_instance: ModelProvider model_type_instance: AIModel # pydantic configs - model_config = ConfigDict(arbitrary_types_allowed=True, - protected_namespaces=()) + model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 0d5b0a1b2c6ba6..44725623dc4bd4 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -8,18 +8,19 @@ class QuotaUnit(Enum): - TIMES = 'times' - TOKENS = 'tokens' - CREDITS = 'credits' + TIMES = "times" + TOKENS = "tokens" + CREDITS = "credits" class SystemConfigurationStatus(Enum): """ Enum class for system configuration status. """ - ACTIVE = 'active' - QUOTA_EXCEEDED = 'quota-exceeded' - UNSUPPORTED = 'unsupported' + + ACTIVE = "active" + QUOTA_EXCEEDED = "quota-exceeded" + UNSUPPORTED = "unsupported" class RestrictModel(BaseModel): @@ -35,6 +36,7 @@ class QuotaConfiguration(BaseModel): """ Model class for provider quota configuration. """ + quota_type: ProviderQuotaType quota_unit: QuotaUnit quota_limit: int @@ -47,6 +49,7 @@ class SystemConfiguration(BaseModel): """ Model class for provider system configuration. """ + enabled: bool current_quota_type: Optional[ProviderQuotaType] = None quota_configurations: list[QuotaConfiguration] = [] @@ -57,6 +60,7 @@ class CustomProviderConfiguration(BaseModel): """ Model class for provider custom configuration. """ + credentials: dict @@ -64,6 +68,7 @@ class CustomModelConfiguration(BaseModel): """ Model class for provider custom model configuration. """ + model: str model_type: ModelType credentials: dict @@ -76,6 +81,7 @@ class CustomConfiguration(BaseModel): """ Model class for provider custom configuration. """ + provider: Optional[CustomProviderConfiguration] = None models: list[CustomModelConfiguration] = [] @@ -84,6 +90,7 @@ class ModelLoadBalancingConfiguration(BaseModel): """ Class for model load balancing configuration. """ + id: str name: str credentials: dict @@ -93,6 +100,7 @@ class ModelSettings(BaseModel): """ Model class for model settings. """ + model: str model_type: ModelType enabled: bool = True diff --git a/api/core/errors/error.py b/api/core/errors/error.py index 859a747c12157f..3b186476ebe977 100644 --- a/api/core/errors/error.py +++ b/api/core/errors/error.py @@ -3,6 +3,7 @@ class LLMError(Exception): """Base class for all LLM exceptions.""" + description: Optional[str] = None def __init__(self, description: Optional[str] = None) -> None: @@ -11,6 +12,7 @@ def __init__(self, description: Optional[str] = None) -> None: class LLMBadRequestError(LLMError): """Raised when the LLM returns bad request.""" + description = "Bad Request" @@ -18,6 +20,7 @@ class ProviderTokenNotInitError(Exception): """ Custom exception raised when the provider token is not initialized. """ + description = "Provider Token Not Init" def __init__(self, *args, **kwargs): @@ -28,6 +31,7 @@ class QuotaExceededError(Exception): """ Custom exception raised when the quota for a provider has been exceeded. """ + description = "Quota Exceeded" @@ -35,6 +39,7 @@ class AppInvokeQuotaExceededError(Exception): """ Custom exception raised when the quota for an app has been exceeded. """ + description = "App Invoke Quota Exceeded" @@ -42,4 +47,11 @@ class ModelCurrentlyNotSupportError(Exception): """ Custom exception raised when the model not support """ + description = "Model Currently Not Support" + + +class InvokeRateLimitError(Exception): + """Raised when the Invoke returns rate limit error.""" + + description = "Rate Limit Error" diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index 4db7a999736c5b..38cebb6b6b1c36 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -20,10 +20,7 @@ def request(self, point: APIBasedExtensionPoint, params: dict) -> dict: :param params: the request params :return: the response json """ - headers = { - "Content-Type": "application/json", - "Authorization": "Bearer {}".format(self.api_key) - } + headers = {"Content-Type": "application/json", "Authorization": "Bearer {}".format(self.api_key)} url = self.api_endpoint @@ -32,20 +29,17 @@ def request(self, point: APIBasedExtensionPoint, params: dict) -> dict: proxies = None if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: proxies = { - 'http': dify_config.SSRF_PROXY_HTTP_URL, - 'https': dify_config.SSRF_PROXY_HTTPS_URL, + "http": dify_config.SSRF_PROXY_HTTP_URL, + "https": dify_config.SSRF_PROXY_HTTPS_URL, } response = requests.request( - method='POST', + method="POST", url=url, - json={ - 'point': point.value, - 'params': params - }, + json={"point": point.value, "params": params}, headers=headers, timeout=self.timeout, - proxies=proxies + proxies=proxies, ) except requests.exceptions.Timeout: raise ValueError("request timeout") @@ -53,9 +47,8 @@ def request(self, point: APIBasedExtensionPoint, params: dict) -> dict: raise ValueError("request connection error") if response.status_code != 200: - raise ValueError("request error, status_code: {}, content: {}".format( - response.status_code, - response.text[:100] - )) + raise ValueError( + "request error, status_code: {}, content: {}".format(response.status_code, response.text[:100]) + ) return response.json() diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 0296126d8b094f..97dbaf2026e790 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -3,6 +3,7 @@ import json import logging import os +from pathlib import Path from typing import Any, Optional from pydantic import BaseModel @@ -11,8 +12,8 @@ class ExtensionModule(enum.Enum): - MODERATION = 'moderation' - EXTERNAL_DATA_TOOL = 'external_data_tool' + MODERATION = "moderation" + EXTERNAL_DATA_TOOL = "external_data_tool" class ModuleExtension(BaseModel): @@ -41,12 +42,12 @@ def scan_extensions(cls): position_map = {} # get the path of the current class - current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py') + current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py") current_dir_path = os.path.dirname(current_path) # traverse subdirectories for subdir_name in os.listdir(current_dir_path): - if subdir_name.startswith('__'): + if subdir_name.startswith("__"): continue subdir_path = os.path.join(current_dir_path, subdir_name) @@ -58,21 +59,20 @@ def scan_extensions(cls): # in the front-end page and business logic, there are special treatments. builtin = False position = None - if '__builtin__' in file_names: + if "__builtin__" in file_names: builtin = True - builtin_file_path = os.path.join(subdir_path, '__builtin__') + builtin_file_path = os.path.join(subdir_path, "__builtin__") if os.path.exists(builtin_file_path): - with open(builtin_file_path, encoding='utf-8') as f: - position = int(f.read().strip()) - position_map[extension_name] = position + position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip()) + position_map[extension_name] = position - if (extension_name + '.py') not in file_names: + if (extension_name + ".py") not in file_names: logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") continue # Dynamic loading {subdir_name}.py file and find the subclass of Extensible - py_path = os.path.join(subdir_path, extension_name + '.py') + py_path = os.path.join(subdir_path, extension_name + ".py") spec = importlib.util.spec_from_file_location(extension_name, py_path) if not spec or not spec.loader: raise Exception(f"Failed to load module {extension_name} from {py_path}") @@ -91,25 +91,29 @@ def scan_extensions(cls): json_data = {} if not builtin: - if 'schema.json' not in file_names: + if "schema.json" not in file_names: logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") continue - json_path = os.path.join(subdir_path, 'schema.json') + json_path = os.path.join(subdir_path, "schema.json") json_data = {} if os.path.exists(json_path): - with open(json_path, encoding='utf-8') as f: + with open(json_path, encoding="utf-8") as f: json_data = json.load(f) - extensions.append(ModuleExtension( - extension_class=extension_class, - name=extension_name, - label=json_data.get('label'), - form_schema=json_data.get('form_schema'), - builtin=builtin, - position=position - )) - - sorted_extensions = sort_to_dict_by_position_map(position_map=position_map, data=extensions, name_func=lambda x: x.name) + extensions.append( + ModuleExtension( + extension_class=extension_class, + name=extension_name, + label=json_data.get("label"), + form_schema=json_data.get("form_schema"), + builtin=builtin, + position=position, + ) + ) + + sorted_extensions = sort_to_dict_by_position_map( + position_map=position_map, data=extensions, name_func=lambda x: x.name + ) return sorted_extensions diff --git a/api/core/extension/extension.py b/api/core/extension/extension.py index 29e892c58ac550..3da170455e3398 100644 --- a/api/core/extension/extension.py +++ b/api/core/extension/extension.py @@ -6,10 +6,7 @@ class Extension: __module_extensions: dict[str, dict[str, ModuleExtension]] = {} - module_classes = { - ExtensionModule.MODERATION: Moderation, - ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool - } + module_classes = {ExtensionModule.MODERATION: Moderation, ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool} def init(self): for module, module_class in self.module_classes.items(): diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py index 58c82502ea4447..54ec97a4933a94 100644 --- a/api/core/external_data_tool/api/api.py +++ b/api/core/external_data_tool/api/api.py @@ -30,10 +30,11 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: raise ValueError("api_based_extension_id is required") # get api_based_extension - api_based_extension = db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + api_based_extension = ( + db.session.query(APIBasedExtension) + .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .first() + ) if not api_based_extension: raise ValueError("api_based_extension_id is invalid") @@ -50,47 +51,42 @@ def query(self, inputs: dict, query: Optional[str] = None) -> str: api_based_extension_id = self.config.get("api_based_extension_id") # get api_based_extension - api_based_extension = db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == self.tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + api_based_extension = ( + db.session.query(APIBasedExtension) + .filter(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id) + .first() + ) if not api_based_extension: - raise ValueError("[External data tool] API query failed, variable: {}, " - "error: api_based_extension_id is invalid" - .format(self.variable)) + raise ValueError( + "[External data tool] API query failed, variable: {}, " + "error: api_based_extension_id is invalid".format(self.variable) + ) # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=self.tenant_id, - token=api_based_extension.api_key - ) + api_key = encrypter.decrypt_token(tenant_id=self.tenant_id, token=api_based_extension.api_key) try: # request api - requestor = APIBasedExtensionRequestor( - api_endpoint=api_based_extension.api_endpoint, - api_key=api_key - ) + requestor = APIBasedExtensionRequestor(api_endpoint=api_based_extension.api_endpoint, api_key=api_key) except Exception as e: - raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format( - self.variable, - e - )) - - response_json = requestor.request(point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, params={ - 'app_id': self.app_id, - 'tool_variable': self.variable, - 'inputs': inputs, - 'query': query - }) - - if 'result' not in response_json: - raise ValueError("[External data tool] API query failed, variable: {}, error: result not found in response" - .format(self.variable)) - - if not isinstance(response_json['result'], str): - raise ValueError("[External data tool] API query failed, variable: {}, error: result is not string" - .format(self.variable)) - - return response_json['result'] + raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(self.variable, e)) + + response_json = requestor.request( + point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, + params={"app_id": self.app_id, "tool_variable": self.variable, "inputs": inputs, "query": query}, + ) + + if "result" not in response_json: + raise ValueError( + "[External data tool] API query failed, variable: {}, error: result not found in response".format( + self.variable + ) + ) + + if not isinstance(response_json["result"], str): + raise ValueError( + "[External data tool] API query failed, variable: {}, error: result is not string".format(self.variable) + ) + + return response_json["result"] diff --git a/api/core/external_data_tool/external_data_fetch.py b/api/core/external_data_tool/external_data_fetch.py index 8601cb34e79582..84b94e117ff5f9 100644 --- a/api/core/external_data_tool/external_data_fetch.py +++ b/api/core/external_data_tool/external_data_fetch.py @@ -12,11 +12,14 @@ class ExternalDataFetch: - def fetch(self, tenant_id: str, - app_id: str, - external_data_tools: list[ExternalDataVariableEntity], - inputs: dict, - query: str) -> dict: + def fetch( + self, + tenant_id: str, + app_id: str, + external_data_tools: list[ExternalDataVariableEntity], + inputs: dict, + query: str, + ) -> dict: """ Fill in variable inputs from external data tools if exists. @@ -38,7 +41,7 @@ def fetch(self, tenant_id: str, app_id, tool, inputs, - query + query, ) futures[future] = tool @@ -50,12 +53,15 @@ def fetch(self, tenant_id: str, inputs.update(results) return inputs - def _query_external_data_tool(self, flask_app: Flask, - tenant_id: str, - app_id: str, - external_data_tool: ExternalDataVariableEntity, - inputs: dict, - query: str) -> tuple[Optional[str], Optional[str]]: + def _query_external_data_tool( + self, + flask_app: Flask, + tenant_id: str, + app_id: str, + external_data_tool: ExternalDataVariableEntity, + inputs: dict, + query: str, + ) -> tuple[Optional[str], Optional[str]]: """ Query external data tool. :param flask_app: flask app @@ -72,17 +78,10 @@ def _query_external_data_tool(self, flask_app: Flask, tool_config = external_data_tool.config external_data_tool_factory = ExternalDataToolFactory( - name=tool_type, - tenant_id=tenant_id, - app_id=app_id, - variable=tool_variable, - config=tool_config + name=tool_type, tenant_id=tenant_id, app_id=app_id, variable=tool_variable, config=tool_config ) # query external data tool - result = external_data_tool_factory.query( - inputs=inputs, - query=query - ) + result = external_data_tool_factory.query(inputs=inputs, query=query) return tool_variable, result diff --git a/api/core/external_data_tool/factory.py b/api/core/external_data_tool/factory.py index 979f243af65f61..28721098594962 100644 --- a/api/core/external_data_tool/factory.py +++ b/api/core/external_data_tool/factory.py @@ -5,14 +5,10 @@ class ExternalDataToolFactory: - def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None: extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) self.__extension_instance = extension_class( - tenant_id=tenant_id, - app_id=app_id, - variable=variable, - config=config + tenant_id=tenant_id, app_id=app_id, variable=variable, config=config ) @classmethod diff --git a/api/core/file/__init__.py b/api/core/file/__init__.py index e69de29bb2d1d6..fe9e52258ac046 100644 --- a/api/core/file/__init__.py +++ b/api/core/file/__init__.py @@ -0,0 +1,19 @@ +from .constants import FILE_MODEL_IDENTITY +from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType +from .models import ( + File, + FileUploadConfig, + ImageConfig, +) + +__all__ = [ + "FileType", + "FileUploadConfig", + "FileTransferMethod", + "FileBelongsTo", + "File", + "ImageConfig", + "FileAttribute", + "ArrayFileAttribute", + "FILE_MODEL_IDENTITY", +] diff --git a/api/core/file/constants.py b/api/core/file/constants.py new file mode 100644 index 00000000000000..ce1d238e93742b --- /dev/null +++ b/api/core/file/constants.py @@ -0,0 +1 @@ +FILE_MODEL_IDENTITY = "__dify__file__" diff --git a/api/core/file/enums.py b/api/core/file/enums.py new file mode 100644 index 00000000000000..f4153f1676b620 --- /dev/null +++ b/api/core/file/enums.py @@ -0,0 +1,55 @@ +from enum import Enum + + +class FileType(str, Enum): + IMAGE = "image" + DOCUMENT = "document" + AUDIO = "audio" + VIDEO = "video" + CUSTOM = "custom" + + @staticmethod + def value_of(value): + for member in FileType: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class FileTransferMethod(str, Enum): + REMOTE_URL = "remote_url" + LOCAL_FILE = "local_file" + TOOL_FILE = "tool_file" + + @staticmethod + def value_of(value): + for member in FileTransferMethod: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class FileBelongsTo(str, Enum): + USER = "user" + ASSISTANT = "assistant" + + @staticmethod + def value_of(value): + for member in FileBelongsTo: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class FileAttribute(str, Enum): + TYPE = "type" + SIZE = "size" + NAME = "name" + MIME_TYPE = "mime_type" + TRANSFER_METHOD = "transfer_method" + URL = "url" + EXTENSION = "extension" + + +class ArrayFileAttribute(str, Enum): + LENGTH = "length" diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py new file mode 100644 index 00000000000000..eb260a8f84fbbd --- /dev/null +++ b/api/core/file/file_manager.py @@ -0,0 +1,170 @@ +import base64 + +from configs import dify_config +from core.file import file_repository +from core.helper import ssrf_proxy +from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent +from extensions.ext_database import db +from extensions.ext_storage import storage + +from . import helpers +from .enums import FileAttribute +from .models import File, FileTransferMethod, FileType +from .tool_file_parser import ToolFileParser + + +def get_attr(*, file: File, attr: FileAttribute): + match attr: + case FileAttribute.TYPE: + return file.type.value + case FileAttribute.SIZE: + return file.size + case FileAttribute.NAME: + return file.filename + case FileAttribute.MIME_TYPE: + return file.mime_type + case FileAttribute.TRANSFER_METHOD: + return file.transfer_method.value + case FileAttribute.URL: + return file.remote_url + case FileAttribute.EXTENSION: + return file.extension + case _: + raise ValueError(f"Invalid file attribute: {attr}") + + +def to_prompt_message_content( + f: File, + /, + *, + image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW, +): + """ + Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object. + + This function takes a File object and converts it to an appropriate PromptMessageContent + object, which can be used as a prompt for image or audio-based AI models. + + Args: + f (File): The File object to convert. + detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts. + If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW. + + Returns: + Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level + + Raises: + ValueError: If the file type is not supported or if required data is missing. + """ + match f.type: + case FileType.IMAGE: + if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url": + data = _to_url(f) + else: + data = _to_base64_data_string(f) + + return ImagePromptMessageContent(data=data, detail=image_detail_config) + case FileType.AUDIO: + encoded_string = _file_to_encoded_string(f) + if f.extension is None: + raise ValueError("Missing file extension") + return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) + case FileType.VIDEO: + if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url": + data = _to_url(f) + else: + data = _to_base64_data_string(f) + return VideoPromptMessageContent(data=data, format=f.extension.lstrip(".")) + case _: + raise ValueError("file type f.type is not supported") + + +def download(f: File, /): + if f.transfer_method == FileTransferMethod.TOOL_FILE: + tool_file = file_repository.get_tool_file(session=db.session(), file=f) + return _download_file_content(tool_file.file_key) + elif f.transfer_method == FileTransferMethod.LOCAL_FILE: + upload_file = file_repository.get_upload_file(session=db.session(), file=f) + return _download_file_content(upload_file.key) + # remote file + response = ssrf_proxy.get(f.remote_url, follow_redirects=True) + response.raise_for_status() + return response.content + + +def _download_file_content(path: str, /): + """ + Download and return the contents of a file as bytes. + + This function loads the file from storage and ensures it's in bytes format. + + Args: + path (str): The path to the file in storage. + + Returns: + bytes: The contents of the file as a bytes object. + + Raises: + ValueError: If the loaded file is not a bytes object. + """ + data = storage.load(path, stream=False) + if not isinstance(data, bytes): + raise ValueError(f"file {path} is not a bytes object") + return data + + +def _get_encoded_string(f: File, /): + match f.transfer_method: + case FileTransferMethod.REMOTE_URL: + response = ssrf_proxy.get(f.remote_url, follow_redirects=True) + response.raise_for_status() + content = response.content + encoded_string = base64.b64encode(content).decode("utf-8") + return encoded_string + case FileTransferMethod.LOCAL_FILE: + upload_file = file_repository.get_upload_file(session=db.session(), file=f) + data = _download_file_content(upload_file.key) + encoded_string = base64.b64encode(data).decode("utf-8") + return encoded_string + case FileTransferMethod.TOOL_FILE: + tool_file = file_repository.get_tool_file(session=db.session(), file=f) + data = _download_file_content(tool_file.file_key) + encoded_string = base64.b64encode(data).decode("utf-8") + return encoded_string + case _: + raise ValueError(f"Unsupported transfer method: {f.transfer_method}") + + +def _to_base64_data_string(f: File, /): + encoded_string = _get_encoded_string(f) + return f"data:{f.mime_type};base64,{encoded_string}" + + +def _file_to_encoded_string(f: File, /): + match f.type: + case FileType.IMAGE: + return _to_base64_data_string(f) + case FileType.VIDEO: + return _to_base64_data_string(f) + case FileType.AUDIO: + return _get_encoded_string(f) + case _: + raise ValueError(f"file type {f.type} is not supported") + + +def _to_url(f: File, /): + if f.transfer_method == FileTransferMethod.REMOTE_URL: + if f.remote_url is None: + raise ValueError("Missing file remote_url") + return f.remote_url + elif f.transfer_method == FileTransferMethod.LOCAL_FILE: + if f.related_id is None: + raise ValueError("Missing file related_id") + return helpers.get_signed_file_url(upload_file_id=f.related_id) + elif f.transfer_method == FileTransferMethod.TOOL_FILE: + # add sign url + if f.related_id is None or f.extension is None: + raise ValueError("Missing file related_id or extension") + return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension) + else: + raise ValueError(f"Unsupported transfer method: {f.transfer_method}") diff --git a/api/core/file/file_obj.py b/api/core/file/file_obj.py deleted file mode 100644 index 3959f4b4a0bb61..00000000000000 --- a/api/core/file/file_obj.py +++ /dev/null @@ -1,142 +0,0 @@ -import enum -from typing import Any, Optional - -from pydantic import BaseModel - -from core.file.tool_file_parser import ToolFileParser -from core.file.upload_file_parser import UploadFileParser -from core.model_runtime.entities.message_entities import ImagePromptMessageContent -from extensions.ext_database import db - - -class FileExtraConfig(BaseModel): - """ - File Upload Entity. - """ - image_config: Optional[dict[str, Any]] = None - - -class FileType(enum.Enum): - IMAGE = 'image' - - @staticmethod - def value_of(value): - for member in FileType: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileTransferMethod(enum.Enum): - REMOTE_URL = 'remote_url' - LOCAL_FILE = 'local_file' - TOOL_FILE = 'tool_file' - - @staticmethod - def value_of(value): - for member in FileTransferMethod: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - -class FileBelongsTo(enum.Enum): - USER = 'user' - ASSISTANT = 'assistant' - - @staticmethod - def value_of(value): - for member in FileBelongsTo: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileVar(BaseModel): - id: Optional[str] = None # message file id - tenant_id: str - type: FileType - transfer_method: FileTransferMethod - url: Optional[str] = None # remote url - related_id: Optional[str] = None - extra_config: Optional[FileExtraConfig] = None - filename: Optional[str] = None - extension: Optional[str] = None - mime_type: Optional[str] = None - - def to_dict(self) -> dict: - return { - '__variant': self.__class__.__name__, - 'tenant_id': self.tenant_id, - 'type': self.type.value, - 'transfer_method': self.transfer_method.value, - 'url': self.preview_url, - 'remote_url': self.url, - 'related_id': self.related_id, - 'filename': self.filename, - 'extension': self.extension, - 'mime_type': self.mime_type, - } - - def to_markdown(self) -> str: - """ - Convert file to markdown - :return: - """ - preview_url = self.preview_url - if self.type == FileType.IMAGE: - text = f'![{self.filename or ""}]({preview_url})' - else: - text = f'[{self.filename or preview_url}]({preview_url})' - - return text - - @property - def data(self) -> Optional[str]: - """ - Get image data, file signed url or base64 data - depending on config MULTIMODAL_SEND_IMAGE_FORMAT - :return: - """ - return self._get_data() - - @property - def preview_url(self) -> Optional[str]: - """ - Get signed preview url - :return: - """ - return self._get_data(force_url=True) - - @property - def prompt_message_content(self) -> ImagePromptMessageContent: - if self.type == FileType.IMAGE: - image_config = self.extra_config.image_config - - return ImagePromptMessageContent( - data=self.data, - detail=ImagePromptMessageContent.DETAIL.HIGH - if image_config.get("detail") == "high" else ImagePromptMessageContent.DETAIL.LOW - ) - - def _get_data(self, force_url: bool = False) -> Optional[str]: - from models.model import UploadFile - if self.type == FileType.IMAGE: - if self.transfer_method == FileTransferMethod.REMOTE_URL: - return self.url - elif self.transfer_method == FileTransferMethod.LOCAL_FILE: - upload_file = (db.session.query(UploadFile) - .filter( - UploadFile.id == self.related_id, - UploadFile.tenant_id == self.tenant_id - ).first()) - - return UploadFileParser.get_image_data( - upload_file=upload_file, - force_url=force_url - ) - elif self.transfer_method == FileTransferMethod.TOOL_FILE: - extension = self.extension - # add sign url - return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=self.related_id, extension=extension) - - return None diff --git a/api/core/file/file_repository.py b/api/core/file/file_repository.py new file mode 100644 index 00000000000000..975e1e72db0e0a --- /dev/null +++ b/api/core/file/file_repository.py @@ -0,0 +1,32 @@ +from sqlalchemy import select +from sqlalchemy.orm import Session + +from models import ToolFile, UploadFile + +from .models import File + + +def get_upload_file(*, session: Session, file: File): + if file.related_id is None: + raise ValueError("Missing file related_id") + stmt = select(UploadFile).filter( + UploadFile.id == file.related_id, + UploadFile.tenant_id == file.tenant_id, + ) + record = session.scalar(stmt) + if not record: + raise ValueError(f"upload file {file.related_id} not found") + return record + + +def get_tool_file(*, session: Session, file: File): + if file.related_id is None: + raise ValueError("Missing file related_id") + stmt = select(ToolFile).filter( + ToolFile.id == file.related_id, + ToolFile.tenant_id == file.tenant_id, + ) + record = session.scalar(stmt) + if not record: + raise ValueError(f"tool file {file.related_id} not found") + return record diff --git a/api/core/file/helpers.py b/api/core/file/helpers.py new file mode 100644 index 00000000000000..12123cf3f74630 --- /dev/null +++ b/api/core/file/helpers.py @@ -0,0 +1,48 @@ +import base64 +import hashlib +import hmac +import os +import time + +from configs import dify_config + + +def get_signed_file_url(upload_file_id: str) -> str: + url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview" + + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + key = dify_config.SECRET_KEY.encode() + msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" + sign = hmac.new(key, msg.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + + +def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + # verify signature + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT + + +def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + # verify signature + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py deleted file mode 100644 index 085ff07cfde921..00000000000000 --- a/api/core/file/message_file_parser.py +++ /dev/null @@ -1,220 +0,0 @@ -import re -from collections.abc import Mapping, Sequence -from typing import Any, Union -from urllib.parse import parse_qs, urlparse - -import requests - -from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar -from extensions.ext_database import db -from models.account import Account -from models.model import EndUser, MessageFile, UploadFile -from services.file_service import IMAGE_EXTENSIONS - - -class MessageFileParser: - - def __init__(self, tenant_id: str, app_id: str) -> None: - self.tenant_id = tenant_id - self.app_id = app_id - - def validate_and_transform_files_arg(self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, - user: Union[Account, EndUser]) -> list[FileVar]: - """ - validate and transform files arg - - :param files: - :param file_extra_config: - :param user: - :return: - """ - for file in files: - if not isinstance(file, dict): - raise ValueError('Invalid file format, must be dict') - if not file.get('type'): - raise ValueError('Missing file type') - FileType.value_of(file.get('type')) - if not file.get('transfer_method'): - raise ValueError('Missing file transfer method') - FileTransferMethod.value_of(file.get('transfer_method')) - if file.get('transfer_method') == FileTransferMethod.REMOTE_URL.value: - if not file.get('url'): - raise ValueError('Missing file url') - if not file.get('url').startswith('http'): - raise ValueError('Invalid file url') - if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'): - raise ValueError('Missing file upload_file_id') - if file.get('transform_method') == FileTransferMethod.TOOL_FILE.value and not file.get('tool_file_id'): - raise ValueError('Missing file tool_file_id') - - # transform files to file objs - type_file_objs = self._to_file_objs(files, file_extra_config) - - # validate files - new_files = [] - for file_type, file_objs in type_file_objs.items(): - if file_type == FileType.IMAGE: - # parse and validate files - image_config = file_extra_config.image_config - - # check if image file feature is enabled - if not image_config: - continue - - # Validate number of files - if len(files) > image_config['number_limits']: - raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}") - - for file_obj in file_objs: - # Validate transfer method - if file_obj.transfer_method.value not in image_config['transfer_methods']: - raise ValueError(f'Invalid transfer method: {file_obj.transfer_method.value}') - - # Validate file type - if file_obj.type != FileType.IMAGE: - raise ValueError(f'Invalid file type: {file_obj.type}') - - if file_obj.transfer_method == FileTransferMethod.REMOTE_URL: - # check remote url valid and is image - result, error = self._check_image_remote_url(file_obj.url) - if result is False: - raise ValueError(error) - elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE: - # get upload file from upload_file_id - upload_file = (db.session.query(UploadFile) - .filter( - UploadFile.id == file_obj.related_id, - UploadFile.tenant_id == self.tenant_id, - UploadFile.created_by == user.id, - UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - UploadFile.extension.in_(IMAGE_EXTENSIONS) - ).first()) - - # check upload file is belong to tenant and user - if not upload_file: - raise ValueError('Invalid upload file') - - new_files.append(file_obj) - - # return all file objs - return new_files - - def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig): - """ - transform message files - - :param files: - :param file_extra_config: - :return: - """ - # transform files to file objs - type_file_objs = self._to_file_objs(files, file_extra_config) - - # return all file objs - return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] - - def _to_file_objs(self, files: list[Union[dict, MessageFile]], - file_extra_config: FileExtraConfig) -> dict[FileType, list[FileVar]]: - """ - transform files to file objs - - :param files: - :param file_extra_config: - :return: - """ - type_file_objs: dict[FileType, list[FileVar]] = { - # Currently only support image - FileType.IMAGE: [] - } - - if not files: - return type_file_objs - - # group by file type and convert file args or message files to FileObj - for file in files: - if isinstance(file, MessageFile): - if file.belongs_to == FileBelongsTo.ASSISTANT.value: - continue - - file_obj = self._to_file_obj(file, file_extra_config) - if file_obj.type not in type_file_objs: - continue - - type_file_objs[file_obj.type].append(file_obj) - - return type_file_objs - - def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig): - """ - transform file to file obj - - :param file: - :return: - """ - if isinstance(file, dict): - transfer_method = FileTransferMethod.value_of(file.get('transfer_method')) - if transfer_method != FileTransferMethod.TOOL_FILE: - return FileVar( - tenant_id=self.tenant_id, - type=FileType.value_of(file.get('type')), - transfer_method=transfer_method, - url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, - related_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, - extra_config=file_extra_config - ) - return FileVar( - tenant_id=self.tenant_id, - type=FileType.value_of(file.get('type')), - transfer_method=transfer_method, - url=None, - related_id=file.get('tool_file_id'), - extra_config=file_extra_config - ) - else: - return FileVar( - id=file.id, - tenant_id=self.tenant_id, - type=FileType.value_of(file.type), - transfer_method=FileTransferMethod.value_of(file.transfer_method), - url=file.url, - related_id=file.upload_file_id or None, - extra_config=file_extra_config - ) - - def _check_image_remote_url(self, url): - try: - headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" - } - - def is_s3_presigned_url(url): - try: - parsed_url = urlparse(url) - if 'amazonaws.com' not in parsed_url.netloc: - return False - query_params = parse_qs(parsed_url.query) - required_params = ['Signature', 'Expires'] - for param in required_params: - if param not in query_params: - return False - if not query_params['Expires'][0].isdigit(): - return False - signature = query_params['Signature'][0] - if not re.match(r'^[A-Za-z0-9+/]+={0,2}$', signature): - return False - return True - except Exception: - return False - - if is_s3_presigned_url(url): - response = requests.get(url, headers=headers, allow_redirects=True) - if response.status_code in {200, 304}: - return True, "" - - response = requests.head(url, headers=headers, allow_redirects=True) - if response.status_code in {200, 304}: - return True, "" - else: - return False, "URL does not exist." - except requests.RequestException as e: - return False, f"Error checking URL: {e}" diff --git a/api/core/file/models.py b/api/core/file/models.py new file mode 100644 index 00000000000000..0142893787e073 --- /dev/null +++ b/api/core/file/models.py @@ -0,0 +1,109 @@ +from collections.abc import Mapping, Sequence +from typing import Optional + +from pydantic import BaseModel, Field, model_validator + +from core.model_runtime.entities.message_entities import ImagePromptMessageContent + +from . import helpers +from .constants import FILE_MODEL_IDENTITY +from .enums import FileTransferMethod, FileType +from .tool_file_parser import ToolFileParser + + +class ImageConfig(BaseModel): + """ + NOTE: This part of validation is deprecated, but still used in app features "Image Upload". + """ + + number_limits: int = 0 + transfer_methods: Sequence[FileTransferMethod] = Field(default_factory=list) + detail: ImagePromptMessageContent.DETAIL | None = None + + +class FileUploadConfig(BaseModel): + """ + File Upload Entity. + """ + + image_config: Optional[ImageConfig] = None + allowed_file_types: Sequence[FileType] = Field(default_factory=list) + allowed_extensions: Sequence[str] = Field(default_factory=list) + allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) + number_limits: int = 0 + + +class File(BaseModel): + dify_model_identity: str = FILE_MODEL_IDENTITY + + id: Optional[str] = None # message file id + tenant_id: str + type: FileType + transfer_method: FileTransferMethod + remote_url: Optional[str] = None # remote url + related_id: Optional[str] = None + filename: Optional[str] = None + extension: Optional[str] = Field(default=None, description="File extension, should contains dot") + mime_type: Optional[str] = None + size: int = -1 + + def to_dict(self) -> Mapping[str, str | int | None]: + data = self.model_dump(mode="json") + return { + **data, + "url": self.generate_url(), + } + + @property + def markdown(self) -> str: + url = self.generate_url() + if self.type == FileType.IMAGE: + text = f'![{self.filename or ""}]({url})' + else: + text = f"[{self.filename or url}]({url})" + + return text + + def generate_url(self) -> Optional[str]: + if self.type == FileType.IMAGE: + if self.transfer_method == FileTransferMethod.REMOTE_URL: + return self.remote_url + elif self.transfer_method == FileTransferMethod.LOCAL_FILE: + if self.related_id is None: + raise ValueError("Missing file related_id") + return helpers.get_signed_file_url(upload_file_id=self.related_id) + elif self.transfer_method == FileTransferMethod.TOOL_FILE: + assert self.related_id is not None + assert self.extension is not None + return ToolFileParser.get_tool_file_manager().sign_file( + tool_file_id=self.related_id, extension=self.extension + ) + else: + if self.transfer_method == FileTransferMethod.REMOTE_URL: + return self.remote_url + elif self.transfer_method == FileTransferMethod.LOCAL_FILE: + if self.related_id is None: + raise ValueError("Missing file related_id") + return helpers.get_signed_file_url(upload_file_id=self.related_id) + elif self.transfer_method == FileTransferMethod.TOOL_FILE: + assert self.related_id is not None + assert self.extension is not None + return ToolFileParser.get_tool_file_manager().sign_file( + tool_file_id=self.related_id, extension=self.extension + ) + + @model_validator(mode="after") + def validate_after(self): + match self.transfer_method: + case FileTransferMethod.REMOTE_URL: + if not self.remote_url: + raise ValueError("Missing file url") + if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"): + raise ValueError("Invalid file url") + case FileTransferMethod.LOCAL_FILE: + if not self.related_id: + raise ValueError("Missing file related_id") + case FileTransferMethod.TOOL_FILE: + if not self.related_id: + raise ValueError("Missing file related_id") + return self diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py index ea8605ac577e3a..a17b7be3675ab1 100644 --- a/api/core/file/tool_file_parser.py +++ b/api/core/file/tool_file_parser.py @@ -1,8 +1,12 @@ -tool_file_manager = { - 'manager': None -} +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from core.tools.tool_file_manager import ToolFileManager + +tool_file_manager: dict[str, Any] = {"manager": None} + class ToolFileParser: @staticmethod - def get_tool_file_manager() -> 'ToolFileManager': - return tool_file_manager['manager'] \ No newline at end of file + def get_tool_file_manager() -> "ToolFileManager": + return tool_file_manager["manager"] diff --git a/api/core/file/upload_file_parser.py b/api/core/file/upload_file_parser.py deleted file mode 100644 index 737a11e426c745..00000000000000 --- a/api/core/file/upload_file_parser.py +++ /dev/null @@ -1,79 +0,0 @@ -import base64 -import hashlib -import hmac -import logging -import os -import time -from typing import Optional - -from configs import dify_config -from extensions.ext_storage import storage - -IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg'] -IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) - - -class UploadFileParser: - @classmethod - def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]: - if not upload_file: - return None - - if upload_file.extension not in IMAGE_EXTENSIONS: - return None - - if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == 'url' or force_url: - return cls.get_signed_temp_image_url(upload_file.id) - else: - # get image file base64 - try: - data = storage.load(upload_file.key) - except FileNotFoundError: - logging.error(f'File not found: {upload_file.key}') - return None - - encoded_string = base64.b64encode(data).decode('utf-8') - return f'data:{upload_file.mime_type};base64,{encoded_string}' - - @classmethod - def get_signed_temp_image_url(cls, upload_file_id) -> str: - """ - get signed url from upload file - - :param upload_file: UploadFile object - :return: - """ - base_url = dify_config.FILES_URL - image_preview_url = f'{base_url}/files/{upload_file_id}/image-preview' - - timestamp = str(int(time.time())) - nonce = os.urandom(16).hex() - data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() - sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - - return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" - - @classmethod - def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - """ - verify signature - - :param upload_file_id: file id - :param timestamp: timestamp - :param nonce: nonce - :param sign: signature - :return: - """ - data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - # verify signature - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index afb2bbbbf317cd..4932284540738a 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -1,15 +1,13 @@ import logging -import time from enum import Enum from threading import Lock -from typing import Literal, Optional +from typing import Optional -from httpx import get, post +from httpx import Timeout, post from pydantic import BaseModel from yarl import URL from configs import dify_config -from core.helper.code_executor.entities import CodeDependency from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer @@ -17,15 +15,11 @@ logger = logging.getLogger(__name__) -# Code Executor -CODE_EXECUTION_ENDPOINT = dify_config.CODE_EXECUTION_ENDPOINT -CODE_EXECUTION_API_KEY = dify_config.CODE_EXECUTION_API_KEY -CODE_EXECUTION_TIMEOUT = (10, 60) - -class CodeExecutionException(Exception): +class CodeExecutionError(Exception): pass + class CodeExecutionResponse(BaseModel): class Data(BaseModel): stdout: Optional[str] = None @@ -37,9 +31,9 @@ class Data(BaseModel): class CodeLanguage(str, Enum): - PYTHON3 = 'python3' - JINJA2 = 'jinja2' - JAVASCRIPT = 'javascript' + PYTHON3 = "python3" + JINJA2 = "jinja2" + JAVASCRIPT = "javascript" class CodeExecutor: @@ -53,73 +47,77 @@ class CodeExecutor: } code_language_to_running_language = { - CodeLanguage.JAVASCRIPT: 'nodejs', + CodeLanguage.JAVASCRIPT: "nodejs", CodeLanguage.JINJA2: CodeLanguage.PYTHON3, CodeLanguage.PYTHON3: CodeLanguage.PYTHON3, } - supported_dependencies_languages: set[CodeLanguage] = { - CodeLanguage.PYTHON3 - } + supported_dependencies_languages: set[CodeLanguage] = {CodeLanguage.PYTHON3} @classmethod - def execute_code(cls, - language: CodeLanguage, - preload: str, - code: str, - dependencies: Optional[list[CodeDependency]] = None) -> str: + def execute_code(cls, language: CodeLanguage, preload: str, code: str) -> str: """ Execute code :param language: code language :param code: code :return: """ - url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'run' + url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / "v1" / "sandbox" / "run" - headers = { - 'X-Api-Key': CODE_EXECUTION_API_KEY - } + headers = {"X-Api-Key": dify_config.CODE_EXECUTION_API_KEY} data = { - 'language': cls.code_language_to_running_language.get(language), - 'code': code, - 'preload': preload, - 'enable_network': True + "language": cls.code_language_to_running_language.get(language), + "code": code, + "preload": preload, + "enable_network": True, } - if dependencies: - data['dependencies'] = [dependency.model_dump() for dependency in dependencies] - try: - response = post(str(url), json=data, headers=headers, timeout=CODE_EXECUTION_TIMEOUT) + response = post( + str(url), + json=data, + headers=headers, + timeout=Timeout( + connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT, + read=dify_config.CODE_EXECUTION_READ_TIMEOUT, + write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT, + pool=None, + ), + ) if response.status_code == 503: - raise CodeExecutionException('Code execution service is unavailable') + raise CodeExecutionError("Code execution service is unavailable") elif response.status_code != 200: - raise Exception(f'Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running') - except CodeExecutionException as e: + raise Exception( + f"Failed to execute code, got status code {response.status_code}," + f" please check if the sandbox service is running" + ) + except CodeExecutionError as e: raise e except Exception as e: - raise CodeExecutionException('Failed to execute code, which is likely a network issue,' - ' please check if the sandbox service is running.' - f' ( Error: {str(e)} )') - + raise CodeExecutionError( + "Failed to execute code, which is likely a network issue," + " please check if the sandbox service is running." + f" ( Error: {str(e)} )" + ) + try: response = response.json() except: - raise CodeExecutionException('Failed to parse response') + raise CodeExecutionError("Failed to parse response") + + if (code := response.get("code")) != 0: + raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response.get('message')}") - if (code := response.get('code')) != 0: - raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}") - response = CodeExecutionResponse(**response) - + if response.data.error: - raise CodeExecutionException(response.data.error) - - return response.data.stdout + raise CodeExecutionError(response.data.error) + + return response.data.stdout or "" @classmethod - def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict: + def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: dict) -> dict: """ Execute code :param language: code language @@ -129,69 +127,13 @@ def execute_workflow_code_template(cls, language: CodeLanguage, code: str, input """ template_transformer = cls.code_template_transformers.get(language) if not template_transformer: - raise CodeExecutionException(f'Unsupported language {language}') + raise CodeExecutionError(f"Unsupported language {language}") - runner, preload, dependencies = template_transformer.transform_caller(code, inputs, dependencies) + runner, preload = template_transformer.transform_caller(code, inputs) try: - response = cls.execute_code(language, preload, runner, dependencies) - except CodeExecutionException as e: + response = cls.execute_code(language, preload, runner) + except CodeExecutionError as e: raise e return template_transformer.transform_response(response) - - @classmethod - def list_dependencies(cls, language: str) -> list[CodeDependency]: - if language not in cls.supported_dependencies_languages: - return [] - - with cls.dependencies_cache_lock: - if language in cls.dependencies_cache: - # check expiration - dependencies = cls.dependencies_cache[language] - if dependencies['expiration'] > time.time(): - return dependencies['data'] - # remove expired cache - del cls.dependencies_cache[language] - - dependencies = cls._get_dependencies(language) - with cls.dependencies_cache_lock: - cls.dependencies_cache[language] = { - 'data': dependencies, - 'expiration': time.time() + 60 - } - - return dependencies - - @classmethod - def _get_dependencies(cls, language: Literal['python3']) -> list[CodeDependency]: - """ - List dependencies - """ - url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'dependencies' - - headers = { - 'X-Api-Key': CODE_EXECUTION_API_KEY - } - - running_language = cls.code_language_to_running_language.get(language) - if isinstance(running_language, Enum): - running_language = running_language.value - - data = { - 'language': running_language, - } - - try: - response = get(str(url), params=data, headers=headers, timeout=CODE_EXECUTION_TIMEOUT) - if response.status_code != 200: - raise Exception(f'Failed to list dependencies, got status code {response.status_code}, please check if the sandbox service is running') - response = response.json() - dependencies = response.get('data', {}).get('dependencies', []) - return [ - CodeDependency(**dependency) for dependency in dependencies - if dependency.get('name') not in Python3TemplateTransformer.get_standard_packages() - ] - except Exception as e: - logger.exception(f'Failed to list dependencies: {e}') - return [] \ No newline at end of file diff --git a/api/core/helper/code_executor/code_node_provider.py b/api/core/helper/code_executor/code_node_provider.py index 761c0e2b2524fa..e233a596b9da0e 100644 --- a/api/core/helper/code_executor/code_node_provider.py +++ b/api/core/helper/code_executor/code_node_provider.py @@ -2,8 +2,6 @@ from pydantic import BaseModel -from core.helper.code_executor.code_executor import CodeExecutor - class CodeNodeProvider(BaseModel): @staticmethod @@ -23,33 +21,14 @@ def get_default_code(cls) -> str: """ pass - @classmethod - def get_default_available_packages(cls) -> list[dict]: - return [p.model_dump() for p in CodeExecutor.list_dependencies(cls.get_language())] - @classmethod def get_default_config(cls) -> dict: return { "type": "code", "config": { - "variables": [ - { - "variable": "arg1", - "value_selector": [] - }, - { - "variable": "arg2", - "value_selector": [] - } - ], + "variables": [{"variable": "arg1", "value_selector": []}, {"variable": "arg2", "value_selector": []}], "code_language": cls.get_language(), "code": cls.get_default_code(), - "outputs": { - "result": { - "type": "string", - "children": None - } - } + "outputs": {"result": {"type": "string", "children": None}}, }, - "available_dependencies": cls.get_default_available_packages(), } diff --git a/api/core/helper/code_executor/entities.py b/api/core/helper/code_executor/entities.py deleted file mode 100644 index cc10288521ad6b..00000000000000 --- a/api/core/helper/code_executor/entities.py +++ /dev/null @@ -1,6 +0,0 @@ -from pydantic import BaseModel - - -class CodeDependency(BaseModel): - name: str - version: str diff --git a/api/core/helper/code_executor/javascript/javascript_code_provider.py b/api/core/helper/code_executor/javascript/javascript_code_provider.py index a157fcc6d147cd..ae324b83a95124 100644 --- a/api/core/helper/code_executor/javascript/javascript_code_provider.py +++ b/api/core/helper/code_executor/javascript/javascript_code_provider.py @@ -18,4 +18,5 @@ def get_default_code(cls) -> str: result: arg1 + arg2 } } - """) + """ + ) diff --git a/api/core/helper/code_executor/javascript/javascript_transformer.py b/api/core/helper/code_executor/javascript/javascript_transformer.py index a4d2551972e3a4..d67a0903aa4d4c 100644 --- a/api/core/helper/code_executor/javascript/javascript_transformer.py +++ b/api/core/helper/code_executor/javascript/javascript_transformer.py @@ -21,5 +21,6 @@ def get_runner_script(cls) -> str: var output_json = JSON.stringify(output_obj) var result = `<>${{output_json}}<>` console.log(result) - """) + """ + ) return runner_script diff --git a/api/core/helper/code_executor/jinja2/jinja2_formatter.py b/api/core/helper/code_executor/jinja2/jinja2_formatter.py index 63f48a56e2a47f..db2eb5ebb6b19a 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_formatter.py +++ b/api/core/helper/code_executor/jinja2/jinja2_formatter.py @@ -3,15 +3,13 @@ class Jinja2Formatter: @classmethod - def format(cls, template: str, inputs: str) -> str: + def format(cls, template: str, inputs: dict) -> str: """ Format template :param template: template :param inputs: inputs :return: """ - result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, code=template, inputs=inputs - ) + result = CodeExecutor.execute_workflow_code_template(language=CodeLanguage.JINJA2, code=template, inputs=inputs) - return result['result'] + return result["result"] diff --git a/api/core/helper/code_executor/jinja2/jinja2_transformer.py b/api/core/helper/code_executor/jinja2/jinja2_transformer.py index a8f8095d52b6d5..63d58edbc794e9 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_transformer.py +++ b/api/core/helper/code_executor/jinja2/jinja2_transformer.py @@ -1,14 +1,9 @@ from textwrap import dedent -from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer from core.helper.code_executor.template_transformer import TemplateTransformer class Jinja2TemplateTransformer(TemplateTransformer): - @classmethod - def get_standard_packages(cls) -> set[str]: - return {'jinja2'} | Python3TemplateTransformer.get_standard_packages() - @classmethod def transform_response(cls, response: str) -> dict: """ @@ -16,9 +11,7 @@ def transform_response(cls, response: str) -> dict: :param response: response :return: """ - return { - 'result': cls.extract_result_str_from_response(response) - } + return {"result": cls.extract_result_str_from_response(response)} @classmethod def get_runner_script(cls) -> str: diff --git a/api/core/helper/code_executor/python3/python3_code_provider.py b/api/core/helper/code_executor/python3/python3_code_provider.py index efcb8a9d1ebccf..9cca8af7c698bc 100644 --- a/api/core/helper/code_executor/python3/python3_code_provider.py +++ b/api/core/helper/code_executor/python3/python3_code_provider.py @@ -13,8 +13,9 @@ def get_language() -> str: def get_default_code(cls) -> str: return dedent( """ - def main(arg1: int, arg2: int) -> dict: + def main(arg1: str, arg2: str) -> dict: return { "result": arg1 + arg2, } - """) + """ + ) diff --git a/api/core/helper/code_executor/python3/python3_transformer.py b/api/core/helper/code_executor/python3/python3_transformer.py index 4a5fa3509325d3..75a5a44d086c3c 100644 --- a/api/core/helper/code_executor/python3/python3_transformer.py +++ b/api/core/helper/code_executor/python3/python3_transformer.py @@ -4,30 +4,6 @@ class Python3TemplateTransformer(TemplateTransformer): - @classmethod - def get_standard_packages(cls) -> set[str]: - return { - 'base64', - 'binascii', - 'collections', - 'datetime', - 'functools', - 'hashlib', - 'hmac', - 'itertools', - 'json', - 'math', - 'operator', - 'os', - 'random', - 're', - 'string', - 'sys', - 'time', - 'traceback', - 'uuid', - } - @classmethod def get_runner_script(cls) -> str: runner_script = dedent(f""" diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index da7ef469d91abf..6f016f27bc874d 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -2,23 +2,15 @@ import re from abc import ABC, abstractmethod from base64 import b64encode -from typing import Optional - -from core.helper.code_executor.entities import CodeDependency class TemplateTransformer(ABC): - _code_placeholder: str = '{{code}}' - _inputs_placeholder: str = '{{inputs}}' - _result_tag: str = '<>' - - @classmethod - def get_standard_packages(cls) -> set[str]: - return set() + _code_placeholder: str = "{{code}}" + _inputs_placeholder: str = "{{inputs}}" + _result_tag: str = "<>" @classmethod - def transform_caller(cls, code: str, inputs: dict, - dependencies: Optional[list[CodeDependency]] = None) -> tuple[str, str, list[CodeDependency]]: + def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]: """ Transform code to python runner :param code: code @@ -28,20 +20,13 @@ def transform_caller(cls, code: str, inputs: dict, runner_script = cls.assemble_runner_script(code, inputs) preload_script = cls.get_preload_script() - packages = dependencies or [] - standard_packages = cls.get_standard_packages() - for package in standard_packages: - if package not in packages: - packages.append(CodeDependency(name=package, version='')) - packages = list({dep.name: dep for dep in packages if dep.name}.values()) - - return runner_script, preload_script, packages + return runner_script, preload_script @classmethod def extract_result_str_from_response(cls, response: str) -> str: - result = re.search(rf'{cls._result_tag}(.*){cls._result_tag}', response, re.DOTALL) + result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL) if not result: - raise ValueError('Failed to parse result') + raise ValueError("Failed to parse result") result = result.group(1) return result @@ -65,7 +50,7 @@ def get_runner_script(cls) -> str: @classmethod def serialize_inputs(cls, inputs: dict) -> str: inputs_json_str = json.dumps(inputs, ensure_ascii=False).encode() - input_base64_encoded = b64encode(inputs_json_str).decode('utf-8') + input_base64_encoded = b64encode(inputs_json_str).decode("utf-8") return input_base64_encoded @classmethod @@ -82,4 +67,4 @@ def get_preload_script(cls) -> str: """ Get preload script """ - return '' + return "" diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index 5e5deb86b47e54..96341a1b780a80 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -8,14 +8,15 @@ def obfuscated_token(token: str): if not token: return token if len(token) <= 8: - return '*' * 20 - return token[:6] + '*' * 12 + token[-2:] + return "*" * 20 + return token[:6] + "*" * 12 + token[-2:] def encrypt_token(tenant_id: str, token: str): from models.account import Tenant + if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()): - raise ValueError(f'Tenant with id {tenant_id} not found') + raise ValueError(f"Tenant with id {tenant_id} not found") encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) return base64.b64encode(encrypted_token).decode() diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py index 29cb4acc7d03c2..5e274f8916869d 100644 --- a/api/core/helper/model_provider_cache.py +++ b/api/core/helper/model_provider_cache.py @@ -25,7 +25,7 @@ def get(self) -> Optional[dict]: cached_provider_credentials = redis_client.get(self.cache_key) if cached_provider_credentials: try: - cached_provider_credentials = cached_provider_credentials.decode('utf-8') + cached_provider_credentials = cached_provider_credentials.decode("utf-8") cached_provider_credentials = json.loads(cached_provider_credentials) except JSONDecodeError: return None diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index 20feae8554f79d..b880590de28476 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -12,19 +12,20 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) -> bool: moderation_config = hosting_configuration.moderation_config - if (moderation_config and moderation_config.enabled is True - and 'openai' in hosting_configuration.provider_map - and hosting_configuration.provider_map['openai'].enabled is True + if ( + moderation_config + and moderation_config.enabled is True + and "openai" in hosting_configuration.provider_map + and hosting_configuration.provider_map["openai"].enabled is True ): using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type provider_name = model_config.provider - if using_provider_type == ProviderType.SYSTEM \ - and provider_name in moderation_config.providers: - hosting_openai_config = hosting_configuration.provider_map['openai'] + if using_provider_type == ProviderType.SYSTEM and provider_name in moderation_config.providers: + hosting_openai_config = hosting_configuration.provider_map["openai"] # 2000 text per chunk length = 2000 - text_chunks = [text[i:i + length] for i in range(0, len(text), length)] + text_chunks = [text[i : i + length] for i in range(0, len(text), length)] if len(text_chunks) == 0: return True @@ -34,15 +35,13 @@ def check_moderation(model_config: ModelConfigWithCredentialsEntity, text: str) try: model_type_instance = OpenAIModerationModel() moderation_result = model_type_instance.invoke( - model='text-moderation-stable', - credentials=hosting_openai_config.credentials, - text=text_chunk + model="text-moderation-stable", credentials=hosting_openai_config.credentials, text=text_chunk ) if moderation_result is True: return True except Exception as ex: logger.exception(ex) - raise InvokeBadRequestError('Rate limit exceeded, please try again later.') + raise InvokeBadRequestError("Rate limit exceeded, please try again later.") return False diff --git a/api/core/helper/module_import_helper.py b/api/core/helper/module_import_helper.py index 2000577a406e6f..e6e149154870da 100644 --- a/api/core/helper/module_import_helper.py +++ b/api/core/helper/module_import_helper.py @@ -37,8 +37,9 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type] """ Get all the subclasses of the parent type from the module """ - classes = [x for _, x in vars(mod).items() - if isinstance(x, type) and x != parent_type and issubclass(x, parent_type)] + classes = [ + x for _, x in vars(mod).items() if isinstance(x, type) and x != parent_type and issubclass(x, parent_type) + ] return classes @@ -56,6 +57,6 @@ def load_single_subclass_from_source( case 1: return subclasses[0] case 0: - raise Exception(f'Missing subclass of {parent_type.__name__} in {script_path}') + raise Exception(f"Missing subclass of {parent_type.__name__} in {script_path}") case _: - raise Exception(f'Multiple subclasses of {parent_type.__name__} in {script_path}') \ No newline at end of file + raise Exception(f"Multiple subclasses of {parent_type.__name__} in {script_path}") diff --git a/api/core/helper/position_helper.py b/api/core/helper/position_helper.py index dd1534c791b313..3efdc8aa471697 100644 --- a/api/core/helper/position_helper.py +++ b/api/core/helper/position_helper.py @@ -3,6 +3,7 @@ from collections.abc import Callable from typing import Any +from configs import dify_config from core.tools.utils.yaml_utils import load_yaml_file @@ -19,10 +20,91 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> return {name: index for index, name in enumerate(positions)} +def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: + """ + Get the mapping for tools from name to index from a YAML file. + :param folder_path: + :param file_name: the YAML file name, default to '_position.yaml' + :return: a dict with name as key and index as value + """ + position_map = get_position_map(folder_path, file_name=file_name) + + return pin_position_map( + position_map, + pin_list=dify_config.POSITION_TOOL_PINS_LIST, + ) + + +def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]: + """ + Get the mapping for providers from name to index from a YAML file. + :param folder_path: + :param file_name: the YAML file name, default to '_position.yaml' + :return: a dict with name as key and index as value + """ + position_map = get_position_map(folder_path, file_name=file_name) + return pin_position_map( + position_map, + pin_list=dify_config.POSITION_PROVIDER_PINS_LIST, + ) + + +def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]: + """ + Pin the items in the pin list to the beginning of the position map. + Overall logic: exclude > include > pin + :param position_map: the position map to be sorted and filtered + :param pin_list: the list of pins to be put at the beginning + :return: the sorted position map + """ + positions = sorted(original_position_map.keys(), key=lambda x: original_position_map[x]) + + # Add pins to position map + position_map = {name: idx for idx, name in enumerate(pin_list)} + + # Add remaining positions to position map + start_idx = len(position_map) + for name in positions: + if name not in position_map: + position_map[name] = start_idx + start_idx += 1 + + return position_map + + +def is_filtered( + include_set: set[str], + exclude_set: set[str], + data: Any, + name_func: Callable[[Any], str], +) -> bool: + """ + Check if the object should be filtered out. + Overall logic: exclude > include > pin + :param include_set: the set of names to be included + :param exclude_set: the set of names to be excluded + :param name_func: the function to get the name of the object + :param data: the data to be filtered + :return: True if the object should be filtered out, False otherwise + """ + if not data: + return False + if not include_set and not exclude_set: + return False + + name = name_func(data) + + if name in exclude_set: # exclude_set is prioritized + return True + if include_set and name not in include_set: # filter out only if include_set is not empty + return True + return False + + def sort_by_position_map( - position_map: dict[str, int], - data: list[Any], - name_func: Callable[[Any], str], + position_map: dict[str, int], + data: list[Any], + name_func: Callable[[Any], str], ) -> list[Any]: """ Sort the objects by the position map. @@ -35,13 +117,13 @@ def sort_by_position_map( if not position_map or not data: return data - return sorted(data, key=lambda x: position_map.get(name_func(x), float('inf'))) + return sorted(data, key=lambda x: position_map.get(name_func(x), float("inf"))) def sort_to_dict_by_position_map( - position_map: dict[str, int], - data: list[Any], - name_func: Callable[[Any], str], + position_map: dict[str, int], + data: list[Any], + name_func: Callable[[Any], str], ) -> OrderedDict[str, Any]: """ Sort the objects into a ordered dict by the position map. diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 14ca8e943c71fc..374bd9d57bc4b1 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -1,40 +1,55 @@ """ Proxy requests to avoid SSRF """ + import logging -import os import time import httpx -SSRF_PROXY_ALL_URL = os.getenv('SSRF_PROXY_ALL_URL', '') -SSRF_PROXY_HTTP_URL = os.getenv('SSRF_PROXY_HTTP_URL', '') -SSRF_PROXY_HTTPS_URL = os.getenv('SSRF_PROXY_HTTPS_URL', '') -SSRF_DEFAULT_MAX_RETRIES = int(os.getenv('SSRF_DEFAULT_MAX_RETRIES', '3')) +from configs import dify_config + +SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES -proxies = { - 'http://': SSRF_PROXY_HTTP_URL, - 'https://': SSRF_PROXY_HTTPS_URL -} if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None +proxy_mounts = ( + { + "http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL), + "https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL), + } + if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL + else None +) BACKOFF_FACTOR = 0.5 STATUS_FORCELIST = [429, 500, 502, 503, 504] + def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): if "allow_redirects" in kwargs: allow_redirects = kwargs.pop("allow_redirects") if "follow_redirects" not in kwargs: kwargs["follow_redirects"] = allow_redirects - + + if "timeout" not in kwargs: + kwargs["timeout"] = httpx.Timeout( + timeout=dify_config.SSRF_DEFAULT_TIME_OUT, + connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT, + read=dify_config.SSRF_DEFAULT_READ_TIME_OUT, + write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT, + ) + retries = 0 while retries <= max_retries: try: - if SSRF_PROXY_ALL_URL: - response = httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs) - elif proxies: - response = httpx.request(method=method, url=url, proxies=proxies, **kwargs) + if dify_config.SSRF_PROXY_ALL_URL: + with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL) as client: + response = client.request(method=method, url=url, **kwargs) + elif proxy_mounts: + with httpx.Client(mounts=proxy_mounts) as client: + response = client.request(method=method, url=url, **kwargs) else: - response = httpx.request(method=method, url=url, **kwargs) + with httpx.Client() as client: + response = client.request(method=method, url=url, **kwargs) if response.status_code not in STATUS_FORCELIST: return response @@ -52,24 +67,24 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('GET', url, max_retries=max_retries, **kwargs) + return make_request("GET", url, max_retries=max_retries, **kwargs) def post(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('POST', url, max_retries=max_retries, **kwargs) + return make_request("POST", url, max_retries=max_retries, **kwargs) def put(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('PUT', url, max_retries=max_retries, **kwargs) + return make_request("PUT", url, max_retries=max_retries, **kwargs) def patch(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('PATCH', url, max_retries=max_retries, **kwargs) + return make_request("PATCH", url, max_retries=max_retries, **kwargs) def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('DELETE', url, max_retries=max_retries, **kwargs) + return make_request("DELETE", url, max_retries=max_retries, **kwargs) def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): - return make_request('HEAD', url, max_retries=max_retries, **kwargs) + return make_request("HEAD", url, max_retries=max_retries, **kwargs) diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py index a6f486e81de006..e848b46c5633ab 100644 --- a/api/core/helper/tool_parameter_cache.py +++ b/api/core/helper/tool_parameter_cache.py @@ -9,15 +9,15 @@ class ToolParameterCacheType(Enum): PARAMETER = "tool_parameter" + class ToolParameterCache: - def __init__(self, - tenant_id: str, - provider: str, - tool_name: str, - cache_type: ToolParameterCacheType, - identity_id: str - ): - self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}:identity_id:{identity_id}" + def __init__( + self, tenant_id: str, provider: str, tool_name: str, cache_type: ToolParameterCacheType, identity_id: str + ): + self.cache_key = ( + f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}" + f":identity_id:{identity_id}" + ) def get(self) -> Optional[dict]: """ @@ -28,7 +28,7 @@ def get(self) -> Optional[dict]: cached_tool_parameter = redis_client.get(self.cache_key) if cached_tool_parameter: try: - cached_tool_parameter = cached_tool_parameter.decode('utf-8') + cached_tool_parameter = cached_tool_parameter.decode("utf-8") cached_tool_parameter = json.loads(cached_tool_parameter) except JSONDecodeError: return None @@ -52,4 +52,4 @@ def delete(self) -> None: :return: """ - redis_client.delete(self.cache_key) \ No newline at end of file + redis_client.delete(self.cache_key) diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py index 6c5d3b8fb6880c..94b02cf98578b1 100644 --- a/api/core/helper/tool_provider_cache.py +++ b/api/core/helper/tool_provider_cache.py @@ -9,6 +9,7 @@ class ToolProviderCredentialsCacheType(Enum): PROVIDER = "tool_provider" + class ToolProviderCredentialsCache: def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType): self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}" @@ -22,7 +23,7 @@ def get(self) -> Optional[dict]: cached_provider_credentials = redis_client.get(self.cache_key) if cached_provider_credentials: try: - cached_provider_credentials = cached_provider_credentials.decode('utf-8') + cached_provider_credentials = cached_provider_credentials.decode("utf-8") cached_provider_credentials = json.loads(cached_provider_credentials) except JSONDecodeError: return None @@ -46,4 +47,4 @@ def delete(self) -> None: :return: """ - redis_client.delete(self.cache_key) \ No newline at end of file + redis_client.delete(self.cache_key) diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 5f7fec58337ee9..b47ba67f2fa64f 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -1,8 +1,9 @@ from typing import Optional -from flask import Config, Flask +from flask import Flask from pydantic import BaseModel +from configs import dify_config from core.entities.provider_entities import QuotaUnit, RestrictModel from core.model_runtime.entities.model_entities import ModelType from models.provider import ProviderQuotaType @@ -44,31 +45,30 @@ class HostingConfiguration: moderation_config: HostedModerationConfig = None def init_app(self, app: Flask) -> None: - config = app.config - - if config.get('EDITION') != 'CLOUD': + if dify_config.EDITION != "CLOUD": return - self.provider_map["azure_openai"] = self.init_azure_openai(config) - self.provider_map["openai"] = self.init_openai(config) - self.provider_map["anthropic"] = self.init_anthropic(config) - self.provider_map["minimax"] = self.init_minimax(config) - self.provider_map["spark"] = self.init_spark(config) - self.provider_map["zhipuai"] = self.init_zhipuai(config) + self.provider_map["azure_openai"] = self.init_azure_openai() + self.provider_map["openai"] = self.init_openai() + self.provider_map["anthropic"] = self.init_anthropic() + self.provider_map["minimax"] = self.init_minimax() + self.provider_map["spark"] = self.init_spark() + self.provider_map["zhipuai"] = self.init_zhipuai() - self.moderation_config = self.init_moderation_config(config) + self.moderation_config = self.init_moderation_config() - def init_azure_openai(self, app_config: Config) -> HostingProvider: + @staticmethod + def init_azure_openai() -> HostingProvider: quota_unit = QuotaUnit.TIMES - if app_config.get("HOSTED_AZURE_OPENAI_ENABLED"): + if dify_config.HOSTED_AZURE_OPENAI_ENABLED: credentials = { - "openai_api_key": app_config.get("HOSTED_AZURE_OPENAI_API_KEY"), - "openai_api_base": app_config.get("HOSTED_AZURE_OPENAI_API_BASE"), - "base_model_name": "gpt-35-turbo" + "openai_api_key": dify_config.HOSTED_AZURE_OPENAI_API_KEY, + "openai_api_base": dify_config.HOSTED_AZURE_OPENAI_API_BASE, + "base_model_name": "gpt-35-turbo", } quotas = [] - hosted_quota_limit = int(app_config.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT", "1000")) + hosted_quota_limit = dify_config.HOSTED_AZURE_OPENAI_QUOTA_LIMIT trial_quota = TrialHostingQuota( quota_limit=hosted_quota_limit, restrict_models=[ @@ -76,120 +76,124 @@ def init_azure_openai(self, app_config: Config) -> HostingProvider: RestrictModel(model="gpt-4o", base_model_name="gpt-4o", model_type=ModelType.LLM), RestrictModel(model="gpt-4o-mini", base_model_name="gpt-4o-mini", model_type=ModelType.LLM), RestrictModel(model="gpt-4-32k", base_model_name="gpt-4-32k", model_type=ModelType.LLM), - RestrictModel(model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM), - RestrictModel(model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM), + RestrictModel( + model="gpt-4-1106-preview", base_model_name="gpt-4-1106-preview", model_type=ModelType.LLM + ), + RestrictModel( + model="gpt-4-vision-preview", base_model_name="gpt-4-vision-preview", model_type=ModelType.LLM + ), RestrictModel(model="gpt-35-turbo", base_model_name="gpt-35-turbo", model_type=ModelType.LLM), - RestrictModel(model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM), - RestrictModel(model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM), - RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM), - RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM), - RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING), - RestrictModel(model="text-embedding-3-small", base_model_name="text-embedding-3-small", model_type=ModelType.TEXT_EMBEDDING), - RestrictModel(model="text-embedding-3-large", base_model_name="text-embedding-3-large", model_type=ModelType.TEXT_EMBEDDING), - ] + RestrictModel( + model="gpt-35-turbo-1106", base_model_name="gpt-35-turbo-1106", model_type=ModelType.LLM + ), + RestrictModel( + model="gpt-35-turbo-instruct", base_model_name="gpt-35-turbo-instruct", model_type=ModelType.LLM + ), + RestrictModel( + model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM + ), + RestrictModel( + model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM + ), + RestrictModel( + model="text-embedding-ada-002", + base_model_name="text-embedding-ada-002", + model_type=ModelType.TEXT_EMBEDDING, + ), + RestrictModel( + model="text-embedding-3-small", + base_model_name="text-embedding-3-small", + model_type=ModelType.TEXT_EMBEDDING, + ), + RestrictModel( + model="text-embedding-3-large", + base_model_name="text-embedding-3-large", + model_type=ModelType.TEXT_EMBEDDING, + ), + ], ) quotas.append(trial_quota) - return HostingProvider( - enabled=True, - credentials=credentials, - quota_unit=quota_unit, - quotas=quotas - ) + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) return HostingProvider( enabled=False, quota_unit=quota_unit, ) - def init_openai(self, app_config: Config) -> HostingProvider: + def init_openai(self) -> HostingProvider: quota_unit = QuotaUnit.CREDITS quotas = [] - if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"): - hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200")) - trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS") - trial_quota = TrialHostingQuota( - quota_limit=hosted_quota_limit, - restrict_models=trial_models - ) + if dify_config.HOSTED_OPENAI_TRIAL_ENABLED: + hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT + trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS") + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models) quotas.append(trial_quota) - if app_config.get("HOSTED_OPENAI_PAID_ENABLED"): - paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS") - paid_quota = PaidHostingQuota( - restrict_models=paid_models - ) + if dify_config.HOSTED_OPENAI_PAID_ENABLED: + paid_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_PAID_MODELS") + paid_quota = PaidHostingQuota(restrict_models=paid_models) quotas.append(paid_quota) if len(quotas) > 0: credentials = { - "openai_api_key": app_config.get("HOSTED_OPENAI_API_KEY"), + "openai_api_key": dify_config.HOSTED_OPENAI_API_KEY, } - if app_config.get("HOSTED_OPENAI_API_BASE"): - credentials["openai_api_base"] = app_config.get("HOSTED_OPENAI_API_BASE") + if dify_config.HOSTED_OPENAI_API_BASE: + credentials["openai_api_base"] = dify_config.HOSTED_OPENAI_API_BASE - if app_config.get("HOSTED_OPENAI_API_ORGANIZATION"): - credentials["openai_organization"] = app_config.get("HOSTED_OPENAI_API_ORGANIZATION") + if dify_config.HOSTED_OPENAI_API_ORGANIZATION: + credentials["openai_organization"] = dify_config.HOSTED_OPENAI_API_ORGANIZATION - return HostingProvider( - enabled=True, - credentials=credentials, - quota_unit=quota_unit, - quotas=quotas - ) + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) return HostingProvider( enabled=False, quota_unit=quota_unit, ) - def init_anthropic(self, app_config: Config) -> HostingProvider: + @staticmethod + def init_anthropic() -> HostingProvider: quota_unit = QuotaUnit.TOKENS quotas = [] - if app_config.get("HOSTED_ANTHROPIC_TRIAL_ENABLED"): - hosted_quota_limit = int(app_config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0")) - trial_quota = TrialHostingQuota( - quota_limit=hosted_quota_limit - ) + if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED: + hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT + trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit) quotas.append(trial_quota) - if app_config.get("HOSTED_ANTHROPIC_PAID_ENABLED"): + if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED: paid_quota = PaidHostingQuota() quotas.append(paid_quota) if len(quotas) > 0: credentials = { - "anthropic_api_key": app_config.get("HOSTED_ANTHROPIC_API_KEY"), + "anthropic_api_key": dify_config.HOSTED_ANTHROPIC_API_KEY, } - if app_config.get("HOSTED_ANTHROPIC_API_BASE"): - credentials["anthropic_api_url"] = app_config.get("HOSTED_ANTHROPIC_API_BASE") + if dify_config.HOSTED_ANTHROPIC_API_BASE: + credentials["anthropic_api_url"] = dify_config.HOSTED_ANTHROPIC_API_BASE - return HostingProvider( - enabled=True, - credentials=credentials, - quota_unit=quota_unit, - quotas=quotas - ) + return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas) return HostingProvider( enabled=False, quota_unit=quota_unit, ) - def init_minimax(self, app_config: Config) -> HostingProvider: + @staticmethod + def init_minimax() -> HostingProvider: quota_unit = QuotaUnit.TOKENS - if app_config.get("HOSTED_MINIMAX_ENABLED"): + if dify_config.HOSTED_MINIMAX_ENABLED: quotas = [FreeHostingQuota()] return HostingProvider( enabled=True, credentials=None, # use credentials from the provider quota_unit=quota_unit, - quotas=quotas + quotas=quotas, ) return HostingProvider( @@ -197,16 +201,17 @@ def init_minimax(self, app_config: Config) -> HostingProvider: quota_unit=quota_unit, ) - def init_spark(self, app_config: Config) -> HostingProvider: + @staticmethod + def init_spark() -> HostingProvider: quota_unit = QuotaUnit.TOKENS - if app_config.get("HOSTED_SPARK_ENABLED"): + if dify_config.HOSTED_SPARK_ENABLED: quotas = [FreeHostingQuota()] return HostingProvider( enabled=True, credentials=None, # use credentials from the provider quota_unit=quota_unit, - quotas=quotas + quotas=quotas, ) return HostingProvider( @@ -214,16 +219,17 @@ def init_spark(self, app_config: Config) -> HostingProvider: quota_unit=quota_unit, ) - def init_zhipuai(self, app_config: Config) -> HostingProvider: + @staticmethod + def init_zhipuai() -> HostingProvider: quota_unit = QuotaUnit.TOKENS - if app_config.get("HOSTED_ZHIPUAI_ENABLED"): + if dify_config.HOSTED_ZHIPUAI_ENABLED: quotas = [FreeHostingQuota()] return HostingProvider( enabled=True, credentials=None, # use credentials from the provider quota_unit=quota_unit, - quotas=quotas + quotas=quotas, ) return HostingProvider( @@ -231,22 +237,19 @@ def init_zhipuai(self, app_config: Config) -> HostingProvider: quota_unit=quota_unit, ) - def init_moderation_config(self, app_config: Config) -> HostedModerationConfig: - if app_config.get("HOSTED_MODERATION_ENABLED") \ - and app_config.get("HOSTED_MODERATION_PROVIDERS"): - return HostedModerationConfig( - enabled=True, - providers=app_config.get("HOSTED_MODERATION_PROVIDERS").split(',') - ) + @staticmethod + def init_moderation_config() -> HostedModerationConfig: + if dify_config.HOSTED_MODERATION_ENABLED and dify_config.HOSTED_MODERATION_PROVIDERS: + return HostedModerationConfig(enabled=True, providers=dify_config.HOSTED_MODERATION_PROVIDERS.split(",")) - return HostedModerationConfig( - enabled=False - ) + return HostedModerationConfig(enabled=False) @staticmethod - def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]: - models_str = app_config.get(env_var) + def parse_restrict_models_from_env(env_var: str) -> list[RestrictModel]: + models_str = dify_config.model_dump().get(env_var) models_list = models_str.split(",") if models_str else [] - return [RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) for model_name in models_list if - model_name.strip()] - + return [ + RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) + for model_name in models_list + if model_name.strip() + ] diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index b20c6ed187f20b..e2a94073cf4c14 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -16,9 +16,8 @@ from core.errors.error import ProviderTokenNotInitError from core.llm_generator.llm_generator import LLMGenerator from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.model_entities import ModelType, PriceType -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.model_runtime.entities.model_entities import ModelType +from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting @@ -41,7 +40,6 @@ class IndexingRunner: - def __init__(self): self.storage = storage self.model_manager = ModelManager() @@ -51,25 +49,26 @@ def run(self, dataset_documents: list[DatasetDocument]): for dataset_document in dataset_documents: try: # get dataset - dataset = Dataset.query.filter_by( - id=dataset_document.dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") # get the process rule - processing_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ - first() + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + .first() + ) index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() # extract text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) # transform - documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language, - processing_rule.to_dict()) + documents = self._transform( + index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict() + ) # save segment self._load_segments(dataset, dataset_document, documents) @@ -78,20 +77,20 @@ def run(self, dataset_documents: list[DatasetDocument]): index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, - documents=documents + documents=documents, ) - except DocumentIsPausedException: - raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) + except DocumentIsPausedError: + raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) except ProviderTokenNotInitError as e: - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e.description) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() except ObjectDeletedError: - logging.warning('Document deleted, document id: {}'.format(dataset_document.id)) + logging.warning("Document deleted, document id: {}".format(dataset_document.id)) except Exception as e: logging.exception("consume document failed") - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() @@ -100,26 +99,25 @@ def run_in_splitting_status(self, dataset_document: DatasetDocument): """Run the indexing process when the index_status is splitting.""" try: # get dataset - dataset = Dataset.query.filter_by( - id=dataset_document.dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") # get exist document_segment list and delete document_segments = DocumentSegment.query.filter_by( - dataset_id=dataset.id, - document_id=dataset_document.id + dataset_id=dataset.id, document_id=dataset_document.id ).all() for document_segment in document_segments: db.session.delete(document_segment) db.session.commit() # get the process rule - processing_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ - first() + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + .first() + ) index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -127,28 +125,26 @@ def run_in_splitting_status(self, dataset_document: DatasetDocument): text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict()) # transform - documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language, - processing_rule.to_dict()) + documents = self._transform( + index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict() + ) # save segment self._load_segments(dataset, dataset_document, documents) # load self._load( - index_processor=index_processor, - dataset=dataset, - dataset_document=dataset_document, - documents=documents + index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents ) - except DocumentIsPausedException: - raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) + except DocumentIsPausedError: + raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) except ProviderTokenNotInitError as e: - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e.description) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() except Exception as e: logging.exception("consume document failed") - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() @@ -157,17 +153,14 @@ def run_in_indexing_status(self, dataset_document: DatasetDocument): """Run the indexing process when the index_status is indexing.""" try: # get dataset - dataset = Dataset.query.filter_by( - id=dataset_document.dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") # get exist document_segment list and delete document_segments = DocumentSegment.query.filter_by( - dataset_id=dataset.id, - document_id=dataset_document.id + dataset_id=dataset.id, document_id=dataset_document.id ).all() documents = [] @@ -182,42 +175,48 @@ def run_in_indexing_status(self, dataset_document: DatasetDocument): "doc_hash": document_segment.index_node_hash, "document_id": document_segment.document_id, "dataset_id": document_segment.dataset_id, - } + }, ) documents.append(document) # build index # get the process rule - processing_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \ - first() + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) + .first() + ) index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() self._load( - index_processor=index_processor, - dataset=dataset, - dataset_document=dataset_document, - documents=documents + index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents ) - except DocumentIsPausedException: - raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id)) + except DocumentIsPausedError: + raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) except ProviderTokenNotInitError as e: - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e.description) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() except Exception as e: logging.exception("consume document failed") - dataset_document.indexing_status = 'error' + dataset_document.indexing_status = "error" dataset_document.error = str(e) dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() - def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict, - doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, - indexing_technique: str = 'economy') -> dict: + def indexing_estimate( + self, + tenant_id: str, + extract_settings: list[ExtractSetting], + tmp_processing_rule: dict, + doc_form: Optional[str] = None, + doc_language: str = "English", + dataset_id: Optional[str] = None, + indexing_technique: str = "economy", + ) -> dict: """ Estimate the indexing for the document. """ @@ -231,18 +230,16 @@ def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSettin embedding_model_instance = None if dataset_id: - dataset = Dataset.query.filter_by( - id=dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_id).first() if not dataset: - raise ValueError('Dataset not found.') - if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality': + raise ValueError("Dataset not found.") + if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality": if dataset.embedding_model_provider: embedding_model_instance = self.model_manager.get_model_instance( tenant_id=tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) else: embedding_model_instance = self.model_manager.get_default_model_instance( @@ -250,16 +247,13 @@ def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSettin model_type=ModelType.TEXT_EMBEDDING, ) else: - if indexing_technique == 'high_quality': + if indexing_technique == "high_quality": embedding_model_instance = self.model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) - tokens = 0 preview_texts = [] total_segments = 0 - total_price = 0 - currency = 'USD' index_type = doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() all_text_docs = [] @@ -268,8 +262,7 @@ def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSettin text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) all_text_docs.extend(text_docs) processing_rule = DatasetProcessRule( - mode=tmp_processing_rule["mode"], - rules=json.dumps(tmp_processing_rule["rules"]) + mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) ) # get splitter @@ -277,150 +270,118 @@ def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSettin # split to documents documents = self._split_to_documents_for_estimate( - text_docs=text_docs, - splitter=splitter, - processing_rule=processing_rule + text_docs=text_docs, splitter=splitter, processing_rule=processing_rule ) total_segments += len(documents) for document in documents: if len(preview_texts) < 5: preview_texts.append(document.page_content) - if indexing_technique == 'high_quality' or embedding_model_instance: - tokens += embedding_model_instance.get_text_embedding_num_tokens( - texts=[self.filter_string(document.page_content)] - ) - - if doc_form and doc_form == 'qa_model': - model_instance = self.model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.LLM - ) - - model_type_instance = model_instance.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) + if doc_form and doc_form == "qa_model": if len(preview_texts) > 0: # qa model document - response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], - doc_language) - document_qa_list = self.format_split_text(response) - price_info = model_type_instance.get_price( - model=model_instance.model, - credentials=model_instance.credentials, - price_type=PriceType.INPUT, - tokens=total_segments * 2000, + response = LLMGenerator.generate_qa_document( + current_user.current_tenant_id, preview_texts[0], doc_language ) - return { - "total_segments": total_segments * 20, - "tokens": total_segments * 2000, - "total_price": '{:f}'.format(price_info.total_amount), - "currency": price_info.currency, - "qa_preview": document_qa_list, - "preview": preview_texts - } - if embedding_model_instance: - embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance) - embedding_price_info = embedding_model_type_instance.get_price( - model=embedding_model_instance.model, - credentials=embedding_model_instance.credentials, - price_type=PriceType.INPUT, - tokens=tokens - ) - total_price = '{:f}'.format(embedding_price_info.total_amount) - currency = embedding_price_info.currency - return { - "total_segments": total_segments, - "tokens": tokens, - "total_price": total_price, - "currency": currency, - "preview": preview_texts - } - - def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \ - -> list[Document]: + document_qa_list = self.format_split_text(response) + + return {"total_segments": total_segments * 20, "qa_preview": document_qa_list, "preview": preview_texts} + return {"total_segments": total_segments, "preview": preview_texts} + + def _extract( + self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict + ) -> list[Document]: # load file - if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]: + if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}: return [] data_source_info = dataset_document.data_source_info_dict text_docs = [] - if dataset_document.data_source_type == 'upload_file': - if not data_source_info or 'upload_file_id' not in data_source_info: + if dataset_document.data_source_type == "upload_file": + if not data_source_info or "upload_file_id" not in data_source_info: raise ValueError("no upload file found") - file_detail = db.session.query(UploadFile). \ - filter(UploadFile.id == data_source_info['upload_file_id']). \ - one_or_none() + file_detail = ( + db.session.query(UploadFile).filter(UploadFile.id == data_source_info["upload_file_id"]).one_or_none() + ) if file_detail: extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=file_detail, - document_model=dataset_document.doc_form + datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) - elif dataset_document.data_source_type == 'notion_import': - if (not data_source_info or 'notion_workspace_id' not in data_source_info - or 'notion_page_id' not in data_source_info): + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + elif dataset_document.data_source_type == "notion_import": + if ( + not data_source_info + or "notion_workspace_id" not in data_source_info + or "notion_page_id" not in data_source_info + ): raise ValueError("no notion import info found") extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ - "notion_workspace_id": data_source_info['notion_workspace_id'], - "notion_obj_id": data_source_info['notion_page_id'], - "notion_page_type": data_source_info['type'], + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], "document": dataset_document, - "tenant_id": dataset_document.tenant_id + "tenant_id": dataset_document.tenant_id, }, - document_model=dataset_document.doc_form + document_model=dataset_document.doc_form, ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) - elif dataset_document.data_source_type == 'website_crawl': - if (not data_source_info or 'provider' not in data_source_info - or 'url' not in data_source_info or 'job_id' not in data_source_info): + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) + elif dataset_document.data_source_type == "website_crawl": + if ( + not data_source_info + or "provider" not in data_source_info + or "url" not in data_source_info + or "job_id" not in data_source_info + ): raise ValueError("no website import info found") extract_setting = ExtractSetting( datasource_type="website_crawl", website_info={ - "provider": data_source_info['provider'], - "job_id": data_source_info['job_id'], + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], "tenant_id": dataset_document.tenant_id, - "url": data_source_info['url'], - "mode": data_source_info['mode'], - "only_main_content": data_source_info['only_main_content'] + "url": data_source_info["url"], + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], }, - document_model=dataset_document.doc_form + document_model=dataset_document.doc_form, ) - text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) + text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) # update document status to splitting self._update_document_index_status( document_id=dataset_document.id, after_indexing_status="splitting", extra_update_params={ DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs), - DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - } + DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + }, ) # replace doc id to document model id text_docs = cast(list[Document], text_docs) for text_doc in text_docs: - text_doc.metadata['document_id'] = dataset_document.id - text_doc.metadata['dataset_id'] = dataset_document.dataset_id + text_doc.metadata["document_id"] = dataset_document.id + text_doc.metadata["dataset_id"] = dataset_document.dataset_id return text_docs - def filter_string(self, text): - text = re.sub(r'<\|', '<', text) - text = re.sub(r'\|>', '>', text) - text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text) + @staticmethod + def filter_string(text): + text = re.sub(r"<\|", "<", text) + text = re.sub(r"\|>", ">", text) + text = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]", "", text) # Unicode U+FFFE - text = re.sub('\uFFFE', '', text) + text = re.sub("\ufffe", "", text) return text - def _get_splitter(self, processing_rule: DatasetProcessRule, - embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: + @staticmethod + def _get_splitter( + processing_rule: DatasetProcessRule, embedding_model_instance: Optional[ModelInstance] + ) -> TextSplitter: """ Get the NodeParser object according to the processing rule. """ @@ -434,10 +395,10 @@ def _get_splitter(self, processing_rule: DatasetProcessRule, separator = segmentation["separator"] if separator: - separator = separator.replace('\\n', '\n') + separator = separator.replace("\\n", "\n") - if segmentation.get('chunk_overlap'): - chunk_overlap = segmentation['chunk_overlap'] + if segmentation.get("chunk_overlap"): + chunk_overlap = segmentation["chunk_overlap"] else: chunk_overlap = 0 @@ -446,22 +407,27 @@ def _get_splitter(self, processing_rule: DatasetProcessRule, chunk_overlap=chunk_overlap, fixed_separator=separator, separators=["\n\n", "。", ". ", " ", ""], - embedding_model_instance=embedding_model_instance + embedding_model_instance=embedding_model_instance, ) else: # Automatic segmentation character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( - chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], - chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'], + chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"], + chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"], separators=["\n\n", "。", ". ", " ", ""], - embedding_model_instance=embedding_model_instance + embedding_model_instance=embedding_model_instance, ) return character_splitter - def _step_split(self, text_docs: list[Document], splitter: TextSplitter, - dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \ - -> list[Document]: + def _step_split( + self, + text_docs: list[Document], + splitter: TextSplitter, + dataset: Dataset, + dataset_document: DatasetDocument, + processing_rule: DatasetProcessRule, + ) -> list[Document]: """ Split the text documents into documents and save them to the document segment. """ @@ -471,14 +437,12 @@ def _step_split(self, text_docs: list[Document], splitter: TextSplitter, processing_rule=processing_rule, tenant_id=dataset.tenant_id, document_form=dataset_document.doc_form, - document_language=dataset_document.doc_language + document_language=dataset_document.doc_language, ) # save node to document segment doc_store = DatasetDocumentStore( - dataset=dataset, - user_id=dataset_document.created_by, - document_id=dataset_document.id + dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id ) # add document segments @@ -492,7 +456,7 @@ def _step_split(self, text_docs: list[Document], splitter: TextSplitter, extra_update_params={ DatasetDocument.cleaning_completed_at: cur_time, DatasetDocument.splitting_completed_at: cur_time, - } + }, ) # update segment status to indexing @@ -500,15 +464,21 @@ def _step_split(self, text_docs: list[Document], splitter: TextSplitter, dataset_document_id=dataset_document.id, update_params={ DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - } + DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + }, ) return documents - def _split_to_documents(self, text_docs: list[Document], splitter: TextSplitter, - processing_rule: DatasetProcessRule, tenant_id: str, - document_form: str, document_language: str) -> list[Document]: + def _split_to_documents( + self, + text_docs: list[Document], + splitter: TextSplitter, + processing_rule: DatasetProcessRule, + tenant_id: str, + document_form: str, + document_language: str, + ) -> list[Document]: """ Split the text documents into nodes. """ @@ -523,13 +493,12 @@ def _split_to_documents(self, text_docs: list[Document], splitter: TextSplitter, documents = splitter.split_documents([text_doc]) split_documents = [] for document_node in documents: - if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata['doc_id'] = doc_id - document_node.metadata['doc_hash'] = hash - # delete Spliter character + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash + # delete Splitter character page_content = document_node.page_content if page_content.startswith(".") or page_content.startswith("。"): page_content = page_content[1:] @@ -541,15 +510,21 @@ def _split_to_documents(self, text_docs: list[Document], splitter: TextSplitter, split_documents.append(document_node) all_documents.extend(split_documents) # processing qa document - if document_form == 'qa_model': + if document_form == "qa_model": for i in range(0, len(all_documents), 10): threads = [] - sub_documents = all_documents[i:i + 10] + sub_documents = all_documents[i : i + 10] for doc in sub_documents: - document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={ - 'flask_app': current_app._get_current_object(), - 'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents, - 'document_language': document_language}) + document_format_thread = threading.Thread( + target=self.format_qa_document, + kwargs={ + "flask_app": current_app._get_current_object(), + "tenant_id": tenant_id, + "document_node": doc, + "all_qa_documents": all_qa_documents, + "document_language": document_language, + }, + ) threads.append(document_format_thread) document_format_thread.start() for thread in threads: @@ -568,12 +543,14 @@ def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, al document_qa_list = self.format_split_text(response) qa_documents = [] for result in document_qa_list: - qa_document = Document(page_content=result['question'], metadata=document_node.metadata.model_copy()) + qa_document = Document( + page_content=result["question"], metadata=document_node.metadata.model_copy() + ) doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(result['question']) - qa_document.metadata['answer'] = result['answer'] - qa_document.metadata['doc_id'] = doc_id - qa_document.metadata['doc_hash'] = hash + hash = helper.generate_text_hash(result["question"]) + qa_document.metadata["answer"] = result["answer"] + qa_document.metadata["doc_id"] = doc_id + qa_document.metadata["doc_hash"] = hash qa_documents.append(qa_document) format_documents.extend(qa_documents) except Exception as e: @@ -581,8 +558,9 @@ def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, al all_qa_documents.extend(format_documents) - def _split_to_documents_for_estimate(self, text_docs: list[Document], splitter: TextSplitter, - processing_rule: DatasetProcessRule) -> list[Document]: + def _split_to_documents_for_estimate( + self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule + ) -> list[Document]: """ Split the text documents into nodes. """ @@ -602,8 +580,8 @@ def _split_to_documents_for_estimate(self, text_docs: list[Document], splitter: doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document.page_content) - document.metadata['doc_id'] = doc_id - document.metadata['doc_hash'] = hash + document.metadata["doc_id"] = doc_id + document.metadata["doc_hash"] = hash split_documents.append(document) @@ -611,7 +589,8 @@ def _split_to_documents_for_estimate(self, text_docs: list[Document], splitter: return all_documents - def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str: + @staticmethod + def _document_clean(text: str, processing_rule: DatasetProcessRule) -> str: """ Clean the document text according to the processing rules. """ @@ -619,52 +598,35 @@ def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str rules = DatasetProcessRule.AUTOMATIC_RULES else: rules = json.loads(processing_rule.rules) if processing_rule.rules else {} + document_text = CleanProcessor.clean(text, {"rules": rules}) - if 'pre_processing_rules' in rules: - pre_processing_rules = rules["pre_processing_rules"] - for pre_processing_rule in pre_processing_rules: - if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True: - # Remove extra spaces - pattern = r'\n{3,}' - text = re.sub(pattern, '\n\n', text) - pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}' - text = re.sub(pattern, ' ', text) - elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True: - # Remove email - pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)' - text = re.sub(pattern, '', text) - - # Remove URL - pattern = r'https?://[^\s]+' - text = re.sub(pattern, '', text) - - return text + return document_text - def format_split_text(self, text): + @staticmethod + def format_split_text(text): regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" matches = re.findall(regex, text, re.UNICODE) - return [ - { - "question": q, - "answer": re.sub(r"\n\s*", "\n", a.strip()) - } - for q, a in matches if q and a - ] + return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a] - def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset, - dataset_document: DatasetDocument, documents: list[Document]) -> None: + def _load( + self, + index_processor: BaseIndexProcessor, + dataset: Dataset, + dataset_document: DatasetDocument, + documents: list[Document], + ) -> None: """ insert index and update document/segment status to completed """ embedding_model_instance = None - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": embedding_model_instance = self.model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) # chunk nodes by chunk size @@ -673,18 +635,27 @@ def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset, chunk_size = 10 # create keyword index - create_keyword_thread = threading.Thread(target=self._process_keyword_index, - args=(current_app._get_current_object(), - dataset.id, dataset_document.id, documents)) + create_keyword_thread = threading.Thread( + target=self._process_keyword_index, + args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), + ) create_keyword_thread.start() - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: futures = [] for i in range(0, len(documents), chunk_size): - chunk_documents = documents[i:i + chunk_size] - futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor, - chunk_documents, dataset, - dataset_document, embedding_model_instance)) + chunk_documents = documents[i : i + chunk_size] + futures.append( + executor.submit( + self._process_chunk, + current_app._get_current_object(), + index_processor, + chunk_documents, + dataset, + dataset_document, + embedding_model_instance, + ) + ) for future in futures: tokens += future.result() @@ -700,32 +671,38 @@ def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset, DatasetDocument.tokens: tokens, DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at, - } + DatasetDocument.error: None, + }, ) - def _process_keyword_index(self, flask_app, dataset_id, document_id, documents): + @staticmethod + def _process_keyword_index(flask_app, dataset_id, document_id, documents): with flask_app.app_context(): dataset = Dataset.query.filter_by(id=dataset_id).first() if not dataset: raise ValueError("no dataset found") keyword = Keyword(dataset) keyword.create(documents) - if dataset.indexing_technique != 'high_quality': - document_ids = [document.metadata['doc_id'] for document in documents] + if dataset.indexing_technique != "high_quality": + document_ids = [document.metadata["doc_id"] for document in documents] db.session.query(DocumentSegment).filter( DocumentSegment.document_id == document_id, + DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id.in_(document_ids), - DocumentSegment.status == "indexing" - ).update({ - DocumentSegment.status: "completed", - DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - }) + DocumentSegment.status == "indexing", + ).update( + { + DocumentSegment.status: "completed", + DocumentSegment.enabled: True, + DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + } + ) db.session.commit() - def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document, - embedding_model_instance): + def _process_chunk( + self, flask_app, index_processor, chunk_documents, dataset, dataset_document, embedding_model_instance + ): with flask_app.app_context(): # check document is paused self._check_document_paused_status(dataset_document.id) @@ -733,51 +710,53 @@ def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, d tokens = 0 if embedding_model_instance: tokens += sum( - embedding_model_instance.get_text_embedding_num_tokens( - [document.page_content] - ) + embedding_model_instance.get_text_embedding_num_tokens([document.page_content]) for document in chunk_documents ) # load index index_processor.load(dataset, chunk_documents, with_keywords=False) - document_ids = [document.metadata['doc_id'] for document in chunk_documents] + document_ids = [document.metadata["doc_id"] for document in chunk_documents] db.session.query(DocumentSegment).filter( DocumentSegment.document_id == dataset_document.id, + DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(document_ids), - DocumentSegment.status == "indexing" - ).update({ - DocumentSegment.status: "completed", - DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - }) + DocumentSegment.status == "indexing", + ).update( + { + DocumentSegment.status: "completed", + DocumentSegment.enabled: True, + DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + } + ) db.session.commit() return tokens - def _check_document_paused_status(self, document_id: str): - indexing_cache_key = 'document_{}_is_paused'.format(document_id) + @staticmethod + def _check_document_paused_status(document_id: str): + indexing_cache_key = "document_{}_is_paused".format(document_id) result = redis_client.get(indexing_cache_key) if result: - raise DocumentIsPausedException() + raise DocumentIsPausedError() - def _update_document_index_status(self, document_id: str, after_indexing_status: str, - extra_update_params: Optional[dict] = None) -> None: + @staticmethod + def _update_document_index_status( + document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None + ) -> None: """ Update the document indexing status. """ count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count() if count > 0: - raise DocumentIsPausedException() + raise DocumentIsPausedError() document = DatasetDocument.query.filter_by(id=document_id).first() if not document: - raise DocumentIsDeletedPausedException() + raise DocumentIsDeletedPausedError() - update_params = { - DatasetDocument.indexing_status: after_indexing_status - } + update_params = {DatasetDocument.indexing_status: after_indexing_status} if extra_update_params: update_params.update(extra_update_params) @@ -785,14 +764,16 @@ def _update_document_index_status(self, document_id: str, after_indexing_status: DatasetDocument.query.filter_by(id=document_id).update(update_params) db.session.commit() - def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None: + @staticmethod + def _update_segments_by_document(dataset_document_id: str, update_params: dict) -> None: """ Update the document segment by document id. """ DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) db.session.commit() - def batch_add_segments(self, segments: list[DocumentSegment], dataset: Dataset): + @staticmethod + def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset): """ Batch add segments index processing """ @@ -805,7 +786,7 @@ def batch_add_segments(self, segments: list[DocumentSegment], dataset: Dataset): "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) # save vector index @@ -813,17 +794,23 @@ def batch_add_segments(self, segments: list[DocumentSegment], dataset: Dataset): index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor.load(dataset, documents) - def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset, - text_docs: list[Document], doc_language: str, process_rule: dict) -> list[Document]: + def _transform( + self, + index_processor: BaseIndexProcessor, + dataset: Dataset, + text_docs: list[Document], + doc_language: str, + process_rule: dict, + ) -> list[Document]: # get embedding model instance embedding_model_instance = None - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": if dataset.embedding_model_provider: embedding_model_instance = self.model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) else: embedding_model_instance = self.model_manager.get_default_model_instance( @@ -831,18 +818,20 @@ def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset, model_type=ModelType.TEXT_EMBEDDING, ) - documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance, - process_rule=process_rule, tenant_id=dataset.tenant_id, - doc_language=doc_language) + documents = index_processor.transform( + text_docs, + embedding_model_instance=embedding_model_instance, + process_rule=process_rule, + tenant_id=dataset.tenant_id, + doc_language=doc_language, + ) return documents def _load_segments(self, dataset, dataset_document, documents): # save node to document segment doc_store = DatasetDocumentStore( - dataset=dataset, - user_id=dataset_document.created_by, - document_id=dataset_document.id + dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id ) # add document segments @@ -856,7 +845,7 @@ def _load_segments(self, dataset, dataset_document, documents): extra_update_params={ DatasetDocument.cleaning_completed_at: cur_time, DatasetDocument.splitting_completed_at: cur_time, - } + }, ) # update segment status to indexing @@ -864,15 +853,15 @@ def _load_segments(self, dataset, dataset_document, documents): dataset_document_id=dataset_document.id, update_params={ DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - } + DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + }, ) pass -class DocumentIsPausedException(Exception): +class DocumentIsPausedError(Exception): pass -class DocumentIsDeletedPausedException(Exception): +class DocumentIsDeletedPausedError(Exception): pass diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 8c13b4a45cbe6c..9cf9ed75c04101 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -8,6 +8,8 @@ from core.llm_generator.prompts import ( CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT, + JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE, + PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE, WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, ) from core.model_manager import ModelManager @@ -43,21 +45,18 @@ def generate_conversation_name( with measure_time() as timer: response = model_instance.invoke_llm( - prompt_messages=prompts, - model_parameters={ - "max_tokens": 100, - "temperature": 1 - }, - stream=False + prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False ) answer = response.message.content - cleaned_answer = re.sub(r'^.*(\{.*\}).*$', r'\1', answer, flags=re.DOTALL) + cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL) + if cleaned_answer is None: + return "" result_dict = json.loads(cleaned_answer) - answer = result_dict['Your Output'] + answer = result_dict["Your Output"] name = answer.strip() if len(name) > 75: - name = name[:75] + '...' + name = name[:75] + "..." # get tracing instance trace_manager = TraceQueueManager(app_id=app_id) @@ -79,14 +78,9 @@ def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: st output_parser = SuggestedQuestionsAfterAnswerOutputParser() format_instructions = output_parser.get_format_instructions() - prompt_template = PromptTemplateParser( - template="{{histories}}\n{{format_instructions}}\nquestions:\n" - ) + prompt_template = PromptTemplateParser(template="{{histories}}\n{{format_instructions}}\nquestions:\n") - prompt = prompt_template.format({ - "histories": histories, - "format_instructions": format_instructions - }) + prompt = prompt_template.format({"histories": histories, "format_instructions": format_instructions}) try: model_manager = ModelManager() @@ -101,12 +95,7 @@ def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: st try: response = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters={ - "max_tokens": 256, - "temperature": 0 - }, - stream=False + prompt_messages=prompt_messages, model_parameters={"max_tokens": 256, "temperature": 0}, stream=False ) questions = output_parser.parse(response.message.content) @@ -119,32 +108,24 @@ def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: st return questions @classmethod - def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool, rule_config_max_tokens: int = 512) -> dict: + def generate_rule_config( + cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool, rule_config_max_tokens: int = 512 + ) -> dict: output_parser = RuleConfigGeneratorOutputParser() error = "" error_step = "" - rule_config = { - "prompt": "", - "variables": [], - "opening_statement": "", - "error": "" - } - model_parameters = { - "max_tokens": rule_config_max_tokens, - "temperature": 0.01 - } + rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""} + model_parameters = {"max_tokens": rule_config_max_tokens, "temperature": 0.01} if no_variable: - prompt_template = PromptTemplateParser( - WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE - ) + prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE) prompt_generate = prompt_template.format( inputs={ "TASK_DESCRIPTION": instruction, }, - remove_template_variables=False + remove_template_variables=False, ) prompt_messages = [UserPromptMessage(content=prompt_generate)] @@ -158,13 +139,11 @@ def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: di try: response = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=model_parameters, - stream=False + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False ) rule_config["prompt"] = response.message.content - + except InvokeError as e: error = str(e) error_step = "generate rule config" @@ -179,24 +158,18 @@ def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: di # get rule config prompt, parameter and statement prompt_generate, parameter_generate, statement_generate = output_parser.get_format_instructions() - prompt_template = PromptTemplateParser( - prompt_generate - ) + prompt_template = PromptTemplateParser(prompt_generate) - parameter_template = PromptTemplateParser( - parameter_generate - ) + parameter_template = PromptTemplateParser(parameter_generate) - statement_template = PromptTemplateParser( - statement_generate - ) + statement_template = PromptTemplateParser(statement_generate) # format the prompt_generate_prompt prompt_generate_prompt = prompt_template.format( inputs={ "TASK_DESCRIPTION": instruction, }, - remove_template_variables=False + remove_template_variables=False, ) prompt_messages = [UserPromptMessage(content=prompt_generate_prompt)] @@ -213,9 +186,7 @@ def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: di try: # the first step to generate the task prompt prompt_content = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=model_parameters, - stream=False + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False ) except InvokeError as e: error = str(e) @@ -230,7 +201,7 @@ def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: di inputs={ "INPUT_TEXT": prompt_content.message.content, }, - remove_template_variables=False + remove_template_variables=False, ) parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)] @@ -240,15 +211,13 @@ def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: di "TASK_DESCRIPTION": instruction, "INPUT_TEXT": prompt_content.message.content, }, - remove_template_variables=False + remove_template_variables=False, ) statement_messages = [UserPromptMessage(content=statement_generate_prompt)] try: parameter_content = model_instance.invoke_llm( - prompt_messages=parameter_messages, - model_parameters=model_parameters, - stream=False + prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False ) rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.content) except InvokeError as e: @@ -257,9 +226,7 @@ def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: di try: statement_content = model_instance.invoke_llm( - prompt_messages=statement_messages, - model_parameters=model_parameters, - stream=False + prompt_messages=statement_messages, model_parameters=model_parameters, stream=False ) rule_config["opening_statement"] = statement_content.message.content except InvokeError as e: @@ -274,6 +241,54 @@ def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: di return rule_config + @classmethod + def generate_code( + cls, + tenant_id: str, + instruction: str, + model_config: dict, + code_language: str = "javascript", + max_tokens: int = 1000, + ) -> dict: + if code_language == "python": + prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE) + else: + prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE) + + prompt = prompt_template.format( + inputs={ + "INSTRUCTION": instruction, + "CODE_LANGUAGE": code_language, + }, + remove_template_variables=False, + ) + + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + provider=model_config.get("provider") if model_config else None, + model=model_config.get("name") if model_config else None, + ) + + prompt_messages = [UserPromptMessage(content=prompt)] + model_parameters = {"max_tokens": max_tokens, "temperature": 0.01} + + try: + response = model_instance.invoke_llm( + prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False + ) + + generated_code = response.message.content + return {"code": generated_code, "language": code_language, "error": ""} + + except InvokeError as e: + error = str(e) + return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"} + except Exception as e: + logging.exception(e) + return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"} + @classmethod def generate_qa_document(cls, tenant_id: str, query, document_language: str): prompt = GENERATOR_QA_PROMPT.format(language=document_language) @@ -284,18 +299,10 @@ def generate_qa_document(cls, tenant_id: str, query, document_language: str): model_type=ModelType.LLM, ) - prompt_messages = [ - SystemPromptMessage(content=prompt), - UserPromptMessage(content=query) - ] + prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] response = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters={ - 'temperature': 0.01, - "max_tokens": 2000 - }, - stream=False + prompt_messages=prompt_messages, model_parameters={"temperature": 0.01, "max_tokens": 2000}, stream=False ) answer = response.message.content diff --git a/api/core/llm_generator/output_parser/errors.py b/api/core/llm_generator/output_parser/errors.py index 6a60f8de803726..1e743f1757473e 100644 --- a/api/core/llm_generator/output_parser/errors.py +++ b/api/core/llm_generator/output_parser/errors.py @@ -1,2 +1,2 @@ -class OutputParserException(Exception): +class OutputParserError(Exception): pass diff --git a/api/core/llm_generator/output_parser/rule_config_generator.py b/api/core/llm_generator/output_parser/rule_config_generator.py index 8856f0c6856952..0c7683b16d373e 100644 --- a/api/core/llm_generator/output_parser/rule_config_generator.py +++ b/api/core/llm_generator/output_parser/rule_config_generator.py @@ -1,6 +1,6 @@ from typing import Any -from core.llm_generator.output_parser.errors import OutputParserException +from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.prompts import ( RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, @@ -10,9 +10,12 @@ class RuleConfigGeneratorOutputParser: - def get_format_instructions(self) -> tuple[str, str, str]: - return RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE + return ( + RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, + RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, + RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE, + ) def parse(self, text: str) -> Any: try: @@ -21,16 +24,9 @@ def parse(self, text: str) -> Any: if not isinstance(parsed["prompt"], str): raise ValueError("Expected 'prompt' to be a string.") if not isinstance(parsed["variables"], list): - raise ValueError( - "Expected 'variables' to be a list." - ) + raise ValueError("Expected 'variables' to be a list.") if not isinstance(parsed["opening_statement"], str): - raise ValueError( - "Expected 'opening_statement' to be a str." - ) + raise ValueError("Expected 'opening_statement' to be a str.") return parsed except Exception as e: - raise OutputParserException( - f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}" - ) - + raise OutputParserError(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}") diff --git a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py index 3f046c68fceaf0..182aeed98fd7ff 100644 --- a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -6,7 +6,6 @@ class SuggestedQuestionsAfterAnswerOutputParser: - def get_format_instructions(self) -> str: return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT @@ -15,7 +14,7 @@ def parse(self, text: str) -> Any: if action_match is not None: json_obj = json.loads(action_match.group(0).strip()) else: - json_obj= [] + json_obj = [] print(f"Could not parse LLM output: {text}") return json_obj diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index 87361b385ab771..7c0f24705275f3 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -59,26 +59,95 @@ } User Input: -""" +""" # noqa: E501 + +PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE = ( + "You are an expert programmer. Generate code based on the following instructions:\n\n" + "Instructions: {{INSTRUCTION}}\n\n" + "Write the code in {{CODE_LANGUAGE}}.\n\n" + "Please ensure that you meet the following requirements:\n" + "1. Define a function named 'main'.\n" + "2. The 'main' function must return a dictionary (dict).\n" + "3. You may modify the arguments of the 'main' function, but include appropriate type hints.\n" + "4. The returned dictionary should contain at least one key-value pair.\n\n" + "5. You may ONLY use the following libraries in your code: \n" + "- json\n" + "- datetime\n" + "- math\n" + "- random\n" + "- re\n" + "- string\n" + "- sys\n" + "- time\n" + "- traceback\n" + "- uuid\n" + "- os\n" + "- base64\n" + "- hashlib\n" + "- hmac\n" + "- binascii\n" + "- collections\n" + "- functools\n" + "- operator\n" + "- itertools\n\n" + "Example:\n" + "def main(arg1: str, arg2: int) -> dict:\n" + " return {\n" + ' "result": arg1 * arg2,\n' + " }\n\n" + "IMPORTANT:\n" + "- Provide ONLY the code without any additional explanations, comments, or markdown formatting.\n" + "- DO NOT use markdown code blocks (``` or ``` python). Return the raw code directly.\n" + "- The code should start immediately after this instruction, without any preceding newlines or spaces.\n" + "- The code should be complete, functional, and follow best practices for {{CODE_LANGUAGE}}.\n\n" + "- Always use the format return {'result': ...} for the output.\n\n" + "Generated Code:\n" +) +JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = ( + "You are an expert programmer. Generate code based on the following instructions:\n\n" + "Instructions: {{INSTRUCTION}}\n\n" + "Write the code in {{CODE_LANGUAGE}}.\n\n" + "Please ensure that you meet the following requirements:\n" + "1. Define a function named 'main'.\n" + "2. The 'main' function must return an object.\n" + "3. You may modify the arguments of the 'main' function, but include appropriate JSDoc annotations.\n" + "4. The returned object should contain at least one key-value pair.\n\n" + "5. The returned object should always be in the format: {result: ...}\n\n" + "Example:\n" + "function main(arg1, arg2) {\n" + " return {\n" + " result: arg1 * arg2\n" + " };\n" + "}\n\n" + "IMPORTANT:\n" + "- Provide ONLY the code without any additional explanations, comments, or markdown formatting.\n" + "- DO NOT use markdown code blocks (``` or ``` javascript). Return the raw code directly.\n" + "- The code should start immediately after this instruction, without any preceding newlines or spaces.\n" + "- The code should be complete, functional, and follow best practices for {{CODE_LANGUAGE}}.\n\n" + "Generated Code:\n" +) + SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( "Please help me predict the three most likely questions that human would ask, " "and keeping each question under 20 characters.\n" - "MAKE SURE your output is the SAME language as the Assistant's latest response(if the main response is written in Chinese, then the language of your output must be using Chinese.)!\n" + "MAKE SURE your output is the SAME language as the Assistant's latest response" "The output must be an array in JSON format following the specified schema:\n" - "[\"question1\",\"question2\",\"question3\"]\n" + '["question1","question2","question3"]\n' ) GENERATOR_QA_PROMPT = ( - ' The user will send a long text. Generate a Question and Answer pairs only using the knowledge in the long text. Please think step by step.' - 'Step 1: Understand and summarize the main content of this text.\n' - 'Step 2: What key information or concepts are mentioned in this text?\n' - 'Step 3: Decompose or combine multiple pieces of information and concepts.\n' - 'Step 4: Generate questions and answers based on these key information and concepts.\n' - ' The questions should be clear and detailed, and the answers should be detailed and complete. ' - 'You must answer in {language}, in a style that is clear and detailed in {language}. No language other than {language} should be used. \n' - ' Use the following format: Q1:\nA1:\nQ2:\nA2:...\n' - '' + " The user will send a long text. Generate a Question and Answer pairs only using the knowledge" + " in the long text. Please think step by step." + "Step 1: Understand and summarize the main content of this text.\n" + "Step 2: What key information or concepts are mentioned in this text?\n" + "Step 3: Decompose or combine multiple pieces of information and concepts.\n" + "Step 4: Generate questions and answers based on these key information and concepts.\n" + " The questions should be clear and detailed, and the answers should be detailed and complete. " + "You must answer in {language}, in a style that is clear and detailed in {language}." + " No language other than {language} should be used. \n" + " Use the following format: Q1:\nA1:\nQ2:\nA2:...\n" + "" ) WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE = """ @@ -87,14 +156,14 @@ {{TASK_DESCRIPTION}} Based on task description, please create a well-structured prompt template that another AI could use to consistently complete the task. The prompt template should include: -- Do not inlcude or section and variables in the prompt, assume user will add them at their own will. +- Do not include or section and variables in the prompt, assume user will add them at their own will. - Clear instructions for the AI that will be using this prompt, demarcated with tags. The instructions should provide step-by-step directions on how to complete the task using the input variables. Also Specifies in the instructions that the output should not contain any xml tag. - Relevant examples if needed to clarify the task further, demarcated with tags. Do not include variables in the prompt. Give three pairs of input and output examples. - Include other relevant sections demarcated with appropriate XML tags like , . - Use the same language as task description. - Output in ``` xml ``` and start with Please generate the full prompt template with at least 300 words and output only the prompt template. -""" +""" # noqa: E501 RULE_CONFIG_PROMPT_GENERATE_TEMPLATE = """ Here is a task description for which I would like you to create a high-quality prompt template for: @@ -109,7 +178,7 @@ - Use the same language as task description. - Output in ``` xml ``` and start with Please generate the full prompt template and output only the prompt template. -""" +""" # noqa: E501 RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE = """ I need to extract the following information from the input text. The tag specifies the 'type', 'description' and 'required' of the information to be extracted. @@ -134,7 +203,7 @@ ### Answer I should always output a valid list. Output nothing other than the list of variable_name. Output an empty list if there is no variable name in input text. -""" +""" # noqa: E501 RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE = """ @@ -150,4 +219,4 @@ Here is the task description: {{INPUT_TEXT}} You just need to generate the output -""" +""" # noqa: E501 diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index b33d4dd7cb342c..688fb4776a86e1 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,17 +1,21 @@ from typing import Optional from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.file.message_file_parser import MessageFileParser +from core.file import file_manager +from core.file.models import FileType from core.model_manager import ModelInstance -from core.model_runtime.entities.message_entities import ( +from core.model_runtime.entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, + PromptMessageContent, PromptMessageRole, TextPromptMessageContent, UserPromptMessage, ) +from core.prompt.utils.extract_thread_messages import extract_thread_messages from extensions.ext_database import db +from factories import file_factory from models.model import AppMode, Conversation, Message, MessageFile from models.workflow import WorkflowRun @@ -21,8 +25,9 @@ def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> self.conversation = conversation self.model_instance = model_instance - def get_history_prompt_messages(self, max_token_limit: int = 2000, - message_limit: Optional[int] = None) -> list[PromptMessage]: + def get_history_prompt_messages( + self, max_token_limit: int = 2000, message_limit: Optional[int] = None + ) -> list[PromptMessage]: """ Get history prompt messages. :param max_token_limit: max token limit @@ -31,63 +36,81 @@ def get_history_prompt_messages(self, max_token_limit: int = 2000, app_record = self.conversation.app # fetch limited messages, and return reversed - query = db.session.query( - Message.id, - Message.query, - Message.answer, - Message.created_at, - Message.workflow_run_id - ).filter( - Message.conversation_id == self.conversation.id, - Message.answer != '' - ).order_by(Message.created_at.desc()) + query = ( + db.session.query( + Message.id, + Message.query, + Message.answer, + Message.created_at, + Message.workflow_run_id, + Message.parent_message_id, + ) + .filter( + Message.conversation_id == self.conversation.id, + ) + .order_by(Message.created_at.desc()) + ) if message_limit and message_limit > 0: - message_limit = message_limit if message_limit <= 500 else 500 + message_limit = min(message_limit, 500) else: message_limit = 500 messages = query.limit(message_limit).all() - messages = list(reversed(messages)) - message_file_parser = MessageFileParser( - tenant_id=app_record.tenant_id, - app_id=app_record.id - ) + # instead of all messages from the conversation, we only need to extract messages + # that belong to the thread of last message + thread_messages = extract_thread_messages(messages) + + # for newly created message, its answer is temporarily empty, we don't need to add it to memory + if thread_messages and not thread_messages[0].answer: + thread_messages.pop(0) + + messages = list(reversed(thread_messages)) + prompt_messages = [] for message in messages: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() if files: file_extra_config = None - if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) else: if message.workflow_run_id: - workflow_run = (db.session.query(WorkflowRun) - .filter(WorkflowRun.id == message.workflow_run_id).first()) + workflow_run = ( + db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first() + ) - if workflow_run: + if workflow_run and workflow_run.workflow: file_extra_config = FileUploadConfigManager.convert( - workflow_run.workflow.features_dict, - is_vision=False + workflow_run.workflow.features_dict, is_vision=False ) - if file_extra_config: - file_objs = message_file_parser.transform_message_files( - files, - file_extra_config + detail = ImagePromptMessageContent.DETAIL.LOW + if file_extra_config and app_record: + file_objs = file_factory.build_from_message_files( + message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config ) + if file_extra_config.image_config and file_extra_config.image_config.detail: + detail = file_extra_config.image_config.detail else: file_objs = [] if not file_objs: prompt_messages.append(UserPromptMessage(content=message.query)) else: - prompt_message_contents = [TextPromptMessageContent(data=message.query)] - for file_obj in file_objs: - prompt_message_contents.append(file_obj.prompt_message_content) + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=message.query)) + for file in file_objs: + if file.type in {FileType.IMAGE, FileType.AUDIO}: + prompt_message = file_manager.to_prompt_message_content( + file, + image_detail_config=detail, + ) + prompt_message_contents.append(prompt_message) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: prompt_messages.append(UserPromptMessage(content=message.query)) @@ -97,24 +120,23 @@ def get_history_prompt_messages(self, max_token_limit: int = 2000, return [] # prune the chat message if it exceeds the max token limit - curr_message_tokens = self.model_instance.get_llm_num_tokens( - prompt_messages - ) + curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) if curr_message_tokens > max_token_limit: pruned_memory = [] - while curr_message_tokens > max_token_limit and len(prompt_messages)>1: + while curr_message_tokens > max_token_limit and len(prompt_messages) > 1: pruned_memory.append(prompt_messages.pop(0)) - curr_message_tokens = self.model_instance.get_llm_num_tokens( - prompt_messages - ) + curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) return prompt_messages - def get_history_prompt_text(self, human_prefix: str = "Human", - ai_prefix: str = "Assistant", - max_token_limit: int = 2000, - message_limit: Optional[int] = None) -> str: + def get_history_prompt_text( + self, + human_prefix: str = "Human", + ai_prefix: str = "Assistant", + max_token_limit: int = 2000, + message_limit: Optional[int] = None, + ) -> str: """ Get history prompt text. :param human_prefix: human prefix @@ -123,10 +145,7 @@ def get_history_prompt_text(self, human_prefix: str = "Human", :param message_limit: message limit :return: """ - prompt_messages = self.get_history_prompt_messages( - max_token_limit=max_token_limit, - message_limit=message_limit - ) + prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit) string_messages = [] for m in prompt_messages: diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 7c23e14297d3cf..059ba6c3d1f26e 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,8 +1,9 @@ import logging -import os -from collections.abc import Callable, Generator -from typing import IO, Optional, Union, cast +from collections.abc import Callable, Generator, Iterable, Sequence +from typing import IO, Any, Optional, Union, cast +from configs import dify_config +from core.entities.embedding_type import EmbeddingInputType from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.errors.error import ProviderTokenNotInitError @@ -41,10 +42,11 @@ def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> No configuration=provider_model_bundle.configuration, model_type=provider_model_bundle.model_type_instance.model_type, model=model, - credentials=self.credentials + credentials=self.credentials, ) - def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict: + @staticmethod + def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str) -> dict: """ Fetch credentials from provider model bundle :param provider_model_bundle: provider model bundle @@ -53,20 +55,17 @@ def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBun """ configuration = provider_model_bundle.configuration model_type = provider_model_bundle.model_type_instance.model_type - credentials = configuration.get_current_credentials( - model_type=model_type, - model=model - ) + credentials = configuration.get_current_credentials(model_type=model_type, model=model) if credentials is None: raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.") return credentials - def _get_load_balancing_manager(self, configuration: ProviderConfiguration, - model_type: ModelType, - model: str, - credentials: dict) -> Optional["LBModelManager"]: + @staticmethod + def _get_load_balancing_manager( + configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict + ) -> Optional["LBModelManager"]: """ Get load balancing model credentials :param configuration: provider configuration @@ -79,8 +78,7 @@ def _get_load_balancing_manager(self, configuration: ProviderConfiguration, current_model_setting = None # check if model is disabled by admin for model_setting in configuration.model_settings: - if (model_setting.model_type == model_type - and model_setting.model == model): + if model_setting.model_type == model_type and model_setting.model == model: current_model_setting = model_setting break @@ -93,17 +91,23 @@ def _get_load_balancing_manager(self, configuration: ProviderConfiguration, model_type=model_type, model=model, load_balancing_configs=current_model_setting.load_balancing_configs, - managed_credentials=credentials if configuration.custom_configuration.provider else None + managed_credentials=credentials if configuration.custom_configuration.provider else None, ) return lb_model_manager return None - def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \ - -> Union[LLMResult, Generator]: + def invoke_llm( + self, + prompt_messages: list[PromptMessage], + model_parameters: Optional[dict] = None, + tools: Sequence[PromptMessageTool] | None = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -130,11 +134,12 @@ def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Opt stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) - def get_llm_num_tokens(self, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_llm_num_tokens( + self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """ Get number of tokens for llm @@ -151,16 +156,18 @@ def get_llm_num_tokens(self, prompt_messages: list[PromptMessage], model=self.model, credentials=self.credentials, prompt_messages=prompt_messages, - tools=tools + tools=tools, ) - def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def invoke_text_embedding( + self, texts: list[str], user: Optional[str] = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT + ) -> TextEmbeddingResult: """ Invoke large language model :param texts: texts to embed :param user: unique user id + :param input_type: input type :return: embeddings result """ if not isinstance(self.model_type_instance, TextEmbeddingModel): @@ -172,7 +179,8 @@ def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \ model=self.model, credentials=self.credentials, texts=texts, - user=user + user=user, + input_type=input_type, ) def get_text_embedding_num_tokens(self, texts: list[str]) -> int: @@ -190,13 +198,17 @@ def get_text_embedding_num_tokens(self, texts: list[str]) -> int: function=self.model_type_instance.get_num_tokens, model=self.model, credentials=self.credentials, - texts=texts + texts=texts, ) - def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def invoke_rerank( + self, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -219,11 +231,10 @@ def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[f docs=docs, score_threshold=score_threshold, top_n=top_n, - user=user + user=user, ) - def invoke_moderation(self, text: str, user: Optional[str] = None) \ - -> bool: + def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool: """ Invoke moderation model @@ -240,11 +251,10 @@ def invoke_moderation(self, text: str, user: Optional[str] = None) \ model=self.model, credentials=self.credentials, text=text, - user=user + user=user, ) - def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \ - -> str: + def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke large language model @@ -261,11 +271,10 @@ def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \ model=self.model, credentials=self.credentials, file=file, - user=user + user=user, ) - def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) \ - -> str: + def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]: """ Invoke large language tts model @@ -286,10 +295,10 @@ def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Option content_text=content_text, user=user, tenant_id=tenant_id, - voice=voice + voice=voice, ) - def _round_robin_invoke(self, function: Callable, *args, **kwargs): + def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs): """ Round-robin invoke :param function: function to invoke @@ -310,8 +319,8 @@ def _round_robin_invoke(self, function: Callable, *args, **kwargs): raise last_exception try: - if 'credentials' in kwargs: - del kwargs['credentials'] + if "credentials" in kwargs: + del kwargs["credentials"] return function(*args, **kwargs, credentials=lb_config.credentials) except InvokeRateLimitError as e: # expire in 60 seconds @@ -338,9 +347,7 @@ def get_tts_voices(self, language: Optional[str] = None) -> list: self.model_type_instance = cast(TTSModel, self.model_type_instance) return self.model_type_instance.get_tts_model_voices( - model=self.model, - credentials=self.credentials, - language=language + model=self.model, credentials=self.credentials, language=language ) @@ -361,13 +368,20 @@ def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelTyp return self.get_default_model_instance(tenant_id, model_type) provider_model_bundle = self._provider_manager.get_provider_model_bundle( - tenant_id=tenant_id, - provider=provider, - model_type=model_type + tenant_id=tenant_id, provider=provider, model_type=model_type ) return ModelInstance(provider_model_bundle, model) + def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]: + """ + Return first provider and the first model in the provider + :param tenant_id: tenant id + :param model_type: model type + :return: provider name, model name + """ + return self._provider_manager.get_first_provider_first_model(tenant_id, model_type) + def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance: """ Get default model instance @@ -375,10 +389,7 @@ def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> M :param model_type: model type :return: """ - default_model_entity = self._provider_manager.get_default_model( - tenant_id=tenant_id, - model_type=model_type - ) + default_model_entity = self._provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type) if not default_model_entity: raise ProviderTokenNotInitError(f"Default model not found for {model_type}") @@ -387,17 +398,20 @@ def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> M tenant_id=tenant_id, provider=default_model_entity.provider.provider, model_type=model_type, - model=default_model_entity.model + model=default_model_entity.model, ) class LBModelManager: - def __init__(self, tenant_id: str, - provider: str, - model_type: ModelType, - model: str, - load_balancing_configs: list[ModelLoadBalancingConfiguration], - managed_credentials: Optional[dict] = None) -> None: + def __init__( + self, + tenant_id: str, + provider: str, + model_type: ModelType, + model: str, + load_balancing_configs: list[ModelLoadBalancingConfiguration], + managed_credentials: Optional[dict] = None, + ) -> None: """ Load balancing model manager :param tenant_id: tenant_id @@ -428,10 +442,7 @@ def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]: :return: """ cache_key = "model_lb_index:{}:{}:{}:{}".format( - self._tenant_id, - self._provider, - self._model_type.value, - self._model + self._tenant_id, self._provider, self._model_type.value, self._model ) cooldown_load_balancing_configs = [] @@ -462,10 +473,12 @@ def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]: continue - if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): - logger.info(f"Model LB\nid: {config.id}\nname:{config.name}\n" - f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n" - f"model_type: {self._model_type.value}\nmodel: {self._model}") + if dify_config.DEBUG: + logger.info( + f"Model LB\nid: {config.id}\nname:{config.name}\n" + f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n" + f"model_type: {self._model_type.value}\nmodel: {self._model}" + ) return config @@ -479,14 +492,10 @@ def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> :return: """ cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( - self._tenant_id, - self._provider, - self._model_type.value, - self._model, - config.id + self._tenant_id, self._provider, self._model_type.value, self._model, config.id ) - redis_client.setex(cooldown_cache_key, expire, 'true') + redis_client.setex(cooldown_cache_key, expire, "true") def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool: """ @@ -495,24 +504,17 @@ def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool: :return: """ cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( - self._tenant_id, - self._provider, - self._model_type.value, - self._model, - config.id + self._tenant_id, self._provider, self._model_type.value, self._model, config.id ) - res = redis_client.exists(cooldown_cache_key) res = cast(bool, res) return res - @classmethod - def get_config_in_cooldown_and_ttl(cls, tenant_id: str, - provider: str, - model_type: ModelType, - model: str, - config_id: str) -> tuple[bool, int]: + @staticmethod + def get_config_in_cooldown_and_ttl( + tenant_id: str, provider: str, model_type: ModelType, model: str, config_id: str + ) -> tuple[bool, int]: """ Get model load balancing config is in cooldown and ttl :param tenant_id: workspace id @@ -523,11 +525,7 @@ def get_config_in_cooldown_and_ttl(cls, tenant_id: str, :return: """ cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format( - tenant_id, - provider, - model_type.value, - model, - config_id + tenant_id, provider, model_type.value, model, config_id ) ttl = redis_client.ttl(cooldown_cache_key) diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py index bba004a32a21d6..6bd9325785a2da 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/core/model_runtime/callbacks/base_callback.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from typing import Optional from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk @@ -13,17 +14,27 @@ } -class Callback: +class Callback(ABC): """ Base class for callbacks. Only for LLM. """ + raise_error: bool = False - def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + @abstractmethod + def on_before_invoke( + self, + llm_instance: AIModel, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Before invoke callback @@ -39,10 +50,20 @@ def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, """ raise NotImplementedError() - def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None): + @abstractmethod + def on_new_chunk( + self, + llm_instance: AIModel, + chunk: LLMResultChunk, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ): """ On new chunk callback @@ -59,10 +80,20 @@ def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, """ raise NotImplementedError() - def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + @abstractmethod + def on_after_invoke( + self, + llm_instance: AIModel, + result: LLMResult, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ After invoke callback @@ -79,10 +110,20 @@ def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, """ raise NotImplementedError() - def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + @abstractmethod + def on_invoke_error( + self, + llm_instance: AIModel, + ex: Exception, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Invoke error callback @@ -99,9 +140,7 @@ def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, cred """ raise NotImplementedError() - def print_text( - self, text: str, color: Optional[str] = None, end: str = "" - ) -> None: + def print_text(self, text: str, color: Optional[str] = None, end: str = "") -> None: """Print text with highlighting and no end characters.""" text_to_print = self._get_colored_text(text, color) if color else text print(text_to_print, end=end) diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/core/model_runtime/callbacks/logging_callback.py index 0406853b88b9c9..3b6b825244dfdc 100644 --- a/api/core/model_runtime/callbacks/logging_callback.py +++ b/api/core/model_runtime/callbacks/logging_callback.py @@ -10,11 +10,20 @@ logger = logging.getLogger(__name__) + class LoggingCallback(Callback): - def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_before_invoke( + self, + llm_instance: AIModel, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Before invoke callback @@ -28,40 +37,49 @@ def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, :param stream: is stream response :param user: unique user id """ - self.print_text("\n[on_llm_before_invoke]\n", color='blue') - self.print_text(f"Model: {model}\n", color='blue') - self.print_text("Parameters:\n", color='blue') + self.print_text("\n[on_llm_before_invoke]\n", color="blue") + self.print_text(f"Model: {model}\n", color="blue") + self.print_text("Parameters:\n", color="blue") for key, value in model_parameters.items(): - self.print_text(f"\t{key}: {value}\n", color='blue') + self.print_text(f"\t{key}: {value}\n", color="blue") if stop: - self.print_text(f"\tstop: {stop}\n", color='blue') + self.print_text(f"\tstop: {stop}\n", color="blue") if tools: - self.print_text("\tTools:\n", color='blue') + self.print_text("\tTools:\n", color="blue") for tool in tools: - self.print_text(f"\t\t{tool.name}\n", color='blue') + self.print_text(f"\t\t{tool.name}\n", color="blue") - self.print_text(f"Stream: {stream}\n", color='blue') + self.print_text(f"Stream: {stream}\n", color="blue") if user: - self.print_text(f"User: {user}\n", color='blue') + self.print_text(f"User: {user}\n", color="blue") - self.print_text("Prompt messages:\n", color='blue') + self.print_text("Prompt messages:\n", color="blue") for prompt_message in prompt_messages: if prompt_message.name: - self.print_text(f"\tname: {prompt_message.name}\n", color='blue') + self.print_text(f"\tname: {prompt_message.name}\n", color="blue") - self.print_text(f"\trole: {prompt_message.role.value}\n", color='blue') - self.print_text(f"\tcontent: {prompt_message.content}\n", color='blue') + self.print_text(f"\trole: {prompt_message.role.value}\n", color="blue") + self.print_text(f"\tcontent: {prompt_message.content}\n", color="blue") if stream: self.print_text("\n[on_llm_new_chunk]") - def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None): + def on_new_chunk( + self, + llm_instance: AIModel, + chunk: LLMResultChunk, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ): """ On new chunk callback @@ -79,10 +97,19 @@ def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, sys.stdout.write(chunk.delta.message.content) sys.stdout.flush() - def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + def on_after_invoke( + self, + llm_instance: AIModel, + result: LLMResult, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ After invoke callback @@ -97,24 +124,33 @@ def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, :param stream: is stream response :param user: unique user id """ - self.print_text("\n[on_llm_after_invoke]\n", color='yellow') - self.print_text(f"Content: {result.message.content}\n", color='yellow') + self.print_text("\n[on_llm_after_invoke]\n", color="yellow") + self.print_text(f"Content: {result.message.content}\n", color="yellow") if result.message.tool_calls: - self.print_text("Tool calls:\n", color='yellow') + self.print_text("Tool calls:\n", color="yellow") for tool_call in result.message.tool_calls: - self.print_text(f"\t{tool_call.id}\n", color='yellow') - self.print_text(f"\t{tool_call.function.name}\n", color='yellow') - self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color='yellow') - - self.print_text(f"Model: {result.model}\n", color='yellow') - self.print_text(f"Usage: {result.usage}\n", color='yellow') - self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color='yellow') - - def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> None: + self.print_text(f"\t{tool_call.id}\n", color="yellow") + self.print_text(f"\t{tool_call.function.name}\n", color="yellow") + self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color="yellow") + + self.print_text(f"Model: {result.model}\n", color="yellow") + self.print_text(f"Usage: {result.usage}\n", color="yellow") + self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color="yellow") + + def on_invoke_error( + self, + llm_instance: AIModel, + ex: Exception, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> None: """ Invoke error callback @@ -129,5 +165,5 @@ def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, cred :param stream: is stream response :param user: unique user id """ - self.print_text("\n[on_llm_invoke_error]\n", color='red') + self.print_text("\n[on_llm_invoke_error]\n", color="red") logger.exception(ex) diff --git a/api/core/model_runtime/docs/en_US/customizable_model_scale_out.md b/api/core/model_runtime/docs/en_US/customizable_model_scale_out.md new file mode 100644 index 00000000000000..f050919d81b767 --- /dev/null +++ b/api/core/model_runtime/docs/en_US/customizable_model_scale_out.md @@ -0,0 +1,310 @@ +## Custom Integration of Pre-defined Models + +### Introduction + +After completing the vendors integration, the next step is to connect the vendor's models. To illustrate the entire connection process, we will use Xinference as an example to demonstrate a complete vendor integration. + +It is important to note that for custom models, each model connection requires a complete vendor credential. + +Unlike pre-defined models, a custom vendor integration always includes the following two parameters, which do not need to be defined in the vendor YAML file. + +![](images/index/image-3.png) + +As mentioned earlier, vendors do not need to implement validate_provider_credential. The runtime will automatically call the corresponding model layer's validate_credentials to validate the credentials based on the model type and name selected by the user. + +### Writing the Vendor YAML + +First, we need to identify the types of models supported by the vendor we are integrating. + +Currently supported model types are as follows: + +- `llm` Text Generation Models + +- `text_embedding` Text Embedding Models + +- `rerank` Rerank Models + +- `speech2text` Speech-to-Text + +- `tts` Text-to-Speech + +- `moderation` Moderation + +Xinference supports LLM, Text Embedding, and Rerank. So we will start by writing xinference.yaml. + +```yaml +provider: xinference #Define the vendor identifier +label: # Vendor display name, supports both en_US (English) and zh_Hans (Simplified Chinese). If zh_Hans is not set, it will use en_US by default. + en_US: Xorbits Inference +icon_small: # Small icon, refer to other vendors' icons stored in the _assets directory within the vendor implementation directory; follows the same language policy as the label + en_US: icon_s_en.svg +icon_large: # Large icon + en_US: icon_l_en.svg +help: # Help information + title: + en_US: How to deploy Xinference + zh_Hans: 如何部署 Xinference + url: + en_US: https://github.com/xorbitsai/inference +supported_model_types: # Supported model types. Xinference supports LLM, Text Embedding, and Rerank +- llm +- text-embedding +- rerank +configurate_methods: # Since Xinference is a locally deployed vendor with no predefined models, users need to deploy whatever models they need according to Xinference documentation. Thus, it only supports custom models. +- customizable-model +provider_credential_schema: + credential_form_schemas: +``` + + +Then, we need to determine what credentials are required to define a model in Xinference. + +- Since it supports three different types of models, we need to specify the model_type to denote the model type. Here is how we can define it: + +```yaml +provider_credential_schema: + credential_form_schemas: + - variable: model_type + type: select + label: + en_US: Model type + zh_Hans: 模型类型 + required: true + options: + - value: text-generation + label: + en_US: Language Model + zh_Hans: 语言模型 + - value: embeddings + label: + en_US: Text Embedding + - value: reranking + label: + en_US: Rerank +``` + +- Next, each model has its own model_name, so we need to define that here: + +```yaml + - variable: model_name + type: text-input + label: + en_US: Model name + zh_Hans: 模型名称 + required: true + placeholder: + zh_Hans: 填写模型名称 + en_US: Input model name +``` + +- Specify the Xinference local deployment address: + +```yaml + - variable: server_url + label: + zh_Hans: 服务器URL + en_US: Server url + type: text-input + required: true + placeholder: + zh_Hans: 在此输入Xinference的服务器地址,如 https://example.com/xxx + en_US: Enter the url of your Xinference, for example https://example.com/xxx +``` + +- Each model has a unique model_uid, so we also need to define that here: + +```yaml + - variable: model_uid + label: + zh_Hans: 模型UID + en_US: Model uid + type: text-input + required: true + placeholder: + zh_Hans: 在此输入您的Model UID + en_US: Enter the model uid +``` + +Now, we have completed the basic definition of the vendor. + +### Writing the Model Code + +Next, let's take the `llm` type as an example and write `xinference.llm.llm.py`. + +In `llm.py`, create a Xinference LLM class, we name it `XinferenceAILargeLanguageModel` (this can be arbitrary), inheriting from the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods: + +- LLM Invocation + +Implement the core method for LLM invocation, supporting both stream and synchronous responses. + +```python +def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None) \ + -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool usage + :param stop: stop words + :param stream: is the response a stream + :param user: unique user id + :return: full response or stream response chunk generator result + """ +``` + +When implementing, ensure to use two functions to return data separately for synchronous and stream responses. This is important because Python treats functions containing the `yield` keyword as generator functions, mandating them to return `Generator` types. Here’s an example (note that the example uses simplified parameters; in real implementation, use the parameter list as defined above): + +```python +def _invoke(self, stream: bool, **kwargs) \ + -> Union[LLMResult, Generator]: + if stream: + return self._handle_stream_response(**kwargs) + return self._handle_sync_response(**kwargs) + +def _handle_stream_response(self, **kwargs) -> Generator: + for chunk in response: + yield chunk +def _handle_sync_response(self, **kwargs) -> LLMResult: + return LLMResult(**response) +``` + +- Pre-compute Input Tokens + +If the model does not provide an interface for pre-computing tokens, you can return 0 directly. + +```python +def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],tools: Optional[list[PromptMessageTool]] = None) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param tools: tools for tool usage + :return: token count + """ +``` + + +Sometimes, you might not want to return 0 directly. In such cases, you can use `self._get_num_tokens_by_gpt2(text: str)` to get pre-computed tokens. This method is provided by the `AIModel` base class, and it uses GPT2's Tokenizer for calculation. However, it should be noted that this is only a substitute and may not be fully accurate. + +- Model Credentials Validation + +Similar to vendor credentials validation, this method validates individual model credentials. + +```python +def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: None + """ +``` + +- Model Parameter Schema + +Unlike custom types, since the YAML file does not define which parameters a model supports, we need to dynamically generate the model parameter schema. + +For instance, Xinference supports `max_tokens`, `temperature`, and `top_p` parameters. + +However, some vendors may support different parameters for different models. For example, the `OpenLLM` vendor supports `top_k`, but not all models provided by this vendor support `top_k`. Let's say model A supports `top_k` but model B does not. In such cases, we need to dynamically generate the model parameter schema, as illustrated below: + +```python + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + """ + used to define customizable model schema + """ + rules = [ + ParameterRule( + name='temperature', type=ParameterType.FLOAT, + use_template='temperature', + label=I18nObject( + zh_Hans='温度', en_US='Temperature' + ) + ), + ParameterRule( + name='top_p', type=ParameterType.FLOAT, + use_template='top_p', + label=I18nObject( + zh_Hans='Top P', en_US='Top P' + ) + ), + ParameterRule( + name='max_tokens', type=ParameterType.INT, + use_template='max_tokens', + min=1, + default=512, + label=I18nObject( + zh_Hans='最大生成长度', en_US='Max Tokens' + ) + ) + ] + + # if model is A, add top_k to rules + if model == 'A': + rules.append( + ParameterRule( + name='top_k', type=ParameterType.INT, + use_template='top_k', + min=1, + default=50, + label=I18nObject( + zh_Hans='Top K', en_US='Top K' + ) + ) + ) + + """ + some NOT IMPORTANT code here + """ + + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=model_type, + model_properties={ + ModelPropertyKey.MODE: ModelType.LLM, + }, + parameter_rules=rules + ) + + return entity +``` + +- Exception Error Mapping + +When a model invocation error occurs, it should be mapped to the runtime's specified `InvokeError` type, enabling Dify to handle different errors appropriately. + +Runtime Errors: + +- `InvokeConnectionError` Connection error during invocation +- `InvokeServerUnavailableError` Service provider unavailable +- `InvokeRateLimitError` Rate limit reached +- `InvokeAuthorizationError` Authorization failure +- `InvokeBadRequestError` Invalid request parameters + +```python + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ +``` + +For interface method details, see: [Interfaces](./interfaces.md). For specific implementations, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py). \ No newline at end of file diff --git a/api/core/model_runtime/docs/en_US/images/index/image-1.png b/api/core/model_runtime/docs/en_US/images/index/image-1.png new file mode 100644 index 00000000000000..b158d44b29dcc2 Binary files /dev/null and b/api/core/model_runtime/docs/en_US/images/index/image-1.png differ diff --git a/api/core/model_runtime/docs/en_US/images/index/image-2.png b/api/core/model_runtime/docs/en_US/images/index/image-2.png new file mode 100644 index 00000000000000..c70cd3da5eea19 Binary files /dev/null and b/api/core/model_runtime/docs/en_US/images/index/image-2.png differ diff --git a/api/core/model_runtime/docs/en_US/images/index/image-3.png b/api/core/model_runtime/docs/en_US/images/index/image-3.png new file mode 100644 index 00000000000000..bf0b9a7f47fddf Binary files /dev/null and b/api/core/model_runtime/docs/en_US/images/index/image-3.png differ diff --git a/api/core/model_runtime/docs/en_US/images/index/image.png b/api/core/model_runtime/docs/en_US/images/index/image.png new file mode 100644 index 00000000000000..eb63d107e1c385 Binary files /dev/null and b/api/core/model_runtime/docs/en_US/images/index/image.png differ diff --git a/api/core/model_runtime/docs/en_US/predefined_model_scale_out.md b/api/core/model_runtime/docs/en_US/predefined_model_scale_out.md new file mode 100644 index 00000000000000..3e16257452c7a0 --- /dev/null +++ b/api/core/model_runtime/docs/en_US/predefined_model_scale_out.md @@ -0,0 +1,173 @@ +## Predefined Model Integration + +After completing the vendor integration, the next step is to integrate the models from the vendor. + +First, we need to determine the type of model to be integrated and create the corresponding model type `module` under the respective vendor's directory. + +Currently supported model types are: + +- `llm` Text Generation Model +- `text_embedding` Text Embedding Model +- `rerank` Rerank Model +- `speech2text` Speech-to-Text +- `tts` Text-to-Speech +- `moderation` Moderation + +Continuing with `Anthropic` as an example, `Anthropic` only supports LLM, so create a `module` named `llm` under `model_providers.anthropic`. + +For predefined models, we first need to create a YAML file named after the model under the `llm` `module`, such as `claude-2.1.yaml`. + +### Prepare Model YAML + +```yaml +model: claude-2.1 # Model identifier +# Display name of the model, which can be set to en_US English or zh_Hans Chinese. If zh_Hans is not set, it will default to en_US. +# This can also be omitted, in which case the model identifier will be used as the label +label: + en_US: claude-2.1 +model_type: llm # Model type, claude-2.1 is an LLM +features: # Supported features, agent-thought supports Agent reasoning, vision supports image understanding +- agent-thought +model_properties: # Model properties + mode: chat # LLM mode, complete for text completion models, chat for conversation models + context_size: 200000 # Maximum context size +parameter_rules: # Parameter rules for the model call; only LLM requires this +- name: temperature # Parameter variable name + # Five default configuration templates are provided: temperature/top_p/max_tokens/presence_penalty/frequency_penalty + # The template variable name can be set directly in use_template, which will use the default configuration in entities.defaults.PARAMETER_RULE_TEMPLATE + # Additional configuration parameters will override the default configuration if set + use_template: temperature +- name: top_p + use_template: top_p +- name: top_k + label: # Display name of the parameter + zh_Hans: 取样数量 + en_US: Top k + type: int # Parameter type, supports float/int/string/boolean + help: # Help information, describing the parameter's function + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false # Whether the parameter is mandatory; can be omitted +- name: max_tokens_to_sample + use_template: max_tokens + default: 4096 # Default value of the parameter + min: 1 # Minimum value of the parameter, applicable to float/int only + max: 4096 # Maximum value of the parameter, applicable to float/int only +pricing: # Pricing information + input: '8.00' # Input unit price, i.e., prompt price + output: '24.00' # Output unit price, i.e., response content price + unit: '0.000001' # Price unit, meaning the above prices are per 100K + currency: USD # Price currency +``` + +It is recommended to prepare all model configurations before starting the implementation of the model code. + +You can also refer to the YAML configuration information under the corresponding model type directories of other vendors in the `model_providers` directory. For the complete YAML rules, refer to: [Schema](schema.md#aimodelentity). + +### Implement the Model Call Code + +Next, create a Python file named `llm.py` under the `llm` `module` to write the implementation code. + +Create an Anthropic LLM class named `AnthropicLargeLanguageModel` (or any other name), inheriting from the `__base.large_language_model.LargeLanguageModel` base class, and implement the following methods: + +- LLM Call + +Implement the core method for calling the LLM, supporting both streaming and synchronous responses. + +```python + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None) \ + -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ +``` + +Ensure to use two functions for returning data, one for synchronous returns and the other for streaming returns, because Python identifies functions containing the `yield` keyword as generator functions, fixing the return type to `Generator`. Thus, synchronous and streaming returns need to be implemented separately, as shown below (note that the example uses simplified parameters, for actual implementation follow the above parameter list): + +```python + def _invoke(self, stream: bool, **kwargs) \ + -> Union[LLMResult, Generator]: + if stream: + return self._handle_stream_response(**kwargs) + return self._handle_sync_response(**kwargs) + + def _handle_stream_response(self, **kwargs) -> Generator: + for chunk in response: + yield chunk + def _handle_sync_response(self, **kwargs) -> LLMResult: + return LLMResult(**response) +``` + +- Pre-compute Input Tokens + +If the model does not provide an interface to precompute tokens, return 0 directly. + +```python + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return: + """ +``` + +- Validate Model Credentials + +Similar to vendor credential validation, but specific to a single model. + +```python + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ +``` + +- Map Invoke Errors + +When a model call fails, map it to a specific `InvokeError` type as required by Runtime, allowing Dify to handle different errors accordingly. + +Runtime Errors: + +- `InvokeConnectionError` Connection error + +- `InvokeServerUnavailableError` Service provider unavailable +- `InvokeRateLimitError` Rate limit reached +- `InvokeAuthorizationError` Authorization failed +- `InvokeBadRequestError` Parameter error + +```python + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ +``` + +For interface method explanations, see: [Interfaces](./interfaces.md). For detailed implementation, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py). \ No newline at end of file diff --git a/api/core/model_runtime/docs/en_US/provider_scale_out.md b/api/core/model_runtime/docs/en_US/provider_scale_out.md index ba356c5cab63d0..07be5811d30137 100644 --- a/api/core/model_runtime/docs/en_US/provider_scale_out.md +++ b/api/core/model_runtime/docs/en_US/provider_scale_out.md @@ -58,7 +58,7 @@ provider_credential_schema: # Provider credential rules, as Anthropic only supp en_US: Enter your API URL ``` -You can also refer to the YAML configuration information under other provider directories in `model_providers`. The complete YAML rules are available at: [Schema](schema.md#Provider). +You can also refer to the YAML configuration information under other provider directories in `model_providers`. The complete YAML rules are available at: [Schema](schema.md#provider). ### Implementing Provider Code diff --git a/api/core/model_runtime/docs/en_US/schema.md b/api/core/model_runtime/docs/en_US/schema.md index 67f4e0879dc01c..f819a4dbdcf0ad 100644 --- a/api/core/model_runtime/docs/en_US/schema.md +++ b/api/core/model_runtime/docs/en_US/schema.md @@ -52,7 +52,7 @@ - `mode` (string) voice model.(available for model type `tts`) - `name` (string) voice model display name.(available for model type `tts`) - `language` (string) the voice model supports languages.(available for model type `tts`) - - `word_limit` (int) Single conversion word limit, paragraphwise by default(available for model type `tts`) + - `word_limit` (int) Single conversion word limit, paragraph-wise by default(available for model type `tts`) - `audio_type` (string) Support audio file extension format, e.g.:mp3,wav(available for model type `tts`) - `max_workers` (int) Number of concurrent workers supporting text and audio conversion(available for model type`tts`) - `max_characters_per_chunk` (int) Maximum characters per chunk (available for model type `moderation`) @@ -150,7 +150,7 @@ - `input` (float) Input price, i.e., Prompt price - `output` (float) Output price, i.e., returned content price -- `unit` (float) Pricing unit, e.g., if the price is meausred in 1M tokens, the corresponding token amount for the unit price is `0.000001`. +- `unit` (float) Pricing unit, e.g., if the price is measured in 1M tokens, the corresponding token amount for the unit price is `0.000001`. - `currency` (string) Currency unit ### ProviderCredentialSchema diff --git a/api/core/model_runtime/docs/zh_Hans/customizable_model_scale_out.md b/api/core/model_runtime/docs/zh_Hans/customizable_model_scale_out.md index 7b3a8edba3703d..240f65802b8cb2 100644 --- a/api/core/model_runtime/docs/zh_Hans/customizable_model_scale_out.md +++ b/api/core/model_runtime/docs/zh_Hans/customizable_model_scale_out.md @@ -205,7 +205,7 @@ provider_credential_schema: 但是有的供应商根据不同的模型支持不同的参数,如供应商`OpenLLM`支持`top_k`,但是并不是这个供应商提供的所有模型都支持`top_k`,我们这里举例A模型支持`top_k`,B模型不支持`top_k`,那么我们需要在这里动态生成模型参数的Schema,如下所示: ```python - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ used to define customizable model schema """ diff --git a/api/core/model_runtime/docs/zh_Hans/predefined_model_scale_out.md b/api/core/model_runtime/docs/zh_Hans/predefined_model_scale_out.md index 56f379a92fdd32..17fc088a63a92e 100644 --- a/api/core/model_runtime/docs/zh_Hans/predefined_model_scale_out.md +++ b/api/core/model_runtime/docs/zh_Hans/predefined_model_scale_out.md @@ -62,7 +62,7 @@ pricing: # 价格信息 建议将所有模型配置都准备完毕后再开始模型代码的实现。 -同样,也可以参考 `model_providers` 目录下其他供应商对应模型类型目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#AIModel)。 +同样,也可以参考 `model_providers` 目录下其他供应商对应模型类型目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#aimodelentity)。 ### 实现模型调用代码 diff --git a/api/core/model_runtime/docs/zh_Hans/provider_scale_out.md b/api/core/model_runtime/docs/zh_Hans/provider_scale_out.md index b34544c789fa76..78aad8876f4b84 100644 --- a/api/core/model_runtime/docs/zh_Hans/provider_scale_out.md +++ b/api/core/model_runtime/docs/zh_Hans/provider_scale_out.md @@ -117,7 +117,7 @@ model_credential_schema: en_US: Enter your API Base ``` -也可以参考 `model_providers` 目录下其他供应商目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#Provider)。 +也可以参考 `model_providers` 目录下其他供应商目录下的 YAML 配置信息,完整的 YAML 规则见:[Schema](schema.md#provider)。 #### 实现供应商代码 diff --git a/api/core/model_runtime/entities/__init__.py b/api/core/model_runtime/entities/__init__.py index e69de29bb2d1d6..f5d4427e3e7a72 100644 --- a/api/core/model_runtime/entities/__init__.py +++ b/api/core/model_runtime/entities/__init__.py @@ -0,0 +1,40 @@ +from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from .message_entities import ( + AssistantPromptMessage, + AudioPromptMessageContent, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContent, + PromptMessageContentType, + PromptMessageRole, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, + VideoPromptMessageContent, +) +from .model_entities import ModelPropertyKey + +__all__ = [ + "ImagePromptMessageContent", + "VideoPromptMessageContent", + "PromptMessage", + "PromptMessageRole", + "LLMUsage", + "ModelPropertyKey", + "AssistantPromptMessage", + "PromptMessage", + "PromptMessageContent", + "PromptMessageRole", + "SystemPromptMessage", + "TextPromptMessageContent", + "UserPromptMessage", + "PromptMessageTool", + "ToolPromptMessage", + "PromptMessageContentType", + "LLMResult", + "LLMResultChunk", + "LLMResultChunkDelta", + "AudioPromptMessageContent", +] diff --git a/api/core/model_runtime/entities/common_entities.py b/api/core/model_runtime/entities/common_entities.py index 175c13cfdcc04c..659ad59bd67f91 100644 --- a/api/core/model_runtime/entities/common_entities.py +++ b/api/core/model_runtime/entities/common_entities.py @@ -7,6 +7,7 @@ class I18nObject(BaseModel): """ Model class for i18n object. """ + zh_Hans: Optional[str] = None en_US: str diff --git a/api/core/model_runtime/entities/defaults.py b/api/core/model_runtime/entities/defaults.py index d2076bf74a3cde..4d0c9aa08f7337 100644 --- a/api/core/model_runtime/entities/defaults.py +++ b/api/core/model_runtime/entities/defaults.py @@ -2,107 +2,129 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { DefaultParameterName.TEMPERATURE: { - 'label': { - 'en_US': 'Temperature', - 'zh_Hans': '温度', - }, - 'type': 'float', - 'help': { - 'en_US': 'Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.', - 'zh_Hans': '温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。', - }, - 'required': False, - 'default': 0.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "label": { + "en_US": "Temperature", + "zh_Hans": "温度", + }, + "type": "float", + "help": { + "en_US": "Controls randomness. Lower temperature results in less random completions." + " As the temperature approaches zero, the model will become deterministic and repetitive." + " Higher temperature results in more random completions.", + "zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。" + "较高的温度会导致更多的随机完成。", + }, + "required": False, + "default": 0.0, + "min": 0.0, + "max": 1.0, + "precision": 2, }, DefaultParameterName.TOP_P: { - 'label': { - 'en_US': 'Top P', - 'zh_Hans': 'Top P', - }, - 'type': 'float', - 'help': { - 'en_US': 'Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.', - 'zh_Hans': '通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。', - }, - 'required': False, - 'default': 1.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "label": { + "en_US": "Top P", + "zh_Hans": "Top P", + }, + "type": "float", + "help": { + "en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options" + " are considered.", + "zh_Hans": "通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。", + }, + "required": False, + "default": 1.0, + "min": 0.0, + "max": 1.0, + "precision": 2, + }, + DefaultParameterName.TOP_K: { + "label": { + "en_US": "Top K", + "zh_Hans": "Top K", + }, + "type": "int", + "help": { + "en_US": "Limits the number of tokens to consider for each step by keeping only the k most likely tokens.", + "zh_Hans": "通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。", + }, + "required": False, + "default": 50, + "min": 1, + "max": 100, + "precision": 0, }, DefaultParameterName.PRESENCE_PENALTY: { - 'label': { - 'en_US': 'Presence Penalty', - 'zh_Hans': '存在惩罚', - }, - 'type': 'float', - 'help': { - 'en_US': 'Applies a penalty to the log-probability of tokens already in the text.', - 'zh_Hans': '对文本中已有的标记的对数概率施加惩罚。', - }, - 'required': False, - 'default': 0.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "label": { + "en_US": "Presence Penalty", + "zh_Hans": "存在惩罚", + }, + "type": "float", + "help": { + "en_US": "Applies a penalty to the log-probability of tokens already in the text.", + "zh_Hans": "对文本中已有的标记的对数概率施加惩罚。", + }, + "required": False, + "default": 0.0, + "min": 0.0, + "max": 1.0, + "precision": 2, }, DefaultParameterName.FREQUENCY_PENALTY: { - 'label': { - 'en_US': 'Frequency Penalty', - 'zh_Hans': '频率惩罚', - }, - 'type': 'float', - 'help': { - 'en_US': 'Applies a penalty to the log-probability of tokens that appear in the text.', - 'zh_Hans': '对文本中出现的标记的对数概率施加惩罚。', - }, - 'required': False, - 'default': 0.0, - 'min': 0.0, - 'max': 1.0, - 'precision': 2, + "label": { + "en_US": "Frequency Penalty", + "zh_Hans": "频率惩罚", + }, + "type": "float", + "help": { + "en_US": "Applies a penalty to the log-probability of tokens that appear in the text.", + "zh_Hans": "对文本中出现的标记的对数概率施加惩罚。", + }, + "required": False, + "default": 0.0, + "min": 0.0, + "max": 1.0, + "precision": 2, }, DefaultParameterName.MAX_TOKENS: { - 'label': { - 'en_US': 'Max Tokens', - 'zh_Hans': '最大标记', - }, - 'type': 'int', - 'help': { - 'en_US': 'Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.', - 'zh_Hans': '指定生成结果长度的上限。如果生成结果截断,可以调大该参数。', - }, - 'required': False, - 'default': 64, - 'min': 1, - 'max': 2048, - 'precision': 0, + "label": { + "en_US": "Max Tokens", + "zh_Hans": "最大标记", + }, + "type": "int", + "help": { + "en_US": "Specifies the upper limit on the length of generated results." + " If the generated results are truncated, you can increase this parameter.", + "zh_Hans": "指定生成结果长度的上限。如果生成结果截断,可以调大该参数。", + }, + "required": False, + "default": 64, + "min": 1, + "max": 2048, + "precision": 0, }, DefaultParameterName.RESPONSE_FORMAT: { - 'label': { - 'en_US': 'Response Format', - 'zh_Hans': '回复格式', + "label": { + "en_US": "Response Format", + "zh_Hans": "回复格式", }, - 'type': 'string', - 'help': { - 'en_US': 'Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.', - 'zh_Hans': '设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等', + "type": "string", + "help": { + "en_US": "Set a response format, ensure the output from llm is a valid code block as possible," + " such as JSON, XML, etc.", + "zh_Hans": "设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等", }, - 'required': False, - 'options': ['JSON', 'XML'], + "required": False, + "options": ["JSON", "XML"], }, DefaultParameterName.JSON_SCHEMA: { - 'label': { - 'en_US': 'JSON Schema', + "label": { + "en_US": "JSON Schema", }, - 'type': 'text', - 'help': { - 'en_US': 'Set a response json schema will ensure LLM to adhere it.', - 'zh_Hans': '设置返回的json schema,llm将按照它返回', + "type": "text", + "help": { + "en_US": "Set a response json schema will ensure LLM to adhere it.", + "zh_Hans": "设置返回的json schema,llm将按照它返回", }, - 'required': False, + "required": False, }, } diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index b5bd9e267a0573..88531d8ae00037 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -12,11 +12,12 @@ class LLMMode(Enum): """ Enum class for large language model mode. """ + COMPLETION = "completion" CHAT = "chat" @classmethod - def value_of(cls, value: str) -> 'LLMMode': + def value_of(cls, value: str) -> "LLMMode": """ Get value of given mode. @@ -26,13 +27,14 @@ def value_of(cls, value: str) -> 'LLMMode': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") class LLMUsage(ModelUsage): """ Model class for llm usage. """ + prompt_tokens: int prompt_unit_price: Decimal prompt_price_unit: Decimal @@ -50,24 +52,60 @@ class LLMUsage(ModelUsage): def empty_usage(cls): return cls( prompt_tokens=0, - prompt_unit_price=Decimal('0.0'), - prompt_price_unit=Decimal('0.0'), - prompt_price=Decimal('0.0'), + prompt_unit_price=Decimal("0.0"), + prompt_price_unit=Decimal("0.0"), + prompt_price=Decimal("0.0"), completion_tokens=0, - completion_unit_price=Decimal('0.0'), - completion_price_unit=Decimal('0.0'), - completion_price=Decimal('0.0'), + completion_unit_price=Decimal("0.0"), + completion_price_unit=Decimal("0.0"), + completion_price=Decimal("0.0"), total_tokens=0, - total_price=Decimal('0.0'), - currency='USD', - latency=0.0 + total_price=Decimal("0.0"), + currency="USD", + latency=0.0, ) + def plus(self, other: "LLMUsage") -> "LLMUsage": + """ + Add two LLMUsage instances together. + + :param other: Another LLMUsage instance to add + :return: A new LLMUsage instance with summed values + """ + if self.total_tokens == 0: + return other + else: + return LLMUsage( + prompt_tokens=self.prompt_tokens + other.prompt_tokens, + prompt_unit_price=other.prompt_unit_price, + prompt_price_unit=other.prompt_price_unit, + prompt_price=self.prompt_price + other.prompt_price, + completion_tokens=self.completion_tokens + other.completion_tokens, + completion_unit_price=other.completion_unit_price, + completion_price_unit=other.completion_price_unit, + completion_price=self.completion_price + other.completion_price, + total_tokens=self.total_tokens + other.total_tokens, + total_price=self.total_price + other.total_price, + currency=other.currency, + latency=self.latency + other.latency, + ) + + def __add__(self, other: "LLMUsage") -> "LLMUsage": + """ + Overload the + operator to add two LLMUsage instances. + + :param other: Another LLMUsage instance to add + :return: A new LLMUsage instance with summed values + """ + return self.plus(other) + class LLMResult(BaseModel): """ Model class for llm result. """ + + id: Optional[str] = None model: str prompt_messages: list[PromptMessage] message: AssistantPromptMessage @@ -79,6 +117,7 @@ class LLMResultChunkDelta(BaseModel): """ Model class for llm result chunk delta. """ + index: int message: AssistantPromptMessage usage: Optional[LLMUsage] = None @@ -89,6 +128,7 @@ class LLMResultChunk(BaseModel): """ Model class for llm result chunk. """ + model: str prompt_messages: list[PromptMessage] system_fingerprint: Optional[str] = None @@ -99,4 +139,5 @@ class NumTokensResult(PriceInfo): """ Model class for number of tokens result. """ + tokens: int diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index e8e6963b56d7a7..3c244d368ef78b 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -2,20 +2,21 @@ from enum import Enum from typing import Optional -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, Field, field_validator class PromptMessageRole(Enum): """ Enum class for prompt message. """ + SYSTEM = "system" USER = "user" ASSISTANT = "assistant" TOOL = "tool" @classmethod - def value_of(cls, value: str) -> 'PromptMessageRole': + def value_of(cls, value: str) -> "PromptMessageRole": """ Get value of given mode. @@ -25,13 +26,14 @@ def value_of(cls, value: str) -> 'PromptMessageRole': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid prompt message type value {value}') + raise ValueError(f"invalid prompt message type value {value}") class PromptMessageTool(BaseModel): """ Model class for prompt message tool. """ + name: str description: str parameters: dict @@ -41,7 +43,8 @@ class PromptMessageFunction(BaseModel): """ Model class for prompt message function. """ - type: str = 'function' + + type: str = "function" function: PromptMessageTool @@ -49,14 +52,18 @@ class PromptMessageContentType(Enum): """ Enum class for prompt message content type. """ - TEXT = 'text' - IMAGE = 'image' + + TEXT = "text" + IMAGE = "image" + AUDIO = "audio" + VIDEO = "video" class PromptMessageContent(BaseModel): """ Model class for prompt message content. """ + type: PromptMessageContentType data: str @@ -65,16 +72,30 @@ class TextPromptMessageContent(PromptMessageContent): """ Model class for text prompt message content. """ + type: PromptMessageContentType = PromptMessageContentType.TEXT +class VideoPromptMessageContent(PromptMessageContent): + type: PromptMessageContentType = PromptMessageContentType.VIDEO + data: str = Field(..., description="Base64 encoded video data") + format: str = Field(..., description="Video format") + + +class AudioPromptMessageContent(PromptMessageContent): + type: PromptMessageContentType = PromptMessageContentType.AUDIO + data: str = Field(..., description="Base64 encoded audio data") + format: str = Field(..., description="Audio format") + + class ImagePromptMessageContent(PromptMessageContent): """ Model class for image prompt message content. """ - class DETAIL(Enum): - LOW = 'low' - HIGH = 'high' + + class DETAIL(str, Enum): + LOW = "low" + HIGH = "high" type: PromptMessageContentType = PromptMessageContentType.IMAGE detail: DETAIL = DETAIL.LOW @@ -84,6 +105,7 @@ class PromptMessage(ABC, BaseModel): """ Model class for prompt message. """ + role: PromptMessageRole content: Optional[str | list[PromptMessageContent]] = None name: Optional[str] = None @@ -101,6 +123,7 @@ class UserPromptMessage(PromptMessage): """ Model class for user prompt message. """ + role: PromptMessageRole = PromptMessageRole.USER @@ -108,14 +131,17 @@ class AssistantPromptMessage(PromptMessage): """ Model class for assistant prompt message. """ + class ToolCall(BaseModel): """ Model class for assistant prompt message tool call. """ + class ToolCallFunction(BaseModel): """ Model class for assistant prompt message tool call function. """ + name: str arguments: str @@ -123,7 +149,7 @@ class ToolCallFunction(BaseModel): type: str function: ToolCallFunction - @field_validator('id', mode='before') + @field_validator("id", mode="before") @classmethod def transform_id_to_str(cls, value) -> str: if not isinstance(value, str): @@ -145,10 +171,12 @@ def is_empty(self) -> bool: return True + class SystemPromptMessage(PromptMessage): """ Model class for system prompt message. """ + role: PromptMessageRole = PromptMessageRole.SYSTEM @@ -156,6 +184,7 @@ class ToolPromptMessage(PromptMessage): """ Model class for tool prompt message. """ + role: PromptMessageRole = PromptMessageRole.TOOL tool_call_id: str diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index c257ce63d27926..52ea787c3ad572 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -11,6 +11,7 @@ class ModelType(Enum): """ Enum class for model type. """ + LLM = "llm" TEXT_EMBEDDING = "text-embedding" RERANK = "rerank" @@ -26,22 +27,22 @@ def value_of(cls, origin_model_type: str) -> "ModelType": :return: model type """ - if origin_model_type == 'text-generation' or origin_model_type == cls.LLM.value: + if origin_model_type in {"text-generation", cls.LLM.value}: return cls.LLM - elif origin_model_type == 'embeddings' or origin_model_type == cls.TEXT_EMBEDDING.value: + elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}: return cls.TEXT_EMBEDDING - elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value: + elif origin_model_type in {"reranking", cls.RERANK.value}: return cls.RERANK - elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value: + elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}: return cls.SPEECH2TEXT - elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value: + elif origin_model_type in {"tts", cls.TTS.value}: return cls.TTS - elif origin_model_type == 'text2img' or origin_model_type == cls.TEXT2IMG.value: + elif origin_model_type in {"text2img", cls.TEXT2IMG.value}: return cls.TEXT2IMG elif origin_model_type == cls.MODERATION.value: return cls.MODERATION else: - raise ValueError(f'invalid origin model type {origin_model_type}') + raise ValueError(f"invalid origin model type {origin_model_type}") def to_origin_model_type(self) -> str: """ @@ -50,26 +51,28 @@ def to_origin_model_type(self) -> str: :return: origin model type """ if self == self.LLM: - return 'text-generation' + return "text-generation" elif self == self.TEXT_EMBEDDING: - return 'embeddings' + return "embeddings" elif self == self.RERANK: - return 'reranking' + return "reranking" elif self == self.SPEECH2TEXT: - return 'speech2text' + return "speech2text" elif self == self.TTS: - return 'tts' + return "tts" elif self == self.MODERATION: - return 'moderation' + return "moderation" elif self == self.TEXT2IMG: - return 'text2img' + return "text2img" else: - raise ValueError(f'invalid model type {self}') + raise ValueError(f"invalid model type {self}") + class FetchFrom(Enum): """ Enum class for fetch from. """ + PREDEFINED_MODEL = "predefined-model" CUSTOMIZABLE_MODEL = "customizable-model" @@ -78,6 +81,7 @@ class ModelFeature(Enum): """ Enum class for llm feature. """ + TOOL_CALL = "tool-call" MULTI_TOOL_CALL = "multi-tool-call" AGENT_THOUGHT = "agent-thought" @@ -85,12 +89,14 @@ class ModelFeature(Enum): STREAM_TOOL_CALL = "stream-tool-call" -class DefaultParameterName(Enum): +class DefaultParameterName(str, Enum): """ Enum class for parameter template variable. """ + TEMPERATURE = "temperature" TOP_P = "top_p" + TOP_K = "top_k" PRESENCE_PENALTY = "presence_penalty" FREQUENCY_PENALTY = "frequency_penalty" MAX_TOKENS = "max_tokens" @@ -98,7 +104,7 @@ class DefaultParameterName(Enum): JSON_SCHEMA = "json_schema" @classmethod - def value_of(cls, value: Any) -> 'DefaultParameterName': + def value_of(cls, value: Any) -> "DefaultParameterName": """ Get parameter name from value. @@ -108,13 +114,14 @@ def value_of(cls, value: Any) -> 'DefaultParameterName': for name in cls: if name.value == value: return name - raise ValueError(f'invalid parameter name {value}') + raise ValueError(f"invalid parameter name {value}") class ParameterType(Enum): """ Enum class for parameter type. """ + FLOAT = "float" INT = "int" STRING = "string" @@ -126,6 +133,7 @@ class ModelPropertyKey(Enum): """ Enum class for model property key. """ + MODE = "mode" CONTEXT_SIZE = "context_size" MAX_CHUNKS = "max_chunks" @@ -143,6 +151,7 @@ class ProviderModel(BaseModel): """ Model class for provider model. """ + model: str label: I18nObject model_type: ModelType @@ -157,6 +166,7 @@ class ParameterRule(BaseModel): """ Model class for parameter rule. """ + name: str use_template: Optional[str] = None label: I18nObject @@ -174,6 +184,7 @@ class PriceConfig(BaseModel): """ Model class for pricing info. """ + input: Decimal output: Optional[Decimal] = None unit: Decimal @@ -184,6 +195,7 @@ class AIModelEntity(ProviderModel): """ Model class for AI model. """ + parameter_rules: list[ParameterRule] = [] pricing: Optional[PriceConfig] = None @@ -196,6 +208,7 @@ class PriceType(Enum): """ Enum class for price type. """ + INPUT = "input" OUTPUT = "output" @@ -204,6 +217,7 @@ class PriceInfo(BaseModel): """ Model class for price info. """ + unit_price: Decimal unit: Decimal total_amount: Decimal diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index f88f89d5886332..bfe861a97ffbf8 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -12,6 +12,7 @@ class ConfigurateMethod(Enum): """ Enum class for configurate method of provider model. """ + PREDEFINED_MODEL = "predefined-model" CUSTOMIZABLE_MODEL = "customizable-model" @@ -20,6 +21,7 @@ class FormType(Enum): """ Enum class for form type. """ + TEXT_INPUT = "text-input" SECRET_INPUT = "secret-input" SELECT = "select" @@ -31,6 +33,7 @@ class FormShowOnObject(BaseModel): """ Model class for form show on. """ + variable: str value: str @@ -39,6 +42,7 @@ class FormOption(BaseModel): """ Model class for form option. """ + label: I18nObject value: str show_on: list[FormShowOnObject] = [] @@ -46,15 +50,14 @@ class FormOption(BaseModel): def __init__(self, **data): super().__init__(**data) if not self.label: - self.label = I18nObject( - en_US=self.value - ) + self.label = I18nObject(en_US=self.value) class CredentialFormSchema(BaseModel): """ Model class for credential form schema. """ + variable: str label: I18nObject type: FormType @@ -70,6 +73,7 @@ class ProviderCredentialSchema(BaseModel): """ Model class for provider credential schema. """ + credential_form_schemas: list[CredentialFormSchema] @@ -82,6 +86,7 @@ class ModelCredentialSchema(BaseModel): """ Model class for model credential schema. """ + model: FieldModelSchema credential_form_schemas: list[CredentialFormSchema] @@ -90,6 +95,7 @@ class SimpleProviderEntity(BaseModel): """ Simple model class for provider. """ + provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -102,6 +108,7 @@ class ProviderHelpEntity(BaseModel): """ Model class for provider help. """ + title: I18nObject url: I18nObject @@ -110,6 +117,7 @@ class ProviderEntity(BaseModel): """ Model class for provider. """ + provider: str label: I18nObject description: Optional[I18nObject] = None @@ -138,7 +146,7 @@ def to_simple_provider(self) -> SimpleProviderEntity: icon_small=self.icon_small, icon_large=self.icon_large, supported_model_types=self.supported_model_types, - models=self.models + models=self.models, ) @@ -146,5 +154,6 @@ class ProviderConfig(BaseModel): """ Model class for provider config. """ + provider: str credentials: dict diff --git a/api/core/model_runtime/entities/rerank_entities.py b/api/core/model_runtime/entities/rerank_entities.py index d51efd2b3be133..99709e1bcd2127 100644 --- a/api/core/model_runtime/entities/rerank_entities.py +++ b/api/core/model_runtime/entities/rerank_entities.py @@ -5,6 +5,7 @@ class RerankDocument(BaseModel): """ Model class for rerank document. """ + index: int text: str score: float @@ -14,5 +15,6 @@ class RerankResult(BaseModel): """ Model class for rerank result. """ + model: str docs: list[RerankDocument] diff --git a/api/core/model_runtime/entities/text_embedding_entities.py b/api/core/model_runtime/entities/text_embedding_entities.py index 7be3def3791333..846b89d6580b18 100644 --- a/api/core/model_runtime/entities/text_embedding_entities.py +++ b/api/core/model_runtime/entities/text_embedding_entities.py @@ -9,6 +9,7 @@ class EmbeddingUsage(ModelUsage): """ Model class for embedding usage. """ + tokens: int total_tokens: int unit_price: Decimal @@ -22,7 +23,7 @@ class TextEmbeddingResult(BaseModel): """ Model class for text embedding result. """ + model: str embeddings: list[list[float]] usage: EmbeddingUsage - diff --git a/api/core/model_runtime/errors/invoke.py b/api/core/model_runtime/errors/invoke.py index 0513cfaf67b216..edfb19c7d07d4c 100644 --- a/api/core/model_runtime/errors/invoke.py +++ b/api/core/model_runtime/errors/invoke.py @@ -3,6 +3,7 @@ class InvokeError(Exception): """Base class for all LLM exceptions.""" + description: Optional[str] = None def __init__(self, description: Optional[str] = None) -> None: @@ -14,24 +15,29 @@ def __str__(self): class InvokeConnectionError(InvokeError): """Raised when the Invoke returns connection error.""" + description = "Connection Error" class InvokeServerUnavailableError(InvokeError): """Raised when the Invoke returns server unavailable error.""" + description = "Server Unavailable Error" class InvokeRateLimitError(InvokeError): """Raised when the Invoke returns rate limit error.""" + description = "Rate Limit Error" class InvokeAuthorizationError(InvokeError): """Raised when the Invoke returns authorization error.""" + description = "Incorrect model credentials provided, please check and try again. " class InvokeBadRequestError(InvokeError): """Raised when the Invoke returns bad request.""" + description = "Bad Request Error" diff --git a/api/core/model_runtime/errors/validate.py b/api/core/model_runtime/errors/validate.py index 8db79a52bb612a..7fcd2133f9f8d1 100644 --- a/api/core/model_runtime/errors/validate.py +++ b/api/core/model_runtime/errors/validate.py @@ -2,4 +2,5 @@ class CredentialsValidateFailedError(Exception): """ Credentials validate failed error """ + pass diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 0de216bf896fc2..79a1d28ebe637e 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -66,12 +66,16 @@ def _transform_invoke_error(self, error: Exception) -> InvokeError: :param error: model invoke error :return: unified error """ - provider_name = self.__class__.__module__.split('.')[-3] + provider_name = self.__class__.__module__.split(".")[-3] for invoke_error, model_errors in self._invoke_error_mapping.items(): if isinstance(error, tuple(model_errors)): if invoke_error == InvokeAuthorizationError: - return invoke_error(description=f"[{provider_name}] Incorrect model credentials provided, please check and try again. ") + return invoke_error( + description=( + f"[{provider_name}] Incorrect model credentials provided, please check and try again." + ) + ) return invoke_error(description=f"[{provider_name}] {invoke_error.description}, {str(error)}") @@ -115,7 +119,7 @@ def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens if not price_config: raise ValueError(f"Price config not found for model {model}") total_amount = tokens * unit_price * price_config.unit - total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) + total_amount = total_amount.quantize(decimal.Decimal("0.0000001"), rounding=decimal.ROUND_HALF_UP) return PriceInfo( unit_price=unit_price, @@ -136,24 +140,26 @@ def predefined_models(self) -> list[AIModelEntity]: model_schemas = [] # get module name - model_type = self.__class__.__module__.split('.')[-1] + model_type = self.__class__.__module__.split(".")[-1] # get provider name - provider_name = self.__class__.__module__.split('.')[-3] + provider_name = self.__class__.__module__.split(".")[-3] # get the path of current classes current_path = os.path.abspath(__file__) # get parent path of the current path - provider_model_type_path = os.path.join(os.path.dirname(os.path.dirname(current_path)), provider_name, model_type) + provider_model_type_path = os.path.join( + os.path.dirname(os.path.dirname(current_path)), provider_name, model_type + ) # get all yaml files path under provider_model_type_path that do not start with __ model_schema_yaml_paths = [ os.path.join(provider_model_type_path, model_schema_yaml) for model_schema_yaml in os.listdir(provider_model_type_path) - if not model_schema_yaml.startswith('__') - and not model_schema_yaml.startswith('_') - and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml)) - and model_schema_yaml.endswith('.yaml') + if not model_schema_yaml.startswith("__") + and not model_schema_yaml.startswith("_") + and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml)) + and model_schema_yaml.endswith(".yaml") ] # get _position.yaml file path @@ -165,10 +171,10 @@ def predefined_models(self) -> list[AIModelEntity]: yaml_data = load_yaml_file(model_schema_yaml_path) new_parameter_rules = [] - for parameter_rule in yaml_data.get('parameter_rules', []): - if 'use_template' in parameter_rule: + for parameter_rule in yaml_data.get("parameter_rules", []): + if "use_template" in parameter_rule: try: - default_parameter_name = DefaultParameterName.value_of(parameter_rule['use_template']) + default_parameter_name = DefaultParameterName.value_of(parameter_rule["use_template"]) default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name) copy_default_parameter_rule = default_parameter_rule.copy() copy_default_parameter_rule.update(parameter_rule) @@ -176,31 +182,26 @@ def predefined_models(self) -> list[AIModelEntity]: except ValueError: pass - if 'label' not in parameter_rule: - parameter_rule['label'] = { - 'zh_Hans': parameter_rule['name'], - 'en_US': parameter_rule['name'] - } + if "label" not in parameter_rule: + parameter_rule["label"] = {"zh_Hans": parameter_rule["name"], "en_US": parameter_rule["name"]} new_parameter_rules.append(parameter_rule) - yaml_data['parameter_rules'] = new_parameter_rules + yaml_data["parameter_rules"] = new_parameter_rules - if 'label' not in yaml_data: - yaml_data['label'] = { - 'zh_Hans': yaml_data['model'], - 'en_US': yaml_data['model'] - } + if "label" not in yaml_data: + yaml_data["label"] = {"zh_Hans": yaml_data["model"], "en_US": yaml_data["model"]} - yaml_data['fetch_from'] = FetchFrom.PREDEFINED_MODEL.value + yaml_data["fetch_from"] = FetchFrom.PREDEFINED_MODEL.value try: # yaml_data to entity model_schema = AIModelEntity(**yaml_data) except Exception as e: model_schema_yaml_file_name = os.path.basename(model_schema_yaml_path).rstrip(".yaml") - raise Exception(f'Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}:' - f' {str(e)}') + raise Exception( + f"Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}: {str(e)}" + ) # cache model schema model_schemas.append(model_schema) @@ -235,7 +236,9 @@ def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> return None - def get_customizable_model_schema_from_credentials(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]: + def get_customizable_model_schema_from_credentials( + self, model: str, credentials: Mapping + ) -> Optional[AIModelEntity]: """ Get customizable model schema from credentials @@ -261,19 +264,19 @@ def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Op try: default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template) default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name) - if not parameter_rule.max and 'max' in default_parameter_rule: - parameter_rule.max = default_parameter_rule['max'] - if not parameter_rule.min and 'min' in default_parameter_rule: - parameter_rule.min = default_parameter_rule['min'] - if not parameter_rule.default and 'default' in default_parameter_rule: - parameter_rule.default = default_parameter_rule['default'] - if not parameter_rule.precision and 'precision' in default_parameter_rule: - parameter_rule.precision = default_parameter_rule['precision'] - if not parameter_rule.required and 'required' in default_parameter_rule: - parameter_rule.required = default_parameter_rule['required'] - if not parameter_rule.help and 'help' in default_parameter_rule: + if not parameter_rule.max and "max" in default_parameter_rule: + parameter_rule.max = default_parameter_rule["max"] + if not parameter_rule.min and "min" in default_parameter_rule: + parameter_rule.min = default_parameter_rule["min"] + if not parameter_rule.default and "default" in default_parameter_rule: + parameter_rule.default = default_parameter_rule["default"] + if not parameter_rule.precision and "precision" in default_parameter_rule: + parameter_rule.precision = default_parameter_rule["precision"] + if not parameter_rule.required and "required" in default_parameter_rule: + parameter_rule.required = default_parameter_rule["required"] + if not parameter_rule.help and "help" in default_parameter_rule: parameter_rule.help = I18nObject( - en_US=default_parameter_rule['help']['en_US'], + en_US=default_parameter_rule["help"]["en_US"], ) if ( parameter_rule.help diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 02ba0c9410f937..5b6f96129bde25 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -1,5 +1,4 @@ import logging -import os import re import time from abc import abstractmethod @@ -8,6 +7,7 @@ from pydantic import ConfigDict +from configs import dify_config from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.callbacks.logging_callback import LoggingCallback from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -35,16 +35,24 @@ class LargeLanguageModel(AIModel): """ Model class for large language model. """ + model_type: ModelType = ModelType.LLM # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \ - -> Union[LLMResult, Generator]: + def invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: Optional[dict] = None, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -69,7 +77,7 @@ def invoke(self, model: str, credentials: dict, callbacks = callbacks or [] - if bool(os.environ.get("DEBUG", 'False').lower() == 'true'): + if dify_config.DEBUG: callbacks.append(LoggingCallback()) # trigger before invoke callbacks @@ -82,11 +90,11 @@ def invoke(self, model: str, credentials: dict, stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) try: - if "response_format" in model_parameters: + if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}: result = self._code_block_mode_wrapper( model=model, credentials=credentials, @@ -96,10 +104,19 @@ def invoke(self, model: str, credentials: dict, stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) else: - result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + result = self._invoke( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ) except Exception as e: self._trigger_invoke_error_callbacks( model=model, @@ -111,7 +128,7 @@ def invoke(self, model: str, credentials: dict, stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) raise self._transform_invoke_error(e) @@ -127,7 +144,7 @@ def invoke(self, model: str, credentials: dict, stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) elif isinstance(result, LLMResult): self._trigger_after_invoke_callbacks( @@ -140,15 +157,23 @@ def invoke(self, model: str, credentials: dict, stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) return result - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: Optional[list[Callback]] = None) -> Union[LLMResult, Generator]: + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper, ensure the response is a code block with output markdown quote @@ -171,7 +196,7 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_message {{instructions}} -""" +""" # noqa: E501 code_block = model_parameters.get("response_format", "") if not code_block: @@ -183,9 +208,9 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_message tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - + model_parameters.pop("response_format") stop = stop or [] stop.extend(["\n```", "```\n"]) @@ -195,15 +220,16 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_message if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): # override the system message prompt_messages[0] = SystemPromptMessage( - content=block_prompts - .replace("{{instructions}}", str(prompt_messages[0].content)) + content=block_prompts.replace("{{instructions}}", str(prompt_messages[0].content)) ) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=block_prompts - .replace("{{instructions}}", f"Please output a valid {code_block} object.") - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=block_prompts.replace("{{instructions}}", f"Please output a valid {code_block} object.") + ), + ) if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): # add ```JSON\n to the last text message @@ -216,9 +242,7 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_message break else: # append a user message - prompt_messages.append(UserPromptMessage( - content=f"```{code_block}\n" - )) + prompt_messages.append(UserPromptMessage(content=f"```{code_block}\n")) response = self._invoke( model=model, @@ -228,33 +252,30 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_message tools=tools, stop=stop, stream=stream, - user=user + user=user, ) if isinstance(response, Generator): first_chunk = next(response) + def new_generator(): yield first_chunk yield from response if first_chunk.delta.message.content and first_chunk.delta.message.content.startswith("`"): return self._code_block_mode_stream_processor_with_backtick( - model=model, - prompt_messages=prompt_messages, - input_generator=new_generator() + model=model, prompt_messages=prompt_messages, input_generator=new_generator() ) else: return self._code_block_mode_stream_processor( - model=model, - prompt_messages=prompt_messages, - input_generator=new_generator() + model=model, prompt_messages=prompt_messages, input_generator=new_generator() ) - + return response - def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage], - input_generator: Generator[LLMResultChunk, None, None] - ) -> Generator[LLMResultChunk, None, None]: + def _code_block_mode_stream_processor( + self, model: str, prompt_messages: list[PromptMessage], input_generator: Generator[LLMResultChunk, None, None] + ) -> Generator[LLMResultChunk, None, None]: """ Code block mode stream processor, ensure the response is a code block with output markdown quote @@ -303,16 +324,13 @@ def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[Pr prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=new_piece, - tool_calls=[] - ), - ) + message=AssistantPromptMessage(content=new_piece, tool_calls=[]), + ), ) - def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list, - input_generator: Generator[LLMResultChunk, None, None]) \ - -> Generator[LLMResultChunk, None, None]: + def _code_block_mode_stream_processor_with_backtick( + self, model: str, prompt_messages: list, input_generator: Generator[LLMResultChunk, None, None] + ) -> Generator[LLMResultChunk, None, None]: """ Code block mode stream processor, ensure the response is a code block with output markdown quote. This version skips the language identifier that follows the opening triple backticks. @@ -378,18 +396,23 @@ def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_mes prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=new_piece, - tool_calls=[] - ), - ) + message=AssistantPromptMessage(content=new_piece, tool_calls=[]), + ), ) - def _invoke_result_generator(self, model: str, result: Generator, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> Generator: + def _invoke_result_generator( + self, + model: str, + result: Generator, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Generator: """ Invoke result generator @@ -397,9 +420,7 @@ def _invoke_result_generator(self, model: str, result: Generator, credentials: d :return: result generator """ callbacks = callbacks or [] - prompt_message = AssistantPromptMessage( - content="" - ) + prompt_message = AssistantPromptMessage(content="") usage = None system_fingerprint = None real_model = model @@ -418,7 +439,7 @@ def _invoke_result_generator(self, model: str, result: Generator, credentials: d stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) prompt_message.content += chunk.delta.message.content @@ -437,8 +458,8 @@ def _invoke_result_generator(self, model: str, result: Generator, credentials: d model=real_model, prompt_messages=prompt_messages, message=prompt_message, - usage=usage if usage else LLMUsage.empty_usage(), - system_fingerprint=system_fingerprint + usage=usage or LLMUsage.empty_usage(), + system_fingerprint=system_fingerprint, ), credentials=credentials, prompt_messages=prompt_messages, @@ -447,15 +468,21 @@ def _invoke_result_generator(self, model: str, result: Generator, credentials: d stop=stop, stream=stream, user=user, - callbacks=callbacks + callbacks=callbacks, ) @abstractmethod - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -470,10 +497,15 @@ def _invoke(self, model: str, credentials: dict, :return: full response or stream response chunk generator result """ raise NotImplementedError - + @abstractmethod - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -519,7 +551,9 @@ def get_model_mode(self, model: str, credentials: Optional[Mapping] = None) -> L return mode - def _calc_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage: + def _calc_response_usage( + self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int + ) -> LLMUsage: """ Calculate response usage @@ -539,10 +573,7 @@ def _calc_response_usage(self, model: str, credentials: dict, prompt_tokens: int # get completion price info completion_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.OUTPUT, - tokens=completion_tokens + model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens ) # transform usage @@ -558,16 +589,23 @@ def _calc_response_usage(self, model: str, credentials: dict, prompt_tokens: int total_tokens=prompt_tokens + completion_tokens, total_price=prompt_price_info.total_amount + completion_price_info.total_amount, currency=prompt_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage - def _trigger_before_invoke_callbacks(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: + def _trigger_before_invoke_callbacks( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> None: """ Trigger before invoke callbacks @@ -593,7 +631,7 @@ def _trigger_before_invoke_callbacks(self, model: str, credentials: dict, tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: @@ -601,11 +639,19 @@ def _trigger_before_invoke_callbacks(self, model: str, credentials: dict, else: logger.warning(f"Callback {callback.__class__.__name__} on_before_invoke failed with error {e}") - def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: + def _trigger_new_chunk_callbacks( + self, + chunk: LLMResultChunk, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> None: """ Trigger new chunk callbacks @@ -632,7 +678,7 @@ def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, creden tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: @@ -640,11 +686,19 @@ def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, creden else: logger.warning(f"Callback {callback.__class__.__name__} on_new_chunk failed with error {e}") - def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: + def _trigger_after_invoke_callbacks( + self, + model: str, + result: LLMResult, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> None: """ Trigger after invoke callbacks @@ -672,7 +726,7 @@ def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credent tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: @@ -680,11 +734,19 @@ def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credent else: logger.warning(f"Callback {callback.__class__.__name__} on_after_invoke failed with error {e}") - def _trigger_invoke_error_callbacks(self, model: str, ex: Exception, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: + def _trigger_invoke_error_callbacks( + self, + model: str, + ex: Exception, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> None: """ Trigger invoke error callbacks @@ -712,7 +774,7 @@ def _trigger_invoke_error_callbacks(self, model: str, ex: Exception, credentials tools=tools, stop=stop, stream=stream, - user=user + user=user, ) except Exception as e: if callback.raise_error: @@ -758,11 +820,13 @@ def _validate_and_filter_model_parameters(self, model: str, model_parameters: di # validate parameter value range if parameter_rule.min is not None and parameter_value < parameter_rule.min: raise ValueError( - f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.") + f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}." + ) if parameter_rule.max is not None and parameter_value > parameter_rule.max: raise ValueError( - f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.") + f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}." + ) elif parameter_rule.type == ParameterType.FLOAT: if not isinstance(parameter_value, float | int): raise ValueError(f"Model Parameter {parameter_name} should be float.") @@ -775,16 +839,20 @@ def _validate_and_filter_model_parameters(self, model: str, model_parameters: di else: if parameter_value != round(parameter_value, parameter_rule.precision): raise ValueError( - f"Model Parameter {parameter_name} should be round to {parameter_rule.precision} decimal places.") + f"Model Parameter {parameter_name} should be round to {parameter_rule.precision}" + f" decimal places." + ) # validate parameter value range if parameter_rule.min is not None and parameter_value < parameter_rule.min: raise ValueError( - f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}.") + f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}." + ) if parameter_rule.max is not None and parameter_value > parameter_rule.max: raise ValueError( - f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.") + f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}." + ) elif parameter_rule.type == ParameterType.BOOLEAN: if not isinstance(parameter_value, bool): raise ValueError(f"Model Parameter {parameter_name} should be bool.") @@ -792,6 +860,13 @@ def _validate_and_filter_model_parameters(self, model: str, model_parameters: di if not isinstance(parameter_value, str): raise ValueError(f"Model Parameter {parameter_name} should be string.") + # validate options + if parameter_rule.options and parameter_value not in parameter_rule.options: + raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.") + elif parameter_rule.type == ParameterType.TEXT: + if not isinstance(parameter_value, str): + raise ValueError(f"Model Parameter {parameter_name} should be text.") + # validate options if parameter_rule.options and parameter_value not in parameter_rule.options: raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.") diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py index 780460a3f738fe..4374093de4ab38 100644 --- a/api/core/model_runtime/model_providers/__base/model_provider.py +++ b/api/core/model_runtime/model_providers/__base/model_provider.py @@ -29,32 +29,32 @@ def validate_provider_credentials(self, credentials: dict) -> None: def get_provider_schema(self) -> ProviderEntity: """ Get provider schema - + :return: provider schema """ if self.provider_schema: return self.provider_schema - + # get dirname of the current path - provider_name = self.__class__.__module__.split('.')[-1] + provider_name = self.__class__.__module__.split(".")[-1] # get the path of the model_provider classes base_path = os.path.abspath(__file__) current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name) - + # read provider schema from yaml file - yaml_path = os.path.join(current_path, f'{provider_name}.yaml') + yaml_path = os.path.join(current_path, f"{provider_name}.yaml") yaml_data = load_yaml_file(yaml_path) - + try: # yaml_data to entity provider_schema = ProviderEntity(**yaml_data) except Exception as e: - raise Exception(f'Invalid provider schema for {provider_name}: {str(e)}') + raise Exception(f"Invalid provider schema for {provider_name}: {str(e)}") # cache schema self.provider_schema = provider_schema - + return provider_schema def models(self, model_type: ModelType) -> list[AIModelEntity]: @@ -92,15 +92,15 @@ def get_model_instance(self, model_type: ModelType) -> AIModel: # get the path of the model type classes base_path = os.path.abspath(__file__) - model_type_name = model_type.value.replace('-', '_') + model_type_name = model_type.value.replace("-", "_") model_type_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name, model_type_name) - model_type_py_path = os.path.join(model_type_path, f'{model_type_name}.py') + model_type_py_path = os.path.join(model_type_path, f"{model_type_name}.py") if not os.path.isdir(model_type_path) or not os.path.exists(model_type_py_path): - raise Exception(f'Invalid model type {model_type} for provider {provider_name}') + raise Exception(f"Invalid model type {model_type} for provider {provider_name}") # Dynamic loading {model_type_name}.py file and find the subclass of AIModel - parent_module = '.'.join(self.__class__.__module__.split('.')[:-1]) + parent_module = ".".join(self.__class__.__module__.split(".")[:-1]) mod = import_module_from_source( module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path ) diff --git a/api/core/model_runtime/model_providers/__base/moderation_model.py b/api/core/model_runtime/model_providers/__base/moderation_model.py index 2b17f292c5db00..d04414ccb87a63 100644 --- a/api/core/model_runtime/model_providers/__base/moderation_model.py +++ b/api/core/model_runtime/model_providers/__base/moderation_model.py @@ -12,14 +12,13 @@ class ModerationModel(AIModel): """ Model class for moderation model. """ + model_type: ModelType = ModelType.MODERATION # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, - text: str, user: Optional[str] = None) \ - -> bool: + def invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool: """ Invoke moderation model @@ -37,9 +36,7 @@ def invoke(self, model: str, credentials: dict, raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - text: str, user: Optional[str] = None) \ - -> bool: + def _invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool: """ Invoke large language model @@ -50,4 +47,3 @@ def _invoke(self, model: str, credentials: dict, :return: false if text is safe, true otherwise """ raise NotImplementedError - diff --git a/api/core/model_runtime/model_providers/__base/rerank_model.py b/api/core/model_runtime/model_providers/__base/rerank_model.py index 2c86f25180eab8..5fb96047425592 100644 --- a/api/core/model_runtime/model_providers/__base/rerank_model.py +++ b/api/core/model_runtime/model_providers/__base/rerank_model.py @@ -11,12 +11,19 @@ class RerankModel(AIModel): """ Base Model class for rerank model. """ + model_type: ModelType = ModelType.RERANK - def invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -37,10 +44,16 @@ def invoke(self, model: str, credentials: dict, raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model diff --git a/api/core/model_runtime/model_providers/__base/speech2text_model.py b/api/core/model_runtime/model_providers/__base/speech2text_model.py index 4fb11025fe07fd..b6b0b737436d9c 100644 --- a/api/core/model_runtime/model_providers/__base/speech2text_model.py +++ b/api/core/model_runtime/model_providers/__base/speech2text_model.py @@ -12,14 +12,13 @@ class Speech2TextModel(AIModel): """ Model class for speech2text model. """ + model_type: ModelType = ModelType.SPEECH2TEXT # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke large language model @@ -35,9 +34,7 @@ def invoke(self, model: str, credentials: dict, raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke large language model @@ -59,4 +56,4 @@ def _get_demo_file_path(self) -> str: current_dir = os.path.dirname(os.path.abspath(__file__)) # Construct the path to the audio file - return os.path.join(current_dir, 'audio.mp3') + return os.path.join(current_dir, "audio.mp3") diff --git a/api/core/model_runtime/model_providers/__base/text2img_model.py b/api/core/model_runtime/model_providers/__base/text2img_model.py index e0f1adb1c47f23..a5810e2f0e4b09 100644 --- a/api/core/model_runtime/model_providers/__base/text2img_model.py +++ b/api/core/model_runtime/model_providers/__base/text2img_model.py @@ -11,14 +11,15 @@ class Text2ImageModel(AIModel): """ Model class for text2img model. """ + model_type: ModelType = ModelType.TEXT2IMG # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, prompt: str, - model_parameters: dict, user: Optional[str] = None) \ - -> list[IO[bytes]]: + def invoke( + self, model: str, credentials: dict, prompt: str, model_parameters: dict, user: Optional[str] = None + ) -> list[IO[bytes]]: """ Invoke Text2Image model @@ -36,9 +37,9 @@ def invoke(self, model: str, credentials: dict, prompt: str, raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, prompt: str, - model_parameters: dict, user: Optional[str] = None) \ - -> list[IO[bytes]]: + def _invoke( + self, model: str, credentials: dict, prompt: str, model_parameters: dict, user: Optional[str] = None + ) -> list[IO[bytes]]: """ Invoke Text2Image model diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index 381d2f6cd19ed4..2d38fba955fb86 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -4,6 +4,7 @@ from pydantic import ConfigDict +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.model_providers.__base.ai_model import AIModel @@ -13,41 +14,54 @@ class TextEmbeddingModel(AIModel): """ Model class for text embedding model. """ + model_type: ModelType = ModelType.TEXT_EMBEDDING # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: """ - Invoke large language model + Invoke text embedding model :param model: model name :param credentials: model credentials :param texts: texts to embed :param user: unique user id + :param input_type: input type :return: embeddings result """ self.started_at = time.perf_counter() try: - return self._invoke(model, credentials, texts, user) + return self._invoke(model, credentials, texts, user, input_type) except Exception as e: raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: """ - Invoke large language model + Invoke text embedding model :param model: model name :param credentials: model credentials :param texts: texts to embed :param user: unique user id + :param input_type: input type :return: embeddings result """ raise NotImplementedError diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py index 6059b3f5619685..5fe6dda6ad5d79 100644 --- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py +++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py @@ -7,27 +7,28 @@ _tokenizer = None _lock = Lock() + class GPT2Tokenizer: @staticmethod def _get_num_tokens_by_gpt2(text: str) -> int: """ - use gpt2 tokenizer to get num tokens + use gpt2 tokenizer to get num tokens """ _tokenizer = GPT2Tokenizer.get_encoder() tokens = _tokenizer.encode(text, verbose=False) return len(tokens) - + @staticmethod def get_num_tokens(text: str) -> int: return GPT2Tokenizer._get_num_tokens_by_gpt2(text) - + @staticmethod def get_encoder() -> Any: global _tokenizer, _lock with _lock: if _tokenizer is None: base_path = abspath(__file__) - gpt2_tokenizer_path = join(dirname(base_path), 'gpt2') + gpt2_tokenizer_path = join(dirname(base_path), "gpt2") _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path) - return _tokenizer \ No newline at end of file + return _tokenizer diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py index 64e85d2c119ee8..b394ea4e9d22fe 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/core/model_runtime/model_providers/__base/tts_model.py @@ -1,7 +1,8 @@ import logging import re from abc import abstractmethod -from typing import Optional +from collections.abc import Iterable +from typing import Any, Optional from pydantic import ConfigDict @@ -13,15 +14,23 @@ class TTSModel(AIModel): """ - Model class for ttstext model. + Model class for TTS model. """ + model_type: ModelType = ModelType.TTS # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, - user: Optional[str] = None): + def invoke( + self, + model: str, + tenant_id: str, + credentials: dict, + content_text: str, + voice: str, + user: Optional[str] = None, + ) -> Iterable[bytes]: """ Invoke large language model @@ -35,14 +44,27 @@ def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: st :return: translated audio file """ try: - return self._invoke(model=model, credentials=credentials, user=user, - content_text=content_text, voice=voice, tenant_id=tenant_id) + return self._invoke( + model=model, + credentials=credentials, + user=user, + content_text=content_text, + voice=voice, + tenant_id=tenant_id, + ) except Exception as e: raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, - user: Optional[str] = None): + def _invoke( + self, + model: str, + tenant_id: str, + credentials: dict, + content_text: str, + voice: str, + user: Optional[str] = None, + ) -> Iterable[bytes]: """ Invoke large language model @@ -59,24 +81,27 @@ def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: s def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: """ - Get voice for given tts model voices + Retrieves the list of voices supported by a given text-to-speech (TTS) model. - :param language: tts language - :param model: model name - :param credentials: model credentials - :return: voices lists + :param language: The language for which the voices are requested. + :param model: The name of the TTS model. + :param credentials: The credentials required to access the TTS model. + :return: A list of voices supported by the TTS model. """ model_schema = self.get_model_schema(model, credentials) - if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties: - voices = model_schema.model_properties[ModelPropertyKey.VOICES] - if language: - return [{'name': d['name'], 'value': d['mode']} for d in voices if - language and language in d.get('language')] - else: - return [{'name': d['name'], 'value': d['mode']} for d in voices] + if not model_schema or ModelPropertyKey.VOICES not in model_schema.model_properties: + raise ValueError("this model does not support voice") + + voices = model_schema.model_properties[ModelPropertyKey.VOICES] + if language: + return [ + {"name": d["name"], "value": d["mode"]} for d in voices if language and language in d.get("language") + ] + else: + return [{"name": d["name"], "value": d["mode"]} for d in voices] - def _get_model_default_voice(self, model: str, credentials: dict) -> any: + def _get_model_default_voice(self, model: str, credentials: dict) -> Any: """ Get voice for given tts model @@ -99,8 +124,10 @@ def _get_model_audio_type(self, model: str, credentials: dict) -> str: """ model_schema = self.get_model_schema(model, credentials) - if model_schema and ModelPropertyKey.AUDIO_TYPE in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE] + if not model_schema or ModelPropertyKey.AUDIO_TYPE not in model_schema.model_properties: + raise ValueError("this model does not support audio type") + + return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE] def _get_model_word_limit(self, model: str, credentials: dict) -> int: """ @@ -109,8 +136,10 @@ def _get_model_word_limit(self, model: str, credentials: dict) -> int: """ model_schema = self.get_model_schema(model, credentials) - if model_schema and ModelPropertyKey.WORD_LIMIT in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT] + if not model_schema or ModelPropertyKey.WORD_LIMIT not in model_schema.model_properties: + raise ValueError("this model does not support word limit") + + return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT] def _get_model_workers_limit(self, model: str, credentials: dict) -> int: """ @@ -119,27 +148,29 @@ def _get_model_workers_limit(self, model: str, credentials: dict) -> int: """ model_schema = self.get_model_schema(model, credentials) - if model_schema and ModelPropertyKey.MAX_WORKERS in model_schema.model_properties: - return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS] + if not model_schema or ModelPropertyKey.MAX_WORKERS not in model_schema.model_properties: + raise ValueError("this model does not support max workers") + + return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS] @staticmethod - def _split_text_into_sentences(org_text, max_length=2000, pattern=r'[。.!?]'): + def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"): match = re.compile(pattern) tx = match.finditer(org_text) start = 0 result = [] - one_sentence = '' + one_sentence = "" for i in tx: end = i.regs[0][1] tmp = org_text[start:end] if len(one_sentence + tmp) > max_length: result.append(one_sentence) - one_sentence = '' + one_sentence = "" one_sentence += tmp start = end last_sens = org_text[start:] if last_sens: one_sentence += last_sens - if one_sentence != '': + if one_sentence != "": result.append(one_sentence) return result diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml index d10314ba039e63..89fccef6598fdd 100644 --- a/api/core/model_runtime/model_providers/_position.yaml +++ b/api/core/model_runtime/model_providers/_position.yaml @@ -37,3 +37,7 @@ - siliconflow - perfxcloud - zhinao +- fireworks +- mixedbread +- nomic +- voyage diff --git a/api/core/model_runtime/model_providers/anthropic/anthropic.py b/api/core/model_runtime/model_providers/anthropic/anthropic.py index 00a6bbce3b563a..5b12f04a3e59b8 100644 --- a/api/core/model_runtime/model_providers/anthropic/anthropic.py +++ b/api/core/model_runtime/model_providers/anthropic/anthropic.py @@ -19,13 +19,10 @@ def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - # Use `claude-instant-1` model for validate, - model_instance.validate_credentials( - model='claude-instant-1.2', - credentials=credentials - ) + # Use `claude-3-opus-20240229` model for validate, + model_instance.validate_credentials(model="claude-3-opus-20240229", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/anthropic/llm/_position.yaml b/api/core/model_runtime/model_providers/anthropic/llm/_position.yaml index 8394c4276a786e..b7b28a70d46afe 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/_position.yaml @@ -1,3 +1,5 @@ +- claude-3-5-haiku-20241022 +- claude-3-5-sonnet-20241022 - claude-3-5-sonnet-20240620 - claude-3-haiku-20240307 - claude-3-opus-20240229 diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-haiku-20241022.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-haiku-20241022.yaml new file mode 100644 index 00000000000000..892146f6a57fe9 --- /dev/null +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-haiku-20241022.yaml @@ -0,0 +1,38 @@ +model: claude-3-5-haiku-20241022 +label: + en_US: claude-3-5-haiku-20241022 +model_type: llm +features: + - agent-thought + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 + - name: response_format + use_template: response_format +pricing: + input: '1.00' + output: '5.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml new file mode 100644 index 00000000000000..e20b8c4960734c --- /dev/null +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml @@ -0,0 +1,39 @@ +model: claude-3-5-sonnet-20241022 +label: + en_US: claude-3-5-sonnet-20241022 +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 + - name: response_format + use_template: response_format +pricing: + input: '3.00' + output: '15.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.2.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.2.yaml index 929a7f8725ec3a..ac69bbf4d293d3 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.2.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.2.yaml @@ -33,3 +33,4 @@ pricing: output: '5.51' unit: '0.000001' currency: USD +deprecated: true diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 19ce401999c50f..3a5a42ba05b44b 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -1,6 +1,6 @@ import base64 +import io import json -import mimetypes from collections.abc import Generator from typing import Optional, Union, cast @@ -18,6 +18,7 @@ ) from anthropic.types.beta.tools import ToolsBetaMessage from httpx import Timeout +from PIL import Image from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta @@ -50,15 +51,21 @@ {{instructions}} -""" +""" # noqa: E501 class AnthropicLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -75,10 +82,17 @@ def _invoke(self, model: str, credentials: dict, # invoke model return self._chat_generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -95,41 +109,39 @@ def _chat_generate(self, model: str, credentials: dict, credentials_kwargs = self._to_credential_kwargs(credentials) # transform model parameters from completion api of anthropic to chat api - if 'max_tokens_to_sample' in model_parameters: - model_parameters['max_tokens'] = model_parameters.pop('max_tokens_to_sample') + if "max_tokens_to_sample" in model_parameters: + model_parameters["max_tokens"] = model_parameters.pop("max_tokens_to_sample") # init model client client = Anthropic(**credentials_kwargs) extra_model_kwargs = {} if stop: - extra_model_kwargs['stop_sequences'] = stop + extra_model_kwargs["stop_sequences"] = stop if user: - extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user) + extra_model_kwargs["metadata"] = completion_create_params.Metadata(user_id=user) system, prompt_message_dicts = self._convert_prompt_messages(prompt_messages) if system: - extra_model_kwargs['system'] = system + extra_model_kwargs["system"] = system # Add the new header for claude-3-5-sonnet-20240620 model extra_headers = {} if model == "claude-3-5-sonnet-20240620": - if model_parameters.get('max_tokens') > 4096: + if model_parameters.get("max_tokens") > 4096: extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15" if tools: - extra_model_kwargs['tools'] = [ - self._transform_tool_prompt(tool) for tool in tools - ] + extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools] response = client.beta.tools.messages.create( model=model, messages=prompt_message_dicts, stream=stream, extra_headers=extra_headers, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) else: # chat model @@ -139,22 +151,30 @@ def _chat_generate(self, model: str, credentials: dict, stream=stream, extra_headers=extra_headers, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) if stream: return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages) return self._handle_chat_generate_response(model, credentials, response, prompt_messages) - - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ - if model_parameters.get('response_format'): + if model_parameters.get("response_format"): stop = stop or [] # chat model self._transform_chat_json_prompts( @@ -166,24 +186,27 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_message stop=stop, stream=stream, user=user, - response_format=model_parameters['response_format'] + response_format=model_parameters["response_format"], ) - model_parameters.pop('response_format') + model_parameters.pop("response_format") return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def _transform_tool_prompt(self, tool: PromptMessageTool) -> dict: - return { - 'name': tool.name, - 'description': tool.description, - 'input_schema': tool.parameters - } - - def _transform_chat_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + return {"name": tool.name, "description": tool.description, "input_schema": tool.parameters} + + def _transform_chat_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts """ @@ -196,22 +219,30 @@ def _transform_chat_json_prompts(self, model: str, credentials: dict, if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): # override the system message prompt_messages[0] = SystemPromptMessage( - content=ANTHROPIC_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + content=ANTHROPIC_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=ANTHROPIC_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=ANTHROPIC_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", f"Please output a valid {response_format} object." + ).replace("{{block}}", response_format) + ), + ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -227,9 +258,9 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr tokens = client.count_tokens(prompt) tool_call_inner_prompts_tokens_map = { - 'claude-3-opus-20240229': 395, - 'claude-3-haiku-20240307': 264, - 'claude-3-sonnet-20240229': 159 + "claude-3-opus-20240229": 395, + "claude-3-haiku-20240307": 264, + "claude-3-sonnet-20240229": 159, } if model in tool_call_inner_prompts_tokens_map and tools: @@ -256,13 +287,18 @@ def validate_credentials(self, model: str, credentials: dict) -> None: "temperature": 0, "max_tokens": 20, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _handle_chat_generate_response(self, model: str, credentials: dict, response: Union[Message, ToolsBetaMessage], - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: Union[Message, ToolsBetaMessage], + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm chat response @@ -273,22 +309,18 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, response :return: llm response """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content='', - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content="", tool_calls=[]) for content in response.content: - if content.type == 'text': + if content.type == "text": assistant_prompt_message.content += content.text - elif content.type == 'tool_use': + elif content.type == "tool_use": tool_call = AssistantPromptMessage.ToolCall( id=content.id, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=content.name, - arguments=json.dumps(content.input) - ) + name=content.name, arguments=json.dumps(content.input) + ), ) assistant_prompt_message.tool_calls.append(tool_call) @@ -307,17 +339,14 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, response # transform response response = LLMResult( - model=response.model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage + model=response.model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage ) return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, - response: Stream[MessageStreamEvent], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_chat_generate_stream_response( + self, model: str, credentials: dict, response: Stream[MessageStreamEvent], prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm chat stream response @@ -326,7 +355,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, :param prompt_messages: prompt messages :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" return_model = None input_tokens = 0 output_tokens = 0 @@ -337,24 +366,23 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, for chunk in response: if isinstance(chunk, MessageStartEvent): - if hasattr(chunk, 'content_block'): + if hasattr(chunk, "content_block"): content_block = chunk.content_block if isinstance(content_block, dict): - if content_block.get('type') == 'tool_use': + if content_block.get("type") == "tool_use": tool_call = AssistantPromptMessage.ToolCall( - id=content_block.get('id'), - type='function', + id=content_block.get("id"), + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=content_block.get('name'), - arguments='' - ) + name=content_block.get("name"), arguments="" + ), ) tool_calls.append(tool_call) - elif hasattr(chunk, 'delta'): + elif hasattr(chunk, "delta"): delta = chunk.delta if isinstance(delta, dict) and len(tool_calls) > 0: - if delta.get('type') == 'input_json_delta': - tool_calls[-1].function.arguments += delta.get('partial_json', '') + if delta.get("type") == "input_json_delta": + tool_calls[-1].function.arguments += delta.get("partial_json", "") elif chunk.message: return_model = chunk.message.model input_tokens = chunk.message.usage.input_tokens @@ -368,29 +396,24 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, # transform empty tool call arguments to {} for tool_call in tool_calls: if not tool_call.function.arguments: - tool_call.function.arguments = '{}' + tool_call.function.arguments = "{}" yield LLMResultChunk( model=return_model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index + 1, - message=AssistantPromptMessage( - content='', - tool_calls=tool_calls - ), + message=AssistantPromptMessage(content="", tool_calls=tool_calls), finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) elif isinstance(chunk, ContentBlockDeltaEvent): - chunk_text = chunk.delta.text if chunk.delta.text else '' + chunk_text = chunk.delta.text or "" full_assistant_content += chunk_text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=chunk_text - ) + assistant_prompt_message = AssistantPromptMessage(content=chunk_text) index = chunk.index @@ -400,7 +423,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, delta=LLMResultChunkDelta( index=chunk.index, message=assistant_prompt_message, - ) + ), ) def _to_credential_kwargs(self, credentials: dict) -> dict: @@ -411,14 +434,14 @@ def _to_credential_kwargs(self, credentials: dict) -> dict: :return: """ credentials_kwargs = { - "api_key": credentials['anthropic_api_key'], + "api_key": credentials["anthropic_api_key"], "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "max_retries": 1, } - if credentials.get('anthropic_api_url'): - credentials['anthropic_api_url'] = credentials['anthropic_api_url'].rstrip('/') - credentials_kwargs['base_url'] = credentials['anthropic_api_url'] + if credentials.get("anthropic_api_url"): + credentials["anthropic_api_url"] = credentials["anthropic_api_url"].rstrip("/") + credentials_kwargs["base_url"] = credentials["anthropic_api_url"] return credentials_kwargs @@ -451,10 +474,7 @@ def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tupl for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -462,26 +482,27 @@ def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tupl # fetch image data from url try: image_content = requests.get(message_content.data).content - mime_type, _ = mimetypes.guess_type(message_content.data) - base64_data = base64.b64encode(image_content).decode('utf-8') + with Image.open(io.BytesIO(image_content)) as img: + mime_type = f"image/{img.format.lower()}" + base64_data = base64.b64encode(image_content).decode("utf-8") except Exception as ex: - raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") + raise ValueError( + f"Failed to fetch image data from url {message_content.data}, {ex}" + ) else: data_split = message_content.data.split(";base64,") mime_type = data_split[0].replace("data:", "") base64_data = data_split[1] - if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: - raise ValueError(f"Unsupported image type {mime_type}, " - f"only support image/jpeg, image/png, image/gif, and image/webp") + if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}: + raise ValueError( + f"Unsupported image type {mime_type}, " + f"only support image/jpeg, image/png, image/gif, and image/webp" + ) sub_message_dict = { "type": "image", - "source": { - "type": "base64", - "media_type": mime_type, - "data": base64_data - } + "source": {"type": "base64", "media_type": mime_type, "data": base64_data}, } sub_messages.append(sub_message_dict) prompt_message_dicts.append({"role": "user", "content": sub_messages}) @@ -490,34 +511,28 @@ def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tupl content = [] if message.tool_calls: for tool_call in message.tool_calls: - content.append({ - "type": "tool_use", - "id": tool_call.id, - "name": tool_call.function.name, - "input": json.loads(tool_call.function.arguments) - }) + content.append( + { + "type": "tool_use", + "id": tool_call.id, + "name": tool_call.function.name, + "input": json.loads(tool_call.function.arguments), + } + ) if message.content: - content.append({ - "type": "text", - "text": message.content - }) - + content.append({"type": "text", "text": message.content}) + if prompt_message_dicts[-1]["role"] == "assistant": prompt_message_dicts[-1]["content"].extend(content) else: - prompt_message_dicts.append({ - "role": "assistant", - "content": content - }) + prompt_message_dicts.append({"role": "assistant", "content": content}) elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = { "role": "user", - "content": [{ - "type": "tool_result", - "tool_use_id": message.tool_call_id, - "content": message.content - }] + "content": [ + {"type": "tool_result", "tool_use_id": message.tool_call_id, "content": message.content} + ], } prompt_message_dicts.append(message_dict) else: @@ -574,16 +589,13 @@ def _convert_messages_to_prompt_anthropic(self, messages: list[PromptMessage]) - :return: Combined string with necessary human_prompt and ai_prompt tags. """ if not messages: - return '' + return "" messages = messages.copy() # don't mutate the original list if not isinstance(messages[-1], AssistantPromptMessage): messages.append(AssistantPromptMessage(content="")) - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() @@ -599,24 +611,14 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - anthropic.APIConnectionError, - anthropic.APITimeoutError - ], - InvokeServerUnavailableError: [ - anthropic.InternalServerError - ], - InvokeRateLimitError: [ - anthropic.RateLimitError - ], - InvokeAuthorizationError: [ - anthropic.AuthenticationError, - anthropic.PermissionDeniedError - ], + InvokeConnectionError: [anthropic.APIConnectionError, anthropic.APITimeoutError], + InvokeServerUnavailableError: [anthropic.InternalServerError], + InvokeRateLimitError: [anthropic.RateLimitError], + InvokeAuthorizationError: [anthropic.AuthenticationError, anthropic.PermissionDeniedError], InvokeBadRequestError: [ anthropic.BadRequestError, anthropic.NotFoundError, anthropic.UnprocessableEntityError, - anthropic.APIError - ] + anthropic.APIError, + ], } diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/__init__.py b/api/core/model_runtime/model_providers/azure_ai_studio/__init__.py similarity index 100% rename from api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/__init__.py rename to api/core/model_runtime/model_providers/azure_ai_studio/__init__.py diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_l_en.png new file mode 100644 index 00000000000000..4b941654a78c15 Binary files /dev/null and b/api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_l_en.png differ diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_s_en.png b/api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_s_en.png new file mode 100644 index 00000000000000..ca3043dc8dcb19 Binary files /dev/null and b/api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_s_en.png differ diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.py b/api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.py new file mode 100644 index 00000000000000..75d21d1ce9875e --- /dev/null +++ b/api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.py @@ -0,0 +1,17 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class AzureAIStudioProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + pass diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.yaml b/api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.yaml new file mode 100644 index 00000000000000..9e17ba088480db --- /dev/null +++ b/api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.yaml @@ -0,0 +1,65 @@ +provider: azure_ai_studio +label: + zh_Hans: Azure AI Studio + en_US: Azure AI Studio +icon_small: + en_US: icon_s_en.png +icon_large: + en_US: icon_l_en.png +description: + en_US: Azure AI Studio + zh_Hans: Azure AI Studio +background: "#93c5fd" +help: + title: + en_US: How to deploy customized model on Azure AI Studio + zh_Hans: 如何在Azure AI Studio上的私有化部署的模型 + url: + en_US: https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models + zh_Hans: https://learn.microsoft.com/zh-cn/azure/ai-studio/how-to/deploy-models +supported_model_types: + - llm + - rerank +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: endpoint + label: + en_US: Azure AI Studio Endpoint + type: text-input + required: true + placeholder: + zh_Hans: 请输入你的Azure AI Studio推理端点 + en_US: 'Enter your API Endpoint, eg: https://example.com' + - variable: api_key + required: true + label: + en_US: API Key + zh_Hans: API Key + type: secret-input + placeholder: + en_US: Enter your Azure AI Studio API Key + zh_Hans: 在此输入您的 Azure AI Studio API Key + show_on: + - variable: __model_type + value: llm + - variable: jwt_token + required: true + label: + en_US: JWT Token + zh_Hans: JWT令牌 + type: secret-input + placeholder: + en_US: Enter your Azure AI Studio JWT Token + zh_Hans: 在此输入您的 Azure AI Studio 推理 API Key + show_on: + - variable: __model_type + value: rerank diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/__init__.py b/api/core/model_runtime/model_providers/azure_ai_studio/llm/__init__.py similarity index 100% rename from api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/__init__.py rename to api/core/model_runtime/model_providers/azure_ai_studio/llm/__init__.py diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py b/api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py new file mode 100644 index 00000000000000..53030bad8340b4 --- /dev/null +++ b/api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py @@ -0,0 +1,334 @@ +import logging +from collections.abc import Generator +from typing import Any, Optional, Union + +from azure.ai.inference import ChatCompletionsClient +from azure.ai.inference.models import StreamingChatCompletionsUpdate +from azure.core.credentials import AzureKeyCredential +from azure.core.exceptions import ( + ClientAuthenticationError, + DecodeError, + DeserializationError, + HttpResponseError, + ResourceExistsError, + ResourceModifiedError, + ResourceNotFoundError, + ResourceNotModifiedError, + SerializationError, + ServiceRequestError, + ServiceResponseError, +) + +from core.model_runtime.callbacks.base_callback import Callback +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + I18nObject, + ModelType, + ParameterRule, + ParameterType, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + +logger = logging.getLogger(__name__) + + +class AzureAIStudioLargeLanguageModel(LargeLanguageModel): + """ + Model class for Azure AI Studio large language model. + """ + + client: Any = None + + from azure.ai.inference.models import StreamingChatCompletionsUpdate + + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + + if not self.client: + endpoint = credentials.get("endpoint") + api_key = credentials.get("api_key") + self.client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key)) + + messages = [{"role": msg.role.value, "content": msg.content} for msg in prompt_messages] + + payload = { + "messages": messages, + "max_tokens": model_parameters.get("max_tokens", 4096), + "temperature": model_parameters.get("temperature", 0), + "top_p": model_parameters.get("top_p", 1), + "stream": stream, + } + + if stop: + payload["stop"] = stop + + if tools: + payload["tools"] = [tool.model_dump() for tool in tools] + + try: + response = self.client.complete(**payload) + + if stream: + return self._handle_stream_response(response, model, prompt_messages) + else: + return self._handle_non_stream_response(response, model, prompt_messages, credentials) + except Exception as e: + raise self._transform_invoke_error(e) + + def _handle_stream_response(self, response, model: str, prompt_messages: list[PromptMessage]) -> Generator: + for chunk in response: + if isinstance(chunk, StreamingChatCompletionsUpdate): + if chunk.choices: + delta = chunk.choices[0].delta + if delta.content: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=delta.content, tool_calls=[]), + ), + ) + + def _handle_non_stream_response( + self, response, model: str, prompt_messages: list[PromptMessage], credentials: dict + ) -> LLMResult: + assistant_text = response.choices[0].message.content + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) + usage = self._calc_response_usage( + model, credentials, response.usage.prompt_tokens, response.usage.completion_tokens + ) + result = LLMResult(model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage) + + if hasattr(response, "system_fingerprint"): + result.system_fingerprint = response.system_fingerprint + + return result + + def _invoke_result_generator( + self, + model: str, + result: Generator, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Generator: + """ + Invoke result generator + + :param result: result generator + :return: result generator + """ + callbacks = callbacks or [] + prompt_message = AssistantPromptMessage(content="") + usage = None + system_fingerprint = None + real_model = model + + try: + for chunk in result: + if isinstance(chunk, dict): + content = chunk["choices"][0]["message"]["content"] + usage = chunk["usage"] + chunk = LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=content, tool_calls=[]), + ), + system_fingerprint=chunk.get("system_fingerprint"), + ) + + yield chunk + + self._trigger_new_chunk_callbacks( + chunk=chunk, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + callbacks=callbacks, + ) + + prompt_message.content += chunk.delta.message.content + real_model = chunk.model + if hasattr(chunk.delta, "usage"): + usage = chunk.delta.usage + + if chunk.system_fingerprint: + system_fingerprint = chunk.system_fingerprint + except Exception as e: + raise self._transform_invoke_error(e) + + self._trigger_after_invoke_callbacks( + model=model, + result=LLMResult( + model=real_model, + prompt_messages=prompt_messages, + message=prompt_message, + usage=usage or LLMUsage.empty_usage(), + system_fingerprint=system_fingerprint, + ), + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + callbacks=callbacks, + ) + + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return: + """ + # Implement token counting logic here + # Might need to use a tokenizer specific to the Azure AI Studio model + return 0 + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + endpoint = credentials.get("endpoint") + api_key = credentials.get("api_key") + client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key)) + client.get_model_info() + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + ServiceRequestError, + ], + InvokeServerUnavailableError: [ + ServiceResponseError, + ], + InvokeAuthorizationError: [ + ClientAuthenticationError, + ], + InvokeBadRequestError: [ + HttpResponseError, + DecodeError, + ResourceExistsError, + ResourceNotFoundError, + ResourceModifiedError, + ResourceNotModifiedError, + SerializationError, + DeserializationError, + ], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + """ + Used to define customizable model schema + """ + rules = [ + ParameterRule( + name="temperature", + type=ParameterType.FLOAT, + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), + ), + ParameterRule( + name="top_p", + type=ParameterType.FLOAT, + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), + ), + ParameterRule( + name="max_tokens", + type=ParameterType.INT, + use_template="max_tokens", + min=1, + default=512, + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), + ] + + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.LLM, + features=[], + model_properties={}, + parameter_rules=rules, + ) + + return entity diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/__init__.py b/api/core/model_runtime/model_providers/azure_ai_studio/rerank/__init__.py similarity index 100% rename from api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/__init__.py rename to api/core/model_runtime/model_providers/azure_ai_studio/rerank/__init__.py diff --git a/api/core/model_runtime/model_providers/azure_ai_studio/rerank/rerank.py b/api/core/model_runtime/model_providers/azure_ai_studio/rerank/rerank.py new file mode 100644 index 00000000000000..84672520e07e57 --- /dev/null +++ b/api/core/model_runtime/model_providers/azure_ai_studio/rerank/rerank.py @@ -0,0 +1,164 @@ +import json +import logging +import os +import ssl +import urllib.request +from typing import Optional + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + +logger = logging.getLogger(__name__) + + +class AzureRerankModel(RerankModel): + """ + Model class for Azure AI Studio rerank model. + """ + + def _allow_self_signed_https(self, allowed): + # bypass the server certificate verification on client side + if allowed and not os.environ.get("PYTHONHTTPSVERIFY", "") and getattr(ssl, "_create_unverified_context", None): + ssl._create_default_https_context = ssl._create_unverified_context + + def _azure_rerank(self, query_input: str, docs: list[str], endpoint: str, api_key: str): + # self._allow_self_signed_https(True) # Enable if using self-signed certificate + + data = {"inputs": query_input, "docs": docs} + + body = json.dumps(data).encode("utf-8") + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + + req = urllib.request.Request(endpoint, body, headers) + + try: + with urllib.request.urlopen(req) as response: + result = response.read() + return json.loads(result) + except urllib.error.HTTPError as error: + logger.exception(f"The request failed with status code: {error.code}") + logger.exception(error.info()) + logger.exception(error.read().decode("utf8", "ignore")) + raise + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + try: + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + endpoint = credentials.get("endpoint") + api_key = credentials.get("jwt_token") + + if not endpoint or not api_key: + raise ValueError("Azure endpoint and API key must be provided in credentials") + + result = self._azure_rerank(query, docs, endpoint, api_key) + logger.info(f"Azure rerank result: {result}") + + rerank_documents = [] + for idx, (doc, score_dict) in enumerate(zip(docs, result)): + score = score_dict["score"] + rerank_document = RerankDocument(index=idx, text=doc, score=score) + + if score_threshold is None or score >= score_threshold: + rerank_documents.append(rerank_document) + + rerank_documents.sort(key=lambda x: x.score, reverse=True) + + if top_n: + rerank_documents = rerank_documents[:top_n] + + return RerankResult(model=model, docs=rerank_documents) + + except Exception as e: + logger.exception(f"Exception in Azure rerank: {e}") + raise + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [urllib.error.URLError], + InvokeServerUnavailableError: [urllib.error.HTTPError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError, json.JSONDecodeError], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.RERANK, + model_properties={}, + parameter_rules=[], + ) + + return entity diff --git a/api/core/model_runtime/model_providers/azure_openai/_common.py b/api/core/model_runtime/model_providers/azure_openai/_common.py index 31c788d226db34..32a0269af49314 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_common.py +++ b/api/core/model_runtime/model_providers/azure_openai/_common.py @@ -15,10 +15,10 @@ class _CommonAzureOpenAI: @staticmethod def _to_credential_kwargs(credentials: dict) -> dict: - api_version = credentials.get('openai_api_version', AZURE_OPENAI_API_VERSION) + api_version = credentials.get("openai_api_version", AZURE_OPENAI_API_VERSION) credentials_kwargs = { - "api_key": credentials['openai_api_key'], - "azure_endpoint": credentials['openai_api_base'], + "api_key": credentials["openai_api_key"], + "azure_endpoint": credentials["openai_api_base"], "api_version": api_version, "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "max_retries": 1, @@ -29,24 +29,14 @@ def _to_credential_kwargs(credentials: dict) -> dict: @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - openai.APIConnectionError, - openai.APITimeoutError - ], - InvokeServerUnavailableError: [ - openai.InternalServerError - ], - InvokeRateLimitError: [ - openai.RateLimitError - ], - InvokeAuthorizationError: [ - openai.AuthenticationError, - openai.PermissionDeniedError - ], + InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError], + InvokeServerUnavailableError: [openai.InternalServerError], + InvokeRateLimitError: [openai.RateLimitError], + InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError], InvokeBadRequestError: [ openai.BadRequestError, openai.NotFoundError, openai.UnprocessableEntityError, - openai.APIError - ] + openai.APIError, + ], } diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index 984cca3744dbff..e61a9e0474b101 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -14,11 +14,32 @@ PriceConfig, ) -AZURE_OPENAI_API_VERSION = '2024-02-15-preview' +AZURE_OPENAI_API_VERSION = "2024-02-15-preview" + +AZURE_DEFAULT_PARAM_SEED_HELP = I18nObject( + zh_Hans="如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性," + "您应该参考 system_fingerprint 响应参数来监视变化。", + en_US="If specified, model will make a best effort to sample deterministically," + " such that repeated requests with the same seed and parameters should return the same result." + " Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter" + " to monitor changes in the backend.", +) + def _get_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule: rule = ParameterRule( - name='max_tokens', + name="max_tokens", + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.MAX_TOKENS], + ) + rule.default = default + rule.min = min_val + rule.max = max_val + return rule + + +def _get_o1_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule: + rule = ParameterRule( + name="max_completion_tokens", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.MAX_TOKENS], ) rule.default = default @@ -34,11 +55,11 @@ class AzureBaseModel(BaseModel): LLM_BASE_MODELS = [ AzureBaseModel( - base_model_name='gpt-35-turbo', + base_model_name="gpt-35-turbo", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -53,51 +74,47 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.0005, output=0.0015, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-35-turbo-16k', + base_model_name="gpt-35-turbo-16k", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -112,37 +129,37 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), - _get_max_tokens(default=512, min_val=1, max_val=16385) + _get_max_tokens(default=512, min_val=1, max_val=16385), ], pricing=PriceConfig( input=0.003, output=0.004, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-35-turbo-0125', + base_model_name="gpt-35-turbo-0125", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -157,51 +174,47 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.0005, output=0.0015, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4', + base_model_name="gpt-4", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -216,67 +229,57 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=8192), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', - help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' - ), + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, required=False, precision=2, min=0, max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.03, output=0.06, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-32k', + base_model_name="gpt-4-32k", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -291,67 +294,57 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=32768), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', - help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' - ), + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, required=False, precision=2, min=0, max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.06, output=0.12, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-0125-preview', + base_model_name="gpt-4-0125-preview", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -366,67 +359,57 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', - help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' - ), + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, required=False, precision=2, min=0, max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-1106-preview', + base_model_name="gpt-4-1106-preview", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -441,67 +424,57 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', - help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' - ), + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, required=False, precision=2, min=0, max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4o-mini', + base_model_name="gpt-4o-mini", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -517,67 +490,57 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=16384), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', - help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' - ), + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, required=False, precision=2, min=0, max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.150, output=0.600, unit=0.000001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4o-mini-2024-07-18', + base_model_name="gpt-4o-mini-2024-07-18", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -593,67 +556,67 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=16384), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', - help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' - ), + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, required=False, precision=2, min=0, max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", + help=I18nObject( + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), - type='string', + required=False, + options=["text", "json_object", "json_schema"], + ), + ParameterRule( + name="json_schema", + label=I18nObject(en_US="JSON Schema"), + type="text", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="设置返回的json schema,llm将按照它返回", + en_US="Set a response json schema will ensure LLM to adhere it.", ), required=False, - options=['text', 'json_object'] ), ], pricing=PriceConfig( input=0.150, output=0.600, unit=0.000001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4o', + base_model_name="gpt-4o", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -669,67 +632,57 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', - help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' - ), + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, required=False, precision=2, min=0, max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=5.00, output=15.00, unit=0.000001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4o-2024-05-13', + base_model_name="gpt-4o-2024-05-13", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -745,67 +698,57 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', - help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' - ), + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, required=False, precision=2, min=0, max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=5.00, output=15.00, unit=0.000001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-turbo', + base_model_name="gpt-4o-2024-08-06", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -821,67 +764,133 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, + required=False, + precision=2, + min=0, + max=1, + ), + ParameterRule( + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", + help=I18nObject( + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), - type='int', + required=False, + options=["text", "json_object", "json_schema"], + ), + ParameterRule( + name="json_schema", + label=I18nObject(en_US="JSON Schema"), + type="text", help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' + zh_Hans="设置返回的json schema,llm将按照它返回", + en_US="Set a response json schema will ensure LLM to adhere it.", ), required=False, + ), + ], + pricing=PriceConfig( + input=5.00, + output=15.00, + unit=0.000001, + currency="USD", + ), + ), + ), + AzureBaseModel( + base_model_name="gpt-4-turbo", + entity=AIModelEntity( + model="fake-deployment-name", + label=I18nObject( + en_US="fake-deployment-name-label", + ), + model_type=ModelType.LLM, + features=[ + ModelFeature.AGENT_THOUGHT, + ModelFeature.VISION, + ModelFeature.MULTI_TOOL_CALL, + ModelFeature.STREAM_TOOL_CALL, + ], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.MODE: LLMMode.CHAT.value, + ModelPropertyKey.CONTEXT_SIZE: 128000, + }, + parameter_rules=[ + ParameterRule( + name="temperature", + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], + ), + ParameterRule( + name="top_p", + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], + ), + ParameterRule( + name="presence_penalty", + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], + ), + ParameterRule( + name="frequency_penalty", + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], + ), + _get_max_tokens(default=512, min_val=1, max_val=4096), + ParameterRule( + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, + required=False, precision=2, min=0, max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-turbo-2024-04-09', + base_model_name="gpt-4-turbo-2024-04-09", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, features=[ @@ -897,72 +906,60 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', - help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' - ), + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, required=False, precision=2, min=0, max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-4-vision-preview', + base_model_name="gpt-4-vision-preview", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, - features=[ - ModelFeature.VISION - ], + features=[ModelFeature.VISION], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ ModelPropertyKey.MODE: LLMMode.CHAT.value, @@ -970,67 +967,57 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), ParameterRule( - name='seed', - label=I18nObject( - zh_Hans='种子', - en_US='Seed' - ), - type='int', - help=I18nObject( - zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。', - en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.' - ), + name="seed", + label=I18nObject(zh_Hans="种子", en_US="Seed"), + type="int", + help=AZURE_DEFAULT_PARAM_SEED_HELP, required=False, precision=2, min=0, max=1, ), ParameterRule( - name='response_format', - label=I18nObject( - zh_Hans='回复格式', - en_US='response_format' - ), - type='string', + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", help=I18nObject( - zh_Hans='指定模型必须输出的格式', - en_US='specifying the format that the model must output' + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" ), required=False, - options=['text', 'json_object'] + options=["text", "json_object"], ), ], pricing=PriceConfig( input=0.01, output=0.03, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='gpt-35-turbo-instruct', + base_model_name="gpt-35-turbo-instruct", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, @@ -1040,19 +1027,19 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), @@ -1061,16 +1048,16 @@ class AzureBaseModel(BaseModel): input=0.0015, output=0.002, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='text-davinci-003', + base_model_name="text-davinci-003", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label', + en_US="fake-deployment-name-label", ), model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, @@ -1080,19 +1067,19 @@ class AzureBaseModel(BaseModel): }, parameter_rules=[ ParameterRule( - name='temperature', + name="temperature", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], ), ParameterRule( - name='top_p', + name="top_p", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], ), ParameterRule( - name='presence_penalty', + name="presence_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY], ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), _get_max_tokens(default=512, min_val=1, max_val=4096), @@ -1101,20 +1088,91 @@ class AzureBaseModel(BaseModel): input=0.02, output=0.02, unit=0.001, - currency='USD', - ) - ) - ) -] - -EMBEDDING_BASE_MODELS = [ + currency="USD", + ), + ), + ), AzureBaseModel( - base_model_name='text-embedding-ada-002', + base_model_name="o1-preview", entity=AIModelEntity( - model='fake-deployment-name', + model="fake-deployment-name", label=I18nObject( - en_US='fake-deployment-name-label' + en_US="fake-deployment-name-label", ), + model_type=ModelType.LLM, + features=[ + ModelFeature.AGENT_THOUGHT, + ], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.MODE: LLMMode.CHAT.value, + ModelPropertyKey.CONTEXT_SIZE: 128000, + }, + parameter_rules=[ + ParameterRule( + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", + help=I18nObject( + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" + ), + required=False, + options=["text", "json_object"], + ), + _get_o1_max_tokens(default=512, min_val=1, max_val=32768), + ], + pricing=PriceConfig( + input=15.00, + output=60.00, + unit=0.000001, + currency="USD", + ), + ), + ), + AzureBaseModel( + base_model_name="o1-mini", + entity=AIModelEntity( + model="fake-deployment-name", + label=I18nObject( + en_US="fake-deployment-name-label", + ), + model_type=ModelType.LLM, + features=[ + ModelFeature.AGENT_THOUGHT, + ], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.MODE: LLMMode.CHAT.value, + ModelPropertyKey.CONTEXT_SIZE: 128000, + }, + parameter_rules=[ + ParameterRule( + name="response_format", + label=I18nObject(zh_Hans="回复格式", en_US="response_format"), + type="string", + help=I18nObject( + zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output" + ), + required=False, + options=["text", "json_object"], + ), + _get_o1_max_tokens(default=512, min_val=1, max_val=65536), + ], + pricing=PriceConfig( + input=3.00, + output=12.00, + unit=0.000001, + currency="USD", + ), + ), + ), +] +EMBEDDING_BASE_MODELS = [ + AzureBaseModel( + base_model_name="text-embedding-ada-002", + entity=AIModelEntity( + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ @@ -1124,17 +1182,15 @@ class AzureBaseModel(BaseModel): pricing=PriceConfig( input=0.0001, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='text-embedding-3-small', + base_model_name="text-embedding-3-small", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ @@ -1144,17 +1200,15 @@ class AzureBaseModel(BaseModel): pricing=PriceConfig( input=0.00002, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='text-embedding-3-large', + base_model_name="text-embedding-3-large", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ @@ -1164,135 +1218,129 @@ class AzureBaseModel(BaseModel): pricing=PriceConfig( input=0.00013, unit=0.001, - currency='USD', - ) - ) - ) + currency="USD", + ), + ), + ), ] SPEECH2TEXT_BASE_MODELS = [ AzureBaseModel( - base_model_name='whisper-1', + base_model_name="whisper-1", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.SPEECH2TEXT, model_properties={ ModelPropertyKey.FILE_UPLOAD_LIMIT: 25, - ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: 'flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm' - } - ) + ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: "flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm", + }, + ), ) ] TTS_BASE_MODELS = [ AzureBaseModel( - base_model_name='tts-1', + base_model_name="tts-1", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TTS, model_properties={ - ModelPropertyKey.DEFAULT_VOICE: 'alloy', + ModelPropertyKey.DEFAULT_VOICE: "alloy", ModelPropertyKey.VOICES: [ { - 'mode': 'alloy', - 'name': 'Alloy', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "alloy", + "name": "Alloy", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'echo', - 'name': 'Echo', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "echo", + "name": "Echo", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'fable', - 'name': 'Fable', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "fable", + "name": "Fable", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'onyx', - 'name': 'Onyx', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "onyx", + "name": "Onyx", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'nova', - 'name': 'Nova', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "nova", + "name": "Nova", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'shimmer', - 'name': 'Shimmer', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "shimmer", + "name": "Shimmer", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, ], ModelPropertyKey.WORD_LIMIT: 120, - ModelPropertyKey.AUDIO_TYPE: 'mp3', - ModelPropertyKey.MAX_WORKERS: 5 + ModelPropertyKey.AUDIO_TYPE: "mp3", + ModelPropertyKey.MAX_WORKERS: 5, }, pricing=PriceConfig( input=0.015, unit=0.001, - currency='USD', - ) - ) + currency="USD", + ), + ), ), AzureBaseModel( - base_model_name='tts-1-hd', + base_model_name="tts-1-hd", entity=AIModelEntity( - model='fake-deployment-name', - label=I18nObject( - en_US='fake-deployment-name-label' - ), + model="fake-deployment-name", + label=I18nObject(en_US="fake-deployment-name-label"), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TTS, model_properties={ - ModelPropertyKey.DEFAULT_VOICE: 'alloy', + ModelPropertyKey.DEFAULT_VOICE: "alloy", ModelPropertyKey.VOICES: [ { - 'mode': 'alloy', - 'name': 'Alloy', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "alloy", + "name": "Alloy", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'echo', - 'name': 'Echo', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "echo", + "name": "Echo", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'fable', - 'name': 'Fable', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "fable", + "name": "Fable", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'onyx', - 'name': 'Onyx', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "onyx", + "name": "Onyx", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'nova', - 'name': 'Nova', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "nova", + "name": "Nova", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, { - 'mode': 'shimmer', - 'name': 'Shimmer', - 'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + "mode": "shimmer", + "name": "Shimmer", + "language": ["zh-Hans", "en-US", "de-DE", "fr-FR", "es-ES", "it-IT", "th-TH", "id-ID"], }, ], ModelPropertyKey.WORD_LIMIT: 120, - ModelPropertyKey.AUDIO_TYPE: 'mp3', - ModelPropertyKey.MAX_WORKERS: 5 + ModelPropertyKey.AUDIO_TYPE: "mp3", + ModelPropertyKey.MAX_WORKERS: 5, }, pricing=PriceConfig( input=0.03, unit=0.001, - currency='USD', - ) - ) - ) + currency="USD", + ), + ), + ), ] diff --git a/api/core/model_runtime/model_providers/azure_openai/azure_openai.py b/api/core/model_runtime/model_providers/azure_openai/azure_openai.py index 68977b2266718d..2e3c6aab0588ec 100644 --- a/api/core/model_runtime/model_providers/azure_openai/azure_openai.py +++ b/api/core/model_runtime/model_providers/azure_openai/azure_openai.py @@ -6,6 +6,5 @@ class AzureOpenAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml index be4d4651d7b06f..1ef5e83abca6de 100644 --- a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml +++ b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml @@ -53,6 +53,18 @@ model_credential_schema: type: select required: true options: + - label: + en_US: 2024-10-01-preview + value: 2024-10-01-preview + - label: + en_US: 2024-09-01-preview + value: 2024-09-01-preview + - label: + en_US: 2024-08-01-preview + value: 2024-08-01-preview + - label: + en_US: 2024-07-01-preview + value: 2024-07-01-preview - label: en_US: 2024-05-01-preview value: 2024-05-01-preview @@ -114,6 +126,18 @@ model_credential_schema: show_on: - variable: __model_type value: llm + - label: + en_US: o1-mini + value: o1-mini + show_on: + - variable: __model_type + value: llm + - label: + en_US: o1-preview + value: o1-preview + show_on: + - variable: __model_type + value: llm - label: en_US: gpt-4o-mini value: gpt-4o-mini @@ -138,6 +162,12 @@ model_credential_schema: show_on: - variable: __model_type value: llm + - label: + en_US: gpt-4o-2024-08-06 + value: gpt-4o-2024-08-06 + show_on: + - variable: __model_type + value: llm - label: en_US: gpt-4-turbo value: gpt-4-turbo diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index 1911caa952bbba..1cd4823e131401 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -1,4 +1,5 @@ import copy +import json import logging from collections.abc import Generator, Sequence from typing import Optional, Union, cast @@ -33,16 +34,18 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: - - base_model_name = credentials.get('base_model_name') - if not base_model_name: - raise ValueError('Base Model Name is required') + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + base_model_name = self._get_base_model_name(credentials) ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: @@ -55,7 +58,7 @@ def _invoke(self, model: str, credentials: dict, tools=tools, stop=stop, stream=stream, - user=user + user=user, ) else: # text completion model @@ -66,7 +69,7 @@ def _invoke(self, model: str, credentials: dict, model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) def get_num_tokens( @@ -74,14 +77,12 @@ def get_num_tokens( model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None + tools: Optional[list[PromptMessageTool]] = None, ) -> int: - base_model_name = credentials.get('base_model_name') - if not base_model_name: - raise ValueError('Base Model Name is required') + base_model_name = self._get_base_model_name(credentials) model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) if not model_entity: - raise ValueError(f'Base Model Name {base_model_name} is invalid') + raise ValueError(f"Base Model Name {base_model_name} is invalid") model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE) if model_mode == LLMMode.CHAT.value: @@ -91,21 +92,19 @@ def get_num_tokens( # text completion model, do not support tool calling content = prompt_messages[0].content assert isinstance(content, str) - return self._num_tokens_from_string(credentials,content) + return self._num_tokens_from_string(credentials, content) def validate_credentials(self, model: str, credentials: dict) -> None: - if 'openai_api_base' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required') + if "openai_api_base" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required") - if 'openai_api_key' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API key is required') + if "openai_api_key" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API key is required") - if 'base_model_name' not in credentials: - raise CredentialsValidateFailedError('Base Model Name is required') + if "base_model_name" not in credentials: + raise CredentialsValidateFailedError("Base Model Name is required") - base_model_name = credentials.get('base_model_name') - if not base_model_name: - raise CredentialsValidateFailedError('Base Model Name is required') + base_model_name = self._get_base_model_name(credentials) ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) if not ai_model_entity: @@ -114,10 +113,18 @@ def validate_credentials(self, model: str, credentials: dict) -> None: try: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) - if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: + if model.startswith("o1"): + client.chat.completions.create( + messages=[{"role": "user", "content": "ping"}], + model=model, + temperature=1, + max_completion_tokens=20, + stream=False, + ) + elif ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: # chat model client.chat.completions.create( - messages=[{"role": "user", "content": 'ping'}], + messages=[{"role": "user", "content": "ping"}], model=model, temperature=0, max_tokens=20, @@ -126,7 +133,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: else: # text completion model client.completions.create( - prompt='ping', + prompt="ping", model=model, temperature=0, max_tokens=20, @@ -136,33 +143,33 @@ def validate_credentials(self, model: str, credentials: dict) -> None: raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - base_model_name = credentials.get('base_model_name') - if not base_model_name: - raise ValueError('Base Model Name is required') + base_model_name = self._get_base_model_name(credentials) ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) return ai_model_entity.entity if ai_model_entity else None - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: - + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) extra_model_kwargs = {} if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user # text completion model response = client.completions.create( - prompt=prompt_messages[0].content, - model=model, - stream=stream, - **model_parameters, - **extra_model_kwargs + prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs ) if stream: @@ -171,15 +178,12 @@ def _generate(self, model: str, credentials: dict, return self._handle_generate_response(model, credentials, response, prompt_messages) def _handle_generate_response( - self, model: str, credentials: dict, response: Completion, - prompt_messages: list[PromptMessage] + self, model: str, credentials: dict, response: Completion, prompt_messages: list[PromptMessage] ): assistant_text = response.choices[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) # calculate num tokens if response.usage: @@ -208,24 +212,21 @@ def _handle_generate_response( return result def _handle_generate_stream_response( - self, model: str, credentials: dict, response: Stream[Completion], - prompt_messages: list[PromptMessage] + self, model: str, credentials: dict, response: Stream[Completion], prompt_messages: list[PromptMessage] ) -> Generator: - full_text = '' + full_text = "" for chunk in response: if len(chunk.choices) == 0: continue delta = chunk.choices[0] - if delta.finish_reason is None and (delta.text is None or delta.text == ''): + if delta.finish_reason is None and (delta.text is None or delta.text == ""): continue # transform assistant message to prompt message - text = delta.text if delta.text else '' - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + text = delta.text or "" + assistant_prompt_message = AssistantPromptMessage(content=text) full_text += text @@ -253,8 +254,8 @@ def _handle_generate_stream_response( index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage - ) + usage=usage, + ), ) else: yield LLMResultChunk( @@ -264,45 +265,66 @@ def _handle_generate_stream_response( delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: - + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) response_format = model_parameters.get("response_format") if response_format: - if response_format == "json_object": - response_format = {"type": "json_object"} + if response_format == "json_schema": + json_schema = model_parameters.get("json_schema") + if not json_schema: + raise ValueError("Must define JSON Schema when the response format is json_schema") + try: + schema = json.loads(json_schema) + except: + raise ValueError(f"not correct json_schema format: {json_schema}") + model_parameters.pop("json_schema") + model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema} else: - response_format = {"type": "text"} - - model_parameters["response_format"] = response_format + model_parameters["response_format"] = {"type": response_format} extra_model_kwargs = {} if tools: - extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] - # extra_model_kwargs['functions'] = [{ - # "name": tool.name, - # "description": tool.description, - # "parameters": tool.parameters - # } for tool in tools] + extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user + + # clear illegal prompt messages + prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) + + block_as_stream = False + if model.startswith("o1"): + if stream: + block_as_stream = True + stream = False + + if "stream_options" in extra_model_kwargs: + del extra_model_kwargs["stream_options"] + + if "stop" in extra_model_kwargs: + del extra_model_kwargs["stop"] # chat model - messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] response = client.chat.completions.create( - messages=messages, + messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], model=model, stream=stream, **model_parameters, @@ -312,12 +334,99 @@ def _chat_generate(self, model: str, credentials: dict, if stream: return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) - return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) + block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) + + if block_as_stream: + return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop) + + return block_result + + def _handle_chat_block_as_stream_response( + self, + block_result: LLMResult, + prompt_messages: list[PromptMessage], + stop: Optional[list[str]] = None, + ) -> Generator[LLMResultChunk, None, None]: + """ + Handle llm chat response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :param stop: stop words + :return: llm response chunk generator + """ + text = block_result.message.content + text = cast(str, text) + + if stop: + text = self.enforce_stop_tokens(text, stop) + + yield LLMResultChunk( + model=block_result.model, + prompt_messages=prompt_messages, + system_fingerprint=block_result.system_fingerprint, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=text), + finish_reason="stop", + usage=block_result.usage, + ), + ) + + def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: + """ + Clear illegal prompt messages for OpenAI API + + :param model: model name + :param prompt_messages: prompt messages + :return: cleaned prompt messages + """ + checklist = ["gpt-4-turbo", "gpt-4-turbo-2024-04-09"] + + if model in checklist: + # count how many user messages are there + user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)]) + if user_message_count > 1: + for prompt_message in prompt_messages: + if isinstance(prompt_message, UserPromptMessage): + if isinstance(prompt_message.content, list): + prompt_message.content = "\n".join( + [ + item.data + if item.type == PromptMessageContentType.TEXT + else "[IMAGE]" + if item.type == PromptMessageContentType.IMAGE + else "" + for item in prompt_message.content + ] + ) + + if model.startswith("o1"): + system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)]) + if system_message_count > 0: + new_prompt_messages = [] + for prompt_message in prompt_messages: + if isinstance(prompt_message, SystemPromptMessage): + prompt_message = UserPromptMessage( + content=prompt_message.content, + name=prompt_message.name, + ) + + new_prompt_messages.append(prompt_message) + prompt_messages = new_prompt_messages + + return prompt_messages def _handle_chat_generate_response( - self, model: str, credentials: dict, response: ChatCompletion, + self, + model: str, + credentials: dict, + response: ChatCompletion, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None + tools: Optional[list[PromptMessageTool]] = None, ): assistant_message = response.choices[0].message assistant_message_tool_calls = assistant_message.tool_calls @@ -327,10 +436,7 @@ def _handle_chat_generate_response( self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) # calculate num tokens if response.usage: @@ -362,13 +468,13 @@ def _handle_chat_generate_stream_response( credentials: dict, response: Stream[ChatCompletionChunk], prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None + tools: Optional[list[PromptMessageTool]] = None, ): index = 0 - full_assistant_content = '' + full_assistant_content = "" real_model = model system_fingerprint = None - completion = '' + completion = "" tool_calls = [] for chunk in response: if len(chunk.choices) == 0: @@ -379,7 +485,6 @@ def _handle_chat_generate_stream_response( if delta.delta is None: continue - # extract tool calls from response self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls) @@ -388,16 +493,13 @@ def _handle_chat_generate_stream_response( continue # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls) - full_assistant_content += delta.delta.content if delta.delta.content else '' + full_assistant_content += delta.delta.content or "" real_model = chunk.model system_fingerprint = chunk.system_fingerprint - completion += delta.delta.content if delta.delta.content else '' + completion += delta.delta.content or "" yield LLMResultChunk( model=real_model, @@ -406,17 +508,15 @@ def _handle_chat_generate_stream_response( delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) - index += 0 + index += 1 # calculate num tokens prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools) - full_assistant_prompt_message = AssistantPromptMessage( - content=completion - ) + full_assistant_prompt_message = AssistantPromptMessage(content=completion) completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message]) # transform usage @@ -427,27 +527,24 @@ def _handle_chat_generate_stream_response( prompt_messages=prompt_messages, system_fingerprint=system_fingerprint, delta=LLMResultChunkDelta( - index=index, - message=AssistantPromptMessage(content=''), - finish_reason='stop', - usage=usage - ) + index=index, message=AssistantPromptMessage(content=""), finish_reason="stop", usage=usage + ), ) @staticmethod - def _update_tool_calls(tool_calls: list[AssistantPromptMessage.ToolCall], tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]]) -> None: + def _update_tool_calls( + tool_calls: list[AssistantPromptMessage.ToolCall], + tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]], + ) -> None: if tool_calls_response: for response_tool_call in tool_calls_response: if isinstance(response_tool_call, ChatCompletionMessageToolCall): function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) elif isinstance(response_tool_call, ChoiceDeltaToolCall): @@ -456,8 +553,10 @@ def _update_tool_calls(tool_calls: list[AssistantPromptMessage.ToolCall], tool_c tool_calls[index].id = response_tool_call.id or tool_calls[index].id tool_calls[index].type = response_tool_call.type or tool_calls[index].type if response_tool_call.function: - tool_calls[index].function.name = response_tool_call.function.name or tool_calls[index].function.name - tool_calls[index].function.arguments += response_tool_call.function.arguments or '' + tool_calls[index].function.name = ( + response_tool_call.function.name or tool_calls[index].function.name + ) + tool_calls[index].function.arguments += response_tool_call.function.arguments or "" else: assert response_tool_call.id is not None assert response_tool_call.type is not None @@ -466,13 +565,10 @@ def _update_tool_calls(tool_calls: list[AssistantPromptMessage.ToolCall], tool_c assert response_tool_call.function.arguments is not None function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) @@ -488,19 +584,13 @@ def _convert_prompt_message_to_dict(message: PromptMessage): for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} @@ -518,7 +608,7 @@ def _convert_prompt_message_to_dict(message: PromptMessage): "role": "tool", "name": message.name, "content": message.content, - "tool_call_id": message.tool_call_id + "tool_call_id": message.tool_call_id, } else: raise ValueError(f"Got unknown type {message}") @@ -528,10 +618,11 @@ def _convert_prompt_message_to_dict(message: PromptMessage): return message_dict - def _num_tokens_from_string(self, credentials: dict, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string( + self, credentials: dict, text: str, tools: Optional[list[PromptMessageTool]] = None + ) -> int: try: - encoding = tiktoken.encoding_for_model(credentials['base_model_name']) + encoding = tiktoken.encoding_for_model(credentials["base_model_name"]) except KeyError: encoding = tiktoken.get_encoding("cl100k_base") @@ -543,14 +634,13 @@ def _num_tokens_from_string(self, credentials: dict, text: str, return num_tokens def _num_tokens_from_messages( - self, credentials: dict, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None + self, credentials: dict, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" - model = credentials['base_model_name'] + model = credentials["base_model_name"] try: encoding = tiktoken.encoding_for_model(model) except KeyError: @@ -563,7 +653,7 @@ def _num_tokens_from_messages( tokens_per_message = 4 # if there's a name, the role is omitted tokens_per_name = -1 - elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4"): + elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4") or model.startswith("o1"): tokens_per_message = 3 tokens_per_name = 1 else: @@ -584,10 +674,10 @@ def _num_tokens_from_messages( # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -619,40 +709,39 @@ def _num_tokens_from_messages( @staticmethod def _num_tokens_for_tools(encoding: tiktoken.Encoding, tools: list[PromptMessageTool]) -> int: - num_tokens = 0 for tool in tools: - num_tokens += len(encoding.encode('type')) - num_tokens += len(encoding.encode('function')) + num_tokens += len(encoding.encode("type")) + num_tokens += len(encoding.encode("function")) # calculate num tokens for function object - num_tokens += len(encoding.encode('name')) + num_tokens += len(encoding.encode("name")) num_tokens += len(encoding.encode(tool.name)) - num_tokens += len(encoding.encode('description')) + num_tokens += len(encoding.encode("description")) num_tokens += len(encoding.encode(tool.description)) parameters = tool.parameters - num_tokens += len(encoding.encode('parameters')) - if 'title' in parameters: - num_tokens += len(encoding.encode('title')) - num_tokens += len(encoding.encode(parameters['title'])) - num_tokens += len(encoding.encode('type')) - num_tokens += len(encoding.encode(parameters['type'])) - if 'properties' in parameters: - num_tokens += len(encoding.encode('properties')) - for key, value in parameters['properties'].items(): + num_tokens += len(encoding.encode("parameters")) + if "title" in parameters: + num_tokens += len(encoding.encode("title")) + num_tokens += len(encoding.encode(parameters["title"])) + num_tokens += len(encoding.encode("type")) + num_tokens += len(encoding.encode(parameters["type"])) + if "properties" in parameters: + num_tokens += len(encoding.encode("properties")) + for key, value in parameters["properties"].items(): num_tokens += len(encoding.encode(key)) for field_key, field_value in value.items(): num_tokens += len(encoding.encode(field_key)) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += len(encoding.encode(enum_field)) else: num_tokens += len(encoding.encode(field_key)) num_tokens += len(encoding.encode(str(field_value))) - if 'required' in parameters: - num_tokens += len(encoding.encode('required')) - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += len(encoding.encode("required")) + for required_field in parameters["required"]: num_tokens += 3 num_tokens += len(encoding.encode(required_field)) @@ -667,3 +756,9 @@ def _get_ai_model_entity(base_model_name: str, model: str): ai_model_entity_copy.entity.label.en_US = model ai_model_entity_copy.entity.label.zh_Hans = model return ai_model_entity_copy + + def _get_base_model_name(self, credentials: dict) -> str: + base_model_name = credentials.get("base_model_name") + if not base_model_name: + raise ValueError("Base Model Name is required") + return base_model_name diff --git a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py index 8aebcb90e40b6a..a2b14cf3dbe6d4 100644 --- a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py @@ -15,9 +15,7 @@ class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -40,7 +38,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: try: audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self._speech2text_invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -65,10 +63,9 @@ def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> return response.text def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model) return ai_model_entity.entity - @staticmethod def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: for ai_model_entity in SPEECH2TEXT_BASE_MODELS: diff --git a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py index e073bef0149486..c45ce87ea76838 100644 --- a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py @@ -7,6 +7,7 @@ import tiktoken from openai import AzureOpenAI +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import AIModelEntity, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -16,19 +17,33 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): - - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: - base_model_name = credentials['base_model_name'] + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :param input_type: input type + :return: embeddings result + """ + base_model_name = credentials["base_model_name"] credentials_kwargs = self._to_credential_kwargs(credentials) client = AzureOpenAI(**credentials_kwargs) extra_model_kwargs = {} if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user - extra_model_kwargs['encoding_format'] = 'base64' + extra_model_kwargs["encoding_format"] = "base64" context_size = self._get_context_size(model, credentials) max_chunks = self._get_max_chunks(model, credentials) @@ -44,11 +59,9 @@ def _invoke(self, model: str, credentials: dict, enc = tiktoken.get_encoding("cl100k_base") for i, text in enumerate(texts): - token = enc.encode( - text - ) + token = enc.encode(text) for j in range(0, len(token), context_size): - tokens += [token[j: j + context_size]] + tokens += [token[j : j + context_size]] indices += [i] batched_embeddings = [] @@ -56,10 +69,7 @@ def _invoke(self, model: str, credentials: dict, for i in _iter: embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - client=client, - texts=tokens[i: i + max_chunks], - extra_model_kwargs=extra_model_kwargs + model=model, client=client, texts=tokens[i : i + max_chunks], extra_model_kwargs=extra_model_kwargs ) used_tokens += embedding_used_tokens @@ -75,10 +85,7 @@ def _invoke(self, model: str, credentials: dict, _result = results[i] if len(_result) == 0: embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - client=client, - texts="", - extra_model_kwargs=extra_model_kwargs + model=model, client=client, texts="", extra_model_kwargs=extra_model_kwargs ) used_tokens += embedding_used_tokens @@ -88,24 +95,16 @@ def _invoke(self, model: str, credentials: dict, embeddings[i] = (average / np.linalg.norm(average)).tolist() # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - return TextEmbeddingResult( - embeddings=embeddings, - usage=usage, - model=base_model_name - ) + return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=base_model_name) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: if len(texts) == 0: return 0 try: - enc = tiktoken.encoding_for_model(credentials['base_model_name']) + enc = tiktoken.encoding_for_model(credentials["base_model_name"]) except KeyError: enc = tiktoken.get_encoding("cl100k_base") @@ -118,57 +117,52 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int return total_num_tokens def validate_credentials(self, model: str, credentials: dict) -> None: - if 'openai_api_base' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required') + if "openai_api_base" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required") - if 'openai_api_key' not in credentials: - raise CredentialsValidateFailedError('Azure OpenAI API key is required') + if "openai_api_key" not in credentials: + raise CredentialsValidateFailedError("Azure OpenAI API key is required") - if 'base_model_name' not in credentials: - raise CredentialsValidateFailedError('Base Model Name is required') + if "base_model_name" not in credentials: + raise CredentialsValidateFailedError("Base Model Name is required") - if not self._get_ai_model_entity(credentials['base_model_name'], model): + if not self._get_ai_model_entity(credentials["base_model_name"], model): raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid') try: credentials_kwargs = self._to_credential_kwargs(credentials) client = AzureOpenAI(**credentials_kwargs) - self._embedding_invoke( - model=model, - client=client, - texts=['ping'], - extra_model_kwargs={} - ) + self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={}) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model) return ai_model_entity.entity @staticmethod - def _embedding_invoke(model: str, client: AzureOpenAI, texts: Union[list[str], str], - extra_model_kwargs: dict) -> tuple[list[list[float]], int]: + def _embedding_invoke( + model: str, client: AzureOpenAI, texts: Union[list[str], str], extra_model_kwargs: dict + ) -> tuple[list[list[float]], int]: response = client.embeddings.create( input=texts, model=model, **extra_model_kwargs, ) - if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64': + if "encoding_format" in extra_model_kwargs and extra_model_kwargs["encoding_format"] == "base64": # decode base64 embedding - return ([list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data], - response.usage.total_tokens) + return ( + [list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data], + response.usage.total_tokens, + ) return [data.embedding for data in response.data], response.usage.total_tokens def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -179,7 +173,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py index 3d2bac1c310277..133cc9f76e0720 100644 --- a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py @@ -1,6 +1,6 @@ import concurrent.futures import copy -from typing import Optional +from typing import Any, Optional from openai import AzureOpenAI @@ -17,8 +17,9 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, tenant_id: str, credentials: dict, - content_text: str, voice: str, user: Optional[str] = None) -> any: + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ) -> Any: """ _invoke text2speech model @@ -30,13 +31,12 @@ def _invoke(self, model: str, tenant_id: str, credentials: dict, :param user: unique user id :return: text translated to audio file """ - if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]: + if not voice or voice not in [ + d["value"] for d in self.get_tts_model_voices(model=model, credentials=credentials) + ]: voice = self._get_model_default_voice(model, credentials) - return self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - voice=voice) + return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -50,14 +50,13 @@ def validate_credentials(self, model: str, credentials: dict) -> None: self._tts_invoke_streaming( model=model, credentials=credentials, - content_text='Hello Dify!', + content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, - voice: str) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> Any: """ _tts_invoke_streaming text2speech model :param model: model name @@ -70,28 +69,34 @@ def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: st # doc: https://platform.openai.com/docs/guides/text-to-speech credentials_kwargs = self._to_credential_kwargs(credentials) client = AzureOpenAI(**credentials_kwargs) - # max font is 4096,there is 3500 limit for each request + # max length is 4096 characters, there is 3500 limit for each request max_length = 3500 if len(content_text) > max_length: sentences = self._split_text_into_sentences(content_text, max_length=max_length) executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences))) - futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model, - response_format="mp3", - input=sentences[i], voice=voice) for i in range(len(sentences))] - for index, future in enumerate(futures): - yield from future.result().__enter__().iter_bytes(1024) + futures = [ + executor.submit( + client.audio.speech.with_streaming_response.create, + model=model, + response_format="mp3", + input=sentences[i], + voice=voice, + ) + for i in range(len(sentences)) + ] + for future in futures: + yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801 else: - response = client.audio.speech.with_streaming_response.create(model=model, voice=voice, - response_format="mp3", - input=content_text.strip()) + response = client.audio.speech.with_streaming_response.create( + model=model, voice=voice, response_format="mp3", input=content_text.strip() + ) - yield from response.__enter__().iter_bytes(1024) + yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801 except Exception as ex: raise InvokeBadRequestError(str(ex)) - def _process_sentence(self, sentence: str, model: str, - voice, credentials: dict): + def _process_sentence(self, sentence: str, model: str, voice, credentials: dict): """ _tts_invoke openai text2speech model api @@ -108,10 +113,9 @@ def _process_sentence(self, sentence: str, model: str, return response.read() def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + ai_model_entity = self._get_ai_model_entity(credentials["base_model_name"], model) return ai_model_entity.entity - @staticmethod def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel | None: for ai_model_entity in TTS_BASE_MODELS: diff --git a/api/core/model_runtime/model_providers/baichuan/baichuan.py b/api/core/model_runtime/model_providers/baichuan/baichuan.py index 71bd6b5d923ed1..626fc811cfd47b 100644 --- a/api/core/model_runtime/model_providers/baichuan/baichuan.py +++ b/api/core/model_runtime/model_providers/baichuan/baichuan.py @@ -6,6 +6,7 @@ logger = logging.getLogger(__name__) + class BaichuanProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -19,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: model_instance = self.get_model_instance(ModelType.LLM) # Use `baichuan2-turbo` model for validate, - model_instance.validate_credentials( - model='baichuan2-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="baichuan2-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/baichuan/baichuan.yaml b/api/core/model_runtime/model_providers/baichuan/baichuan.yaml index 792126af7fd58f..81e6e36215aa84 100644 --- a/api/core/model_runtime/model_providers/baichuan/baichuan.yaml +++ b/api/core/model_runtime/model_providers/baichuan/baichuan.yaml @@ -27,11 +27,3 @@ provider_credential_schema: placeholder: zh_Hans: 在此输入您的 API Key en_US: Enter your API Key - - variable: secret_key - label: - en_US: Secret Key - type: secret-input - required: false - placeholder: - zh_Hans: 在此输入您的 Secret Key - en_US: Enter your Secret Key diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml index 04849500dcb7f1..8360dd5faffb00 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml @@ -43,3 +43,4 @@ parameter_rules: zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。 en_US: Allow the model to perform external search to enhance the generation results. required: false +deprecated: true diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo-192k.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo-192k.yaml index c8156c152b15bd..0ce0265cfe5c6c 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo-192k.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo-192k.yaml @@ -43,3 +43,4 @@ parameter_rules: zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。 en_US: Allow the model to perform external search to enhance the generation results. required: false +deprecated: true diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml index f91329c77aa9ec..ccb4ee8b92bc16 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml @@ -4,36 +4,32 @@ label: model_type: llm features: - agent-thought + - multi-tool-call model_properties: mode: chat context_size: 32000 parameter_rules: - name: temperature use_template: temperature + default: 0.3 - name: top_p use_template: top_p + default: 0.85 - name: top_k label: zh_Hans: 取样数量 en_US: Top k type: int + min: 0 + max: 20 + default: 5 help: zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 en_US: Only sample from the top K options for each subsequent token. required: false - name: max_tokens use_template: max_tokens - required: true - default: 8000 - min: 1 - max: 192000 - - name: presence_penalty - use_template: presence_penalty - - name: frequency_penalty - use_template: frequency_penalty - default: 1 - min: 1 - max: 2 + default: 2048 - name: with_search_enhance label: zh_Hans: 搜索增强 diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo-128k.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo-128k.yaml index bf72e8229671f6..d9cd086e82c994 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo-128k.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo-128k.yaml @@ -4,36 +4,44 @@ label: model_type: llm features: - agent-thought + - multi-tool-call model_properties: mode: chat context_size: 128000 parameter_rules: - name: temperature use_template: temperature + default: 0.3 - name: top_p use_template: top_p + default: 0.85 - name: top_k label: zh_Hans: 取样数量 en_US: Top k type: int + min: 0 + max: 20 + default: 5 help: zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 en_US: Only sample from the top K options for each subsequent token. required: false - name: max_tokens use_template: max_tokens - required: true - default: 8000 - min: 1 - max: 128000 - - name: presence_penalty - use_template: presence_penalty - - name: frequency_penalty - use_template: frequency_penalty - default: 1 - min: 1 - max: 2 + default: 2048 + - name: res_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object - name: with_search_enhance label: zh_Hans: 搜索增强 diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo.yaml index 85882519b86741..58f9b39a438a08 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan3-turbo.yaml @@ -4,36 +4,44 @@ label: model_type: llm features: - agent-thought + - multi-tool-call model_properties: mode: chat context_size: 32000 parameter_rules: - name: temperature use_template: temperature + default: 0.3 - name: top_p use_template: top_p + default: 0.85 - name: top_k label: zh_Hans: 取样数量 en_US: Top k type: int + min: 0 + max: 20 + default: 5 help: zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 en_US: Only sample from the top K options for each subsequent token. required: false - name: max_tokens use_template: max_tokens - required: true - default: 8000 - min: 1 - max: 32000 - - name: presence_penalty - use_template: presence_penalty - - name: frequency_penalty - use_template: frequency_penalty - default: 1 - min: 1 - max: 2 + default: 2048 + - name: res_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object - name: with_search_enhance label: zh_Hans: 搜索增强 diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan4.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan4.yaml index f8c65660818818..6a1135e165fcaf 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan4.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan4.yaml @@ -4,36 +4,44 @@ label: model_type: llm features: - agent-thought + - multi-tool-call model_properties: mode: chat context_size: 32000 parameter_rules: - name: temperature use_template: temperature + default: 0.3 - name: top_p use_template: top_p + default: 0.85 - name: top_k label: zh_Hans: 取样数量 en_US: Top k type: int + min: 0 + max: 20 + default: 5 help: zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 en_US: Only sample from the top K options for each subsequent token. required: false - name: max_tokens use_template: max_tokens - required: true - default: 8000 - min: 1 - max: 32000 - - name: presence_penalty - use_template: presence_penalty - - name: frequency_penalty - use_template: frequency_penalty - default: 1 - min: 1 - max: 2 + default: 2048 + - name: res_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object - name: with_search_enhance label: zh_Hans: 搜索增强 diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py index 7549b2fb60f71c..a7ca28d49d636e 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py @@ -4,17 +4,18 @@ class BaichuanTokenizer: @classmethod def count_chinese_characters(cls, text: str) -> int: - return len(re.findall(r'[\u4e00-\u9fa5]', text)) + return len(re.findall(r"[\u4e00-\u9fa5]", text)) @classmethod def count_english_vocabularies(cls, text: str) -> int: # remove all non-alphanumeric characters but keep spaces and other symbols like !, ., etc. - text = re.sub(r'[^a-zA-Z0-9\s]', '', text) + text = re.sub(r"[^a-zA-Z0-9\s]", "", text) # count the number of words not characters return len(text.split()) - + @classmethod def _get_num_tokens(cls, text: str) -> int: - # tokens = number of Chinese characters + number of English words * 1.3 (for estimation only, subject to actual return) + # tokens = number of Chinese characters + number of English words * 1.3 + # (for estimation only, subject to actual return) # https://platform.baichuan-ai.com/docs/text-Embedding - return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3) \ No newline at end of file + return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3) diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py index d7d8b7c91b6e2d..d5fda73009bba9 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py @@ -1,14 +1,13 @@ -from collections.abc import Generator -from enum import Enum -from hashlib import md5 -from json import dumps, loads -from typing import Any, Union +import json +from collections.abc import Iterator +from typing import Any, Optional, Union from requests import post +from core.model_runtime.entities.message_entities import PromptMessageTool from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( BadRequestError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InternalServerError, InvalidAPIKeyError, InvalidAuthenticationError, @@ -16,203 +15,130 @@ ) -class BaichuanMessage: - class Role(Enum): - USER = 'user' - ASSISTANT = 'assistant' - # Baichuan does not have system message - _SYSTEM = 'system' - - role: str = Role.USER.value - content: str - usage: dict[str, int] = None - stop_reason: str = '' - - def to_dict(self) -> dict[str, Any]: - return { - 'role': self.role, - 'content': self.content, - } - - def __init__(self, content: str, role: str = 'user') -> None: - self.content = content - self.role = role - class BaichuanModel: api_key: str - secret_key: str - def __init__(self, api_key: str, secret_key: str = '') -> None: + def __init__(self, api_key: str) -> None: self.api_key = api_key - self.secret_key = secret_key - def _model_mapping(self, model: str) -> str: + @property + def _model_mapping(self) -> dict: return { - 'baichuan2-turbo': 'Baichuan2-Turbo', - 'baichuan2-turbo-192k': 'Baichuan2-Turbo-192k', - 'baichuan2-53b': 'Baichuan2-53B', - 'baichuan3-turbo': 'Baichuan3-Turbo', - 'baichuan3-turbo-128k': 'Baichuan3-Turbo-128k', - 'baichuan4': 'Baichuan4', - }[model] - - def _handle_chat_generate_response(self, response) -> BaichuanMessage: - resp = response.json() - choices = resp.get('choices', []) - message = BaichuanMessage(content='', role='assistant') - for choice in choices: - message.content += choice['message']['content'] - message.role = choice['message']['role'] - if choice['finish_reason']: - message.stop_reason = choice['finish_reason'] - - if 'usage' in resp: - message.usage = { - 'prompt_tokens': resp['usage']['prompt_tokens'], - 'completion_tokens': resp['usage']['completion_tokens'], - 'total_tokens': resp['usage']['total_tokens'], - } + "baichuan2-turbo": "Baichuan2-Turbo", + "baichuan3-turbo": "Baichuan3-Turbo", + "baichuan3-turbo-128k": "Baichuan3-Turbo-128k", + "baichuan4": "Baichuan4", + } - return message - - def _handle_chat_stream_generate_response(self, response) -> Generator: - for line in response.iter_lines(): - if not line: - continue - line = line.decode('utf-8') - # remove the first `data: ` prefix - if line.startswith('data:'): - line = line[5:].strip() - try: - data = loads(line) - except Exception as e: - if line.strip() == '[DONE]': - return - choices = data.get('choices', []) - # save stop reason temporarily - stop_reason = '' - for choice in choices: - if choice.get('finish_reason'): - stop_reason = choice['finish_reason'] - - if len(choice['delta']['content']) == 0: - continue - yield BaichuanMessage(**choice['delta']) - - # if there is usage, the response is the last one, yield it and return - if 'usage' in data: - message = BaichuanMessage(content='', role='assistant') - message.usage = { - 'prompt_tokens': data['usage']['prompt_tokens'], - 'completion_tokens': data['usage']['completion_tokens'], - 'total_tokens': data['usage']['total_tokens'], - } - message.stop_reason = stop_reason - yield message - - def _build_parameters(self, model: str, stream: bool, messages: list[BaichuanMessage], - parameters: dict[str, Any]) \ - -> dict[str, Any]: - if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b' - or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'): - prompt_messages = [] - for message in messages: - if message.role == BaichuanMessage.Role.USER.value or message.role == BaichuanMessage.Role._SYSTEM.value: - # check if the latest message is a user message - if len(prompt_messages) > 0 and prompt_messages[-1]['role'] == BaichuanMessage.Role.USER.value: - prompt_messages[-1]['content'] += message.content - else: - prompt_messages.append({ - 'content': message.content, - 'role': BaichuanMessage.Role.USER.value, - }) - elif message.role == BaichuanMessage.Role.ASSISTANT.value: - prompt_messages.append({ - 'content': message.content, - 'role': message.role, - }) - # [baichuan] frequency_penalty must be between 1 and 2 - if 'frequency_penalty' in parameters: - if parameters['frequency_penalty'] < 1 or parameters['frequency_penalty'] > 2: - parameters['frequency_penalty'] = 1 + @property + def request_headers(self) -> dict[str, Any]: + return { + "Content-Type": "application/json", + "Authorization": "Bearer " + self.api_key, + } + + def _build_parameters( + self, + model: str, + stream: bool, + messages: list[dict], + parameters: dict[str, Any], + tools: Optional[list[PromptMessageTool]] = None, + ) -> dict[str, Any]: + if model in self._model_mapping: + # the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters. + # we need to rename it to res_format to get its value + if parameters.get("res_format") == "json_object": + parameters["response_format"] = {"type": "json_object"} + + if tools or parameters.get("with_search_enhance") is True: + parameters["tools"] = [] + + # with_search_enhance is deprecated, use web_search instead + if parameters.get("with_search_enhance") is True: + parameters["tools"].append( + { + "type": "web_search", + "web_search": {"enable": True}, + } + ) + if tools: + for tool in tools: + parameters["tools"].append( + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + }, + } + ) # turbo api accepts flat parameters return { - 'model': self._model_mapping(model), - 'stream': stream, - 'messages': prompt_messages, + "model": self._model_mapping.get(model), + "stream": stream, + "messages": messages, **parameters, } else: raise BadRequestError(f"Unknown model: {model}") - - def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]: - if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b' - or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'): - # there is no secret key for turbo api - return { - 'Content-Type': 'application/json', - 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ', - 'Authorization': 'Bearer ' + self.api_key, - } - else: - raise BadRequestError(f"Unknown model: {model}") - - def _calculate_md5(self, input_string): - return md5(input_string.encode('utf-8')).hexdigest() - - def generate(self, model: str, stream: bool, messages: list[BaichuanMessage], - parameters: dict[str, Any], timeout: int) \ - -> Union[Generator, BaichuanMessage]: - - if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b' - or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'): - api_base = 'https://api.baichuan-ai.com/v1/chat/completions' + + def generate( + self, + model: str, + stream: bool, + messages: list[dict], + parameters: dict[str, Any], + timeout: int, + tools: Optional[list[PromptMessageTool]] = None, + ) -> Union[Iterator, dict]: + if model in self._model_mapping: + api_base = "https://api.baichuan-ai.com/v1/chat/completions" else: raise BadRequestError(f"Unknown model: {model}") - - try: - data = self._build_parameters(model, stream, messages, parameters) - headers = self._build_headers(model, data) - except KeyError: - raise InternalServerError(f"Failed to build parameters for model: {model}") + + data = self._build_parameters(model, stream, messages, parameters, tools) try: response = post( url=api_base, - headers=headers, - data=dumps(data), + headers=self.request_headers, + data=json.dumps(data), timeout=timeout, - stream=stream + stream=stream, ) except Exception as e: raise InternalServerError(f"Failed to invoke model: {e}") - + if response.status_code != 200: try: resp = response.json() # try to parse error message - err = resp['error']['code'] - msg = resp['error']['message'] + err = resp["error"]["type"] + msg = resp["error"]["message"] except Exception as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") - if err == 'invalid_api_key': + if err == "invalid_api_key": raise InvalidAPIKeyError(msg) - elif err == 'insufficient_quota': - raise InsufficientAccountBalance(msg) - elif err == 'invalid_authentication': + elif err == "insufficient_quota": + raise InsufficientAccountBalanceError(msg) + elif err == "invalid_authentication": raise InvalidAuthenticationError(msg) - elif 'rate' in err: + elif err == "invalid_request_error": + raise BadRequestError(msg) + elif "rate" in err: raise RateLimitReachedError(msg) - elif 'internal' in err: + elif "internal" in err: raise InternalServerError(msg) - elif err == 'api_key_empty': + elif err == "api_key_empty": raise InvalidAPIKeyError(msg) else: raise InternalServerError(f"Unknown error: {err} with message: {msg}") - + if stream: - return self._handle_chat_stream_generate_response(response) + return response.iter_lines() else: - return self._handle_chat_generate_response(response) \ No newline at end of file + return response.json() diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py index 67d76b4a291c06..309b5cf413bd54 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py @@ -1,17 +1,22 @@ class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass -class InsufficientAccountBalance(Exception): + +class InsufficientAccountBalanceError(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/baichuan/llm/llm.py b/api/core/model_runtime/model_providers/baichuan/llm/llm.py index edcd3af4203cfb..91a14bf1009006 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/llm.py @@ -1,7 +1,12 @@ -from collections.abc import Generator +import json +from collections.abc import Generator, Iterator from typing import cast -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -21,10 +26,10 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer -from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanMessage, BaichuanModel +from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( BadRequestError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InternalServerError, InvalidAPIKeyError, InvalidAuthenticationError, @@ -32,20 +37,40 @@ ) -class BaichuanLarguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, - model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) +class BaichuanLanguageModel(LargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + return self._generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stream=stream, + ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: return self._num_tokens_from_messages(prompt_messages) - def _num_tokens_from_messages(self, messages: list[PromptMessage], ) -> int: + def _num_tokens_from_messages( + self, + messages: list[PromptMessage], + ) -> int: """Calculate num tokens for baichuan model""" def tokens(text: str): @@ -59,10 +84,10 @@ def tokens(text: str): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -84,20 +109,14 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) message_dict = {"role": "assistant", "content": message.content} + if message.tool_calls: + message_dict["tool_calls"] = [tool_call.dict() for tool_call in message.tool_calls] elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - message_dict = {"role": "user", "content": message.content} + message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolPromptMessage): - # copy from core/model_runtime/model_providers/anthropic/llm/llm.py message = cast(ToolPromptMessage, message) - message_dict = { - "role": "user", - "content": [{ - "type": "tool_result", - "tool_use_id": message.tool_call_id, - "content": message.content - }] - } + message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} else: raise ValueError(f"Unknown message type {type(message)}") @@ -105,102 +124,152 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: def validate_credentials(self, model: str, credentials: dict) -> None: # ping - instance = BaichuanModel( - api_key=credentials['api_key'], - secret_key=credentials.get('secret_key', '') - ) + instance = BaichuanModel(api_key=credentials["api_key"]) try: - instance.generate(model=model, stream=False, messages=[ - BaichuanMessage(content='ping', role='user') - ], parameters={ - 'max_tokens': 1, - }, timeout=60) + instance.generate( + model=model, + stream=False, + messages=[{"content": "ping", "role": "user"}], + parameters={ + "max_tokens": 1, + }, + timeout=60, + ) except Exception as e: raise CredentialsValidateFailedError(f"Invalid API key: {e}") - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - if tools is not None and len(tools) > 0: - raise InvokeBadRequestError("Baichuan model doesn't support tools") - - instance = BaichuanModel( - api_key=credentials['api_key'], - secret_key=credentials.get('secret_key', '') - ) - - # convert prompt messages to baichuan messages - messages = [ - BaichuanMessage( - content=message.content if isinstance(message.content, str) else ''.join([ - content.data for content in message.content - ]), - role=message.role.value - ) for message in prompt_messages - ] + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stream: bool = True, + ) -> LLMResult | Generator: + instance = BaichuanModel(api_key=credentials["api_key"]) + messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] # invoke model - response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters, - timeout=60) + response = instance.generate( + model=model, + stream=stream, + messages=messages, + parameters=model_parameters, + timeout=60, + tools=tools, + ) if stream: return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response) return self._handle_chat_generate_response(model, prompt_messages, credentials, response) - def _handle_chat_generate_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: BaichuanMessage) -> LLMResult: - # convert baichuan message to llm result - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=response.usage['prompt_tokens'], - completion_tokens=response.usage['completion_tokens']) + def _handle_chat_generate_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: dict, + ) -> LLMResult: + choices = response.get("choices", []) + assistant_message = AssistantPromptMessage(content="", tool_calls=[]) + if choices and choices[0]["finish_reason"] == "tool_calls": + for choice in choices: + for tool_call in choice["message"]["tool_calls"]: + tool = AssistantPromptMessage.ToolCall( + id=tool_call.get("id", ""), + type=tool_call.get("type", ""), + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=tool_call.get("function", {}).get("name", ""), + arguments=tool_call.get("function", {}).get("arguments", ""), + ), + ) + assistant_message.tool_calls.append(tool) + else: + for choice in choices: + assistant_message.content += choice["message"]["content"] + assistant_message.role = choice["message"]["role"] + + usage = response.get("usage") + if usage: + # transform usage + prompt_tokens = usage["prompt_tokens"] + completion_tokens = usage["completion_tokens"] + else: + # calculate num tokens + prompt_tokens = self._num_tokens_from_messages(prompt_messages) + completion_tokens = self._num_tokens_from_messages([assistant_message]) + + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + return LLMResult( model=model, prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=response.content, - tool_calls=[] - ), + message=assistant_message, usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Generator[BaichuanMessage, None, None]) -> Generator: - for message in response: - if message.usage: - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=message.usage['prompt_tokens'], - completion_tokens=message.usage['completion_tokens']) + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Iterator, + ) -> Generator: + for line in response: + if not line: + continue + line = line.decode("utf-8") + # remove the first `data: ` prefix + if line.startswith("data:"): + line = line[5:].strip() + try: + data = json.loads(line) + except Exception as e: + if line.strip() == "[DONE]": + return + choices = data.get("choices", []) + + stop_reason = "" + for choice in choices: + if choice.get("finish_reason"): + stop_reason = choice["finish_reason"] + + if len(choice["delta"]["content"]) == 0: + continue yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), - usage=usage, - finish_reason=message.stop_reason if message.stop_reason else None, + message=AssistantPromptMessage(content=choice["delta"]["content"], tool_calls=[]), + finish_reason=stop_reason, ), ) - else: + + # if there is usage, the response is the last one, yield it and return + if "usage" in data: + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=data["usage"]["prompt_tokens"], + completion_tokens=data["usage"]["completion_tokens"], + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), - finish_reason=message.stop_reason if message.stop_reason else None, + message=AssistantPromptMessage(content="", tool_calls=[]), + usage=usage, + finish_reason=stop_reason, ), ) @@ -215,21 +284,13 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py index 5ae90d54b5e421..1ace68d2b99502 100644 --- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py @@ -4,6 +4,7 @@ from requests import post +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( @@ -19,7 +20,7 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( BadRequestError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InternalServerError, InvalidAPIKeyError, InvalidAuthenticationError, @@ -31,11 +32,17 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): """ Model class for BaiChuan text embedding model. """ - api_base: str = 'http://api.baichuan-ai.com/v1/embeddings' - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + api_base: str = "http://api.baichuan-ai.com/v1/embeddings" + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -43,30 +50,26 @@ def _invoke(self, model: str, credentials: dict, :param credentials: model credentials :param texts: texts to embed :param user: unique user id + :param input_type: input type :return: embeddings result """ - api_key = credentials['api_key'] - if model != 'baichuan-text-embedding': - raise ValueError('Invalid model name') + api_key = credentials["api_key"] + if model != "baichuan-text-embedding": + raise ValueError("Invalid model name") if not api_key: - raise CredentialsValidateFailedError('api_key is required') - + raise CredentialsValidateFailedError("api_key is required") + # split into chunks of batch size 16 chunks = [] for i in range(0, len(texts), 16): - chunks.append(texts[i:i + 16]) + chunks.append(texts[i : i + 16]) embeddings = [] token_usage = 0 for chunk in chunks: - # embeding chunk - chunk_embeddings, chunk_usage = self.embedding( - model=model, - api_key=api_key, - texts=chunk, - user=user - ) + # embedding chunk + chunk_embeddings, chunk_usage = self.embedding(model=model, api_key=api_key, texts=chunk, user=user) embeddings.extend(chunk_embeddings) token_usage += chunk_usage @@ -74,17 +77,14 @@ def _invoke(self, model: str, credentials: dict, result = TextEmbeddingResult( model=model, embeddings=embeddings, - usage=self._calc_response_usage( - model=model, - credentials=credentials, - tokens=token_usage - ) + usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage), ) return result - - def embedding(self, model: str, api_key, texts: list[str], user: Optional[str] = None) \ - -> tuple[list[list[float]], int]: + + def embedding( + self, model: str, api_key, texts: list[str], user: Optional[str] = None + ) -> tuple[list[list[float]], int]: """ Embed given texts @@ -95,56 +95,47 @@ def embedding(self, model: str, api_key, texts: list[str], user: Optional[str] = :return: embeddings result """ url = self.api_base - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} - data = { - 'model': 'Baichuan-Text-Embedding', - 'input': texts - } + data = {"model": "Baichuan-Text-Embedding", "input": texts} try: response = post(url, headers=headers, data=dumps(data)) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: try: resp = response.json() # try to parse error message - err = resp['error']['code'] - msg = resp['error']['message'] + err = resp["error"]["code"] + msg = resp["error"]["message"] except Exception as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") - if err == 'invalid_api_key': + if err == "invalid_api_key": raise InvalidAPIKeyError(msg) - elif err == 'insufficient_quota': - raise InsufficientAccountBalance(msg) - elif err == 'invalid_authentication': - raise InvalidAuthenticationError(msg) - elif err and 'rate' in err: + elif err == "insufficient_quota": + raise InsufficientAccountBalanceError(msg) + elif err == "invalid_authentication": + raise InvalidAuthenticationError(msg) + elif err and "rate" in err: raise RateLimitReachedError(msg) - elif err and 'internal' in err: + elif err and "internal" in err: raise InternalServerError(msg) - elif err == 'api_key_empty': + elif err == "api_key_empty": raise InvalidAPIKeyError(msg) else: raise InternalServerError(f"Unknown error: {err} with message: {msg}") - + try: resp = response.json() - embeddings = resp['data'] - usage = resp['usage'] + embeddings = resp["data"] + usage = resp["usage"] except Exception as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") - return [ - data['embedding'] for data in embeddings - ], usage['total_tokens'] - + return [data["embedding"] for data in embeddings], usage["total_tokens"] def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -170,32 +161,24 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvalidAPIKeyError: - raise CredentialsValidateFailedError('Invalid api key') + raise CredentialsValidateFailedError("Invalid api key") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, - InsufficientAccountBalance, + InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -207,10 +190,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -221,7 +201,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/bedrock/bedrock.py b/api/core/model_runtime/model_providers/bedrock/bedrock.py index e99bc52ff8258b..1cfc1d199cbf8d 100644 --- a/api/core/model_runtime/model_providers/bedrock/bedrock.py +++ b/api/core/model_runtime/model_providers/bedrock/bedrock.py @@ -6,6 +6,7 @@ logger = logging.getLogger(__name__) + class BedrockProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -19,13 +20,10 @@ def validate_provider_credentials(self, credentials: dict) -> None: model_instance = self.get_model_instance(ModelType.LLM) # Use `amazon.titan-text-lite-v1` model by default for validating credentials - model_for_validation = credentials.get('model_for_validation', 'amazon.titan-text-lite-v1') - model_instance.validate_credentials( - model=model_for_validation, - credentials=credentials - ) + model_for_validation = credentials.get("model_for_validation", "amazon.titan-text-lite-v1") + model_instance.validate_credentials(model=model_for_validation, credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/bedrock/bedrock.yaml b/api/core/model_runtime/model_providers/bedrock/bedrock.yaml index c540ee23b31672..952f968b9d0f2f 100644 --- a/api/core/model_runtime/model_providers/bedrock/bedrock.yaml +++ b/api/core/model_runtime/model_providers/bedrock/bedrock.yaml @@ -50,34 +50,62 @@ provider_credential_schema: label: en_US: US East (N. Virginia) zh_Hans: 美国东部 (弗吉尼亚北部) + - value: us-east-2 + label: + en_US: US East (Ohio) + zh_Hans: 美国东部 (弗吉尼亚北部) - value: us-west-2 label: en_US: US West (Oregon) zh_Hans: 美国西部 (俄勒冈州) + - value: ap-south-1 + label: + en_US: Asia Pacific (Mumbai) + zh_Hans: 亚太地区(孟买) - value: ap-southeast-1 label: en_US: Asia Pacific (Singapore) zh_Hans: 亚太地区 (新加坡) + - value: ap-southeast-2 + label: + en_US: Asia Pacific (Sydney) + zh_Hans: 亚太地区 (悉尼) - value: ap-northeast-1 label: en_US: Asia Pacific (Tokyo) zh_Hans: 亚太地区 (东京) + - value: ap-northeast-2 + label: + en_US: Asia Pacific (Seoul) + zh_Hans: 亚太地区(首尔) + - value: ca-central-1 + label: + en_US: Canada (Central) + zh_Hans: 加拿大(中部) - value: eu-central-1 label: en_US: Europe (Frankfurt) zh_Hans: 欧洲 (法兰克福) + - value: eu-west-1 + label: + en_US: Europe (Ireland) + zh_Hans: 欧洲(爱尔兰) - value: eu-west-2 label: - en_US: Eu west London (London) + en_US: Europe (London) zh_Hans: 欧洲西部 (伦敦) + - value: eu-west-3 + label: + en_US: Europe (Paris) + zh_Hans: 欧洲(巴黎) + - value: sa-east-1 + label: + en_US: South America (São Paulo) + zh_Hans: 南美洲(圣保罗) - value: us-gov-west-1 label: en_US: AWS GovCloud (US-West) zh_Hans: AWS GovCloud (US-West) - - value: ap-southeast-2 - label: - en_US: Asia Pacific (Sydney) - zh_Hans: 亚太地区 (悉尼) - variable: model_for_validation required: false label: diff --git a/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml b/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml index 86c8061deefac8..47e2b020fd09a3 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/_position.yaml @@ -6,6 +6,8 @@ - anthropic.claude-v2:1 - anthropic.claude-3-sonnet-v1:0 - anthropic.claude-3-haiku-v1:0 +- ai21.jamba-1-5-large-v1:0 +- ai21.jamba-1-5-mini-v1:0 - cohere.command-light-text-v14 - cohere.command-text-v14 - cohere.command-r-plus-v1.0 @@ -15,6 +17,10 @@ - meta.llama3-1-405b-instruct-v1:0 - meta.llama3-8b-instruct-v1:0 - meta.llama3-70b-instruct-v1:0 +- us.meta.llama3-2-1b-instruct-v1:0 +- us.meta.llama3-2-3b-instruct-v1:0 +- us.meta.llama3-2-11b-instruct-v1:0 +- us.meta.llama3-2-90b-instruct-v1:0 - meta.llama2-13b-chat-v1 - meta.llama2-70b-chat-v1 - mistral.mistral-large-2407-v1:0 diff --git a/api/core/model_runtime/model_providers/bedrock/llm/ai21.jamba-1-5-large-v1.0.yaml b/api/core/model_runtime/model_providers/bedrock/llm/ai21.jamba-1-5-large-v1.0.yaml new file mode 100644 index 00000000000000..276c7312cee008 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/ai21.jamba-1-5-large-v1.0.yaml @@ -0,0 +1,26 @@ +model: ai21.jamba-1-5-large-v1:0 +label: + en_US: Jamba 1.5 Large +model_type: llm +model_properties: + mode: completion + context_size: 256000 +parameter_rules: + - name: temperature + use_template: temperature + default: 1 + min: 0.0 + max: 2.0 + - name: top_p + use_template: top_p + - name: max_gen_len + use_template: max_tokens + required: true + default: 4096 + min: 1 + max: 4096 +pricing: + input: '0.002' + output: '0.008' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/ai21.jamba-1-5-mini-v1.0.yaml b/api/core/model_runtime/model_providers/bedrock/llm/ai21.jamba-1-5-mini-v1.0.yaml new file mode 100644 index 00000000000000..3461d8ab71329d --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/ai21.jamba-1-5-mini-v1.0.yaml @@ -0,0 +1,26 @@ +model: ai21.jamba-1-5-mini-v1:0 +label: + en_US: Jamba 1.5 Mini +model_type: llm +model_properties: + mode: completion + context_size: 256000 +parameter_rules: + - name: temperature + use_template: temperature + default: 1 + min: 0.0 + max: 2.0 + - name: top_p + use_template: top_p + - name: max_gen_len + use_template: max_tokens + required: true + default: 4096 + min: 1 + max: 4096 +pricing: + input: '0.0002' + output: '0.0004' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-5-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-5-haiku-v1.yaml new file mode 100644 index 00000000000000..9d693dcd48a4d6 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-5-haiku-v1.yaml @@ -0,0 +1,60 @@ +model: anthropic.claude-3-5-haiku-20241022-v1:0 +label: + en_US: Claude 3.5 Haiku +model_type: llm +features: + - agent-thought + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 8192 + min: 1 + max: 8192 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + # docs: https://docs.anthropic.com/claude/docs/system-prompts + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format +pricing: + input: '0.001' + output: '0.005' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml index 53657c08a9bb36..c2d5eb64715616 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml @@ -52,6 +52,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.00025' output: '0.00125' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml index d083d31e302889..f90fa04266187b 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml @@ -52,6 +52,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.015' output: '0.075' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.5.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.5.yaml index 5302231086e79a..dad0d6b6b6c23c 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.5.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.5.yaml @@ -51,6 +51,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.003' output: '0.015' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml index 6995d2bf56c564..962def8011b157 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml @@ -51,6 +51,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.003' output: '0.015' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v2.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v2.yaml new file mode 100644 index 00000000000000..b1e56983751fdb --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v2.yaml @@ -0,0 +1,60 @@ +model: anthropic.claude-3-5-sonnet-20241022-v2:0 +label: + en_US: Claude 3.5 Sonnet V2 +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format +pricing: + input: '0.003' + output: '0.015' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.1.yaml index 1a3239c85eae4c..70294e4ad3ffde 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.1.yaml @@ -45,6 +45,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.008' output: '0.024' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml index 0343e3bbecf592..0a8ea61b6df0c8 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-v2.yaml @@ -45,6 +45,8 @@ parameter_rules: help: zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format pricing: input: '0.008' output: '0.024' diff --git a/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-haiku-v1.yaml new file mode 100644 index 00000000000000..22bd927d1d3f9b --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-haiku-v1.yaml @@ -0,0 +1,61 @@ +model: eu.anthropic.claude-3-haiku-20240307-v1:0 +label: + en_US: Claude 3 Haiku(EU.Cross Region Inference) +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + # docs: https://docs.anthropic.com/claude/docs/system-prompts + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format +pricing: + input: '0.00025' + output: '0.00125' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.5.yaml b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.5.yaml new file mode 100644 index 00000000000000..0dc596432bf162 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.5.yaml @@ -0,0 +1,60 @@ +model: eu.anthropic.claude-3-5-sonnet-20240620-v1:0 +label: + en_US: Claude 3.5 Sonnet(EU.Cross Region Inference) +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format +pricing: + input: '0.003' + output: '0.015' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.yaml new file mode 100644 index 00000000000000..c003fa3908c92c --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v1.yaml @@ -0,0 +1,60 @@ +model: eu.anthropic.claude-3-sonnet-20240229-v1:0 +label: + en_US: Claude 3 Sonnet(EU.Cross Region Inference) +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format +pricing: + input: '0.003' + output: '0.015' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v2.yaml b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v2.yaml new file mode 100644 index 00000000000000..8d831e6fcb18ba --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/eu.anthropic.claude-3-sonnet-v2.yaml @@ -0,0 +1,60 @@ +model: eu.anthropic.claude-3-5-sonnet-20241022-v2:0 +label: + en_US: Claude 3.5 Sonnet V2(EU.Cross Region Inference) +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format +pricing: + input: '0.003' + output: '0.015' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 335fa493cded9f..ff0403ee474d01 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -19,6 +19,7 @@ ) # local import +from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -43,37 +44,89 @@ from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel logger = logging.getLogger(__name__) +ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. +The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure +if you are not sure about the structure. + + +{{instructions}} + +""" # noqa: E501 -class BedrockLargeLanguageModel(LargeLanguageModel): +class BedrockLargeLanguageModel(LargeLanguageModel): # please refer to the documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html # TODO There is invoke issue: context limit on Cohere Model, will add them after fixed. - CONVERSE_API_ENABLED_MODEL_INFO=[ - {'prefix': 'anthropic.claude-v2', 'support_system_prompts': True, 'support_tool_use': False}, - {'prefix': 'anthropic.claude-v1', 'support_system_prompts': True, 'support_tool_use': False}, - {'prefix': 'anthropic.claude-3', 'support_system_prompts': True, 'support_tool_use': True}, - {'prefix': 'meta.llama', 'support_system_prompts': True, 'support_tool_use': False}, - {'prefix': 'mistral.mistral-7b-instruct', 'support_system_prompts': False, 'support_tool_use': False}, - {'prefix': 'mistral.mixtral-8x7b-instruct', 'support_system_prompts': False, 'support_tool_use': False}, - {'prefix': 'mistral.mistral-large', 'support_system_prompts': True, 'support_tool_use': True}, - {'prefix': 'mistral.mistral-small', 'support_system_prompts': True, 'support_tool_use': True}, - {'prefix': 'cohere.command-r', 'support_system_prompts': True, 'support_tool_use': True}, - {'prefix': 'amazon.titan', 'support_system_prompts': False, 'support_tool_use': False} + CONVERSE_API_ENABLED_MODEL_INFO = [ + {"prefix": "anthropic.claude-v2", "support_system_prompts": True, "support_tool_use": False}, + {"prefix": "anthropic.claude-v1", "support_system_prompts": True, "support_tool_use": False}, + {"prefix": "us.anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "eu.anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "us.meta.llama3-2", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "meta.llama", "support_system_prompts": True, "support_tool_use": False}, + {"prefix": "mistral.mistral-7b-instruct", "support_system_prompts": False, "support_tool_use": False}, + {"prefix": "mistral.mixtral-8x7b-instruct", "support_system_prompts": False, "support_tool_use": False}, + {"prefix": "mistral.mistral-large", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "mistral.mistral-small", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "cohere.command-r", "support_system_prompts": True, "support_tool_use": True}, + {"prefix": "amazon.titan", "support_system_prompts": False, "support_tool_use": False}, + {"prefix": "ai21.jamba-1-5", "support_system_prompts": True, "support_tool_use": False}, ] @staticmethod def _find_model_info(model_id): for model in BedrockLargeLanguageModel.CONVERSE_API_ENABLED_MODEL_INFO: - if model_id.startswith(model['prefix']): + if model_id.startswith(model["prefix"]): return model logger.info(f"current model id: {model_id} did not support by Converse API") return None - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: + """ + Code block mode wrapper for invoking large language model + """ + if model_parameters.get("response_format"): + stop = stop or [] + if "```\n" not in stop: + stop.append("```\n") + if "\n```" not in stop: + stop.append("\n```") + response_format = model_parameters.pop("response_format") + format_prompt = SystemPromptMessage( + content=ANTHROPIC_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) + ) + if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): + prompt_messages[0] = format_prompt + else: + prompt_messages.insert(0, format_prompt) + prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) + return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -87,17 +140,28 @@ def _invoke(self, model: str, credentials: dict, :param user: unique user id :return: full response or stream response chunk generator result """ - - model_info= BedrockLargeLanguageModel._find_model_info(model) + + model_info = BedrockLargeLanguageModel._find_model_info(model) if model_info: - model_info['model'] = model + model_info["model"] = model # invoke models via boto3 converse API - return self._generate_with_converse(model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools) + return self._generate_with_converse( + model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools + ) # invoke other models via boto3 client return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) - def _generate_with_converse(self, model_info: dict, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]: + def _generate_with_converse( + self, + model_info: dict, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + tools: Optional[list[PromptMessageTool]] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model with converse API @@ -109,35 +173,39 @@ def _generate_with_converse(self, model_info: dict, credentials: dict, prompt_me :param stream: is stream response :return: full response or stream response chunk generator result """ - bedrock_client = boto3.client(service_name='bedrock-runtime', - aws_access_key_id=credentials.get("aws_access_key_id"), - aws_secret_access_key=credentials.get("aws_secret_access_key"), - region_name=credentials["aws_region"]) + bedrock_client = boto3.client( + service_name="bedrock-runtime", + aws_access_key_id=credentials.get("aws_access_key_id"), + aws_secret_access_key=credentials.get("aws_secret_access_key"), + region_name=credentials["aws_region"], + ) system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages) inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop) parameters = { - 'modelId': model_info['model'], - 'messages': prompt_message_dicts, - 'inferenceConfig': inference_config, - 'additionalModelRequestFields': additional_model_fields, + "modelId": model_info["model"], + "messages": prompt_message_dicts, + "inferenceConfig": inference_config, + "additionalModelRequestFields": additional_model_fields, } - if model_info['support_system_prompts'] and system and len(system) > 0: - parameters['system'] = system + if model_info["support_system_prompts"] and system and len(system) > 0: + parameters["system"] = system - if model_info['support_tool_use'] and tools: - parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools) + if model_info["support_tool_use"] and tools: + parameters["toolConfig"] = self._convert_converse_tool_config(tools=tools) try: if stream: response = bedrock_client.converse_stream(**parameters) - return self._handle_converse_stream_response(model_info['model'], credentials, response, prompt_messages) + return self._handle_converse_stream_response( + model_info["model"], credentials, response, prompt_messages + ) else: response = bedrock_client.converse(**parameters) - return self._handle_converse_response(model_info['model'], credentials, response, prompt_messages) + return self._handle_converse_response(model_info["model"], credentials, response, prompt_messages) except ClientError as ex: - error_code = ex.response['Error']['Code'] + error_code = ex.response["Error"]["Code"] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" raise self._map_client_to_invoke_error(error_code, full_error_msg) except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: @@ -148,8 +216,10 @@ def _generate_with_converse(self, model_info: dict, credentials: dict, prompt_me except Exception as ex: raise InvokeError(str(ex)) - def _handle_converse_response(self, model: str, credentials: dict, response: dict, - prompt_messages: list[PromptMessage]) -> LLMResult: + + def _handle_converse_response( + self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm chat response @@ -159,36 +229,30 @@ def _handle_converse_response(self, model: str, credentials: dict, response: dic :param prompt_messages: prompt messages :return: full response chunk generator result """ - response_content = response['output']['message']['content'] + response_content = response["output"]["message"]["content"] # transform assistant message to prompt message - if response['stopReason'] == 'tool_use': + if response["stopReason"] == "tool_use": tool_calls = [] text, tool_use = self._extract_tool_use(response_content) tool_call = AssistantPromptMessage.ToolCall( - id=tool_use['toolUseId'], - type='function', + id=tool_use["toolUseId"], + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool_use['name'], - arguments=json.dumps(tool_use['input']) - ) + name=tool_use["name"], arguments=json.dumps(tool_use["input"]) + ), ) tool_calls.append(tool_call) - assistant_prompt_message = AssistantPromptMessage( - content=text, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=text, tool_calls=tool_calls) else: - assistant_prompt_message = AssistantPromptMessage( - content=response_content[0]['text'] - ) + assistant_prompt_message = AssistantPromptMessage(content=response_content[0]["text"]) # calculate num tokens - if response['usage']: + if response["usage"]: # transform usage - prompt_tokens = response['usage']['inputTokens'] - completion_tokens = response['usage']['outputTokens'] + prompt_tokens = response["usage"]["inputTokens"] + completion_tokens = response["usage"]["outputTokens"] else: # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -205,20 +269,25 @@ def _handle_converse_response(self, model: str, credentials: dict, response: dic ) return result - def _extract_tool_use(self, content:dict)-> tuple[str, dict]: + def _extract_tool_use(self, content: dict) -> tuple[str, dict]: tool_use = {} - text = '' + text = "" for item in content: - if 'toolUse' in item: - tool_use = item['toolUse'] - elif 'text' in item: - text = item['text'] + if "toolUse" in item: + tool_use = item["toolUse"] + elif "text" in item: + text = item["text"] else: raise ValueError(f"Got unknown item: {item}") return text, tool_use - def _handle_converse_stream_response(self, model: str, credentials: dict, response: dict, - prompt_messages: list[PromptMessage], ) -> Generator: + def _handle_converse_stream_response( + self, + model: str, + credentials: dict, + response: dict, + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm chat stream response @@ -230,7 +299,7 @@ def _handle_converse_stream_response(self, model: str, credentials: dict, respon """ try: - full_assistant_content = '' + full_assistant_content = "" return_model = None input_tokens = 0 output_tokens = 0 @@ -239,87 +308,85 @@ def _handle_converse_stream_response(self, model: str, credentials: dict, respon tool_calls: list[AssistantPromptMessage.ToolCall] = [] tool_use = {} - for chunk in response['stream']: - if 'messageStart' in chunk: + for chunk in response["stream"]: + if "messageStart" in chunk: return_model = model - elif 'messageStop' in chunk: - finish_reason = chunk['messageStop']['stopReason'] - elif 'contentBlockStart' in chunk: - tool = chunk['contentBlockStart']['start']['toolUse'] - tool_use['toolUseId'] = tool['toolUseId'] - tool_use['name'] = tool['name'] - elif 'metadata' in chunk: - input_tokens = chunk['metadata']['usage']['inputTokens'] - output_tokens = chunk['metadata']['usage']['outputTokens'] + elif "messageStop" in chunk: + finish_reason = chunk["messageStop"]["stopReason"] + elif "contentBlockStart" in chunk: + tool = chunk["contentBlockStart"]["start"]["toolUse"] + tool_use["toolUseId"] = tool["toolUseId"] + tool_use["name"] = tool["name"] + elif "metadata" in chunk: + input_tokens = chunk["metadata"]["usage"]["inputTokens"] + output_tokens = chunk["metadata"]["usage"]["outputTokens"] usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens) yield LLMResultChunk( model=return_model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, - message=AssistantPromptMessage( - content='', - tool_calls=tool_calls - ), + message=AssistantPromptMessage(content="", tool_calls=tool_calls), finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) - elif 'contentBlockDelta' in chunk: - delta = chunk['contentBlockDelta']['delta'] - if 'text' in delta: - chunk_text = delta['text'] if delta['text'] else '' + elif "contentBlockDelta" in chunk: + delta = chunk["contentBlockDelta"]["delta"] + if "text" in delta: + chunk_text = delta["text"] or "" full_assistant_content += chunk_text assistant_prompt_message = AssistantPromptMessage( - content=chunk_text if chunk_text else '', + content=chunk_text or "", ) - index = chunk['contentBlockDelta']['contentBlockIndex'] + index = chunk["contentBlockDelta"]["contentBlockIndex"] yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index+1, + index=index + 1, message=assistant_prompt_message, - ) + ), ) - elif 'toolUse' in delta: - if 'input' not in tool_use: - tool_use['input'] = '' - tool_use['input'] += delta['toolUse']['input'] - elif 'contentBlockStop' in chunk: - if 'input' in tool_use: + elif "toolUse" in delta: + if "input" not in tool_use: + tool_use["input"] = "" + tool_use["input"] += delta["toolUse"]["input"] + elif "contentBlockStop" in chunk: + if "input" in tool_use: tool_call = AssistantPromptMessage.ToolCall( - id=tool_use['toolUseId'], - type='function', + id=tool_use["toolUseId"], + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool_use['name'], - arguments=tool_use['input'] - ) + name=tool_use["name"], arguments=tool_use["input"] + ), ) tool_calls.append(tool_call) tool_use = {} except Exception as ex: raise InvokeError(str(ex)) - - def _convert_converse_api_model_parameters(self, model_parameters: dict, stop: Optional[list[str]] = None) -> tuple[dict, dict]: + + def _convert_converse_api_model_parameters( + self, model_parameters: dict, stop: Optional[list[str]] = None + ) -> tuple[dict, dict]: inference_config = {} additional_model_fields = {} - if 'max_tokens' in model_parameters: - inference_config['maxTokens'] = model_parameters['max_tokens'] + if "max_tokens" in model_parameters: + inference_config["maxTokens"] = model_parameters["max_tokens"] + + if "temperature" in model_parameters: + inference_config["temperature"] = model_parameters["temperature"] - if 'temperature' in model_parameters: - inference_config['temperature'] = model_parameters['temperature'] - - if 'top_p' in model_parameters: - inference_config['topP'] = model_parameters['temperature'] + if "top_p" in model_parameters: + inference_config["topP"] = model_parameters["temperature"] if stop: - inference_config['stopSequences'] = stop - - if 'top_k' in model_parameters: - additional_model_fields['top_k'] = model_parameters['top_k'] - + inference_config["stopSequences"] = stop + + if "top_k" in model_parameters: + additional_model_fields["top_k"] = model_parameters["top_k"] + return inference_config, additional_model_fields def _convert_converse_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]: @@ -331,7 +398,7 @@ def _convert_converse_prompt_messages(self, prompt_messages: list[PromptMessage] prompt_message_dicts = [] for message in prompt_messages: if isinstance(message, SystemPromptMessage): - message.content=message.content.strip() + message.content = message.content.strip() system.append({"text": message.content}) else: prompt_message_dicts.append(self._convert_prompt_message_to_dict(message)) @@ -348,15 +415,13 @@ def _convert_converse_tool_config(self, tools: Optional[list[PromptMessageTool]] "toolSpec": { "name": tool.name, "description": tool.description, - "inputSchema": { - "json": tool.parameters - } + "inputSchema": {"json": tool.parameters}, } } ) tool_config["tools"] = configs return tool_config - + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: """ Convert PromptMessage to dict @@ -364,15 +429,13 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) if isinstance(message.content, str): - message_dict = {"role": "user", "content": [{'text': message.content}]} + message_dict = {"role": "user", "content": [{"text": message.content}]} else: sub_messages = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "text": message_content.data - } + sub_message_dict = {"text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -381,10 +444,10 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: try: url = message_content.data image_content = requests.get(url).content - if '?' in url: - url = url.split('?')[0] + if "?" in url: + url = url.split("?")[0] mime_type, _ = mimetypes.guess_type(url) - base64_data = base64.b64encode(image_content).decode('utf-8') + base64_data = base64.b64encode(image_content).decode("utf-8") except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") else: @@ -393,17 +456,14 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: base64_data = data_split[1] image_content = base64.b64decode(base64_data) - if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: - raise ValueError(f"Unsupported image type {mime_type}, " - f"only support image/jpeg, image/png, image/gif, and image/webp") + if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}: + raise ValueError( + f"Unsupported image type {mime_type}, " + f"only support image/jpeg, image/png, image/gif, and image/webp" + ) sub_message_dict = { - "image": { - "format": mime_type.replace('image/', ''), - "source": { - "bytes": image_content - } - } + "image": {"format": mime_type.replace("image/", ""), "source": {"bytes": image_content}} } sub_messages.append(sub_message_dict) @@ -412,36 +472,46 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: message = cast(AssistantPromptMessage, message) if message.tool_calls: message_dict = { - "role": "assistant", "content":[{ - "toolUse": { - "toolUseId": message.tool_calls[0].id, - "name": message.tool_calls[0].function.name, - "input": json.loads(message.tool_calls[0].function.arguments) + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": message.tool_calls[0].id, + "name": message.tool_calls[0].function.name, + "input": json.loads(message.tool_calls[0].function.arguments), + } } - }] + ], } else: - message_dict = {"role": "assistant", "content": [{'text': message.content}]} + message_dict = {"role": "assistant", "content": [{"text": message.content}]} elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - message_dict = [{'text': message.content}] + message_dict = [{"text": message.content}] elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = { "role": "user", - "content": [{ - "toolResult": { - "toolUseId": message.tool_call_id, - "content": [{"json": {"text": message.content}}] - } - }] + "content": [ + { + "toolResult": { + "toolUseId": message.tool_call_id, + "content": [{"json": {"text": message.content}}], + } + } + ], } else: raise ValueError(f"Got unknown type {message}") return message_dict - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage] | str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage] | str, + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -451,15 +521,14 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr :param tools: tools for tool calling :return:md = genai.GenerativeModel(model) """ - prefix = model.split('.')[0] - model_name = model.split('.')[1] - + prefix = model.split(".")[0] + model_name = model.split(".")[1] + if isinstance(prompt_messages, str): prompt = prompt_messages else: prompt = self._convert_messages_to_prompt(prompt_messages, prefix, model_name) - return self._get_num_tokens_by_gpt2(prompt) def validate_credentials(self, model: str, credentials: dict) -> None: @@ -476,30 +545,36 @@ def validate_credentials(self, model: str, credentials: dict) -> None: "max_tokens": 32, } elif "ai21" in model: - # ValidationException: Malformed input request: #/temperature: expected type: Number, found: Null#/maxTokens: expected type: Integer, found: Null#/topP: expected type: Number, found: Null, please reformat your input and try again. + # ValidationException: Malformed input request: #/temperature: expected type: Number, + # found: Null#/maxTokens: expected type: Integer, found: Null#/topP: expected type: Number, found: Null, + # please reformat your input and try again. required_params = { "temperature": 0.7, "topP": 0.9, "maxTokens": 32, } - + try: ping_message = UserPromptMessage(content="ping") - self._invoke(model=model, - credentials=credentials, - prompt_messages=[ping_message], - model_parameters=required_params, - stream=False) - + self._invoke( + model=model, + credentials=credentials, + prompt_messages=[ping_message], + model_parameters=required_params, + stream=False, + ) + except ClientError as ex: - error_code = ex.response['Error']['Code'] + error_code = ex.response["Error"]["Code"] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" raise CredentialsValidateFailedError(str(self._map_client_to_invoke_error(error_code, full_error_msg))) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None) -> str: + def _convert_one_message_to_text( + self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None + ) -> str: """ Convert a single message to a string. @@ -514,7 +589,7 @@ def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str if isinstance(message, UserPromptMessage): body = content - if (isinstance(content, list)): + if isinstance(content, list): body = "".join([c.data for c in content if c.type == PromptMessageContentType.TEXT]) message_text = f"{human_prompt_prefix} {body} {human_prompt_postfix}" elif isinstance(message, AssistantPromptMessage): @@ -528,7 +603,9 @@ def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str return message_text - def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None) -> str: + def _convert_messages_to_prompt( + self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None + ) -> str: """ Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models @@ -537,27 +614,31 @@ def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefi :return: Combined string with necessary human_prompt and ai_prompt tags. """ if not messages: - return '' + return "" messages = messages.copy() # don't mutate the original list if not isinstance(messages[-1], AssistantPromptMessage): messages.append(AssistantPromptMessage(content="")) - text = "".join( - self._convert_one_message_to_text(message, model_prefix, model_name) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message, model_prefix, model_name) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() - def _create_payload(self, model: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True): + def _create_payload( + self, + model: str, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + ): """ Create payload for bedrock api call depending on model provider """ payload = {} - model_prefix = model.split('.')[0] - model_name = model.split('.')[1] + model_prefix = model.split(".")[0] + model_name = model.split(".")[1] if model_prefix == "ai21": payload["temperature"] = model_parameters.get("temperature") @@ -571,21 +652,27 @@ def _create_payload(self, model: str, prompt_messages: list[PromptMessage], mode payload["frequencyPenalty"] = {model_parameters.get("frequencyPenalty")} if model_parameters.get("countPenalty"): payload["countPenalty"] = {model_parameters.get("countPenalty")} - + elif model_prefix == "cohere": - payload = { **model_parameters } + payload = {**model_parameters} payload["prompt"] = prompt_messages[0].content payload["stream"] = stream - + else: raise ValueError(f"Got unknown model prefix {model_prefix}") - + return payload - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -598,18 +685,16 @@ def _generate(self, model: str, credentials: dict, :param user: unique user id :return: full response or stream response chunk generator result """ - client_config = Config( - region_name=credentials["aws_region"] - ) + client_config = Config(region_name=credentials["aws_region"]) runtime_client = boto3.client( - service_name='bedrock-runtime', + service_name="bedrock-runtime", config=client_config, aws_access_key_id=credentials.get("aws_access_key_id"), - aws_secret_access_key=credentials.get("aws_secret_access_key") + aws_secret_access_key=credentials.get("aws_secret_access_key"), ) - model_prefix = model.split('.')[0] + model_prefix = model.split(".")[0] payload = self._create_payload(model, prompt_messages, model_parameters, stop, stream) # need workaround for ai21 models which doesn't support streaming @@ -619,18 +704,13 @@ def _generate(self, model: str, credentials: dict, invoke = runtime_client.invoke_model try: - body_jsonstr=json.dumps(payload) - response = invoke( - modelId=model, - contentType="application/json", - accept= "*/*", - body=body_jsonstr - ) + body_jsonstr = json.dumps(payload) + response = invoke(modelId=model, contentType="application/json", accept="*/*", body=body_jsonstr) except ClientError as ex: - error_code = ex.response['Error']['Code'] + error_code = ex.response["Error"]["Code"] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" raise self._map_client_to_invoke_error(error_code, full_error_msg) - + except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: raise InvokeConnectionError(str(ex)) @@ -639,15 +719,15 @@ def _generate(self, model: str, credentials: dict, except Exception as ex: raise InvokeError(str(ex)) - if stream: return self._handle_generate_stream_response(model, credentials, response, prompt_messages) return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: dict, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -657,7 +737,7 @@ def _handle_generate_response(self, model: str, credentials: dict, response: dic :param prompt_messages: prompt messages :return: llm response """ - response_body = json.loads(response.get('body').read().decode('utf-8')) + response_body = json.loads(response.get("body").read().decode("utf-8")) finish_reason = response_body.get("error") @@ -665,25 +745,23 @@ def _handle_generate_response(self, model: str, credentials: dict, response: dic raise InvokeError(finish_reason) # get output text and calculate num tokens based on model / provider - model_prefix = model.split('.')[0] + model_prefix = model.split(".")[0] if model_prefix == "ai21": - output = response_body.get('completions')[0].get('data').get('text') + output = response_body.get("completions")[0].get("data").get("text") prompt_tokens = len(response_body.get("prompt").get("tokens")) - completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens')) - + completion_tokens = len(response_body.get("completions")[0].get("data").get("tokens")) + elif model_prefix == "cohere": output = response_body.get("generations")[0].get("text") prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, output if output else '') - + completion_tokens = self.get_num_tokens(model, credentials, output or "") + else: raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") # construct assistant message from output - assistant_prompt_message = AssistantPromptMessage( - content=output - ) + assistant_prompt_message = AssistantPromptMessage(content=output) # calculate usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) @@ -698,8 +776,9 @@ def _handle_generate_response(self, model: str, credentials: dict, response: dic return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: dict, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -709,65 +788,59 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon :param prompt_messages: prompt messages :return: llm response chunk generator result """ - model_prefix = model.split('.')[0] + model_prefix = model.split(".")[0] if model_prefix == "ai21": - response_body = json.loads(response.get('body').read().decode('utf-8')) + response_body = json.loads(response.get("body").read().decode("utf-8")) - content = response_body.get('completions')[0].get('data').get('text') - finish_reason = response_body.get('completions')[0].get('finish_reason') + content = response_body.get("completions")[0].get("data").get("text") + finish_reason = response_body.get("completions")[0].get("finish_reason") prompt_tokens = len(response_body.get("prompt").get("tokens")) - completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens')) + completion_tokens = len(response_body.get("completions")[0].get("data").get("tokens")) usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) yield LLMResultChunk( - model=model, - prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=content), - finish_reason=finish_reason, - usage=usage - ) - ) + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, message=AssistantPromptMessage(content=content), finish_reason=finish_reason, usage=usage + ), + ) return - - stream = response.get('body') + + stream = response.get("body") if not stream: - raise InvokeError('No response body') - + raise InvokeError("No response body") + index = -1 for event in stream: - chunk = event.get('chunk') - + chunk = event.get("chunk") + if not chunk: exception_name = next(iter(event)) full_ex_msg = f"{exception_name}: {event[exception_name]['message']}" raise self._map_client_to_invoke_error(exception_name, full_ex_msg) - payload = json.loads(chunk.get('bytes').decode()) + payload = json.loads(chunk.get("bytes").decode()) - model_prefix = model.split('.')[0] + model_prefix = model.split(".")[0] if model_prefix == "cohere": content_delta = payload.get("text") finish_reason = payload.get("finish_reason") - + else: raise ValueError(f"Got unknown model prefix {model_prefix} when handling stream response") # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content = content_delta if content_delta else '', + content=content_delta or "", ) index += 1 - + if not finish_reason: yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) else: @@ -777,36 +850,33 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - finish_reason=finish_reason, - usage=usage - ) + index=index, message=assistant_prompt_message, finish_reason=finish_reason, usage=usage + ), ) - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ Map model invoke error to unified error - The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller - The value is the md = genai.GenerativeModel(model)error type thrown by the model, + The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller + The value is the md = genai.GenerativeModel(model) error type thrown by the model, which needs to be converted into a unified error type for the caller. - :return: Invoke emd = genai.GenerativeModel(model)rror mapping + :return: Invoke emd = genai.GenerativeModel(model) error mapping """ return { InvokeConnectionError: [], InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } - + def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]: """ Map client error to invoke error @@ -818,11 +888,16 @@ def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[I if error_code == "AccessDeniedException": return InvokeAuthorizationError(error_msg) - elif error_code in ["ResourceNotFoundException", "ValidationException"]: + elif error_code in {"ResourceNotFoundException", "ValidationException"}: return InvokeBadRequestError(error_msg) - elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: + elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}: return InvokeRateLimitError(error_msg) - elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]: + elif error_code in { + "ModelTimeoutException", + "ModelErrorException", + "InternalServerException", + "ModelNotReadyException", + }: return InvokeServerUnavailableError(error_msg) elif error_code == "ModelStreamErrorException": return InvokeConnectionError(error_msg) diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-5-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-5-haiku-v1.yaml new file mode 100644 index 00000000000000..9781965555967a --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-5-haiku-v1.yaml @@ -0,0 +1,60 @@ +model: us.anthropic.claude-3-5-haiku-20241022-v1:0 +label: + en_US: Claude 3.5 Haiku(US.Cross Region Inference) +model_type: llm +features: + - agent-thought + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + # docs: https://docs.anthropic.com/claude/docs/system-prompts + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format +pricing: + input: '0.001' + output: '0.005' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-haiku-v1.yaml new file mode 100644 index 00000000000000..14f2450e0a09c5 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-haiku-v1.yaml @@ -0,0 +1,61 @@ +model: us.anthropic.claude-3-haiku-20240307-v1:0 +label: + en_US: Claude 3 Haiku(US.Cross Region Inference) +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + # docs: https://docs.anthropic.com/claude/docs/system-prompts + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format +pricing: + input: '0.00025' + output: '0.00125' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-opus-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-opus-v1.yaml new file mode 100644 index 00000000000000..eef7b766d9d844 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-opus-v1.yaml @@ -0,0 +1,61 @@ +model: us.anthropic.claude-3-opus-20240229-v1:0 +label: + en_US: Claude 3 Opus(US.Cross Region Inference) +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + # docs: https://docs.anthropic.com/claude/docs/system-prompts + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format +pricing: + input: '0.015' + output: '0.075' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.5.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.5.yaml new file mode 100644 index 00000000000000..a02fc350d1f107 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.5.yaml @@ -0,0 +1,60 @@ +model: us.anthropic.claude-3-5-sonnet-20240620-v1:0 +label: + en_US: Claude 3.5 Sonnet(US.Cross Region Inference) +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format +pricing: + input: '0.003' + output: '0.015' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.yaml new file mode 100644 index 00000000000000..7db35277618a35 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v1.yaml @@ -0,0 +1,60 @@ +model: us.anthropic.claude-3-sonnet-20240229-v1:0 +label: + en_US: Claude 3 Sonnet(US.Cross Region Inference) +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format +pricing: + input: '0.003' + output: '0.015' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v2.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v2.yaml new file mode 100644 index 00000000000000..31a403289b0f86 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v2.yaml @@ -0,0 +1,60 @@ +model: us.anthropic.claude-3-5-sonnet-20241022-v2:0 +label: + en_US: Claude 3.5 Sonnet V2(US.Cross Region Inference) +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format +pricing: + input: '0.003' + output: '0.015' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.meta.llama3-2-11b-instruct-v1.0.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.meta.llama3-2-11b-instruct-v1.0.yaml new file mode 100644 index 00000000000000..029f428776e0be --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.meta.llama3-2-11b-instruct-v1.0.yaml @@ -0,0 +1,29 @@ +model: us.meta.llama3-2-11b-instruct-v1:0 +label: + en_US: US Meta Llama 3.2 11B Instruct +model_type: llm +features: + - vision + - tool-call +model_properties: + mode: completion + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.5 + min: 0.0 + max: 1 + - name: top_p + use_template: top_p + - name: max_gen_len + use_template: max_tokens + required: true + default: 512 + min: 1 + max: 2048 +pricing: + input: '0.00035' + output: '0.00035' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.meta.llama3-2-1b-instruct-v1.0.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.meta.llama3-2-1b-instruct-v1.0.yaml new file mode 100644 index 00000000000000..51c8474e54846d --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.meta.llama3-2-1b-instruct-v1.0.yaml @@ -0,0 +1,26 @@ +model: us.meta.llama3-2-1b-instruct-v1:0 +label: + en_US: US Meta Llama 3.2 1B Instruct +model_type: llm +model_properties: + mode: completion + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.5 + min: 0.0 + max: 1 + - name: top_p + use_template: top_p + - name: max_gen_len + use_template: max_tokens + required: true + default: 512 + min: 1 + max: 2048 +pricing: + input: '0.0001' + output: '0.0001' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.meta.llama3-2-3b-instruct-v1.0.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.meta.llama3-2-3b-instruct-v1.0.yaml new file mode 100644 index 00000000000000..472cc7403e2d3e --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.meta.llama3-2-3b-instruct-v1.0.yaml @@ -0,0 +1,26 @@ +model: us.meta.llama3-2-3b-instruct-v1:0 +label: + en_US: US Meta Llama 3.2 3B Instruct +model_type: llm +model_properties: + mode: completion + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.5 + min: 0.0 + max: 1 + - name: top_p + use_template: top_p + - name: max_gen_len + use_template: max_tokens + required: true + default: 512 + min: 1 + max: 2048 +pricing: + input: '0.00015' + output: '0.00015' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.meta.llama3-2-90b-instruct-v1.0.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.meta.llama3-2-90b-instruct-v1.0.yaml new file mode 100644 index 00000000000000..cecd0236ca9e31 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.meta.llama3-2-90b-instruct-v1.0.yaml @@ -0,0 +1,31 @@ +model: us.meta.llama3-2-90b-instruct-v1:0 +label: + en_US: US Meta Llama 3.2 90B Instruct +model_type: llm +features: + - tool-call +model_properties: + mode: completion + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.5 + min: 0.0 + max: 1 + - name: top_p + use_template: top_p + default: 0.9 + min: 0 + max: 1 + - name: max_gen_len + use_template: max_tokens + required: true + default: 512 + min: 1 + max: 2048 +pricing: + input: '0.002' + output: '0.002' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py index 993416cdc8ab4f..2f998d8bdaee90 100644 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py @@ -13,6 +13,7 @@ UnknownServiceError, ) +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( @@ -27,12 +28,16 @@ logger = logging.getLogger(__name__) -class BedrockTextEmbeddingModel(TextEmbeddingModel): - - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: +class BedrockTextEmbeddingModel(TextEmbeddingModel): + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -40,69 +45,59 @@ def _invoke(self, model: str, credentials: dict, :param credentials: model credentials :param texts: texts to embed :param user: unique user id + :param input_type: input type :return: embeddings result """ - client_config = Config( - region_name=credentials["aws_region"] - ) + client_config = Config(region_name=credentials["aws_region"]) bedrock_runtime = boto3.client( - service_name='bedrock-runtime', + service_name="bedrock-runtime", config=client_config, aws_access_key_id=credentials.get("aws_access_key_id"), - aws_secret_access_key=credentials.get("aws_secret_access_key") + aws_secret_access_key=credentials.get("aws_secret_access_key"), ) embeddings = [] token_usage = 0 - - model_prefix = model.split('.')[0] - - if model_prefix == "amazon" : + + model_prefix = model.split(".")[0] + + if model_prefix == "amazon": for text in texts: body = { - "inputText": text, + "inputText": text, } response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) - embeddings.extend([response_body.get('embedding')]) - token_usage += response_body.get('inputTextTokenCount') - logger.warning(f'Total Tokens: {token_usage}') + embeddings.extend([response_body.get("embedding")]) + token_usage += response_body.get("inputTextTokenCount") + logger.warning(f"Total Tokens: {token_usage}") result = TextEmbeddingResult( model=model, embeddings=embeddings, - usage=self._calc_response_usage( - model=model, - credentials=credentials, - tokens=token_usage - ) + usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage), ) return result - if model_prefix == "cohere" : - input_type = 'search_document' if len(texts) > 1 else 'search_query' + if model_prefix == "cohere": + input_type = "search_document" if len(texts) > 1 else "search_query" for text in texts: body = { - "texts": [text], - "input_type": input_type, + "texts": [text], + "input_type": input_type, } response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) - embeddings.extend(response_body.get('embeddings')) + embeddings.extend(response_body.get("embeddings")) token_usage += len(text) result = TextEmbeddingResult( model=model, embeddings=embeddings, - usage=self._calc_response_usage( - model=model, - credentials=credentials, - tokens=token_usage - ) + usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage), ) return result - #others + # others raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") - def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ Get number of tokens for given prompt messages @@ -125,35 +120,41 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :param credentials: model credentials :return: """ - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ Map model invoke error to unified error - The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller - The value is the md = genai.GenerativeModel(model)error type thrown by the model, + The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller + The value is the md = genai.GenerativeModel(model) error type thrown by the model, which needs to be converted into a unified error type for the caller. - :return: Invoke emd = genai.GenerativeModel(model)rror mapping + :return: Invoke emd = genai.GenerativeModel(model) error mapping """ return { InvokeConnectionError: [], InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } - - def _create_payload(self, model_prefix: str, texts: list[str], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True): + + def _create_payload( + self, + model_prefix: str, + texts: list[str], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + ): """ Create payload for bedrock api call depending on model provider """ payload = {} if model_prefix == "amazon": - payload['inputText'] = texts + payload["inputText"] = texts - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -165,10 +166,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -179,7 +177,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -195,35 +193,41 @@ def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[I if error_code == "AccessDeniedException": return InvokeAuthorizationError(error_msg) - elif error_code in ["ResourceNotFoundException", "ValidationException"]: + elif error_code in {"ResourceNotFoundException", "ValidationException"}: return InvokeBadRequestError(error_msg) - elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: + elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}: return InvokeRateLimitError(error_msg) - elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]: + elif error_code in { + "ModelTimeoutException", + "ModelErrorException", + "InternalServerException", + "ModelNotReadyException", + }: return InvokeServerUnavailableError(error_msg) elif error_code == "ModelStreamErrorException": return InvokeConnectionError(error_msg) return InvokeError(error_msg) - - def _invoke_bedrock_embedding(self, model: str, bedrock_runtime, body: dict, ): - accept = 'application/json' - content_type = 'application/json' + def _invoke_bedrock_embedding( + self, + model: str, + bedrock_runtime, + body: dict, + ): + accept = "application/json" + content_type = "application/json" try: response = bedrock_runtime.invoke_model( - body=json.dumps(body), - modelId=model, - accept=accept, - contentType=content_type + body=json.dumps(body), modelId=model, accept=accept, contentType=content_type ) - response_body = json.loads(response.get('body').read().decode('utf-8')) + response_body = json.loads(response.get("body").read().decode("utf-8")) return response_body except ClientError as ex: - error_code = ex.response['Error']['Code'] + error_code = ex.response["Error"]["Code"] full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" raise self._map_client_to_invoke_error(error_code, full_error_msg) - + except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: raise InvokeConnectionError(str(ex)) diff --git a/api/core/model_runtime/model_providers/chatglm/chatglm.py b/api/core/model_runtime/model_providers/chatglm/chatglm.py index e9dd5794f31ce2..71d9a1532281bd 100644 --- a/api/core/model_runtime/model_providers/chatglm/chatglm.py +++ b/api/core/model_runtime/model_providers/chatglm/chatglm.py @@ -20,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: model_instance = self.get_model_instance(ModelType.LLM) # Use `chatglm3-6b` model for validate, - model_instance.validate_credentials( - model='chatglm3-6b', - credentials=credentials - ) + model_instance.validate_credentials(model="chatglm3-6b", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/chatglm/llm/llm.py b/api/core/model_runtime/model_providers/chatglm/llm/llm.py index e83d08af714469..b3eeb48e226e18 100644 --- a/api/core/model_runtime/model_providers/chatglm/llm/llm.py +++ b/api/core/model_runtime/model_providers/chatglm/llm/llm.py @@ -43,12 +43,19 @@ logger = logging.getLogger(__name__) + class ChatGLMLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ Invoke large language model @@ -71,11 +78,16 @@ def _invoke(self, model: str, credentials: dict, tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -96,11 +108,16 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - self._invoke(model=model, credentials=credentials, prompt_messages=[ - UserPromptMessage(content="ping"), - ], model_parameters={ - "max_tokens": 16, - }) + self._invoke( + model=model, + credentials=credentials, + prompt_messages=[ + UserPromptMessage(content="ping"), + ], + model_parameters={ + "max_tokens": 16, + }, + ) except Exception as e: raise CredentialsValidateFailedError(str(e)) @@ -124,24 +141,24 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] ConflictError, NotFoundError, UnprocessableEntityError, - PermissionDeniedError - ], - InvokeRateLimitError: [ - RateLimitError + PermissionDeniedError, ], - InvokeAuthorizationError: [ - AuthenticationError - ], - InvokeBadRequestError: [ - ValueError - ] + InvokeRateLimitError: [RateLimitError], + InvokeAuthorizationError: [AuthenticationError], + InvokeBadRequestError: [ValueError], } - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ Invoke large language model @@ -163,35 +180,31 @@ def _generate(self, model: str, credentials: dict, extra_model_kwargs = {} if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user if tools and len(tools) > 0: - extra_model_kwargs['functions'] = [ - helper.dump_model(tool) for tool in tools - ] + extra_model_kwargs["functions"] = [helper.dump_model(tool) for tool in tools] result = client.chat.completions.create( messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], model=model, stream=stream, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) if stream: return self._handle_chat_generate_stream_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) - + return self._handle_chat_generate_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) - + def _check_chatglm_parameters(self, model: str, model_parameters: dict, tools: list[PromptMessageTool]) -> None: if model.find("chatglm2") != -1 and tools is not None and len(tools) > 0: raise InvokeBadRequestError("ChatGLM2 does not support function calling") @@ -212,7 +225,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: if message.tool_calls and len(message.tool_calls) > 0: message_dict["function_call"] = { "name": message.tool_calls[0].function.name, - "arguments": message.tool_calls[0].function.arguments + "arguments": message.tool_calls[0].function.arguments, } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) @@ -223,12 +236,12 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: message_dict = {"role": "function", "content": message.content} else: raise ValueError(f"Unknown message type {type(message)}") - + return message_dict - - def _extract_response_tool_calls(self, - response_function_calls: list[FunctionCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + + def _extract_response_tool_calls( + self, response_function_calls: list[FunctionCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -239,19 +252,14 @@ def _extract_response_tool_calls(self, if response_function_calls: for response_tool_call in response_function_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.name, - arguments=response_tool_call.arguments + name=response_tool_call.name, arguments=response_tool_call.arguments ) - tool_call = AssistantPromptMessage.ToolCall( - id=0, - type='function', - function=function - ) + tool_call = AssistantPromptMessage.ToolCall(id=0, type="function", function=function) tool_calls.append(tool_call) return tool_calls - + def _to_client_kwargs(self, credentials: dict) -> dict: """ Convert invoke kwargs to client kwargs @@ -265,17 +273,20 @@ def _to_client_kwargs(self, credentials: dict) -> dict: client_kwargs = { "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "api_key": "1", - "base_url": str(URL(credentials['api_base']) / 'v1') + "base_url": str(URL(credentials["api_base"]) / "v1"), } return client_kwargs - - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: Stream[ChatCompletionChunk], - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) \ - -> Generator: - - full_response = '' + + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[ChatCompletionChunk], + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> Generator: + full_response = "" for chunk in response: if len(chunk.choices) == 0: @@ -283,35 +294,37 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""): continue - + # check if there is a tool call in the response function_calls = None if delta.delta.function_call: function_calls = [delta.delta.function_call] - assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else []) + assistant_message_tool_calls = self._extract_response_tool_calls(function_calls or []) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_message_tool_calls + content=delta.delta.content or "", tool_calls=assistant_message_tool_calls ) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=assistant_message_tool_calls + content=full_response, tool_calls=assistant_message_tool_calls ) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) - + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -320,7 +333,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage + usage=usage, ), ) else: @@ -335,11 +348,15 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r ) full_response += delta.delta.content - - def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) \ - -> LLMResult: + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: ChatCompletion, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> LLMResult: """ Handle llm chat response @@ -359,15 +376,14 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, response tool_calls = self._extract_response_tool_calls([function_calls] if function_calls else []) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) response = LLMResult( model=model, @@ -378,7 +394,7 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, response ) return response - + def _num_tokens_from_string(self, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -395,17 +411,19 @@ def _num_tokens_from_string(self, text: str, tools: Optional[list[PromptMessageT return num_tokens - def _num_tokens_from_messages(self, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for chatglm2 and chatglm3 with GPT2 tokenizer. it's too complex to calculate num tokens for chatglm2 and chatglm3 with ChatGLM tokenizer, As a temporary solution we use GPT2 tokenizer instead. """ + def tokens(text: str): return self._get_num_tokens_by_gpt2(text) - + tokens_per_message = 3 tokens_per_name = 1 num_tokens = 0 @@ -414,10 +432,10 @@ def tokens(text: str): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text if key == "function_call": @@ -452,36 +470,37 @@ def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int: :param tools: tools for tool calling :return: number of tokens """ + def tokens(text: str): return self._get_num_tokens_by_gpt2(text) num_tokens = 0 for tool in tools: # calculate num tokens for function object - num_tokens += tokens('name') + num_tokens += tokens("name") num_tokens += tokens(tool.name) - num_tokens += tokens('description') + num_tokens += tokens("description") num_tokens += tokens(tool.description) parameters = tool.parameters - num_tokens += tokens('parameters') - num_tokens += tokens('type') + num_tokens += tokens("parameters") + num_tokens += tokens("type") num_tokens += tokens(parameters.get("type")) - if 'properties' in parameters: - num_tokens += tokens('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += tokens("properties") + for key, value in parameters.get("properties").items(): num_tokens += tokens(key) for field_key, field_value in value.items(): num_tokens += tokens(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += tokens(enum_field) else: num_tokens += tokens(field_key) num_tokens += tokens(str(field_value)) - if 'required' in parameters: - num_tokens += tokens('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += tokens("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += tokens(required_field) diff --git a/api/core/model_runtime/model_providers/cohere/cohere.py b/api/core/model_runtime/model_providers/cohere/cohere.py index cfbcb94d2624f1..8394a45fcf9ca1 100644 --- a/api/core/model_runtime/model_providers/cohere/cohere.py +++ b/api/core/model_runtime/model_providers/cohere/cohere.py @@ -20,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: model_instance = self.get_model_instance(ModelType.RERANK) # Use `rerank-english-v2.0` model for validate, - model_instance.validate_credentials( - model='rerank-english-v2.0', - credentials=credentials - ) + model_instance.validate_credentials(model="rerank-english-v2.0", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/cohere/llm/llm.py b/api/core/model_runtime/model_providers/cohere/llm/llm.py index 89b04c0279f760..3863ad33081962 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/llm.py +++ b/api/core/model_runtime/model_providers/cohere/llm/llm.py @@ -55,11 +55,17 @@ class CohereLargeLanguageModel(LargeLanguageModel): Model class for Cohere large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -85,7 +91,7 @@ def _invoke(self, model: str, credentials: dict, tools=tools, stop=stop, stream=stream, - user=user + user=user, ) else: return self._generate( @@ -95,11 +101,16 @@ def _invoke(self, model: str, credentials: dict, model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -136,30 +147,37 @@ def validate_credentials(self, model: str, credentials: dict) -> None: self._chat_generate( model=model, credentials=credentials, - prompt_messages=[UserPromptMessage(content='ping')], + prompt_messages=[UserPromptMessage(content="ping")], model_parameters={ - 'max_tokens': 20, - 'temperature': 0, + "max_tokens": 20, + "temperature": 0, }, - stream=False + stream=False, ) else: self._generate( model=model, credentials=credentials, - prompt_messages=[UserPromptMessage(content='ping')], + prompt_messages=[UserPromptMessage(content="ping")], model_parameters={ - 'max_tokens': 20, - 'temperature': 0, + "max_tokens": 20, + "temperature": 0, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm model @@ -173,17 +191,17 @@ def _generate(self, model: str, credentials: dict, :return: full response or stream response chunk generator result """ # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) if stop: - model_parameters['end_sequences'] = stop + model_parameters["end_sequences"] = stop if stream: response = client.generate_stream( prompt=prompt_messages[0].content, model=model, **model_parameters, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) return self._handle_generate_stream_response(model, credentials, response, prompt_messages) @@ -192,14 +210,14 @@ def _generate(self, model: str, credentials: dict, prompt=prompt_messages[0].content, model=model, **model_parameters, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: Generation, - prompt_messages: list[PromptMessage]) \ - -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: Generation, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -212,9 +230,7 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Gen assistant_text = response.generations[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) # calculate num tokens prompt_tokens = int(response.meta.billed_units.input_tokens) @@ -225,17 +241,18 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Gen # transform response response = LLMResult( - model=model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage + model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage ) return response - def _handle_generate_stream_response(self, model: str, credentials: dict, - response: Iterator[GenerateStreamedResponse], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + response: Iterator[GenerateStreamedResponse], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -245,7 +262,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, :return: llm response chunk generator """ index = 1 - full_assistant_content = '' + full_assistant_content = "" for chunk in response: if isinstance(chunk, GenerateStreamedResponse_TextGeneration): chunk = cast(GenerateStreamedResponse_TextGeneration, chunk) @@ -255,9 +272,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, continue # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + assistant_prompt_message = AssistantPromptMessage(content=text) full_assistant_content += text @@ -267,7 +282,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) index += 1 @@ -277,9 +292,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, # calculate num tokens prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) completion_tokens = self._num_tokens_from_messages( - model, - credentials, - [AssistantPromptMessage(content=full_assistant_content)] + model, credentials, [AssistantPromptMessage(content=full_assistant_content)] ) # transform usage @@ -290,20 +303,27 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, - message=AssistantPromptMessage(content=''), + message=AssistantPromptMessage(content=""), finish_reason=chunk.finish_reason, - usage=usage - ) + usage=usage, + ), ) break elif isinstance(chunk, GenerateStreamedResponse_StreamError): chunk = cast(GenerateStreamedResponse_StreamError, chunk) raise InvokeBadRequestError(chunk.err) - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -318,27 +338,28 @@ def _chat_generate(self, model: str, credentials: dict, :return: full response or stream response chunk generator result """ # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) if stop: - model_parameters['stop_sequences'] = stop + model_parameters["stop_sequences"] = stop if tools: if len(tools) == 1: raise ValueError("Cohere tool call requires at least two tools to be specified.") - model_parameters['tools'] = self._convert_tools(tools) + model_parameters["tools"] = self._convert_tools(tools) - message, chat_histories, tool_results \ - = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages) + message, chat_histories, tool_results = self._convert_prompt_messages_to_message_and_chat_histories( + prompt_messages + ) if tool_results: - model_parameters['tool_results'] = tool_results + model_parameters["tool_results"] = tool_results # chat model real_model = model if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL: - real_model = model.removesuffix('-chat') + real_model = model.removesuffix("-chat") if stream: response = client.chat_stream( @@ -346,7 +367,7 @@ def _chat_generate(self, model: str, credentials: dict, chat_history=chat_histories, model=real_model, **model_parameters, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages) @@ -356,14 +377,14 @@ def _chat_generate(self, model: str, credentials: dict, chat_history=chat_histories, model=real_model, **model_parameters, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) return self._handle_chat_generate_response(model, credentials, response, prompt_messages) - def _handle_chat_generate_response(self, model: str, credentials: dict, response: NonStreamedChatResponse, - prompt_messages: list[PromptMessage]) \ - -> LLMResult: + def _handle_chat_generate_response( + self, model: str, credentials: dict, response: NonStreamedChatResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm chat response @@ -380,19 +401,15 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, response for cohere_tool_call in response.tool_calls: tool_call = AssistantPromptMessage.ToolCall( id=cohere_tool_call.name, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=cohere_tool_call.name, - arguments=json.dumps(cohere_tool_call.parameters) - ) + name=cohere_tool_call.name, arguments=json.dumps(cohere_tool_call.parameters) + ), ) tool_calls.append(tool_call) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text, tool_calls=tool_calls) # calculate num tokens prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) @@ -403,17 +420,18 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, response # transform response response = LLMResult( - model=model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage + model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage ) return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, - response: Iterator[StreamedChatResponse], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Iterator[StreamedChatResponse], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm chat stream response @@ -423,17 +441,16 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, :return: llm response chunk generator """ - def final_response(full_text: str, - tool_calls: list[AssistantPromptMessage.ToolCall], - index: int, - finish_reason: Optional[str] = None) -> LLMResultChunk: + def final_response( + full_text: str, + tool_calls: list[AssistantPromptMessage.ToolCall], + index: int, + finish_reason: Optional[str] = None, + ) -> LLMResultChunk: # calculate num tokens prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages) - full_assistant_prompt_message = AssistantPromptMessage( - content=full_text, - tool_calls=tool_calls - ) + full_assistant_prompt_message = AssistantPromptMessage(content=full_text, tool_calls=tool_calls) completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message]) # transform usage @@ -444,14 +461,14 @@ def final_response(full_text: str, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, - message=AssistantPromptMessage(content='', tool_calls=tool_calls), + message=AssistantPromptMessage(content="", tool_calls=tool_calls), finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) index = 1 - full_assistant_content = '' + full_assistant_content = "" tool_calls = [] for chunk in response: if isinstance(chunk, StreamedChatResponse_TextGeneration): @@ -462,9 +479,7 @@ def final_response(full_text: str, continue # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + assistant_prompt_message = AssistantPromptMessage(content=text) full_assistant_content += text @@ -474,7 +489,7 @@ def final_response(full_text: str, delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) index += 1 @@ -484,11 +499,10 @@ def final_response(full_text: str, for cohere_tool_call in chunk.tool_calls: tool_call = AssistantPromptMessage.ToolCall( id=cohere_tool_call.name, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=cohere_tool_call.name, - arguments=json.dumps(cohere_tool_call.parameters) - ) + name=cohere_tool_call.name, arguments=json.dumps(cohere_tool_call.parameters) + ), ) tool_calls.append(tool_call) elif isinstance(chunk, StreamedChatResponse_StreamEnd): @@ -496,8 +510,9 @@ def final_response(full_text: str, yield final_response(full_assistant_content, tool_calls, index, chunk.finish_reason) index += 1 - def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \ - -> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]: + def _convert_prompt_messages_to_message_and_chat_histories( + self, prompt_messages: list[PromptMessage] + ) -> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]: """ Convert prompt messages to message and chat histories :param prompt_messages: prompt messages @@ -510,13 +525,14 @@ def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages prompt_message = cast(AssistantPromptMessage, prompt_message) if prompt_message.tool_calls: for tool_call in prompt_message.tool_calls: - latest_tool_call_n_outputs.append(ChatStreamRequestToolResultsItem( - call=ToolCall( - name=tool_call.function.name, - parameters=json.loads(tool_call.function.arguments) - ), - outputs=[] - )) + latest_tool_call_n_outputs.append( + ChatStreamRequestToolResultsItem( + call=ToolCall( + name=tool_call.function.name, parameters=json.loads(tool_call.function.arguments) + ), + outputs=[], + ) + ) else: cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message) if cohere_prompt_message: @@ -529,12 +545,9 @@ def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages if tool_call_n_outputs.call.name == prompt_message.tool_call_id: latest_tool_call_n_outputs[i] = ChatStreamRequestToolResultsItem( call=ToolCall( - name=tool_call_n_outputs.call.name, - parameters=tool_call_n_outputs.call.parameters + name=tool_call_n_outputs.call.name, parameters=tool_call_n_outputs.call.parameters ), - outputs=[{ - "result": prompt_message.content - }] + outputs=[{"result": prompt_message.content}], ) break i += 1 @@ -556,7 +569,7 @@ def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages latest_message = chat_histories.pop() message = latest_message.message else: - raise ValueError('Prompt messages is empty') + raise ValueError("Prompt messages is empty") return message, chat_histories, latest_tool_call_n_outputs @@ -569,7 +582,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> Optional[Ch if isinstance(message.content, str): chat_message = ChatMessage(role="USER", message=message.content) else: - sub_message_text = '' + sub_message_text = "" for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) @@ -597,8 +610,8 @@ def _convert_tools(self, tools: list[PromptMessageTool]) -> list[Tool]: """ cohere_tools = [] for tool in tools: - properties = tool.parameters['properties'] - required_properties = tool.parameters['required'] + properties = tool.parameters["properties"] + required_properties = tool.parameters["required"] parameter_definitions = {} for p_key, p_val in properties.items(): @@ -606,21 +619,16 @@ def _convert_tools(self, tools: list[PromptMessageTool]) -> list[Tool]: if p_key in required_properties: required = True - desc = p_val['description'] - if 'enum' in p_val: - desc += (f"; Only accepts one of the following predefined options: " - f"[{', '.join(p_val['enum'])}]") + desc = p_val["description"] + if "enum" in p_val: + desc += f"; Only accepts one of the following predefined options: [{', '.join(p_val['enum'])}]" parameter_definitions[p_key] = ToolParameterDefinitionsValue( - description=desc, - type=p_val['type'], - required=required + description=desc, type=p_val["type"], required=required ) cohere_tool = Tool( - name=tool.name, - description=tool.description, - parameter_definitions=parameter_definitions + name=tool.name, description=tool.description, parameter_definitions=parameter_definitions ) cohere_tools.append(cohere_tool) @@ -637,12 +645,9 @@ def _num_tokens_from_string(self, model: str, credentials: dict, text: str) -> i :return: number of tokens """ # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) - response = client.tokenize( - text=text, - model=model - ) + response = client.tokenize(text=text, model=model) return len(response.tokens) @@ -658,30 +663,30 @@ def _num_tokens_from_messages(self, model: str, credentials: dict, messages: lis real_model = model if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL: - real_model = model.removesuffix('-chat') + real_model = model.removesuffix("-chat") return self._num_tokens_from_string(real_model, credentials, message_str) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - Cohere supports fine-tuning of their models. This method returns the schema of the base model - but renamed to the fine-tuned model name. + Cohere supports fine-tuning of their models. This method returns the schema of the base model + but renamed to the fine-tuned model name. - :param model: model name - :param credentials: credentials + :param model: model name + :param credentials: credentials - :return: model schema + :return: model schema """ # get model schema models = self.predefined_models() model_map = {model.model: model for model in models} - mode = credentials.get('mode') + mode = credentials.get("mode") - if mode == 'chat': - base_model_schema = model_map['command-light-chat'] + if mode == "chat": + base_model_schema = model_map["command-light-chat"] else: - base_model_schema = model_map['command-light'] + base_model_schema = model_map["command-light"] base_model_schema = cast(AIModelEntity, base_model_schema) @@ -691,16 +696,13 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode entity = AIModelEntity( model=model, - label=I18nObject( - zh_Hans=model, - en_US=model - ), + label=I18nObject(zh_Hans=model, en_US=model), model_type=ModelType.LLM, features=list(base_model_schema_features), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties=dict(base_model_schema_model_properties.items()), parameter_rules=list(base_model_schema_parameters_rules), - pricing=base_model_schema.pricing + pricing=base_model_schema.pricing, ) return entity @@ -716,22 +718,16 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - cohere.errors.service_unavailable_error.ServiceUnavailableError - ], - InvokeServerUnavailableError: [ - cohere.errors.internal_server_error.InternalServerError - ], - InvokeRateLimitError: [ - cohere.errors.too_many_requests_error.TooManyRequestsError - ], + InvokeConnectionError: [cohere.errors.service_unavailable_error.ServiceUnavailableError], + InvokeServerUnavailableError: [cohere.errors.internal_server_error.InternalServerError], + InvokeRateLimitError: [cohere.errors.too_many_requests_error.TooManyRequestsError], InvokeAuthorizationError: [ cohere.errors.unauthorized_error.UnauthorizedError, - cohere.errors.forbidden_error.ForbiddenError + cohere.errors.forbidden_error.ForbiddenError, ], InvokeBadRequestError: [ cohere.core.api_error.ApiError, cohere.errors.bad_request_error.BadRequestError, cohere.errors.not_found_error.NotFoundError, - ] + ], } diff --git a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py index d2fdb30c6feec9..aba8fedbc097e5 100644 --- a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py @@ -21,10 +21,16 @@ class CohereRerankModel(RerankModel): Model class for Cohere rerank model. """ - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -38,20 +44,17 @@ def _invoke(self, model: str, credentials: dict, :return: rerank result """ if len(docs) == 0: - return RerankResult( - model=model, - docs=docs - ) + return RerankResult(model=model, docs=docs) # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) response = client.rerank( query=query, documents=docs, model=model, top_n=top_n, return_documents=True, - request_options=RequestOptions(max_retries=0) + request_options=RequestOptions(max_retries=0), ) rerank_documents = [] @@ -70,10 +73,7 @@ def _invoke(self, model: str, credentials: dict, else: rerank_documents.append(rerank_document) - return RerankResult( - model=model, - docs=rerank_documents - ) + return RerankResult(model=model, docs=rerank_documents) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -94,7 +94,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -110,22 +110,16 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - cohere.errors.service_unavailable_error.ServiceUnavailableError - ], - InvokeServerUnavailableError: [ - cohere.errors.internal_server_error.InternalServerError - ], - InvokeRateLimitError: [ - cohere.errors.too_many_requests_error.TooManyRequestsError - ], + InvokeConnectionError: [cohere.errors.service_unavailable_error.ServiceUnavailableError], + InvokeServerUnavailableError: [cohere.errors.internal_server_error.InternalServerError], + InvokeRateLimitError: [cohere.errors.too_many_requests_error.TooManyRequestsError], InvokeAuthorizationError: [ cohere.errors.unauthorized_error.UnauthorizedError, - cohere.errors.forbidden_error.ForbiddenError + cohere.errors.forbidden_error.ForbiddenError, ], InvokeBadRequestError: [ cohere.core.api_error.ApiError, cohere.errors.bad_request_error.BadRequestError, cohere.errors.not_found_error.NotFoundError, - ] + ], } diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py index 0540fb740f7e90..5fd4d637be7643 100644 --- a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py @@ -5,6 +5,7 @@ import numpy as np from cohere.core import RequestOptions +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( @@ -24,9 +25,14 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): Model class for Cohere text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -34,6 +40,7 @@ def _invoke(self, model: str, credentials: dict, :param credentials: model credentials :param texts: texts to embed :param user: unique user id + :param input_type: input type :return: embeddings result """ # get model properties @@ -46,14 +53,10 @@ def _invoke(self, model: str, credentials: dict, used_tokens = 0 for i, text in enumerate(texts): - tokenize_response = self._tokenize( - model=model, - credentials=credentials, - text=text - ) + tokenize_response = self._tokenize(model=model, credentials=credentials, text=text) for j in range(0, len(tokenize_response), context_size): - tokens += [tokenize_response[j: j + context_size]] + tokens += [tokenize_response[j : j + context_size]] indices += [i] batched_embeddings = [] @@ -62,9 +65,7 @@ def _invoke(self, model: str, credentials: dict, for i in _iter: # call embedding model embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - credentials=credentials, - texts=["".join(token) for token in tokens[i: i + max_chunks]] + model=model, credentials=credentials, texts=["".join(token) for token in tokens[i : i + max_chunks]] ) used_tokens += embedding_used_tokens @@ -80,9 +81,7 @@ def _invoke(self, model: str, credentials: dict, _result = results[i] if len(_result) == 0: embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - credentials=credentials, - texts=[" "] + model=model, credentials=credentials, texts=[" "] ) used_tokens += embedding_used_tokens @@ -92,17 +91,9 @@ def _invoke(self, model: str, credentials: dict, embeddings[i] = (average / np.linalg.norm(average)).tolist() # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - return TextEmbeddingResult( - embeddings=embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -116,14 +107,10 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int if len(texts) == 0: return 0 - full_text = ' '.join(texts) + full_text = " ".join(texts) try: - response = self._tokenize( - model=model, - credentials=credentials, - text=full_text - ) + response = self._tokenize(model=model, credentials=credentials, text=full_text) except Exception as e: raise self._transform_invoke_error(e) @@ -141,14 +128,9 @@ def _tokenize(self, model: str, credentials: dict, text: str) -> list[str]: return [] # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) - response = client.tokenize( - text=text, - model=model, - offline=False, - request_options=RequestOptions(max_retries=0) - ) + response = client.tokenize(text=text, model=model, offline=False, request_options=RequestOptions(max_retries=0)) return response.token_strings @@ -162,11 +144,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: """ try: # call embedding model - self._embedding_invoke( - model=model, - credentials=credentials, - texts=['ping'] - ) + self._embedding_invoke(model=model, credentials=credentials, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -180,14 +158,14 @@ def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> :return: embeddings and used tokens """ # initialize client - client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url')) + client = cohere.Client(credentials.get("api_key"), base_url=credentials.get("base_url")) # call embedding model response = client.embed( texts=texts, model=model, - input_type='search_document' if len(texts) > 1 else 'search_query', - request_options=RequestOptions(max_retries=1) + input_type="search_document" if len(texts) > 1 else "search_query", + request_options=RequestOptions(max_retries=1), ) return response.embeddings, int(response.meta.billed_units.input_tokens) @@ -203,10 +181,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -217,7 +192,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -233,22 +208,16 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - cohere.errors.service_unavailable_error.ServiceUnavailableError - ], - InvokeServerUnavailableError: [ - cohere.errors.internal_server_error.InternalServerError - ], - InvokeRateLimitError: [ - cohere.errors.too_many_requests_error.TooManyRequestsError - ], + InvokeConnectionError: [cohere.errors.service_unavailable_error.ServiceUnavailableError], + InvokeServerUnavailableError: [cohere.errors.internal_server_error.InternalServerError], + InvokeRateLimitError: [cohere.errors.too_many_requests_error.TooManyRequestsError], InvokeAuthorizationError: [ cohere.errors.unauthorized_error.UnauthorizedError, - cohere.errors.forbidden_error.ForbiddenError + cohere.errors.forbidden_error.ForbiddenError, ], InvokeBadRequestError: [ cohere.core.api_error.ApiError, cohere.errors.bad_request_error.BadRequestError, cohere.errors.not_found_error.NotFoundError, - ] + ], } diff --git a/api/core/model_runtime/model_providers/deepseek/deepseek.py b/api/core/model_runtime/model_providers/deepseek/deepseek.py index d61fd4ddc80457..10feef897272db 100644 --- a/api/core/model_runtime/model_providers/deepseek/deepseek.py +++ b/api/core/model_runtime/model_providers/deepseek/deepseek.py @@ -7,9 +7,7 @@ logger = logging.getLogger(__name__) - class DeepSeekProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -22,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: # Use `deepseek-chat` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='deepseek-chat', - credentials=credentials - ) + model_instance.validate_credentials(model="deepseek-chat", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml index 6588a4b5e01468..4973ac8ad6981c 100644 --- a/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml +++ b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml @@ -62,7 +62,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/deepseek/llm/llm.py b/api/core/model_runtime/model_providers/deepseek/llm/llm.py index bdb3823b60e739..6d0a3ee2628ea2 100644 --- a/api/core/model_runtime/model_providers/deepseek/llm/llm.py +++ b/api/core/model_runtime/model_providers/deepseek/llm/llm.py @@ -13,12 +13,17 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -27,10 +32,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None: self._add_custom_parameters(credentials) super().validate_credentials(model, credentials) - # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_string(self, model: str, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -48,8 +51,9 @@ def _num_tokens_from_string(self, model: str, text: str, return num_tokens # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ @@ -69,10 +73,10 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -103,11 +107,10 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['openai_api_key']=credentials['api_key'] - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - credentials['openai_api_base']='https://api.deepseek.com' + credentials["mode"] = "chat" + credentials["openai_api_key"] = credentials["api_key"] + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + credentials["openai_api_base"] = "https://api.deepseek.com" else: - parsed_url = urlparse(credentials['endpoint_url']) - credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}" - + parsed_url = urlparse(credentials["endpoint_url"]) + credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}" diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/__init__.py b/api/core/model_runtime/model_providers/fireworks/__init__.py similarity index 100% rename from api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/__init__.py rename to api/core/model_runtime/model_providers/fireworks/__init__.py diff --git a/api/core/model_runtime/model_providers/fireworks/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/fireworks/_assets/icon_l_en.svg new file mode 100644 index 00000000000000..582605cc422cce --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/_assets/icon_l_en.svg @@ -0,0 +1,3 @@ + + + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/fireworks/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/fireworks/_assets/icon_s_en.svg new file mode 100644 index 00000000000000..86eeba66f9290a --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/_assets/icon_s_en.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/api/core/model_runtime/model_providers/fireworks/_common.py b/api/core/model_runtime/model_providers/fireworks/_common.py new file mode 100644 index 00000000000000..378ced3a4019ba --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/_common.py @@ -0,0 +1,52 @@ +from collections.abc import Mapping + +import openai + +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) + + +class _CommonFireworks: + def _to_credential_kwargs(self, credentials: Mapping) -> dict: + """ + Transform credentials to kwargs for model instance + + :param credentials: + :return: + """ + credentials_kwargs = { + "api_key": credentials["fireworks_api_key"], + "base_url": "https://api.fireworks.ai/inference/v1", + "max_retries": 1, + } + + return credentials_kwargs + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError], + InvokeServerUnavailableError: [openai.InternalServerError], + InvokeRateLimitError: [openai.RateLimitError], + InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError], + InvokeBadRequestError: [ + openai.BadRequestError, + openai.NotFoundError, + openai.UnprocessableEntityError, + openai.APIError, + ], + } diff --git a/api/core/model_runtime/model_providers/fireworks/fireworks.py b/api/core/model_runtime/model_providers/fireworks/fireworks.py new file mode 100644 index 00000000000000..15f25badab994f --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/fireworks.py @@ -0,0 +1,27 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class FireworksProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.LLM) + model_instance.validate_credentials( + model="accounts/fireworks/models/llama-v3p1-8b-instruct", credentials=credentials + ) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/fireworks/fireworks.yaml b/api/core/model_runtime/model_providers/fireworks/fireworks.yaml new file mode 100644 index 00000000000000..ddbaa54eb15018 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/fireworks.yaml @@ -0,0 +1,103 @@ +provider: fireworks +label: + zh_Hans: Fireworks AI + en_US: Fireworks AI +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.svg +background: "#FCFDFF" +help: + title: + en_US: Get your API Key from Fireworks AI + zh_Hans: 从 Fireworks AI 获取 API Key + url: + en_US: https://fireworks.ai/account/api-keys +supported_model_types: + - llm + - text-embedding +configurate_methods: + - predefined-model + - customizable-model +provider_credential_schema: + credential_form_schemas: + - variable: fireworks_api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key +model_credential_schema: + model: + label: + en_US: Model URL + zh_Hans: 模型URL + placeholder: + en_US: Enter your Model URL + zh_Hans: 输入模型URL + credential_form_schemas: + - variable: model_label_zh_Hanns + label: + zh_Hans: 模型中文名称 + en_US: The zh_Hans of Model + required: true + type: text-input + placeholder: + zh_Hans: 在此输入您的模型中文名称 + en_US: Enter your zh_Hans of Model + - variable: model_label_en_US + label: + zh_Hans: 模型英文名称 + en_US: The en_US of Model + required: true + type: text-input + placeholder: + zh_Hans: 在此输入您的模型英文名称 + en_US: Enter your en_US of Model + - variable: fireworks_api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + type: text-input + default: '4096' + placeholder: + zh_Hans: 在此输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: max_tokens + label: + zh_Hans: 最大 token 上限 + en_US: Upper bound for max tokens + default: '4096' + type: text-input + show_on: + - variable: __model_type + value: llm + - variable: function_calling_type + label: + en_US: Function calling + type: select + required: false + default: no_call + options: + - value: no_call + label: + en_US: Not Support + zh_Hans: 不支持 + - value: function_call + label: + en_US: Support + zh_Hans: 支持 + show_on: + - variable: __model_type + value: llm diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/__init__.py b/api/core/model_runtime/model_providers/fireworks/llm/__init__.py similarity index 100% rename from api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/__init__.py rename to api/core/model_runtime/model_providers/fireworks/llm/__init__.py diff --git a/api/core/model_runtime/model_providers/fireworks/llm/_position.yaml b/api/core/model_runtime/model_providers/fireworks/llm/_position.yaml new file mode 100644 index 00000000000000..9f7c1af68cef72 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/_position.yaml @@ -0,0 +1,16 @@ +- llama-v3p1-405b-instruct +- llama-v3p1-70b-instruct +- llama-v3p1-8b-instruct +- llama-v3-70b-instruct +- mixtral-8x22b-instruct +- mixtral-8x7b-instruct +- firefunction-v2 +- firefunction-v1 +- gemma2-9b-it +- llama-v3-70b-instruct-hf +- llama-v3-8b-instruct +- llama-v3-8b-instruct-hf +- mixtral-8x7b-instruct-hf +- mythomax-l2-13b +- phi-3-vision-128k-instruct +- yi-large diff --git a/api/core/model_runtime/model_providers/fireworks/llm/firefunction-v1.yaml b/api/core/model_runtime/model_providers/fireworks/llm/firefunction-v1.yaml new file mode 100644 index 00000000000000..f6bac12832d646 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/firefunction-v1.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/firefunction-v1 +label: + zh_Hans: Firefunction V1 + en_US: Firefunction V1 +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.5' + output: '0.5' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/firefunction-v2.yaml b/api/core/model_runtime/model_providers/fireworks/llm/firefunction-v2.yaml new file mode 100644 index 00000000000000..2979cb46d572a3 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/firefunction-v2.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/firefunction-v2 +label: + zh_Hans: Firefunction V2 + en_US: Firefunction V2 +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.9' + output: '0.9' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/gemma2-9b-it.yaml b/api/core/model_runtime/model_providers/fireworks/llm/gemma2-9b-it.yaml new file mode 100644 index 00000000000000..61ab39482b09b4 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/gemma2-9b-it.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/gemma2-9b-it +label: + zh_Hans: Gemma2 9B Instruct + en_US: Gemma2 9B Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.2' + output: '0.2' + unit: '0.000001' + currency: USD +deprecated: true diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-70b-instruct-hf.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-70b-instruct-hf.yaml new file mode 100644 index 00000000000000..2ae89b88165d12 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-70b-instruct-hf.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/llama-v3-70b-instruct-hf +label: + zh_Hans: Llama3 70B Instruct(HF version) + en_US: Llama3 70B Instruct(HF version) +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.9' + output: '0.9' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-70b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-70b-instruct.yaml new file mode 100644 index 00000000000000..7c24b08ca5cca1 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-70b-instruct.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/llama-v3-70b-instruct +label: + zh_Hans: Llama3 70B Instruct + en_US: Llama3 70B Instruct +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.9' + output: '0.9' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-8b-instruct-hf.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-8b-instruct-hf.yaml new file mode 100644 index 00000000000000..83507ef3e5276e --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-8b-instruct-hf.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/llama-v3-8b-instruct-hf +label: + zh_Hans: Llama3 8B Instruct(HF version) + en_US: Llama3 8B Instruct(HF version) +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.2' + output: '0.2' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-8b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-8b-instruct.yaml new file mode 100644 index 00000000000000..d8ac9537b80e7f --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3-8b-instruct.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/llama-v3-8b-instruct +label: + zh_Hans: Llama3 8B Instruct + en_US: Llama3 8B Instruct +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.2' + output: '0.2' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-405b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-405b-instruct.yaml new file mode 100644 index 00000000000000..c4ddb3e9246d4a --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-405b-instruct.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/llama-v3p1-405b-instruct +label: + zh_Hans: Llama3.1 405B Instruct + en_US: Llama3.1 405B Instruct +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '3' + output: '3' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-70b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-70b-instruct.yaml new file mode 100644 index 00000000000000..62f84f87fa5609 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-70b-instruct.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/llama-v3p1-70b-instruct +label: + zh_Hans: Llama3.1 70B Instruct + en_US: Llama3.1 70B Instruct +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.2' + output: '0.2' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-8b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-8b-instruct.yaml new file mode 100644 index 00000000000000..9bb99c91b65b0b --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p1-8b-instruct.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/llama-v3p1-8b-instruct +label: + zh_Hans: Llama3.1 8B Instruct + en_US: Llama3.1 8B Instruct +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.2' + output: '0.2' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p2-11b-vision-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p2-11b-vision-instruct.yaml new file mode 100644 index 00000000000000..31415a24fa8b7e --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p2-11b-vision-instruct.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/llama-v3p2-11b-vision-instruct +label: + zh_Hans: Llama 3.2 11B Vision Instruct + en_US: Llama 3.2 11B Vision Instruct +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.2' + output: '0.2' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p2-1b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p2-1b-instruct.yaml new file mode 100644 index 00000000000000..c2fd77d2568d29 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p2-1b-instruct.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/llama-v3p2-1b-instruct +label: + zh_Hans: Llama 3.2 1B Instruct + en_US: Llama 3.2 1B Instruct +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.1' + output: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p2-3b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p2-3b-instruct.yaml new file mode 100644 index 00000000000000..4b3c459c7bf2fc --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p2-3b-instruct.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/llama-v3p2-3b-instruct +label: + zh_Hans: Llama 3.2 3B Instruct + en_US: Llama 3.2 3B Instruct +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.1' + output: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p2-90b-vision-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p2-90b-vision-instruct.yaml new file mode 100644 index 00000000000000..0aece7455d6254 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/llama-v3p2-90b-vision-instruct.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/llama-v3p2-90b-vision-instruct +label: + zh_Hans: Llama 3.2 90B Vision Instruct + en_US: Llama 3.2 90B Vision Instruct +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.9' + output: '0.9' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llm.py b/api/core/model_runtime/model_providers/fireworks/llm/llm.py new file mode 100644 index 00000000000000..ffe1ad5fcbf9ed --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/llm.py @@ -0,0 +1,667 @@ +import logging +from collections.abc import Generator +from typing import Optional, Union, cast + +from openai import OpenAI, Stream +from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall +from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall +from openai.types.chat.chat_completion_message import FunctionCall + +from core.model_runtime.callbacks.base_callback import Callback +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.fireworks._common import _CommonFireworks + +logger = logging.getLogger(__name__) + +FIREWORKS_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. +The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure +if you are not sure about the structure. + + +{{instructions}} + +""" # noqa: E501 + + +class FireworksLargeLanguageModel(_CommonFireworks, LargeLanguageModel): + """ + Model class for Fireworks large language model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + + return self._chat_generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ) + + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: + """ + Code block mode wrapper for invoking large language model + """ + if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}: + stop = stop or [] + self._transform_chat_json_prompts( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + response_format=model_parameters["response_format"], + ) + model_parameters.pop("response_format") + + return self._invoke( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ) + + def _transform_chat_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: + """ + Transform json prompts + """ + if stop is None: + stop = [] + if "```\n" not in stop: + stop.append("```\n") + if "\n```" not in stop: + stop.append("\n```") + + if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): + prompt_messages[0] = SystemPromptMessage( + content=FIREWORKS_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) + ) + prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n")) + else: + prompt_messages.insert( + 0, + SystemPromptMessage( + content=FIREWORKS_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", f"Please output a valid {response_format} object." + ).replace("{{block}}", response_format) + ), + ) + prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) + + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return: + """ + return self._num_tokens_from_messages(model, prompt_messages, tools) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + credentials_kwargs = self._to_credential_kwargs(credentials) + client = OpenAI(**credentials_kwargs) + + client.chat.completions.create( + messages=[{"role": "user", "content": "ping"}], model=model, temperature=0, max_tokens=10, stream=False + ) + except Exception as e: + raise CredentialsValidateFailedError(str(e)) + + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + credentials_kwargs = self._to_credential_kwargs(credentials) + client = OpenAI(**credentials_kwargs) + + extra_model_kwargs = {} + + if tools: + extra_model_kwargs["functions"] = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} for tool in tools + ] + + if stop: + extra_model_kwargs["stop"] = stop + + if user: + extra_model_kwargs["user"] = user + + # chat model + response = client.chat.completions.create( + messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], + model=model, + stream=stream, + **model_parameters, + **extra_model_kwargs, + ) + + if stream: + return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) + return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: ChatCompletion, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> LLMResult: + """ + Handle llm chat response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return: llm response + """ + assistant_message = response.choices[0].message + # assistant_message_tool_calls = assistant_message.tool_calls + assistant_message_function_call = assistant_message.function_call + + # extract tool calls from response + # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) + function_call = self._extract_response_function_call(assistant_message_function_call) + tool_calls = [function_call] if function_call else [] + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) + + # calculate num tokens + if response.usage: + # transform usage + prompt_tokens = response.usage.prompt_tokens + completion_tokens = response.usage.completion_tokens + else: + # calculate num tokens + prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools) + completion_tokens = self._num_tokens_from_messages(model, [assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + # transform response + response = LLMResult( + model=response.model, + prompt_messages=prompt_messages, + message=assistant_prompt_message, + usage=usage, + system_fingerprint=response.system_fingerprint, + ) + + return response + + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[ChatCompletionChunk], + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> Generator: + """ + Handle llm chat stream response + + :param model: model name + :param response: response + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return: llm response chunk generator + """ + full_assistant_content = "" + delta_assistant_message_function_call_storage: Optional[ChoiceDeltaFunctionCall] = None + prompt_tokens = 0 + completion_tokens = 0 + final_tool_calls = [] + final_chunk = LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=""), + ), + ) + + for chunk in response: + if len(chunk.choices) == 0: + if chunk.usage: + # calculate num tokens + prompt_tokens = chunk.usage.prompt_tokens + completion_tokens = chunk.usage.completion_tokens + continue + + delta = chunk.choices[0] + has_finish_reason = delta.finish_reason is not None + + if ( + not has_finish_reason + and (delta.delta.content is None or delta.delta.content == "") + and delta.delta.function_call is None + ): + continue + + # assistant_message_tool_calls = delta.delta.tool_calls + assistant_message_function_call = delta.delta.function_call + + # extract tool calls from response + if delta_assistant_message_function_call_storage is not None: + # handle process of stream function call + if assistant_message_function_call: + # message has not ended ever + delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments + continue + else: + # message has ended + assistant_message_function_call = delta_assistant_message_function_call_storage + delta_assistant_message_function_call_storage = None + else: + if assistant_message_function_call: + # start of stream function call + delta_assistant_message_function_call_storage = assistant_message_function_call + if delta_assistant_message_function_call_storage.arguments is None: + delta_assistant_message_function_call_storage.arguments = "" + if not has_finish_reason: + continue + + # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) + function_call = self._extract_response_function_call(assistant_message_function_call) + tool_calls = [function_call] if function_call else [] + if tool_calls: + final_tool_calls.extend(tool_calls) + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls) + + full_assistant_content += delta.delta.content or "" + + if has_finish_reason: + final_chunk = LLMResultChunk( + model=chunk.model, + prompt_messages=prompt_messages, + system_fingerprint=chunk.system_fingerprint, + delta=LLMResultChunkDelta( + index=delta.index, + message=assistant_prompt_message, + finish_reason=delta.finish_reason, + ), + ) + else: + yield LLMResultChunk( + model=chunk.model, + prompt_messages=prompt_messages, + system_fingerprint=chunk.system_fingerprint, + delta=LLMResultChunkDelta( + index=delta.index, + message=assistant_prompt_message, + ), + ) + + if not prompt_tokens: + prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools) + + if not completion_tokens: + full_assistant_prompt_message = AssistantPromptMessage( + content=full_assistant_content, tool_calls=final_tool_calls + ) + completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + final_chunk.delta.usage = usage + + yield final_chunk + + def _extract_response_tool_calls( + self, response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall] + ) -> list[AssistantPromptMessage.ToolCall]: + """ + Extract tool calls from response + + :param response_tool_calls: response tool calls + :return: list of tool calls + """ + tool_calls = [] + if response_tool_calls: + for response_tool_call in response_tool_calls: + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments + ) + + tool_call = AssistantPromptMessage.ToolCall( + id=response_tool_call.id, type=response_tool_call.type, function=function + ) + tool_calls.append(tool_call) + + return tool_calls + + def _extract_response_function_call( + self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall + ) -> AssistantPromptMessage.ToolCall: + """ + Extract function call from response + + :param response_function_call: response function call + :return: tool call + """ + tool_call = None + if response_function_call: + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_function_call.name, arguments=response_function_call.arguments + ) + + tool_call = AssistantPromptMessage.ToolCall( + id=response_function_call.name, type="function", function=function + ) + + return tool_call + + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: + """ + Convert PromptMessage to dict for Fireworks API + """ + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": "user", "content": message.content} + else: + sub_messages = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(TextPromptMessageContent, message_content) + sub_message_dict = {"type": "text", "text": message_content.data} + sub_messages.append(sub_message_dict) + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, message_content) + sub_message_dict = { + "type": "image_url", + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, + } + sub_messages.append(sub_message_dict) + + message_dict = {"role": "user", "content": sub_messages} + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {"role": "assistant", "content": message.content} + if message.tool_calls: + # message_dict["tool_calls"] = [tool_call.dict() for tool_call in + # message.tool_calls] + function_call = message.tool_calls[0] + message_dict["function_call"] = { + "name": function_call.function.name, + "arguments": function_call.function.arguments, + } + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, ToolPromptMessage): + message = cast(ToolPromptMessage, message) + # message_dict = { + # "role": "tool", + # "content": message.content, + # "tool_call_id": message.tool_call_id + # } + message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id} + else: + raise ValueError(f"Got unknown type {message}") + + if message.name: + message_dict["name"] = message.name + + return message_dict + + def _num_tokens_from_messages( + self, + model: str, + messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + credentials: Optional[dict] = None, + ) -> int: + """ + Approximate num tokens with GPT2 tokenizer. + """ + + tokens_per_message = 3 + tokens_per_name = 1 + + num_tokens = 0 + messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] + for message in messages_dict: + num_tokens += tokens_per_message + for key, value in message.items(): + # Cast str(value) in case the message value is not a string + # This occurs with function messages + # TODO: The current token calculation method for the image type is not implemented, + # which need to download the image and then get the resolution for calculation, + # and will increase the request delay + if isinstance(value, list): + text = "" + for item in value: + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] + + value = text + + if key == "tool_calls": + for tool_call in value: + for t_key, t_value in tool_call.items(): + num_tokens += self._get_num_tokens_by_gpt2(t_key) + if t_key == "function": + for f_key, f_value in t_value.items(): + num_tokens += self._get_num_tokens_by_gpt2(f_key) + num_tokens += self._get_num_tokens_by_gpt2(f_value) + else: + num_tokens += self._get_num_tokens_by_gpt2(t_key) + num_tokens += self._get_num_tokens_by_gpt2(t_value) + else: + num_tokens += self._get_num_tokens_by_gpt2(str(value)) + + if key == "name": + num_tokens += tokens_per_name + + # every reply is primed with assistant + num_tokens += 3 + + if tools: + num_tokens += self._num_tokens_for_tools(tools) + + return num_tokens + + def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int: + """ + Calculate num tokens for tool calling with tiktoken package. + + :param tools: tools for tool calling + :return: number of tokens + """ + num_tokens = 0 + for tool in tools: + num_tokens += self._get_num_tokens_by_gpt2("type") + num_tokens += self._get_num_tokens_by_gpt2("function") + num_tokens += self._get_num_tokens_by_gpt2("function") + + # calculate num tokens for function object + num_tokens += self._get_num_tokens_by_gpt2("name") + num_tokens += self._get_num_tokens_by_gpt2(tool.name) + num_tokens += self._get_num_tokens_by_gpt2("description") + num_tokens += self._get_num_tokens_by_gpt2(tool.description) + parameters = tool.parameters + num_tokens += self._get_num_tokens_by_gpt2("parameters") + if "title" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("title") + num_tokens += self._get_num_tokens_by_gpt2(parameters.get("title")) + num_tokens += self._get_num_tokens_by_gpt2("type") + num_tokens += self._get_num_tokens_by_gpt2(parameters.get("type")) + if "properties" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("properties") + for key, value in parameters.get("properties").items(): + num_tokens += self._get_num_tokens_by_gpt2(key) + for field_key, field_value in value.items(): + num_tokens += self._get_num_tokens_by_gpt2(field_key) + if field_key == "enum": + for enum_field in field_value: + num_tokens += 3 + num_tokens += self._get_num_tokens_by_gpt2(enum_field) + else: + num_tokens += self._get_num_tokens_by_gpt2(field_key) + num_tokens += self._get_num_tokens_by_gpt2(str(field_value)) + if "required" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("required") + for required_field in parameters["required"]: + num_tokens += 3 + num_tokens += self._get_num_tokens_by_gpt2(required_field) + + return num_tokens + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + return AIModelEntity( + model=model, + label=I18nObject( + en_US=credentials.get("model_label_en_US", model), + zh_Hans=credentials.get("model_label_zh_Hanns", model), + ), + model_type=ModelType.LLM, + features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] + if credentials.get("function_calling_type") == "function_call" + else [], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)), + ModelPropertyKey.MODE: LLMMode.CHAT.value, + }, + parameter_rules=[ + ParameterRule( + name="temperature", + use_template="temperature", + label=I18nObject(en_US="Temperature", zh_Hans="温度"), + type=ParameterType.FLOAT, + ), + ParameterRule( + name="max_tokens", + use_template="max_tokens", + default=512, + min=1, + max=int(credentials.get("max_tokens", 4096)), + label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"), + type=ParameterType.INT, + ), + ParameterRule( + name="top_p", + use_template="top_p", + label=I18nObject(en_US="Top P", zh_Hans="Top P"), + type=ParameterType.FLOAT, + ), + ParameterRule( + name="top_k", + use_template="top_k", + label=I18nObject(en_US="Top K", zh_Hans="Top K"), + type=ParameterType.FLOAT, + ), + ], + ) diff --git a/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x22b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x22b-instruct.yaml new file mode 100644 index 00000000000000..87d977e26cf1b2 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x22b-instruct.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/mixtral-8x22b-instruct +label: + zh_Hans: Mixtral MoE 8x22B Instruct + en_US: Mixtral MoE 8x22B Instruct +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 65536 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '1.2' + output: '1.2' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x7b-instruct-hf.yaml b/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x7b-instruct-hf.yaml new file mode 100644 index 00000000000000..e3d5a90858c5ae --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x7b-instruct-hf.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/mixtral-8x7b-instruct-hf +label: + zh_Hans: Mixtral MoE 8x7B Instruct(HF version) + en_US: Mixtral MoE 8x7B Instruct(HF version) +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.5' + output: '0.5' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x7b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x7b-instruct.yaml new file mode 100644 index 00000000000000..45f632ceff2cfc --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/mixtral-8x7b-instruct.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/mixtral-8x7b-instruct +label: + zh_Hans: Mixtral MoE 8x7B Instruct + en_US: Mixtral MoE 8x7B Instruct +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.5' + output: '0.5' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/mythomax-l2-13b.yaml b/api/core/model_runtime/model_providers/fireworks/llm/mythomax-l2-13b.yaml new file mode 100644 index 00000000000000..9c3486ba10751b --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/mythomax-l2-13b.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/mythomax-l2-13b +label: + zh_Hans: MythoMax L2 13b + en_US: MythoMax L2 13b +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.2' + output: '0.2' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/phi-3-vision-128k-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/phi-3-vision-128k-instruct.yaml new file mode 100644 index 00000000000000..e399f2edb1b1bd --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/phi-3-vision-128k-instruct.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/phi-3-vision-128k-instruct +label: + zh_Hans: Phi3.5 Vision Instruct + en_US: Phi3.5 Vision Instruct +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.2' + output: '0.2' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/qwen2p5-72b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/qwen2p5-72b-instruct.yaml new file mode 100644 index 00000000000000..9728364340c518 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/qwen2p5-72b-instruct.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/qwen2p5-72b-instruct +label: + zh_Hans: Qwen2.5 72B Instruct + en_US: Qwen2.5 72B Instruct +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.9' + output: '0.9' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/llm/yi-large.yaml b/api/core/model_runtime/model_providers/fireworks/llm/yi-large.yaml new file mode 100644 index 00000000000000..bb4b6f994ec12a --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/yi-large.yaml @@ -0,0 +1,45 @@ +model: accounts/yi-01-ai/models/yi-large +label: + zh_Hans: Yi-Large + en_US: Yi-Large +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '3' + output: '3' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/fireworks/text_embedding/UAE-Large-V1.yaml b/api/core/model_runtime/model_providers/fireworks/text_embedding/UAE-Large-V1.yaml new file mode 100644 index 00000000000000..d7c11691cf9bbc --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/text_embedding/UAE-Large-V1.yaml @@ -0,0 +1,12 @@ +model: WhereIsAI/UAE-Large-V1 +label: + zh_Hans: UAE-Large-V1 + en_US: UAE-Large-V1 +model_type: text-embedding +model_properties: + context_size: 512 + max_chunks: 1 +pricing: + input: '0.008' + unit: '0.000001' + currency: 'USD' diff --git a/api/core/model_runtime/model_providers/fireworks/text_embedding/__init__.py b/api/core/model_runtime/model_providers/fireworks/text_embedding/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/fireworks/text_embedding/gte-base.yaml b/api/core/model_runtime/model_providers/fireworks/text_embedding/gte-base.yaml new file mode 100644 index 00000000000000..d09bafb4d312f9 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/text_embedding/gte-base.yaml @@ -0,0 +1,12 @@ +model: thenlper/gte-base +label: + zh_Hans: GTE-base + en_US: GTE-base +model_type: text-embedding +model_properties: + context_size: 512 + max_chunks: 1 +pricing: + input: '0.008' + unit: '0.000001' + currency: 'USD' diff --git a/api/core/model_runtime/model_providers/fireworks/text_embedding/gte-large.yaml b/api/core/model_runtime/model_providers/fireworks/text_embedding/gte-large.yaml new file mode 100644 index 00000000000000..c41fa2f9d32361 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/text_embedding/gte-large.yaml @@ -0,0 +1,12 @@ +model: thenlper/gte-large +label: + zh_Hans: GTE-large + en_US: GTE-large +model_type: text-embedding +model_properties: + context_size: 512 + max_chunks: 1 +pricing: + input: '0.008' + unit: '0.000001' + currency: 'USD' diff --git a/api/core/model_runtime/model_providers/fireworks/text_embedding/nomic-embed-text-v1.5.yaml b/api/core/model_runtime/model_providers/fireworks/text_embedding/nomic-embed-text-v1.5.yaml new file mode 100644 index 00000000000000..c9098503d96529 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/text_embedding/nomic-embed-text-v1.5.yaml @@ -0,0 +1,12 @@ +model: nomic-ai/nomic-embed-text-v1.5 +label: + zh_Hans: nomic-embed-text-v1.5 + en_US: nomic-embed-text-v1.5 +model_type: text-embedding +model_properties: + context_size: 8192 + max_chunks: 16 +pricing: + input: '0.008' + unit: '0.000001' + currency: 'USD' diff --git a/api/core/model_runtime/model_providers/fireworks/text_embedding/nomic-embed-text-v1.yaml b/api/core/model_runtime/model_providers/fireworks/text_embedding/nomic-embed-text-v1.yaml new file mode 100644 index 00000000000000..89078d3ff69f93 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/text_embedding/nomic-embed-text-v1.yaml @@ -0,0 +1,12 @@ +model: nomic-ai/nomic-embed-text-v1 +label: + zh_Hans: nomic-embed-text-v1 + en_US: nomic-embed-text-v1 +model_type: text-embedding +model_properties: + context_size: 8192 + max_chunks: 16 +pricing: + input: '0.008' + unit: '0.000001' + currency: 'USD' diff --git a/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py new file mode 100644 index 00000000000000..c745a7e978f4be --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py @@ -0,0 +1,151 @@ +import time +from collections.abc import Mapping +from typing import Optional, Union + +import numpy as np +from openai import OpenAI + +from core.entities.embedding_type import EmbeddingInputType +from core.model_runtime.entities.model_entities import PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.model_runtime.model_providers.fireworks._common import _CommonFireworks + + +class FireworksTextEmbeddingModel(_CommonFireworks, TextEmbeddingModel): + """ + Model class for Fireworks text embedding model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :param input_type: input type + :return: embeddings result + """ + + credentials_kwargs = self._to_credential_kwargs(credentials) + client = OpenAI(**credentials_kwargs) + + extra_model_kwargs = {} + if user: + extra_model_kwargs["user"] = user + + extra_model_kwargs["encoding_format"] = "float" + + context_size = self._get_context_size(model, credentials) + max_chunks = self._get_max_chunks(model, credentials) + + inputs = [] + indices = [] + used_tokens = 0 + + for i, text in enumerate(texts): + # Here token count is only an approximation based on the GPT2 tokenizer + # TODO: Optimize for better token estimation and chunking + num_tokens = self._get_num_tokens_by_gpt2(text) + + if num_tokens >= context_size: + cutoff = int(np.floor(len(text) * (context_size / num_tokens))) + # if num tokens is larger than context length, only use the start + inputs.append(text[0:cutoff]) + else: + inputs.append(text) + indices += [i] + + batched_embeddings = [] + _iter = range(0, len(inputs), max_chunks) + + for i in _iter: + embeddings_batch, embedding_used_tokens = self._embedding_invoke( + model=model, + client=client, + texts=inputs[i : i + max_chunks], + extra_model_kwargs=extra_model_kwargs, + ) + used_tokens += embedding_used_tokens + batched_embeddings += embeddings_batch + + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) + return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model) + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + return sum(self._get_num_tokens_by_gpt2(text) for text in texts) + + def validate_credentials(self, model: str, credentials: Mapping) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + # transform credentials to kwargs for model instance + credentials_kwargs = self._to_credential_kwargs(credentials) + client = OpenAI(**credentials_kwargs) + + # call embedding model + self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={}) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _embedding_invoke( + self, model: str, client: OpenAI, texts: Union[list[str], str], extra_model_kwargs: dict + ) -> tuple[list[list[float]], int]: + """ + Invoke embedding model + :param model: model name + :param client: model client + :param texts: texts to embed + :param extra_model_kwargs: extra model kwargs + :return: embeddings and used tokens + """ + response = client.embeddings.create(model=model, input=texts, **extra_model_kwargs) + return [data.embedding for data in response.data], response.usage.total_tokens + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + input_price_info = self.get_price( + model=model, credentials=credentials, tokens=tokens, price_type=PriceType.INPUT + ) + + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at, + ) + + return usage diff --git a/api/core/model_runtime/model_providers/fishaudio/__init__.py b/api/core/model_runtime/model_providers/fishaudio/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/fishaudio/_assets/fishaudio_l_en.svg b/api/core/model_runtime/model_providers/fishaudio/_assets/fishaudio_l_en.svg new file mode 100644 index 00000000000000..d6f7723bd5ca4c --- /dev/null +++ b/api/core/model_runtime/model_providers/fishaudio/_assets/fishaudio_l_en.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/fishaudio/_assets/fishaudio_s_en.svg b/api/core/model_runtime/model_providers/fishaudio/_assets/fishaudio_s_en.svg new file mode 100644 index 00000000000000..d6f7723bd5ca4c --- /dev/null +++ b/api/core/model_runtime/model_providers/fishaudio/_assets/fishaudio_s_en.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/fishaudio/fishaudio.py b/api/core/model_runtime/model_providers/fishaudio/fishaudio.py new file mode 100644 index 00000000000000..3bc4b533e0815a --- /dev/null +++ b/api/core/model_runtime/model_providers/fishaudio/fishaudio.py @@ -0,0 +1,26 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class FishAudioProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + + For debugging purposes, this method now always passes validation. + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.TTS) + model_instance.validate_credentials(credentials=credentials) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/fishaudio/fishaudio.yaml b/api/core/model_runtime/model_providers/fishaudio/fishaudio.yaml new file mode 100644 index 00000000000000..479eb7fb85bd76 --- /dev/null +++ b/api/core/model_runtime/model_providers/fishaudio/fishaudio.yaml @@ -0,0 +1,76 @@ +provider: fishaudio +label: + en_US: Fish Audio +description: + en_US: Models provided by Fish Audio, currently only support TTS. + zh_Hans: Fish Audio 提供的模型,目前仅支持 TTS。 +icon_small: + en_US: fishaudio_s_en.svg +icon_large: + en_US: fishaudio_l_en.svg +background: "#E5E7EB" +help: + title: + en_US: Get your API key from Fish Audio + zh_Hans: 从 Fish Audio 获取你的 API Key + url: + en_US: https://fish.audio/go-api/ +supported_model_types: + - tts +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: api_base + label: + en_US: API URL + type: text-input + required: false + default: https://api.fish.audio + placeholder: + en_US: Enter your API URL + zh_Hans: 在此输入您的 API URL + - variable: use_public_models + label: + en_US: Use Public Models + type: select + required: false + default: "false" + placeholder: + en_US: Toggle to use public models + zh_Hans: 切换以使用公共模型 + options: + - value: "true" + label: + en_US: Allow Public Models + zh_Hans: 使用公共模型 + - value: "false" + label: + en_US: Private Models Only + zh_Hans: 仅使用私有模型 + - variable: latency + label: + en_US: Latency + type: select + required: false + default: "normal" + placeholder: + en_US: Toggle to choice latency + zh_Hans: 切换以调整延迟 + options: + - value: "balanced" + label: + en_US: Low (may affect quality) + zh_Hans: 低延迟 (可能降低质量) + - value: "normal" + label: + en_US: Normal + zh_Hans: 标准 diff --git a/api/core/model_runtime/model_providers/fishaudio/tts/__init__.py b/api/core/model_runtime/model_providers/fishaudio/tts/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/fishaudio/tts/tts.py b/api/core/model_runtime/model_providers/fishaudio/tts/tts.py new file mode 100644 index 00000000000000..e518d7b95b6e33 --- /dev/null +++ b/api/core/model_runtime/model_providers/fishaudio/tts/tts.py @@ -0,0 +1,158 @@ +from typing import Any, Optional + +import httpx + +from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.tts_model import TTSModel + + +class FishAudioText2SpeechModel(TTSModel): + """ + Model class for Fish.audio Text to Speech model. + """ + + def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: + api_base = credentials.get("api_base", "https://api.fish.audio") + api_key = credentials.get("api_key") + use_public_models = credentials.get("use_public_models", "false") == "true" + + params = { + "self": str(not use_public_models).lower(), + "page_size": "100", + } + + if language is not None: + if "-" in language: + language = language.split("-")[0] + params["language"] = language + + results = httpx.get( + f"{api_base}/model", + headers={"Authorization": f"Bearer {api_key}"}, + params=params, + ) + + results.raise_for_status() + data = results.json() + + return [{"name": i["title"], "value": i["_id"]} for i in data["items"]] + + def _invoke( + self, + model: str, + tenant_id: str, + credentials: dict, + content_text: str, + voice: str, + user: Optional[str] = None, + ) -> Any: + """ + Invoke text2speech model + + :param model: model name + :param tenant_id: user tenant id + :param credentials: model credentials + :param voice: model timbre + :param content_text: text content to be translated + :param user: unique user id + :return: generator yielding audio chunks + """ + + return self._tts_invoke_streaming( + model=model, + credentials=credentials, + content_text=content_text, + voice=voice, + ) + + def validate_credentials(self, credentials: dict, user: Optional[str] = None) -> None: + """ + Validate credentials for text2speech model + + :param credentials: model credentials + :param user: unique user id + """ + + try: + self.get_tts_model_voices( + None, + credentials={ + "api_key": credentials["api_key"], + "api_base": credentials["api_base"], + # Disable public models will trigger a 403 error if user is not logged in + "use_public_models": "false", + }, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> Any: + """ + Invoke streaming text2speech model + :param model: model name + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: ID of the reference audio (if any) + :return: generator yielding audio chunks + """ + + try: + word_limit = self._get_model_word_limit(model, credentials) + if len(content_text) > word_limit: + sentences = self._split_text_into_sentences(content_text, max_length=word_limit) + else: + sentences = [content_text.strip()] + + for i in range(len(sentences)): + yield from self._tts_invoke_streaming_sentence( + credentials=credentials, content_text=sentences[i], voice=voice + ) + + except Exception as ex: + raise InvokeBadRequestError(str(ex)) + + def _tts_invoke_streaming_sentence(self, credentials: dict, content_text: str, voice: Optional[str] = None) -> Any: + """ + Invoke streaming text2speech model + + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: ID of the reference audio (if any) + :return: generator yielding audio chunks + """ + api_key = credentials.get("api_key") + api_url = credentials.get("api_base", "https://api.fish.audio") + latency = credentials.get("latency") + + if not api_key: + raise InvokeBadRequestError("API key is required") + + with httpx.stream( + "POST", + api_url + "/v1/tts", + json={"text": content_text, "reference_id": voice, "latency": latency}, + headers={ + "Authorization": f"Bearer {api_key}", + }, + timeout=None, + ) as response: + if response.status_code != 200: + raise InvokeBadRequestError(f"Error: {response.status_code} - {response.text}") + yield from response.iter_bytes() + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeBadRequestError: [ + httpx.HTTPStatusError, + ], + } diff --git a/api/core/model_runtime/model_providers/fishaudio/tts/tts.yaml b/api/core/model_runtime/model_providers/fishaudio/tts/tts.yaml new file mode 100644 index 00000000000000..b4a446a95701c1 --- /dev/null +++ b/api/core/model_runtime/model_providers/fishaudio/tts/tts.yaml @@ -0,0 +1,5 @@ +model: tts-default +model_type: tts +model_properties: + word_limit: 1000 + audio_type: 'mp3' diff --git a/api/core/model_runtime/model_providers/gitee_ai/_assets/Gitee-AI-Logo-full.svg b/api/core/model_runtime/model_providers/gitee_ai/_assets/Gitee-AI-Logo-full.svg new file mode 100644 index 00000000000000..f9738b585b73cd --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/_assets/Gitee-AI-Logo-full.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/api/core/model_runtime/model_providers/gitee_ai/_assets/Gitee-AI-Logo.svg b/api/core/model_runtime/model_providers/gitee_ai/_assets/Gitee-AI-Logo.svg new file mode 100644 index 00000000000000..1f51187f197843 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/_assets/Gitee-AI-Logo.svg @@ -0,0 +1,3 @@ + + + diff --git a/api/core/model_runtime/model_providers/gitee_ai/_common.py b/api/core/model_runtime/model_providers/gitee_ai/_common.py new file mode 100644 index 00000000000000..0750f3b75d0542 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/_common.py @@ -0,0 +1,47 @@ +from dashscope.common.error import ( + AuthenticationError, + InvalidParameter, + RequestFailure, + ServiceUnavailableError, + UnsupportedHTTPMethod, + UnsupportedModel, +) + +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) + + +class _CommonGiteeAI: + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + RequestFailure, + ], + InvokeServerUnavailableError: [ + ServiceUnavailableError, + ], + InvokeRateLimitError: [], + InvokeAuthorizationError: [ + AuthenticationError, + ], + InvokeBadRequestError: [ + InvalidParameter, + UnsupportedModel, + UnsupportedHTTPMethod, + ], + } diff --git a/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py b/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py new file mode 100644 index 00000000000000..14aa8119052ad0 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py @@ -0,0 +1,36 @@ +import logging + +import requests + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class GiteeAIProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + api_key = credentials.get("api_key") + if not api_key: + raise CredentialsValidateFailedError("Credentials validation failed: api_key not given") + + # send a get request to validate the credentials + headers = {"Authorization": f"Bearer {api_key}"} + response = requests.get("https://ai.gitee.com/api/base/account/me", headers=headers, timeout=(10, 300)) + + if response.status_code != 200: + raise CredentialsValidateFailedError( + f"Credentials validation failed with status code {response.status_code}" + ) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.yaml b/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.yaml new file mode 100644 index 00000000000000..7f7d0f2e538ab6 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.yaml @@ -0,0 +1,35 @@ +provider: gitee_ai +label: + en_US: Gitee AI + zh_Hans: Gitee AI +description: + en_US: 快速体验大模型,领先探索 AI 开源世界 + zh_Hans: 快速体验大模型,领先探索 AI 开源世界 +icon_small: + en_US: Gitee-AI-Logo.svg +icon_large: + en_US: Gitee-AI-Logo-full.svg +help: + title: + en_US: Get your token from Gitee AI + zh_Hans: 从 Gitee AI 获取 token + url: + en_US: https://ai.gitee.com/dashboard/settings/tokens +supported_model_types: + - llm + - text-embedding + - rerank + - speech2text + - tts +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/Qwen2-72B-Instruct.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/Qwen2-72B-Instruct.yaml new file mode 100644 index 00000000000000..0348438a75ed1a --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/Qwen2-72B-Instruct.yaml @@ -0,0 +1,105 @@ +model: Qwen2-72B-Instruct +label: + zh_Hans: Qwen2-72B-Instruct + en_US: Qwen2-72B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 6400 +parameter_rules: + - name: stream + use_template: boolean + label: + en_US: "Stream" + zh_Hans: "流式" + type: boolean + default: true + required: true + help: + en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process." + zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。" + + - name: max_tokens + use_template: max_tokens + label: + en_US: "Max Tokens" + zh_Hans: "最大Token数" + type: int + default: 512 + min: 1 + required: true + help: + en_US: "The maximum number of tokens that can be generated by the model varies depending on the model." + zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。" + + - name: temperature + use_template: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + use_template: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_k + use_template: top_k + label: + en_US: "Top K" + zh_Hans: "Top K" + type: int + default: 50 + min: 0 + max: 100 + required: true + help: + en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be." + zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: -1.0 + max: 1.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/Qwen2-7B-Instruct.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/Qwen2-7B-Instruct.yaml new file mode 100644 index 00000000000000..ba1ad788f50507 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/Qwen2-7B-Instruct.yaml @@ -0,0 +1,105 @@ +model: Qwen2-7B-Instruct +label: + zh_Hans: Qwen2-7B-Instruct + en_US: Qwen2-7B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: stream + use_template: boolean + label: + en_US: "Stream" + zh_Hans: "流式" + type: boolean + default: true + required: true + help: + en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process." + zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。" + + - name: max_tokens + use_template: max_tokens + label: + en_US: "Max Tokens" + zh_Hans: "最大Token数" + type: int + default: 512 + min: 1 + required: true + help: + en_US: "The maximum number of tokens that can be generated by the model varies depending on the model." + zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。" + + - name: temperature + use_template: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + use_template: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_k + use_template: top_k + label: + en_US: "Top K" + zh_Hans: "Top K" + type: int + default: 50 + min: 0 + max: 100 + required: true + help: + en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be." + zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: -1.0 + max: 1.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/Yi-1.5-34B-Chat.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/Yi-1.5-34B-Chat.yaml new file mode 100644 index 00000000000000..f7260c987b3c59 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/Yi-1.5-34B-Chat.yaml @@ -0,0 +1,105 @@ +model: Yi-1.5-34B-Chat +label: + zh_Hans: Yi-1.5-34B-Chat + en_US: Yi-1.5-34B-Chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: stream + use_template: boolean + label: + en_US: "Stream" + zh_Hans: "流式" + type: boolean + default: true + required: true + help: + en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process." + zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。" + + - name: max_tokens + use_template: max_tokens + label: + en_US: "Max Tokens" + zh_Hans: "最大Token数" + type: int + default: 512 + min: 1 + required: true + help: + en_US: "The maximum number of tokens that can be generated by the model varies depending on the model." + zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。" + + - name: temperature + use_template: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + use_template: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_k + use_template: top_k + label: + en_US: "Top K" + zh_Hans: "Top K" + type: int + default: 50 + min: 0 + max: 100 + required: true + help: + en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be." + zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: -1.0 + max: 1.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/_position.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/_position.yaml new file mode 100644 index 00000000000000..21f6120742b1a1 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/_position.yaml @@ -0,0 +1,7 @@ +- Qwen2-7B-Instruct +- Qwen2-72B-Instruct +- Yi-1.5-34B-Chat +- glm-4-9b-chat +- deepseek-coder-33B-instruct-chat +- deepseek-coder-33B-instruct-completions +- codegeex4-all-9b diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/codegeex4-all-9b.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/codegeex4-all-9b.yaml new file mode 100644 index 00000000000000..8632cd92aba705 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/codegeex4-all-9b.yaml @@ -0,0 +1,105 @@ +model: codegeex4-all-9b +label: + zh_Hans: codegeex4-all-9b + en_US: codegeex4-all-9b +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 40960 +parameter_rules: + - name: stream + use_template: boolean + label: + en_US: "Stream" + zh_Hans: "流式" + type: boolean + default: true + required: true + help: + en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process." + zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。" + + - name: max_tokens + use_template: max_tokens + label: + en_US: "Max Tokens" + zh_Hans: "最大Token数" + type: int + default: 512 + min: 1 + required: true + help: + en_US: "The maximum number of tokens that can be generated by the model varies depending on the model." + zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。" + + - name: temperature + use_template: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + use_template: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_k + use_template: top_k + label: + en_US: "Top K" + zh_Hans: "Top K" + type: int + default: 50 + min: 0 + max: 100 + required: true + help: + en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be." + zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: -1.0 + max: 1.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/deepseek-coder-33B-instruct-chat.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/deepseek-coder-33B-instruct-chat.yaml new file mode 100644 index 00000000000000..2ac00761d508de --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/deepseek-coder-33B-instruct-chat.yaml @@ -0,0 +1,105 @@ +model: deepseek-coder-33B-instruct-chat +label: + zh_Hans: deepseek-coder-33B-instruct-chat + en_US: deepseek-coder-33B-instruct-chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 9000 +parameter_rules: + - name: stream + use_template: boolean + label: + en_US: "Stream" + zh_Hans: "流式" + type: boolean + default: true + required: true + help: + en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process." + zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。" + + - name: max_tokens + use_template: max_tokens + label: + en_US: "Max Tokens" + zh_Hans: "最大Token数" + type: int + default: 512 + min: 1 + required: true + help: + en_US: "The maximum number of tokens that can be generated by the model varies depending on the model." + zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。" + + - name: temperature + use_template: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + use_template: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_k + use_template: top_k + label: + en_US: "Top K" + zh_Hans: "Top K" + type: int + default: 50 + min: 0 + max: 100 + required: true + help: + en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be." + zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: -1.0 + max: 1.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/deepseek-coder-33B-instruct-completions.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/deepseek-coder-33B-instruct-completions.yaml new file mode 100644 index 00000000000000..7c364d89f7c3be --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/deepseek-coder-33B-instruct-completions.yaml @@ -0,0 +1,91 @@ +model: deepseek-coder-33B-instruct-completions +label: + zh_Hans: deepseek-coder-33B-instruct-completions + en_US: deepseek-coder-33B-instruct-completions +model_type: llm +features: + - agent-thought +model_properties: + mode: completion + context_size: 9000 +parameter_rules: + - name: stream + use_template: boolean + label: + en_US: "Stream" + zh_Hans: "流式" + type: boolean + default: true + required: true + help: + en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process." + zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。" + + - name: max_tokens + use_template: max_tokens + label: + en_US: "Max Tokens" + zh_Hans: "最大Token数" + type: int + default: 512 + min: 1 + required: true + help: + en_US: "The maximum number of tokens that can be generated by the model varies depending on the model." + zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。" + + - name: temperature + use_template: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + use_template: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: -1.0 + max: 1.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/glm-4-9b-chat.yaml b/api/core/model_runtime/model_providers/gitee_ai/llm/glm-4-9b-chat.yaml new file mode 100644 index 00000000000000..2afe1cf959f3fc --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/glm-4-9b-chat.yaml @@ -0,0 +1,105 @@ +model: glm-4-9b-chat +label: + zh_Hans: glm-4-9b-chat + en_US: glm-4-9b-chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: stream + use_template: boolean + label: + en_US: "Stream" + zh_Hans: "流式" + type: boolean + default: true + required: true + help: + en_US: "Whether to return the results in batches through streaming. If set to true, the generated text will be pushed to the user in real time during the generation process." + zh_Hans: "是否通过流式分批返回结果。如果设置为 true,生成过程中实时地向用户推送每一部分生成的文本。" + + - name: max_tokens + use_template: max_tokens + label: + en_US: "Max Tokens" + zh_Hans: "最大Token数" + type: int + default: 512 + min: 1 + required: true + help: + en_US: "The maximum number of tokens that can be generated by the model varies depending on the model." + zh_Hans: "模型可生成的最大 token 个数,不同模型上限不同。" + + - name: temperature + use_template: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + use_template: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_k + use_template: top_k + label: + en_US: "Top K" + zh_Hans: "Top K" + type: int + default: 50 + min: 0 + max: 100 + required: true + help: + en_US: "The value range is [0,100], which limits the model to only select from the top k words with the highest probability when choosing the next word at each step. The larger the value, the more diverse text generation will be." + zh_Hans: "取值范围为 [0,100],限制模型在每一步选择下一个词时,只从概率最高的前 k 个词中选取。数值越大,文本生成越多样。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: -1.0 + max: 1.0 + precision: 1 + required: false + help: + en_US: "Used to adjust the frequency of repeated content in automatically generated text. Positive numbers reduce repetition, while negative numbers increase repetition. After setting this parameter, if a word has already appeared in the text, the model will decrease the probability of choosing that word for subsequent generation." + zh_Hans: "用于调整自动生成文本中重复内容的频率。正数减少重复,负数增加重复。设置此参数后,如果一个词在文本中已经出现过,模型在后续生成中选择该词的概率会降低。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/llm.py b/api/core/model_runtime/model_providers/gitee_ai/llm/llm.py new file mode 100644 index 00000000000000..b65db6f6658221 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/llm.py @@ -0,0 +1,47 @@ +from collections.abc import Generator +from typing import Optional, Union + +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult +from core.model_runtime.entities.message_entities import ( + PromptMessage, + PromptMessageTool, +) +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel + + +class GiteeAILargeLanguageModel(OAIAPICompatLargeLanguageModel): + MODEL_TO_IDENTITY: dict[str, str] = { + "Yi-1.5-34B-Chat": "Yi-34B-Chat", + "deepseek-coder-33B-instruct-completions": "deepseek-coder-33B-instruct", + "deepseek-coder-33B-instruct-chat": "deepseek-coder-33B-instruct", + } + + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + self._add_custom_parameters(credentials, model, model_parameters) + return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials, model, None) + super().validate_credentials(model, credentials) + + @staticmethod + def _add_custom_parameters(credentials: dict, model: str, model_parameters: dict) -> None: + if model is None: + model = "bge-large-zh-v1.5" + + model_identity = GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model) + credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model_identity}/" + if model.endswith("completions"): + credentials["mode"] = LLMMode.COMPLETION.value + else: + credentials["mode"] = LLMMode.CHAT.value diff --git a/api/core/model_runtime/model_providers/gitee_ai/rerank/__init__.py b/api/core/model_runtime/model_providers/gitee_ai/rerank/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/gitee_ai/rerank/_position.yaml b/api/core/model_runtime/model_providers/gitee_ai/rerank/_position.yaml new file mode 100644 index 00000000000000..83162fd338d676 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/rerank/_position.yaml @@ -0,0 +1 @@ +- bge-reranker-v2-m3 diff --git a/api/core/model_runtime/model_providers/gitee_ai/rerank/bge-reranker-v2-m3.yaml b/api/core/model_runtime/model_providers/gitee_ai/rerank/bge-reranker-v2-m3.yaml new file mode 100644 index 00000000000000..f0681641e1210d --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/rerank/bge-reranker-v2-m3.yaml @@ -0,0 +1,4 @@ +model: bge-reranker-v2-m3 +model_type: rerank +model_properties: + context_size: 1024 diff --git a/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py b/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py new file mode 100644 index 00000000000000..231345c2f4e231 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py @@ -0,0 +1,128 @@ +from typing import Optional + +import httpx + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class GiteeAIRerankModel(RerankModel): + """ + Model class for rerank model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n documents to return + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + base_url = credentials.get("base_url", "https://ai.gitee.com/api/serverless") + base_url = base_url.removesuffix("/") + + try: + body = {"model": model, "query": query, "documents": docs} + if top_n is not None: + body["top_n"] = top_n + response = httpx.post( + f"{base_url}/{model}/rerank", + json=body, + headers={"Authorization": f"Bearer {credentials.get('api_key')}"}, + ) + + response.raise_for_status() + results = response.json() + + rerank_documents = [] + for result in results["results"]: + rerank_document = RerankDocument( + index=result["index"], + text=result["document"]["text"], + score=result["relevance_score"], + ) + if score_threshold is None or result["relevance_score"] >= score_threshold: + rerank_documents.append(rerank_document) + return RerankResult(model=model, docs=rerank_documents) + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.01, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return { + InvokeConnectionError: [httpx.ConnectError], + InvokeServerUnavailableError: [httpx.RemoteProtocolError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.RERANK, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, + ) + + return entity diff --git a/api/core/model_runtime/model_providers/gitee_ai/speech2text/__init__.py b/api/core/model_runtime/model_providers/gitee_ai/speech2text/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/gitee_ai/speech2text/_position.yaml b/api/core/model_runtime/model_providers/gitee_ai/speech2text/_position.yaml new file mode 100644 index 00000000000000..8e9b47598bd3a7 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/speech2text/_position.yaml @@ -0,0 +1,2 @@ +- whisper-base +- whisper-large diff --git a/api/core/model_runtime/model_providers/gitee_ai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/gitee_ai/speech2text/speech2text.py new file mode 100644 index 00000000000000..5597f5b43e57df --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/speech2text/speech2text.py @@ -0,0 +1,53 @@ +import os +from typing import IO, Optional + +import requests + +from core.model_runtime.errors.invoke import InvokeBadRequestError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from core.model_runtime.model_providers.gitee_ai._common import _CommonGiteeAI + + +class GiteeAISpeech2TextModel(_CommonGiteeAI, Speech2TextModel): + """ + Model class for OpenAI Compatible Speech to text model. + """ + + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: + """ + Invoke speech2text model + + :param model: model name + :param credentials: model credentials + :param file: audio file + :param user: unique user id + :return: text for given audio file + """ + # doc: https://ai.gitee.com/docs/openapi/serverless#tag/serverless/POST/{service}/speech-to-text + + endpoint_url = f"https://ai.gitee.com/api/serverless/{model}/speech-to-text" + files = [("file", file)] + _, file_ext = os.path.splitext(file.name) + headers = {"Content-Type": f"audio/{file_ext}", "Authorization": f"Bearer {credentials.get('api_key')}"} + response = requests.post(endpoint_url, headers=headers, files=files) + if response.status_code != 200: + raise InvokeBadRequestError(response.text) + response_data = response.json() + return response_data["text"] + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + audio_file_path = self._get_demo_file_path() + + with open(audio_file_path, "rb") as audio_file: + self._invoke(model, credentials, audio_file) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) diff --git a/api/core/model_runtime/model_providers/gitee_ai/speech2text/whisper-base.yaml b/api/core/model_runtime/model_providers/gitee_ai/speech2text/whisper-base.yaml new file mode 100644 index 00000000000000..a50bf5fc2d60c4 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/speech2text/whisper-base.yaml @@ -0,0 +1,5 @@ +model: whisper-base +model_type: speech2text +model_properties: + file_upload_limit: 1 + supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm diff --git a/api/core/model_runtime/model_providers/gitee_ai/speech2text/whisper-large.yaml b/api/core/model_runtime/model_providers/gitee_ai/speech2text/whisper-large.yaml new file mode 100644 index 00000000000000..1be7b1a3919f99 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/speech2text/whisper-large.yaml @@ -0,0 +1,5 @@ +model: whisper-large +model_type: speech2text +model_properties: + file_upload_limit: 1 + supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm diff --git a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/_position.yaml b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/_position.yaml new file mode 100644 index 00000000000000..e8abe6440de110 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/_position.yaml @@ -0,0 +1,3 @@ +- bge-large-zh-v1.5 +- bge-small-zh-v1.5 +- bge-m3 diff --git a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-large-zh-v1.5.yaml b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-large-zh-v1.5.yaml new file mode 100644 index 00000000000000..9e3ca76e8824f7 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-large-zh-v1.5.yaml @@ -0,0 +1,8 @@ +model: bge-large-zh-v1.5 +label: + zh_Hans: bge-large-zh-v1.5 + en_US: bge-large-zh-v1.5 +model_type: text-embedding +model_properties: + context_size: 200000 + max_chunks: 20 diff --git a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-m3.yaml b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-m3.yaml new file mode 100644 index 00000000000000..a7a99a98a3e7bd --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-m3.yaml @@ -0,0 +1,8 @@ +model: bge-m3 +label: + zh_Hans: bge-m3 + en_US: bge-m3 +model_type: text-embedding +model_properties: + context_size: 200000 + max_chunks: 20 diff --git a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-small-zh-v1.5.yaml b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-small-zh-v1.5.yaml new file mode 100644 index 00000000000000..bd760408fa6cc3 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/bge-small-zh-v1.5.yaml @@ -0,0 +1,8 @@ +model: bge-small-zh-v1.5 +label: + zh_Hans: bge-small-zh-v1.5 + en_US: bge-small-zh-v1.5 +model_type: text-embedding +model_properties: + context_size: 200000 + max_chunks: 20 diff --git a/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py new file mode 100644 index 00000000000000..b833c5652c650a --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/text_embedding/text_embedding.py @@ -0,0 +1,31 @@ +from typing import Optional + +from core.entities.embedding_type import EmbeddingInputType +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import ( + OAICompatEmbeddingModel, +) + + +class GiteeAIEmbeddingModel(OAICompatEmbeddingModel): + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: + self._add_custom_parameters(credentials, model) + return super()._invoke(model, credentials, texts, user, input_type) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials, None) + super().validate_credentials(model, credentials) + + @staticmethod + def _add_custom_parameters(credentials: dict, model: str) -> None: + if model is None: + model = "bge-m3" + + credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model}/v1/" diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/ChatTTS.yaml b/api/core/model_runtime/model_providers/gitee_ai/tts/ChatTTS.yaml new file mode 100644 index 00000000000000..940391dfab0c1c --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/tts/ChatTTS.yaml @@ -0,0 +1,11 @@ +model: ChatTTS +model_type: tts +model_properties: + default_voice: 'default' + voices: + - mode: 'default' + name: 'Default' + language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ] + word_limit: 3500 + audio_type: 'mp3' + max_workers: 5 diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/FunAudioLLM-CosyVoice-300M.yaml b/api/core/model_runtime/model_providers/gitee_ai/tts/FunAudioLLM-CosyVoice-300M.yaml new file mode 100644 index 00000000000000..8fc573480158f7 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/tts/FunAudioLLM-CosyVoice-300M.yaml @@ -0,0 +1,11 @@ +model: FunAudioLLM-CosyVoice-300M +model_type: tts +model_properties: + default_voice: 'default' + voices: + - mode: 'default' + name: 'Default' + language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ] + word_limit: 3500 + audio_type: 'mp3' + max_workers: 5 diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/__init__.py b/api/core/model_runtime/model_providers/gitee_ai/tts/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/_position.yaml b/api/core/model_runtime/model_providers/gitee_ai/tts/_position.yaml new file mode 100644 index 00000000000000..13c6ec84540864 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/tts/_position.yaml @@ -0,0 +1,4 @@ +- speecht5_tts +- ChatTTS +- fish-speech-1.2-sft +- FunAudioLLM-CosyVoice-300M diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/fish-speech-1.2-sft.yaml b/api/core/model_runtime/model_providers/gitee_ai/tts/fish-speech-1.2-sft.yaml new file mode 100644 index 00000000000000..93cc28bc9dcca4 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/tts/fish-speech-1.2-sft.yaml @@ -0,0 +1,11 @@ +model: fish-speech-1.2-sft +model_type: tts +model_properties: + default_voice: 'default' + voices: + - mode: 'default' + name: 'Default' + language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ] + word_limit: 3500 + audio_type: 'mp3' + max_workers: 5 diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/speecht5_tts.yaml b/api/core/model_runtime/model_providers/gitee_ai/tts/speecht5_tts.yaml new file mode 100644 index 00000000000000..f9c843bd412573 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/tts/speecht5_tts.yaml @@ -0,0 +1,11 @@ +model: speecht5_tts +model_type: tts +model_properties: + default_voice: 'default' + voices: + - mode: 'default' + name: 'Default' + language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ] + word_limit: 3500 + audio_type: 'mp3' + max_workers: 5 diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py b/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py new file mode 100644 index 00000000000000..ed2bd5b13ddce4 --- /dev/null +++ b/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py @@ -0,0 +1,79 @@ +from typing import Optional + +import requests + +from core.model_runtime.errors.invoke import InvokeBadRequestError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.tts_model import TTSModel +from core.model_runtime.model_providers.gitee_ai._common import _CommonGiteeAI + + +class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel): + """ + Model class for OpenAI Speech to text model. + """ + + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ) -> any: + """ + _invoke text2speech model + + :param model: model name + :param tenant_id: user tenant id + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model timbre + :param user: unique user id + :return: text translated to audio file + """ + return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + validate credentials text2speech model + + :param model: model name + :param credentials: model credentials + :return: text translated to audio file + """ + try: + self._tts_invoke_streaming( + model=model, + credentials=credentials, + content_text="Hello Dify!", + voice=self._get_model_default_voice(model, credentials), + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: + """ + _tts_invoke_streaming text2speech model + :param model: model name + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model timbre + :return: text translated to audio file + """ + try: + # doc: https://ai.gitee.com/docs/openapi/serverless#tag/serverless/POST/{service}/text-to-speech + endpoint_url = "https://ai.gitee.com/api/serverless/" + model + "/text-to-speech" + + headers = {"Content-Type": "application/json"} + api_key = credentials.get("api_key") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + payload = {"inputs": content_text} + response = requests.post(endpoint_url, headers=headers, json=payload) + + if response.status_code != 200: + raise InvokeBadRequestError(response.text) + + data = response.content + + for i in range(0, len(data), 1024): + yield data[i : i + 1024] + except Exception as ex: + raise InvokeBadRequestError(str(ex)) diff --git a/api/core/model_runtime/model_providers/google/google.py b/api/core/model_runtime/model_providers/google/google.py index ba25c74e71e857..70f56a8337b2e6 100644 --- a/api/core/model_runtime/model_providers/google/google.py +++ b/api/core/model_runtime/model_providers/google/google.py @@ -20,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: model_instance = self.get_model_instance(ModelType.LLM) # Use `gemini-pro` model for validate, - model_instance.validate_credentials( - model='gemini-pro', - credentials=credentials - ) + model_instance.validate_credentials(model="gemini-pro", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/google/llm/_position.yaml b/api/core/model_runtime/model_providers/google/llm/_position.yaml new file mode 100644 index 00000000000000..63b9ca3a292e77 --- /dev/null +++ b/api/core/model_runtime/model_providers/google/llm/_position.yaml @@ -0,0 +1,15 @@ +- gemini-1.5-pro +- gemini-1.5-pro-latest +- gemini-1.5-pro-001 +- gemini-1.5-pro-002 +- gemini-1.5-pro-exp-0801 +- gemini-1.5-pro-exp-0827 +- gemini-1.5-flash +- gemini-1.5-flash-latest +- gemini-1.5-flash-001 +- gemini-1.5-flash-002 +- gemini-1.5-flash-exp-0827 +- gemini-1.5-flash-8b-exp-0827 +- gemini-1.5-flash-8b-exp-0924 +- gemini-pro +- gemini-pro-vision diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-001.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-001.yaml new file mode 100644 index 00000000000000..8d8cd248474bbe --- /dev/null +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-001.yaml @@ -0,0 +1,39 @@ +model: gemini-1.5-flash-001 +label: + en_US: Gemini 1.5 Flash 001 +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 1048576 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens_to_sample + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 + - name: response_format + use_template: response_format +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-002.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-002.yaml new file mode 100644 index 00000000000000..ae6b85cb23044d --- /dev/null +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-002.yaml @@ -0,0 +1,39 @@ +model: gemini-1.5-flash-002 +label: + en_US: Gemini 1.5 Flash 002 +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 1048576 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens_to_sample + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 + - name: response_format + use_template: response_format +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0827.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0827.yaml new file mode 100644 index 00000000000000..bbc697e934e055 --- /dev/null +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0827.yaml @@ -0,0 +1,39 @@ +model: gemini-1.5-flash-8b-exp-0827 +label: + en_US: Gemini 1.5 Flash 8B 0827 +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 1048576 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens_to_sample + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 + - name: response_format + use_template: response_format +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0924.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0924.yaml new file mode 100644 index 00000000000000..890faf8c3f497a --- /dev/null +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0924.yaml @@ -0,0 +1,39 @@ +model: gemini-1.5-flash-8b-exp-0924 +label: + en_US: Gemini 1.5 Flash 8B 0924 +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 1048576 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens_to_sample + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 + - name: response_format + use_template: response_format +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-exp-0827.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-exp-0827.yaml new file mode 100644 index 00000000000000..c5695e5dda8eb0 --- /dev/null +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-exp-0827.yaml @@ -0,0 +1,39 @@ +model: gemini-1.5-flash-exp-0827 +label: + en_US: Gemini 1.5 Flash 0827 +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 1048576 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens_to_sample + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 + - name: response_format + use_template: response_format +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml index 24b1c5af8a3fd8..d1c264c3a7f662 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml @@ -1,6 +1,6 @@ model: gemini-1.5-flash-latest label: - en_US: Gemini 1.5 Flash + en_US: Gemini 1.5 Flash Latest model_type: llm features: - agent-thought diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash.yaml new file mode 100644 index 00000000000000..6b794e9beeec60 --- /dev/null +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash.yaml @@ -0,0 +1,39 @@ +model: gemini-1.5-flash +label: + en_US: Gemini 1.5 Flash +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 1048576 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens_to_sample + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 + - name: response_format + use_template: response_format +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-001.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-001.yaml new file mode 100644 index 00000000000000..9ac5e3ad1b95a9 --- /dev/null +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-001.yaml @@ -0,0 +1,39 @@ +model: gemini-1.5-pro-001 +label: + en_US: Gemini 1.5 Pro 001 +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 2097152 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens_to_sample + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 + - name: response_format + use_template: response_format +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-002.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-002.yaml new file mode 100644 index 00000000000000..f1d01d0763d710 --- /dev/null +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-002.yaml @@ -0,0 +1,39 @@ +model: gemini-1.5-pro-002 +label: + en_US: Gemini 1.5 Pro 002 +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 2097152 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens_to_sample + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 + - name: response_format + use_template: response_format +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0801.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0801.yaml new file mode 100644 index 00000000000000..0a918e0d7b1ac3 --- /dev/null +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0801.yaml @@ -0,0 +1,39 @@ +model: gemini-1.5-pro-exp-0801 +label: + en_US: Gemini 1.5 Pro 0801 +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 2097152 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens_to_sample + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 + - name: response_format + use_template: response_format +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0827.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0827.yaml new file mode 100644 index 00000000000000..7452ce46e7dcb6 --- /dev/null +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0827.yaml @@ -0,0 +1,39 @@ +model: gemini-1.5-pro-exp-0827 +label: + en_US: Gemini 1.5 Pro 0827 +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 2097152 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens_to_sample + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 + - name: response_format + use_template: response_format +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml index d65dc026749797..65c2d97e924ef5 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml @@ -1,6 +1,6 @@ model: gemini-1.5-pro-latest label: - en_US: Gemini 1.5 Pro + en_US: Gemini 1.5 Pro Latest model_type: llm features: - agent-thought @@ -9,7 +9,7 @@ features: - stream-tool-call model_properties: mode: chat - context_size: 1048576 + context_size: 2097152 parameter_rules: - name: temperature use_template: temperature diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro.yaml new file mode 100644 index 00000000000000..12620b57b61ee3 --- /dev/null +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro.yaml @@ -0,0 +1,39 @@ +model: gemini-1.5-pro +label: + en_US: Gemini 1.5 Pro +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 2097152 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens_to_sample + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 + - name: response_format + use_template: response_format +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index ebcd0af35b2138..b1b07a611bfb0f 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -1,17 +1,18 @@ import base64 +import io import json import logging -import mimetypes from collections.abc import Generator from typing import Optional, Union, cast import google.ai.generativelanguage as glm -import google.api_core.exceptions as exceptions import google.generativeai as genai -import google.generativeai.client as client import requests -from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory +from google.api_core import exceptions +from google.generativeai.client import _ClientManager +from google.generativeai.types import ContentType, GenerateContentResponse from google.generativeai.types.content_types import to_part +from PIL import Image from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -44,16 +45,21 @@ {{instructions}} -""" +""" # noqa: E501 class GoogleLargeLanguageModel(LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -69,9 +75,14 @@ def _invoke(self, model: str, credentials: dict, """ # invoke model return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -84,7 +95,7 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr prompt = self._convert_messages_to_prompt(prompt_messages) return self._get_num_tokens_by_gpt2(prompt) - + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Google model @@ -94,13 +105,10 @@ def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) return text.rstrip() - + def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool: """ Convert tool messages to glm tools @@ -108,24 +116,33 @@ def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool :param tools: tool messages :return: glm tools """ - return glm.Tool( - function_declarations=[ - glm.FunctionDeclaration( - name=tool.name, - parameters=glm.Schema( - type=glm.Type.OBJECT, - properties={ - key: { - 'type_': value.get('type', 'string').upper(), - 'description': value.get('description', ''), - 'enum': value.get('enum', []) - } for key, value in tool.parameters.get('properties', {}).items() - }, - required=tool.parameters.get('required', []) - ), - ) for tool in tools - ] - ) + function_declarations = [] + for tool in tools: + properties = {} + for key, value in tool.parameters.get("properties", {}).items(): + properties[key] = { + "type_": glm.Type.STRING, + "description": value.get("description", ""), + "enum": value.get("enum", []), + } + + if properties: + parameters = glm.Schema( + type=glm.Type.OBJECT, + properties=properties, + required=tool.parameters.get("required", []), + ) + else: + parameters = None + + function_declaration = glm.FunctionDeclaration( + name=tool.name, + parameters=parameters, + description=tool.description, + ) + function_declarations.append(function_declaration) + + return glm.Tool(function_declarations=function_declarations) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -135,20 +152,25 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :param credentials: model credentials :return: """ - + try: ping_message = SystemPromptMessage(content="ping") self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5}) - + except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None - ) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -162,14 +184,12 @@ def _generate(self, model: str, credentials: dict, :return: full response or stream response chunk generator result """ config_kwargs = model_parameters.copy() - config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None) + config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None) if stop: config_kwargs["stop_sequences"] = stop - google_model = genai.GenerativeModel( - model_name=model - ) + google_model = genai.GenerativeModel(model_name=model) history = [] @@ -179,7 +199,7 @@ def _generate(self, model: str, credentials: dict, content = self._format_message_to_glm_content(last_msg) history.append(content) else: - for msg in prompt_messages: # makes message roles strictly alternating + for msg in prompt_messages: # makes message roles strictly alternating content = self._format_message_to_glm_content(msg) if history and history[-1]["role"] == content["role"]: history[-1]["parts"].extend(content["parts"]) @@ -187,28 +207,18 @@ def _generate(self, model: str, credentials: dict, history.append(content) # Create a new ClientManager with tenant's API key - new_client_manager = client._ClientManager() + new_client_manager = _ClientManager() new_client_manager.configure(api_key=credentials["google_api_key"]) new_custom_client = new_client_manager.make_client("generative") google_model._client = new_custom_client - safety_settings={ - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, - } - response = google_model.generate_content( contents=history, - generation_config=genai.types.GenerationConfig( - **config_kwargs - ), + generation_config=genai.types.GenerationConfig(**config_kwargs), stream=stream, - safety_settings=safety_settings, tools=self._convert_tools_to_glm_tool(tools) if tools else None, - request_options={"timeout": 600} + request_options={"timeout": 600}, ) if stream: @@ -216,8 +226,9 @@ def _generate(self, model: str, credentials: dict, return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: GenerateContentResponse, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -228,9 +239,7 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Gen :return: llm response """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=response.text - ) + assistant_prompt_message = AssistantPromptMessage(content=response.text) # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -249,8 +258,9 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Gen return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: GenerateContentResponse, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -263,9 +273,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon index = -1 for chunk in response: for part in chunk.parts: - assistant_prompt_message = AssistantPromptMessage( - content='' - ) + assistant_prompt_message = AssistantPromptMessage(content="") if part.text: assistant_prompt_message.content += part.text @@ -274,36 +282,31 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon assistant_prompt_message.tool_calls = [ AssistantPromptMessage.ToolCall( id=part.function_call.name, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=part.function_call.name, - arguments=json.dumps(dict(part.function_call.args.items())) - ) + arguments=json.dumps(dict(part.function_call.args.items())), + ), ) ] index += 1 - + if not response._done: - # transform assistant message to prompt message yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) else: - # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -311,8 +314,8 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon index=index, message=assistant_prompt_message, finish_reason=str(chunk.candidates[0].finish_reason), - usage=usage - ) + usage=usage, + ), ) def _convert_one_message_to_text(self, message: PromptMessage) -> str: @@ -327,17 +330,13 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: content = message.content if isinstance(content, list): - content = "".join( - c.data for c in content if c.type != PromptMessageContentType.IMAGE - ) + content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE) if isinstance(message, UserPromptMessage): message_text = f"{human_prompt} {content}" elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" - elif isinstance(message, SystemPromptMessage): - message_text = f"{human_prompt} {content}" - elif isinstance(message, ToolPromptMessage): + elif isinstance(message, SystemPromptMessage | ToolPromptMessage): message_text = f"{human_prompt} {content}" else: raise ValueError(f"Got unknown type {message}") @@ -352,94 +351,86 @@ def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType: :return: glm Content representation of message """ if isinstance(message, UserPromptMessage): - glm_content = { - "role": "user", - "parts": [] - } - if (isinstance(message.content, str)): - glm_content['parts'].append(to_part(message.content)) + glm_content = {"role": "user", "parts": []} + if isinstance(message.content, str): + glm_content["parts"].append(to_part(message.content)) else: for c in message.content: if c.type == PromptMessageContentType.TEXT: - glm_content['parts'].append(to_part(c.data)) + glm_content["parts"].append(to_part(c.data)) elif c.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, c) if message_content.data.startswith("data:"): - metadata, base64_data = c.data.split(',', 1) - mime_type = metadata.split(';', 1)[0].split(':')[1] + metadata, base64_data = c.data.split(",", 1) + mime_type = metadata.split(";", 1)[0].split(":")[1] else: # fetch image data from url try: image_content = requests.get(message_content.data).content - mime_type, _ = mimetypes.guess_type(message_content.data) - base64_data = base64.b64encode(image_content).decode('utf-8') + with Image.open(io.BytesIO(image_content)) as img: + mime_type = f"image/{img.format.lower()}" + base64_data = base64.b64encode(image_content).decode("utf-8") except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") - blob = {"inline_data":{"mime_type":mime_type,"data":base64_data}} - glm_content['parts'].append(blob) + blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}} + glm_content["parts"].append(blob) return glm_content elif isinstance(message, AssistantPromptMessage): - glm_content = { - "role": "model", - "parts": [] - } + glm_content = {"role": "model", "parts": []} if message.content: - glm_content['parts'].append(to_part(message.content)) + glm_content["parts"].append(to_part(message.content)) if message.tool_calls: - glm_content["parts"].append(to_part(glm.FunctionCall( - name=message.tool_calls[0].function.name, - args=json.loads(message.tool_calls[0].function.arguments), - ))) + glm_content["parts"].append( + to_part( + glm.FunctionCall( + name=message.tool_calls[0].function.name, + args=json.loads(message.tool_calls[0].function.arguments), + ) + ) + ) return glm_content elif isinstance(message, SystemPromptMessage): - return { - "role": "user", - "parts": [to_part(message.content)] - } + return {"role": "user", "parts": [to_part(message.content)]} elif isinstance(message, ToolPromptMessage): return { "role": "function", - "parts": [glm.Part(function_response=glm.FunctionResponse( - name=message.name, - response={ - "response": message.content - } - ))] + "parts": [ + glm.Part( + function_response=glm.FunctionResponse( + name=message.name, response={"response": message.content} + ) + ) + ], } else: raise ValueError(f"Got unknown type {message}") - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ Map model invoke error to unified error - The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller - The value is the md = genai.GenerativeModel(model)error type thrown by the model, + The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller + The value is the md = genai.GenerativeModel(model) error type thrown by the model, which needs to be converted into a unified error type for the caller. - :return: Invoke emd = genai.GenerativeModel(model)rror mapping + :return: Invoke emd = genai.GenerativeModel(model) error mapping """ return { - InvokeConnectionError: [ - exceptions.RetryError - ], + InvokeConnectionError: [exceptions.RetryError], InvokeServerUnavailableError: [ exceptions.ServiceUnavailable, exceptions.InternalServerError, exceptions.BadGateway, exceptions.GatewayTimeout, - exceptions.DeadlineExceeded - ], - InvokeRateLimitError: [ - exceptions.ResourceExhausted, - exceptions.TooManyRequests + exceptions.DeadlineExceeded, ], + InvokeRateLimitError: [exceptions.ResourceExhausted, exceptions.TooManyRequests], InvokeAuthorizationError: [ exceptions.Unauthenticated, exceptions.PermissionDenied, exceptions.Unauthenticated, - exceptions.Forbidden + exceptions.Forbidden, ], InvokeBadRequestError: [ exceptions.BadRequest, @@ -455,5 +446,5 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] exceptions.PreconditionFailed, exceptions.RequestRangeNotSatisfiable, exceptions.Cancelled, - ] + ], } diff --git a/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.png new file mode 100644 index 00000000000000..dfe8e78049c4de Binary files /dev/null and b/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.png differ diff --git a/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.svg new file mode 100644 index 00000000000000..bb23bffcf1c039 --- /dev/null +++ b/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.svg @@ -0,0 +1,15 @@ + + + + + + + + + + + + + + + diff --git a/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.png b/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.png new file mode 100644 index 00000000000000..b154821db91ab0 Binary files /dev/null and b/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.png differ diff --git a/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.svg new file mode 100644 index 00000000000000..c5c608cd7c603d --- /dev/null +++ b/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/api/core/model_runtime/model_providers/gpustack/gpustack.py b/api/core/model_runtime/model_providers/gpustack/gpustack.py new file mode 100644 index 00000000000000..321100167ee02e --- /dev/null +++ b/api/core/model_runtime/model_providers/gpustack/gpustack.py @@ -0,0 +1,10 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class GPUStackProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + pass diff --git a/api/core/model_runtime/model_providers/gpustack/gpustack.yaml b/api/core/model_runtime/model_providers/gpustack/gpustack.yaml new file mode 100644 index 00000000000000..ee4a3c159a0b25 --- /dev/null +++ b/api/core/model_runtime/model_providers/gpustack/gpustack.yaml @@ -0,0 +1,120 @@ +provider: gpustack +label: + en_US: GPUStack +icon_small: + en_US: icon_s_en.png +icon_large: + en_US: icon_l_en.png +supported_model_types: + - llm + - text-embedding + - rerank +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: endpoint_url + label: + zh_Hans: 服务器地址 + en_US: Server URL + type: text-input + required: true + placeholder: + zh_Hans: 输入 GPUStack 的服务器地址,如 http://192.168.1.100 + en_US: Enter the GPUStack server URL, e.g. http://192.168.1.100 + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 输入您的 API Key + en_US: Enter your API Key + - variable: mode + show_on: + - variable: __model_type + value: llm + label: + en_US: Completion mode + type: select + required: false + default: chat + placeholder: + zh_Hans: 选择补全类型 + en_US: Select completion type + options: + - value: completion + label: + en_US: Completion + zh_Hans: 补全 + - value: chat + label: + en_US: Chat + zh_Hans: 对话 + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + type: text-input + default: "8192" + placeholder: + zh_Hans: 输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: max_tokens_to_sample + label: + zh_Hans: 最大 token 上限 + en_US: Upper bound for max tokens + show_on: + - variable: __model_type + value: llm + default: "8192" + type: text-input + - variable: function_calling_type + show_on: + - variable: __model_type + value: llm + label: + en_US: Function calling + type: select + required: false + default: no_call + options: + - value: function_call + label: + en_US: Function Call + zh_Hans: Function Call + - value: tool_call + label: + en_US: Tool Call + zh_Hans: Tool Call + - value: no_call + label: + en_US: Not Support + zh_Hans: 不支持 + - variable: vision_support + show_on: + - variable: __model_type + value: llm + label: + zh_Hans: Vision 支持 + en_US: Vision Support + type: select + required: false + default: no_support + options: + - value: support + label: + en_US: Support + zh_Hans: 支持 + - value: no_support + label: + en_US: Not Support + zh_Hans: 不支持 diff --git a/api/core/model_runtime/model_providers/gpustack/llm/__init__.py b/api/core/model_runtime/model_providers/gpustack/llm/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/gpustack/llm/llm.py b/api/core/model_runtime/model_providers/gpustack/llm/llm.py new file mode 100644 index 00000000000000..ce6780b6a7c83b --- /dev/null +++ b/api/core/model_runtime/model_providers/gpustack/llm/llm.py @@ -0,0 +1,45 @@ +from collections.abc import Generator + +from yarl import URL + +from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.message_entities import ( + PromptMessage, + PromptMessageTool, +) +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import ( + OAIAPICompatLargeLanguageModel, +) + + +class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + return super()._invoke( + model, + credentials, + prompt_messages, + model_parameters, + tools, + stop, + stream, + user, + ) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + super().validate_credentials(model, credentials) + + @staticmethod + def _add_custom_parameters(credentials: dict) -> None: + credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai") + credentials["mode"] = "chat" diff --git a/api/core/model_runtime/model_providers/gpustack/rerank/__init__.py b/api/core/model_runtime/model_providers/gpustack/rerank/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/gpustack/rerank/rerank.py b/api/core/model_runtime/model_providers/gpustack/rerank/rerank.py new file mode 100644 index 00000000000000..5ea7532564098d --- /dev/null +++ b/api/core/model_runtime/model_providers/gpustack/rerank/rerank.py @@ -0,0 +1,146 @@ +from json import dumps +from typing import Optional + +import httpx +from requests import post +from yarl import URL + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, +) +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class GPUStackRerankModel(RerankModel): + """ + Model class for GPUStack rerank model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n documents to return + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + endpoint_url = credentials["endpoint_url"] + headers = { + "Authorization": f"Bearer {credentials.get('api_key')}", + "Content-Type": "application/json", + } + + data = {"model": model, "query": query, "documents": docs, "top_n": top_n} + + try: + response = post( + str(URL(endpoint_url) / "v1" / "rerank"), + headers=headers, + data=dumps(data), + timeout=10, + ) + response.raise_for_status() + results = response.json() + + rerank_documents = [] + for result in results["results"]: + index = result["index"] + if "document" in result: + text = result["document"]["text"] + else: + text = docs[index] + + rerank_document = RerankDocument( + index=index, + text=text, + score=result["relevance_score"], + ) + + if score_threshold is None or result["relevance_score"] >= score_threshold: + rerank_documents.append(rerank_document) + + return RerankResult(model=model, docs=rerank_documents) + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return { + InvokeConnectionError: [httpx.ConnectError], + InvokeServerUnavailableError: [httpx.RemoteProtocolError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.RERANK, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, + ) + + return entity diff --git a/api/core/model_runtime/model_providers/gpustack/text_embedding/__init__.py b/api/core/model_runtime/model_providers/gpustack/text_embedding/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py new file mode 100644 index 00000000000000..eb324491a2dace --- /dev/null +++ b/api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py @@ -0,0 +1,35 @@ +from typing import Optional + +from yarl import URL + +from core.entities.embedding_type import EmbeddingInputType +from core.model_runtime.entities.text_embedding_entities import ( + TextEmbeddingResult, +) +from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import ( + OAICompatEmbeddingModel, +) + + +class GPUStackTextEmbeddingModel(OAICompatEmbeddingModel): + """ + Model class for GPUStack text embedding model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: + return super()._invoke(model, credentials, texts, user, input_type) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + super().validate_credentials(model, credentials) + + @staticmethod + def _add_custom_parameters(credentials: dict) -> None: + credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai") diff --git a/api/core/model_runtime/model_providers/groq/groq.py b/api/core/model_runtime/model_providers/groq/groq.py index b3f37b39678388..d0d5ff68f8090e 100644 --- a/api/core/model_runtime/model_providers/groq/groq.py +++ b/api/core/model_runtime/model_providers/groq/groq.py @@ -6,8 +6,8 @@ logger = logging.getLogger(__name__) -class GroqProvider(ModelProvider): +class GroqProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -18,12 +18,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='llama3-8b-8192', - credentials=credentials - ) + model_instance.validate_credentials(model="llama3-8b-8192", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/groq/groq.yaml b/api/core/model_runtime/model_providers/groq/groq.yaml index db17cc8bdd086a..d6534e1bf11c2b 100644 --- a/api/core/model_runtime/model_providers/groq/groq.yaml +++ b/api/core/model_runtime/model_providers/groq/groq.yaml @@ -18,6 +18,7 @@ help: en_US: https://console.groq.com/ supported_model_types: - llm + - speech2text configurate_methods: - predefined-model provider_credential_schema: diff --git a/api/core/model_runtime/model_providers/groq/llm/_position.yaml b/api/core/model_runtime/model_providers/groq/llm/_position.yaml index be115ca920df08..0613b19f87ee5e 100644 --- a/api/core/model_runtime/model_providers/groq/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/groq/llm/_position.yaml @@ -5,3 +5,4 @@ - llama3-8b-8192 - mixtral-8x7b-32768 - llama2-70b-4096 +- llama-guard-3-8b diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-text-preview.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-text-preview.yaml new file mode 100644 index 00000000000000..019d45372361d3 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-text-preview.yaml @@ -0,0 +1,25 @@ +model: llama-3.2-11b-text-preview +label: + zh_Hans: Llama 3.2 11B Text (Preview) + en_US: Llama 3.2 11B Text (Preview) +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8192 +pricing: + input: '0.05' + output: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-vision-preview.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-vision-preview.yaml new file mode 100644 index 00000000000000..56322187973a0a --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-vision-preview.yaml @@ -0,0 +1,26 @@ +model: llama-3.2-11b-vision-preview +label: + zh_Hans: Llama 3.2 11B Vision (Preview) + en_US: Llama 3.2 11B Vision (Preview) +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8192 +pricing: + input: '0.05' + output: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-1b-preview.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-1b-preview.yaml new file mode 100644 index 00000000000000..a44e4ff508eb82 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-1b-preview.yaml @@ -0,0 +1,25 @@ +model: llama-3.2-1b-preview +label: + zh_Hans: Llama 3.2 1B Text (Preview) + en_US: Llama 3.2 1B Text (Preview) +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8192 +pricing: + input: '0.05' + output: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-3b-preview.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-3b-preview.yaml new file mode 100644 index 00000000000000..f2fdd0a05e027a --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-3b-preview.yaml @@ -0,0 +1,25 @@ +model: llama-3.2-3b-preview +label: + zh_Hans: Llama 3.2 3B Text (Preview) + en_US: Llama 3.2 3B Text (Preview) +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8192 +pricing: + input: '0.05' + output: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-text-preview.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-text-preview.yaml new file mode 100644 index 00000000000000..3b34e7c07996bd --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-text-preview.yaml @@ -0,0 +1,25 @@ +model: llama-3.2-90b-text-preview +label: + zh_Hans: Llama 3.2 90B Text (Preview) + en_US: Llama 3.2 90B Text (Preview) +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8192 +pricing: + input: '0.05' + output: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-vision-preview.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-vision-preview.yaml new file mode 100644 index 00000000000000..e7b93101e868f5 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-vision-preview.yaml @@ -0,0 +1,26 @@ +model: llama-3.2-90b-vision-preview +label: + zh_Hans: Llama 3.2 90B Vision (Preview) + en_US: Llama 3.2 90B Vision (Preview) +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8192 +pricing: + input: '0.05' + output: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-guard-3-8b.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-guard-3-8b.yaml new file mode 100644 index 00000000000000..03779ccc66f63a --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/llm/llama-guard-3-8b.yaml @@ -0,0 +1,25 @@ +model: llama-guard-3-8b +label: + zh_Hans: Llama-Guard-3-8B + en_US: Llama-Guard-3-8B +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8192 +pricing: + input: '0.20' + output: '0.20' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/groq/llm/llm.py b/api/core/model_runtime/model_providers/groq/llm/llm.py index 915f7a4e1a7e0d..352a7b519ee168 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llm.py +++ b/api/core/model_runtime/model_providers/groq/llm/llm.py @@ -7,11 +7,17 @@ class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -21,6 +27,5 @@ def validate_credentials(self, model: str, credentials: dict) -> None: @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.groq.com/openai/v1' - + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.groq.com/openai/v1" diff --git a/api/core/model_runtime/model_providers/groq/speech2text/__init__.py b/api/core/model_runtime/model_providers/groq/speech2text/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/groq/speech2text/distil-whisper-large-v3-en.yaml b/api/core/model_runtime/model_providers/groq/speech2text/distil-whisper-large-v3-en.yaml new file mode 100644 index 00000000000000..202d006a66c94f --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/speech2text/distil-whisper-large-v3-en.yaml @@ -0,0 +1,5 @@ +model: distil-whisper-large-v3-en +model_type: speech2text +model_properties: + file_upload_limit: 1 + supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm diff --git a/api/core/model_runtime/model_providers/groq/speech2text/speech2text.py b/api/core/model_runtime/model_providers/groq/speech2text/speech2text.py new file mode 100644 index 00000000000000..75feeb9cb99cc7 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/speech2text/speech2text.py @@ -0,0 +1,30 @@ +from typing import IO, Optional + +from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import OAICompatSpeech2TextModel + + +class GroqSpeech2TextModel(OAICompatSpeech2TextModel): + """ + Model class for Groq Speech to text model. + """ + + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: + """ + Invoke speech2text model + + :param model: model name + :param credentials: model credentials + :param file: audio file + :param user: unique user id + :return: text for given audio file + """ + self._add_custom_parameters(credentials) + return super()._invoke(model, credentials, file) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + return super().validate_credentials(model, credentials) + + @classmethod + def _add_custom_parameters(cls, credentials: dict) -> None: + credentials["endpoint_url"] = "https://api.groq.com/openai/v1" diff --git a/api/core/model_runtime/model_providers/groq/speech2text/whisper-large-v3-turbo.yaml b/api/core/model_runtime/model_providers/groq/speech2text/whisper-large-v3-turbo.yaml new file mode 100644 index 00000000000000..3882a3f4f24653 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/speech2text/whisper-large-v3-turbo.yaml @@ -0,0 +1,5 @@ +model: whisper-large-v3-turbo +model_type: speech2text +model_properties: + file_upload_limit: 1 + supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm diff --git a/api/core/model_runtime/model_providers/groq/speech2text/whisper-large-v3.yaml b/api/core/model_runtime/model_providers/groq/speech2text/whisper-large-v3.yaml new file mode 100644 index 00000000000000..ed02477d709fdf --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/speech2text/whisper-large-v3.yaml @@ -0,0 +1,5 @@ +model: whisper-large-v3 +model_type: speech2text +model_properties: + file_upload_limit: 1 + supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm diff --git a/api/core/model_runtime/model_providers/huggingface_hub/_common.py b/api/core/model_runtime/model_providers/huggingface_hub/_common.py index dd8ae526e6a759..3c4020b6eedf24 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/_common.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/_common.py @@ -4,12 +4,6 @@ class _CommonHuggingfaceHub: - @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: - return { - InvokeBadRequestError: [ - HfHubHTTPError, - BadRequestError - ] - } + return {InvokeBadRequestError: [HfHubHTTPError, BadRequestError]} diff --git a/api/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py b/api/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py index 15e2a4fed41be7..54d2a2bf399623 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/huggingface_hub.py @@ -6,6 +6,5 @@ class HuggingfaceHubProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py index f43a8aedaf2c69..9d29237fdde573 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py @@ -29,16 +29,23 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: - - client = InferenceClient(token=credentials['huggingfacehub_api_token']) - - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - model = credentials['huggingfacehub_endpoint_url'] - - if 'baichuan' in model.lower(): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) + + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + model = credentials["huggingfacehub_endpoint_url"] + + if "baichuan" in model.lower(): stream = False response = client.text_generation( @@ -47,98 +54,100 @@ def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMes stream=stream, model=model, stop_sequences=stop, - **model_parameters) + **model_parameters, + ) if stream: return self._handle_generate_stream_response(model, credentials, prompt_messages, response) return self._handle_generate_response(model, credentials, prompt_messages, response) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: prompt = self._convert_messages_to_prompt(prompt_messages) return self._get_num_tokens_by_gpt2(prompt) def validate_credentials(self, model: str, credentials: dict) -> None: try: - if 'huggingfacehub_api_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type must be provided.') + if "huggingfacehub_api_type" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.") - if credentials['huggingfacehub_api_type'] not in ('inference_endpoints', 'hosted_inference_api'): - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type is invalid.') + if credentials["huggingfacehub_api_type"] not in {"inference_endpoints", "hosted_inference_api"}: + raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.") - if 'huggingfacehub_api_token' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Access Token must be provided.') + if "huggingfacehub_api_token" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Access Token must be provided.") - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - if 'huggingfacehub_endpoint_url' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint URL must be provided.') + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + if "huggingfacehub_endpoint_url" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Endpoint URL must be provided.") - if 'task_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Task Type must be provided.') - elif credentials['huggingfacehub_api_type'] == 'hosted_inference_api': - credentials['task_type'] = self._get_hosted_model_task_type(credentials['huggingfacehub_api_token'], - model) + if "task_type" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Task Type must be provided.") + elif credentials["huggingfacehub_api_type"] == "hosted_inference_api": + credentials["task_type"] = self._get_hosted_model_task_type( + credentials["huggingfacehub_api_token"], model + ) - if credentials['task_type'] not in ("text2text-generation", "text-generation"): - raise CredentialsValidateFailedError('Huggingface Hub Task Type must be one of text2text-generation, ' - 'text-generation.') + if credentials["task_type"] not in {"text2text-generation", "text-generation"}: + raise CredentialsValidateFailedError( + "Huggingface Hub Task Type must be one of text2text-generation, text-generation." + ) - client = InferenceClient(token=credentials['huggingfacehub_api_token']) + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - model = credentials['huggingfacehub_endpoint_url'] + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + model = credentials["huggingfacehub_endpoint_url"] try: - client.text_generation( - prompt='Who are you?', - stream=True, - model=model) + client.text_generation(prompt="Who are you?", stream=True, model=model) except BadRequestError as e: - raise CredentialsValidateFailedError('Only available for models running on with the `text-generation-inference`. ' - 'To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.') + raise CredentialsValidateFailedError( + "Only available for models running on with the `text-generation-inference`. " + "To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference." + ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - model_properties={ - ModelPropertyKey.MODE: LLMMode.COMPLETION.value - }, - parameter_rules=self._get_customizable_model_parameter_rules() + model_properties={ModelPropertyKey.MODE: LLMMode.COMPLETION.value}, + parameter_rules=self._get_customizable_model_parameter_rules(), ) return entity @staticmethod def _get_customizable_model_parameter_rules() -> list[ParameterRule]: - temperature_rule_dict = PARAMETER_RULE_TEMPLATE.get( - DefaultParameterName.TEMPERATURE).copy() - temperature_rule_dict['name'] = 'temperature' + temperature_rule_dict = PARAMETER_RULE_TEMPLATE.get(DefaultParameterName.TEMPERATURE).copy() + temperature_rule_dict["name"] = "temperature" temperature_rule = ParameterRule(**temperature_rule_dict) temperature_rule.default = 0.5 top_p_rule_dict = PARAMETER_RULE_TEMPLATE.get(DefaultParameterName.TOP_P).copy() - top_p_rule_dict['name'] = 'top_p' + top_p_rule_dict["name"] = "top_p" top_p_rule = ParameterRule(**top_p_rule_dict) top_p_rule.default = 0.5 top_k_rule = ParameterRule( - name='top_k', + name="top_k", label={ - 'en_US': 'Top K', - 'zh_Hans': 'Top K', + "en_US": "Top K", + "zh_Hans": "Top K", }, - type='int', + type="int", help={ - 'en_US': 'The number of highest probability vocabulary tokens to keep for top-k-filtering.', - 'zh_Hans': '保留的最高概率词汇标记的数量。', + "en_US": "The number of highest probability vocabulary tokens to keep for top-k-filtering.", + "zh_Hans": "保留的最高概率词汇标记的数量。", }, required=False, default=2, @@ -148,15 +157,15 @@ def _get_customizable_model_parameter_rules() -> list[ParameterRule]: ) max_new_tokens = ParameterRule( - name='max_new_tokens', + name="max_new_tokens", label={ - 'en_US': 'Max New Tokens', - 'zh_Hans': '最大新标记', + "en_US": "Max New Tokens", + "zh_Hans": "最大新标记", }, - type='int', + type="int", help={ - 'en_US': 'Maximum number of generated tokens.', - 'zh_Hans': '生成的标记的最大数量。', + "en_US": "Maximum number of generated tokens.", + "zh_Hans": "生成的标记的最大数量。", }, required=False, default=20, @@ -166,30 +175,30 @@ def _get_customizable_model_parameter_rules() -> list[ParameterRule]: ) seed = ParameterRule( - name='seed', + name="seed", label={ - 'en_US': 'Random sampling seed', - 'zh_Hans': '随机采样种子', + "en_US": "Random sampling seed", + "zh_Hans": "随机采样种子", }, - type='int', + type="int", help={ - 'en_US': 'Random sampling seed.', - 'zh_Hans': '随机采样种子。', + "en_US": "Random sampling seed.", + "zh_Hans": "随机采样种子。", }, required=False, precision=0, ) repetition_penalty = ParameterRule( - name='repetition_penalty', + name="repetition_penalty", label={ - 'en_US': 'Repetition Penalty', - 'zh_Hans': '重复惩罚', + "en_US": "Repetition Penalty", + "zh_Hans": "重复惩罚", }, - type='float', + type="float", help={ - 'en_US': 'The parameter for repetition penalty. 1.0 means no penalty.', - 'zh_Hans': '重复惩罚的参数。1.0 表示没有惩罚。', + "en_US": "The parameter for repetition penalty. 1.0 means no penalty.", + "zh_Hans": "重复惩罚的参数。1.0 表示没有惩罚。", }, required=False, precision=1, @@ -197,11 +206,9 @@ def _get_customizable_model_parameter_rules() -> list[ParameterRule]: return [temperature_rule, top_k_rule, top_p_rule, max_new_tokens, seed, repetition_penalty] - def _handle_generate_stream_response(self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - response: Generator) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: Generator + ) -> Generator: index = -1 for chunk in response: # skip special tokens @@ -210,9 +217,7 @@ def _handle_generate_stream_response(self, index += 1 - assistant_prompt_message = AssistantPromptMessage( - content=chunk.token.text - ) + assistant_prompt_message = AssistantPromptMessage(content=chunk.token.text) if chunk.details: prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -240,15 +245,15 @@ def _handle_generate_stream_response(self, ), ) - def _handle_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any + ) -> LLMResult: if isinstance(response, str): content = response else: content = response.generated_text - assistant_prompt_message = AssistantPromptMessage( - content=content - ) + assistant_prompt_message = AssistantPromptMessage(content=content) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) @@ -270,15 +275,14 @@ def _get_hosted_model_task_type(huggingfacehub_api_token: str, model_name: str): try: if not model_info: - raise ValueError(f'Model {model_name} not found.') + raise ValueError(f"Model {model_name} not found.") - if 'inference' in model_info.cardData and not model_info.cardData['inference']: - raise ValueError(f'Inference API has been turned off for this model {model_name}.') + if "inference" in model_info.cardData and not model_info.cardData["inference"]: + raise ValueError(f"Inference API has been turned off for this model {model_name}.") valid_tasks = ("text2text-generation", "text-generation") if model_info.pipeline_tag not in valid_tasks: - raise ValueError(f"Model {model_name} is not a valid task, " - f"must be one of {valid_tasks}.") + raise ValueError(f"Model {model_name} is not a valid task, must be one of {valid_tasks}.") except Exception as e: raise CredentialsValidateFailedError(f"{str(e)}") @@ -287,10 +291,7 @@ def _get_hosted_model_task_type(huggingfacehub_api_token: str, model_name: str): def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) return text.rstrip() diff --git a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py index 0f0c166f3ec179..8278d1e64def89 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py @@ -6,6 +6,7 @@ import requests from huggingface_hub import HfApi, InferenceClient +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult @@ -13,40 +14,45 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub -HUGGINGFACE_ENDPOINT_API = 'https://api.endpoints.huggingface.cloud/v2/endpoint/' +HUGGINGFACE_ENDPOINT_API = "https://api.endpoints.huggingface.cloud/v2/endpoint/" class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel): - - def _invoke(self, model: str, credentials: dict, texts: list[str], - user: Optional[str] = None) -> TextEmbeddingResult: - client = InferenceClient(token=credentials['huggingfacehub_api_token']) + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :param input_type: input type + :return: embeddings result + """ + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) execute_model = model - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - execute_model = credentials['huggingfacehub_endpoint_url'] + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + execute_model = credentials["huggingfacehub_endpoint_url"] output = client.post( - json={ - "inputs": texts, - "options": { - "wait_for_model": False, - "use_cache": False - } - }, - model=execute_model) + json={"inputs": texts, "options": {"wait_for_model": False, "use_cache": False}}, model=execute_model + ) embeddings = json.loads(output.decode()) tokens = self.get_num_tokens(model, credentials, texts) usage = self._calc_response_usage(model, credentials, tokens) - return TextEmbeddingResult( - embeddings=self._mean_pooling(embeddings), - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=self._mean_pooling(embeddings), usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: num_tokens = 0 @@ -56,52 +62,48 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int def validate_credentials(self, model: str, credentials: dict) -> None: try: - if 'huggingfacehub_api_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type must be provided.') + if "huggingfacehub_api_type" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.") - if 'huggingfacehub_api_token' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub API Token must be provided.') + if "huggingfacehub_api_token" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub API Token must be provided.") - if credentials['huggingfacehub_api_type'] == 'inference_endpoints': - if 'huggingface_namespace' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub User Name / Organization Name must be provided.') + if credentials["huggingfacehub_api_type"] == "inference_endpoints": + if "huggingface_namespace" not in credentials: + raise CredentialsValidateFailedError( + "Huggingface Hub User Name / Organization Name must be provided." + ) - if 'huggingfacehub_endpoint_url' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint URL must be provided.') + if "huggingfacehub_endpoint_url" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Endpoint URL must be provided.") - if 'task_type' not in credentials: - raise CredentialsValidateFailedError('Huggingface Hub Task Type must be provided.') + if "task_type" not in credentials: + raise CredentialsValidateFailedError("Huggingface Hub Task Type must be provided.") - if credentials['task_type'] != 'feature-extraction': - raise CredentialsValidateFailedError('Huggingface Hub Task Type is invalid.') + if credentials["task_type"] != "feature-extraction": + raise CredentialsValidateFailedError("Huggingface Hub Task Type is invalid.") self._check_endpoint_url_model_repository_name(credentials, model) - model = credentials['huggingfacehub_endpoint_url'] + model = credentials["huggingfacehub_endpoint_url"] - elif credentials['huggingfacehub_api_type'] == 'hosted_inference_api': - self._check_hosted_model_task_type(credentials['huggingfacehub_api_token'], - model) + elif credentials["huggingfacehub_api_type"] == "hosted_inference_api": + self._check_hosted_model_task_type(credentials["huggingfacehub_api_token"], model) else: - raise CredentialsValidateFailedError('Huggingface Hub Endpoint Type is invalid.') + raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.") - client = InferenceClient(token=credentials['huggingfacehub_api_token']) - client.feature_extraction(text='hello world', model=model) + client = InferenceClient(token=credentials["huggingfacehub_api_token"]) + client.feature_extraction(text="hello world", model=model) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, - model_properties={ - 'context_size': 10000, - 'max_chunks': 1 - } + model_properties={"context_size": 10000, "max_chunks": 1}, ) return entity @@ -128,24 +130,20 @@ def _check_hosted_model_task_type(huggingfacehub_api_token: str, model_name: str try: if not model_info: - raise ValueError(f'Model {model_name} not found.') + raise ValueError(f"Model {model_name} not found.") - if 'inference' in model_info.cardData and not model_info.cardData['inference']: - raise ValueError(f'Inference API has been turned off for this model {model_name}.') + if "inference" in model_info.cardData and not model_info.cardData["inference"]: + raise ValueError(f"Inference API has been turned off for this model {model_name}.") valid_tasks = "feature-extraction" if model_info.pipeline_tag not in valid_tasks: - raise ValueError(f"Model {model_name} is not a valid task, " - f"must be one of {valid_tasks}.") + raise ValueError(f"Model {model_name} is not a valid task, must be one of {valid_tasks}.") except Exception as e: raise CredentialsValidateFailedError(f"{str(e)}") def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -156,7 +154,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -166,25 +164,26 @@ def _check_endpoint_url_model_repository_name(credentials: dict, model_name: str try: url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}' headers = { - 'Authorization': f'Bearer {credentials["huggingfacehub_api_token"]}', - 'Content-Type': 'application/json' + "Authorization": f'Bearer {credentials["huggingfacehub_api_token"]}', + "Content-Type": "application/json", } response = requests.get(url=url, headers=headers) if response.status_code != 200: - raise ValueError('User Name or Organization Name is invalid.') + raise ValueError("User Name or Organization Name is invalid.") - model_repository_name = '' + model_repository_name = "" for item in response.json().get("items", []): - if item.get("status", {}).get("url") == credentials['huggingfacehub_endpoint_url']: + if item.get("status", {}).get("url") == credentials["huggingfacehub_endpoint_url"]: model_repository_name = item.get("model", {}).get("repository") break if model_repository_name != model_name: raise ValueError( - f'Model Name {model_name} is invalid. Please check it on the inference endpoints console.') + f"Model Name {model_name} is invalid. Please check it on the inference endpoints console." + ) except Exception as e: raise ValueError(str(e)) diff --git a/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py b/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py index 94544662503974..97d7e28dc646f8 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py @@ -6,6 +6,5 @@ class HuggingfaceTeiProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py b/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py index 34013426de5b77..0bb9a9c8b58449 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py @@ -47,29 +47,28 @@ def _invoke( """ if len(docs) == 0: return RerankResult(model=model, docs=[]) - server_url = credentials['server_url'] + server_url = credentials["server_url"] - if server_url.endswith('/'): - server_url = server_url[:-1] + server_url = server_url.removesuffix("/") try: results = TeiHelper.invoke_rerank(server_url, query, docs) rerank_documents = [] - for result in results: + for result in results: rerank_document = RerankDocument( - index=result['index'], - text=result['text'], - score=result['score'], + index=result["index"], + text=result["text"], + score=result["score"], ) - if score_threshold is None or result['score'] >= score_threshold: + if score_threshold is None or result["score"] >= score_threshold: rerank_documents.append(rerank_document) if top_n is not None and len(rerank_documents) >= top_n: break return RerankResult(model=model, docs=rerank_documents) except httpx.HTTPStatusError as e: - raise InvokeServerUnavailableError(str(e)) + raise InvokeServerUnavailableError(str(e)) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -80,21 +79,21 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - server_url = credentials['server_url'] + server_url = credentials["server_url"] extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) - if extra_args.model_type != 'reranker': - raise CredentialsValidateFailedError('Current model is not a rerank model') + if extra_args.model_type != "reranker": + raise CredentialsValidateFailedError("Current model is not a rerank model") - credentials['context_size'] = extra_args.max_input_length + credentials["context_size"] = extra_args.max_input_length self.invoke( model=model, credentials=credentials, - query='Whose kasumi', + query="Whose kasumi", docs=[ 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', - 'Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ', - 'and she leads a team named PopiParty.', + "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", + "and she leads a team named PopiParty.", ], score_threshold=0.8, ) @@ -119,7 +118,7 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ used to define customizable model schema """ @@ -129,7 +128,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.RERANK, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)), }, parameter_rules=[], ) diff --git a/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py index 2aa785c89d27e6..81ab2492144e86 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py @@ -31,16 +31,16 @@ def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraPa with cache_lock: if model_name not in cache: cache[model_name] = { - 'expires': time() + 300, - 'value': TeiHelper._get_tei_extra_parameter(server_url), + "expires": time() + 300, + "value": TeiHelper._get_tei_extra_parameter(server_url), } - return cache[model_name]['value'] + return cache[model_name]["value"] @staticmethod def _clean_cache() -> None: try: with cache_lock: - expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()] + expired_keys = [model_uid for model_uid, model in cache.items() if model["expires"] < time()] for model_uid in expired_keys: del cache[model_uid] except RuntimeError as e: @@ -52,40 +52,39 @@ def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter: get tei model extra parameter like model_type, max_input_length, max_batch_requests """ - url = str(URL(server_url) / 'info') + url = str(URL(server_url) / "info") - # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 + # this method is surrounded by a lock, and default requests may hang forever, + # so we just set a Adapter with max_retries=3 session = Session() - session.mount('http://', HTTPAdapter(max_retries=3)) - session.mount('https://', HTTPAdapter(max_retries=3)) + session.mount("http://", HTTPAdapter(max_retries=3)) + session.mount("https://", HTTPAdapter(max_retries=3)) try: response = session.get(url, timeout=10) except (MissingSchema, ConnectionError, Timeout) as e: - raise RuntimeError(f'get tei model extra parameter failed, url: {url}, error: {e}') + raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}") if response.status_code != 200: raise RuntimeError( - f'get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}' + f"get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}" ) response_json = response.json() - model_type = response_json.get('model_type', {}) + model_type = response_json.get("model_type", {}) if len(model_type.keys()) < 1: - raise RuntimeError('model_type is empty') + raise RuntimeError("model_type is empty") model_type = list(model_type.keys())[0] - if model_type not in ['embedding', 'reranker']: - raise RuntimeError(f'invalid model_type: {model_type}') - - max_input_length = response_json.get('max_input_length', 512) - max_client_batch_size = response_json.get('max_client_batch_size', 1) + if model_type not in {"embedding", "reranker"}: + raise RuntimeError(f"invalid model_type: {model_type}") + + max_input_length = response_json.get("max_input_length", 512) + max_client_batch_size = response_json.get("max_client_batch_size", 1) return TeiModelExtraParameter( - model_type=model_type, - max_input_length=max_input_length, - max_client_batch_size=max_client_batch_size + model_type=model_type, max_input_length=max_input_length, max_client_batch_size=max_client_batch_size ) - + @staticmethod def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: """ @@ -116,12 +115,12 @@ def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: :param texts: texts to tokenize """ resp = httpx.post( - f'{server_url}/tokenize', - json={'inputs': texts}, + f"{server_url}/tokenize", + json={"inputs": texts}, ) resp.raise_for_status() return resp.json() - + @staticmethod def invoke_embeddings(server_url: str, texts: list[str]) -> dict: """ @@ -149,8 +148,8 @@ def invoke_embeddings(server_url: str, texts: list[str]) -> dict: """ # Use OpenAI compatible API here, which has usage tracking resp = httpx.post( - f'{server_url}/v1/embeddings', - json={'input': texts}, + f"{server_url}/v1/embeddings", + json={"input": texts}, ) resp.raise_for_status() return resp.json() @@ -173,11 +172,11 @@ def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]: :param texts: texts to rerank :param candidates: candidates to rerank """ - params = {'query': query, 'texts': docs, 'return_text': True} + params = {"query": query, "texts": docs, "return_text": True} response = httpx.post( - server_url + '/rerank', + server_url + "/rerank", json=params, ) - response.raise_for_status() + response.raise_for_status() return response.json() diff --git a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py index 6897b87f6d7525..a0917630a9f9c7 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py @@ -1,6 +1,7 @@ import time from typing import Optional +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult @@ -23,7 +24,12 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): """ def _invoke( - self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -38,13 +44,12 @@ def _invoke( :param credentials: model credentials :param texts: texts to embed :param user: unique user id + :param input_type: input type :return: embeddings result """ - server_url = credentials['server_url'] - - if server_url.endswith('/'): - server_url = server_url[:-1] + server_url = credentials["server_url"] + server_url = server_url.removesuffix("/") # get model properties context_size = self._get_context_size(model, credentials) @@ -58,7 +63,6 @@ def _invoke( batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts) for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)): - # Check if the number of tokens is larger than the context size num_tokens = len(tokenize_result) @@ -66,20 +70,22 @@ def _invoke( # Find the best cutoff point pre_special_token_count = 0 for token in tokenize_result: - if token['special']: + if token["special"]: pre_special_token_count += 1 else: break - rest_special_token_count = len([token for token in tokenize_result if token['special']]) - pre_special_token_count + rest_special_token_count = ( + len([token for token in tokenize_result if token["special"]]) - pre_special_token_count + ) # Calculate the cutoff point, leave 20 extra space to avoid exceeding the limit token_cutoff = context_size - rest_special_token_count - 20 # Find the cutoff index cutpoint_token = tokenize_result[token_cutoff] - cutoff = cutpoint_token['start'] + cutoff = cutpoint_token["start"] - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) indices += [i] @@ -92,12 +98,12 @@ def _invoke( for i in _iter: iter_texts = inputs[i : i + max_chunks] results = TeiHelper.invoke_embeddings(server_url, iter_texts) - embeddings = results['data'] - embeddings = [embedding['embedding'] for embedding in embeddings] + embeddings = results["data"] + embeddings = [embedding["embedding"] for embedding in embeddings] batched_embeddings.extend(embeddings) - usage = results['usage'] - used_tokens += usage['total_tokens'] + usage = results["usage"] + used_tokens += usage["total_tokens"] except RuntimeError as e: raise InvokeServerUnavailableError(str(e)) @@ -117,10 +123,9 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int :return: """ num_tokens = 0 - server_url = credentials['server_url'] + server_url = credentials["server_url"] - if server_url.endswith('/'): - server_url = server_url[:-1] + server_url = server_url.removesuffix("/") batch_tokens = TeiHelper.invoke_tokenize(server_url, texts) num_tokens = sum(len(tokens) for tokens in batch_tokens) @@ -135,15 +140,15 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - server_url = credentials['server_url'] + server_url = credentials["server_url"] extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) print(extra_args) - if extra_args.model_type != 'embedding': - raise CredentialsValidateFailedError('Current model is not a embedding model') + if extra_args.model_type != "embedding": + raise CredentialsValidateFailedError("Current model is not a embedding model") - credentials['context_size'] = extra_args.max_input_length - credentials['max_chunks'] = extra_args.max_client_batch_size - self._invoke(model=model, credentials=credentials, texts=['ping']) + credentials["context_size"] = extra_args.max_input_length + credentials["max_chunks"] = extra_args.max_client_batch_size + self._invoke(model=model, credentials=credentials, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -184,7 +189,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em return usage - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ used to define customizable model schema """ @@ -195,8 +200,8 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ - ModelPropertyKey.MAX_CHUNKS: int(credentials.get('max_chunks', 1)), - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)), + ModelPropertyKey.MAX_CHUNKS: int(credentials.get("max_chunks", 1)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)), }, parameter_rules=[], ) diff --git a/api/core/model_runtime/model_providers/hunyuan/hunyuan.py b/api/core/model_runtime/model_providers/hunyuan/hunyuan.py index 5a298d33acac5c..e65772e7dda3a1 100644 --- a/api/core/model_runtime/model_providers/hunyuan/hunyuan.py +++ b/api/core/model_runtime/model_providers/hunyuan/hunyuan.py @@ -6,8 +6,8 @@ logger = logging.getLogger(__name__) -class HunyuanProvider(ModelProvider): +class HunyuanProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +19,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: model_instance = self.get_model_instance(ModelType.LLM) # Use `hunyuan-standard` model for validate, - model_instance.validate_credentials( - model='hunyuan-standard', - credentials=credentials - ) + model_instance.validate_credentials(model="hunyuan-standard", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/_position.yaml b/api/core/model_runtime/model_providers/hunyuan/llm/_position.yaml index 2c1b981f8504a4..f494984443cb42 100644 --- a/api/core/model_runtime/model_providers/hunyuan/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/hunyuan/llm/_position.yaml @@ -2,3 +2,5 @@ - hunyuan-standard - hunyuan-standard-256k - hunyuan-pro +- hunyuan-turbo +- hunyuan-vision diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-standard-256k.yaml b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-standard-256k.yaml index 1f94a8623b494c..8504b90eb3cf74 100644 --- a/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-standard-256k.yaml +++ b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-standard-256k.yaml @@ -1,7 +1,7 @@ -model: hunyuan-standard-256k +model: hunyuan-standard-256K label: - zh_Hans: hunyuan-standard-256k - en_US: hunyuan-standard-256k + zh_Hans: hunyuan-standard-256K + en_US: hunyuan-standard-256K model_type: llm features: - agent-thought diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-turbo.yaml b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-turbo.yaml new file mode 100644 index 00000000000000..4837fed4bae563 --- /dev/null +++ b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-turbo.yaml @@ -0,0 +1,38 @@ +model: hunyuan-turbo +label: + zh_Hans: hunyuan-turbo + en_US: hunyuan-turbo +model_type: llm +features: + - agent-thought + - tool-call + - multi-tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 32000 + - name: enable_enhance + label: + zh_Hans: 功能增强 + en_US: Enable Enhancement + type: boolean + help: + zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。 + en_US: Allow the model to perform external search to enhance the generation results. + required: false + default: true +pricing: + input: '0.015' + output: '0.05' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-vision.yaml b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-vision.yaml new file mode 100644 index 00000000000000..9edc7f4710f9a5 --- /dev/null +++ b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-vision.yaml @@ -0,0 +1,39 @@ +model: hunyuan-vision +label: + zh_Hans: hunyuan-vision + en_US: hunyuan-vision +model_type: llm +features: + - agent-thought + - tool-call + - multi-tool-call + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 8000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 8000 + - name: enable_enhance + label: + zh_Hans: 功能增强 + en_US: Enable Enhancement + type: boolean + help: + zh_Hans: 功能增强(如搜索)开关,关闭时将直接由主模型生成回复内容,可以降低响应时延(对于流式输出时的首字时延尤为明显)。但在少数场景里,回复效果可能会下降。 + en_US: Allow the model to perform external search to enhance the generation results. + required: false + default: true +pricing: + input: '0.018' + output: '0.018' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py index 0bdf6ec005056b..2014de8516bc11 100644 --- a/api/core/model_runtime/model_providers/hunyuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/hunyuan/llm/llm.py @@ -1,6 +1,7 @@ import json import logging from collections.abc import Generator +from typing import cast from tencentcloud.common import credential from tencentcloud.common.exception import TencentCloudSDKException @@ -11,9 +12,12 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, + ImagePromptMessageContent, PromptMessage, + PromptMessageContentType, PromptMessageTool, SystemPromptMessage, + TextPromptMessageContent, ToolPromptMessage, UserPromptMessage, ) @@ -23,21 +27,27 @@ logger = logging.getLogger(__name__) -class HunyuanLargeLanguageModel(LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: +class HunyuanLargeLanguageModel(LargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: client = self._setup_hunyuan_client(credentials) request = models.ChatCompletionsRequest() messages_dict = self._convert_prompt_messages_to_dicts(prompt_messages) custom_parameters = { - 'Temperature': model_parameters.get('temperature', 0.0), - 'TopP': model_parameters.get('top_p', 1.0), - 'EnableEnhancement': model_parameters.get('enable_enhance', True) + "Temperature": model_parameters.get("temperature", 0.0), + "TopP": model_parameters.get("top_p", 1.0), + "EnableEnhancement": model_parameters.get("enable_enhance", True), } params = { @@ -47,16 +57,19 @@ def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMes **custom_parameters, } # add Tools and ToolChoice - if (tools and len(tools) > 0): - params['ToolChoice'] = "auto" - params['Tools'] = [{ - "Type": "function", - "Function": { - "Name": tool.name, - "Description": tool.description, - "Parameters": json.dumps(tool.parameters) + if tools and len(tools) > 0: + params["ToolChoice"] = "auto" + params["Tools"] = [ + { + "Type": "function", + "Function": { + "Name": tool.name, + "Description": tool.description, + "Parameters": json.dumps(tool.parameters), + }, } - } for tool in tools] + for tool in tools + ] request.from_json_string(json.dumps(params)) response = client.ChatCompletions(request) @@ -76,22 +89,19 @@ def validate_credentials(self, model: str, credentials: dict) -> None: req = models.ChatCompletionsRequest() params = { "Model": model, - "Messages": [{ - "Role": "user", - "Content": "hello" - }], + "Messages": [{"Role": "user", "Content": "hello"}], "TopP": 1, "Temperature": 0, - "Stream": False + "Stream": False, } req.from_json_string(json.dumps(params)) client.ChatCompletions(req) except Exception as e: - raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") def _setup_hunyuan_client(self, credentials): - secret_id = credentials['secret_id'] - secret_key = credentials['secret_key'] + secret_id = credentials["secret_id"] + secret_key = credentials["secret_key"] cred = credential.Credential(secret_id, secret_key) httpProfile = HttpProfile() httpProfile.endpoint = "hunyuan.tencentcloudapi.com" @@ -106,92 +116,116 @@ def _convert_prompt_messages_to_dicts(self, prompt_messages: list[PromptMessage] for message in prompt_messages: if isinstance(message, AssistantPromptMessage): tool_calls = message.tool_calls - if (tool_calls and len(tool_calls) > 0): + if tool_calls and len(tool_calls) > 0: dict_tool_calls = [ { "Id": tool_call.id, "Type": tool_call.type, "Function": { "Name": tool_call.function.name, - "Arguments": tool_call.function.arguments if (tool_call.function.arguments == "") else "{}" - } - } for tool_call in tool_calls] - - dict_list.append({ - "Role": message.role.value, - # fix set content = "" while tool_call request - # fix [hunyuan] None, [TencentCloudSDKException] code:InvalidParameter message:Messages Content and Contents not allowed empty at the same time. - "Content": " ", # message.content if (message.content is not None) else "", - "ToolCalls": dict_tool_calls - }) + "Arguments": tool_call.function.arguments + if (tool_call.function.arguments == "") + else "{}", + }, + } + for tool_call in tool_calls + ] + + dict_list.append( + { + "Role": message.role.value, + # fix set content = "" while tool_call request + # fix [hunyuan] None, [TencentCloudSDKException] code:InvalidParameter + # message:Messages Content and Contents not allowed empty at the same time. + "Content": " ", # message.content if (message.content is not None) else "", + "ToolCalls": dict_tool_calls, + } + ) else: - dict_list.append({ "Role": message.role.value, "Content": message.content }) + dict_list.append({"Role": message.role.value, "Content": message.content}) elif isinstance(message, ToolPromptMessage): - tool_execute_result = { "result": message.content } - content =json.dumps(tool_execute_result, ensure_ascii=False) - dict_list.append({ "Role": message.role.value, "Content": content, "ToolCallId": message.tool_call_id }) + tool_execute_result = {"result": message.content} + content = json.dumps(tool_execute_result, ensure_ascii=False) + dict_list.append({"Role": message.role.value, "Content": content, "ToolCallId": message.tool_call_id}) + elif isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + dict_list.append({"Role": message.role.value, "Content": message.content}) + else: + sub_messages = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(TextPromptMessageContent, message_content) + sub_message_dict = {"Type": "text", "Text": message_content.data} + sub_messages.append(sub_message_dict) + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, message_content) + sub_message_dict = { + "Type": "image_url", + "ImageUrl": {"Url": message_content.data}, + } + sub_messages.append(sub_message_dict) + dict_list.append({"Role": message.role.value, "Contents": sub_messages}) else: - dict_list.append({ "Role": message.role.value, "Content": message.content }) + dict_list.append({"Role": message.role.value, "Content": message.content}) return dict_list def _handle_stream_chat_response(self, model, credentials, prompt_messages, resp): - tool_call = None tool_calls = [] for index, event in enumerate(resp): logging.debug("_handle_stream_chat_response, event: %s", event) - data_str = event['data'] + data_str = event["data"] data = json.loads(data_str) - choices = data.get('Choices', []) + choices = data.get("Choices", []) if not choices: continue choice = choices[0] - delta = choice.get('Delta', {}) - message_content = delta.get('Content', '') - finish_reason = choice.get('FinishReason', '') + delta = choice.get("Delta", {}) + message_content = delta.get("Content", "") + finish_reason = choice.get("FinishReason", "") - usage = data.get('Usage', {}) - prompt_tokens = usage.get('PromptTokens', 0) - completion_tokens = usage.get('CompletionTokens', 0) + usage = data.get("Usage", {}) + prompt_tokens = usage.get("PromptTokens", 0) + completion_tokens = usage.get("CompletionTokens", 0) - response_tool_calls = delta.get('ToolCalls') - if (response_tool_calls is not None): + response_tool_calls = delta.get("ToolCalls") + if response_tool_calls is not None: new_tool_calls = self._extract_response_tool_calls(response_tool_calls) - if (len(new_tool_calls) > 0): + if len(new_tool_calls) > 0: new_tool_call = new_tool_calls[0] - if (tool_call is None): tool_call = new_tool_call - elif (tool_call.id != new_tool_call.id): + if tool_call is None: + tool_call = new_tool_call + elif tool_call.id != new_tool_call.id: tool_calls.append(tool_call) tool_call = new_tool_call else: tool_call.function.name += new_tool_call.function.name tool_call.function.arguments += new_tool_call.function.arguments - if (tool_call is not None and len(tool_call.function.name) > 0 and len(tool_call.function.arguments) > 0): + if tool_call is not None and len(tool_call.function.name) > 0 and len(tool_call.function.arguments) > 0: tool_calls.append(tool_call) tool_call = None - assistant_prompt_message = AssistantPromptMessage( - content=message_content, - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=message_content, tool_calls=[]) # rewrite content = "" while tool_call to avoid show content on web page - if (len(tool_calls) > 0): assistant_prompt_message.content = "" - + if len(tool_calls) > 0: + assistant_prompt_message.content = "" + # add tool_calls to assistant_prompt_message - if (finish_reason == 'tool_calls'): + if finish_reason == "tool_calls": assistant_prompt_message.tool_calls = tool_calls tool_call = None tool_calls = [] - if (len(finish_reason) > 0): + if len(finish_reason) > 0: usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) delta_chunk = LLMResultChunkDelta( index=index, - role=delta.get('Role', 'assistant'), + role=delta.get("Role", "assistant"), message=assistant_prompt_message, usage=usage, finish_reason=finish_reason, @@ -212,8 +246,9 @@ def _handle_stream_chat_response(self, model, credentials, prompt_messages, resp ) def _handle_chat_response(self, credentials, model, prompt_messages, response): - usage = self._calc_response_usage(model, credentials, response.Usage.PromptTokens, - response.Usage.CompletionTokens) + usage = self._calc_response_usage( + model, credentials, response.Usage.PromptTokens, response.Usage.CompletionTokens + ) assistant_prompt_message = AssistantPromptMessage() assistant_prompt_message.content = response.Choices[0].Message.Content result = LLMResult( @@ -225,8 +260,13 @@ def _handle_chat_response(self, credentials, model, prompt_messages, response): return result - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: if len(prompt_messages) == 0: return 0 prompt = self._convert_messages_to_prompt(prompt_messages) @@ -241,10 +281,7 @@ def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() @@ -287,10 +324,8 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] return { InvokeError: [TencentCloudSDKException], } - - def _extract_response_tool_calls(self, - response_tool_calls: list[dict]) \ - -> list[AssistantPromptMessage.ToolCall]: + + def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -300,17 +335,14 @@ def _extract_response_tool_calls(self, tool_calls = [] if response_tool_calls: for response_tool_call in response_tool_calls: - response_function = response_tool_call.get('Function', {}) + response_function = response_tool_call.get("Function", {}) function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function.get('Name', ''), - arguments=response_function.get('Arguments', '') + name=response_function.get("Name", ""), arguments=response_function.get("Arguments", "") ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.get('Id', 0), - type='function', - function=function + id=response_tool_call.get("Id", 0), type="function", function=function ) tool_calls.append(tool_call) - return tool_calls \ No newline at end of file + return tool_calls diff --git a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py index 64d8dcf795f1c8..b6d857cb37cba0 100644 --- a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py @@ -9,6 +9,7 @@ from tencentcloud.common.profile.http_profile import HttpProfile from tencentcloud.hunyuan.v20230901 import hunyuan_client, models +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( @@ -19,14 +20,20 @@ logger = logging.getLogger(__name__) + class HunyuanTextEmbeddingModel(TextEmbeddingModel): """ Model class for Hunyuan text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -34,12 +41,13 @@ def _invoke(self, model: str, credentials: dict, :param credentials: model credentials :param texts: texts to embed :param user: unique user id + :param input_type: input type :return: embeddings result """ - if model != 'hunyuan-embedding': - raise ValueError('Invalid model name') - + if model != "hunyuan-embedding": + raise ValueError("Invalid model name") + client = self._setup_hunyuan_client(credentials) embeddings = [] @@ -47,9 +55,7 @@ def _invoke(self, model: str, credentials: dict, for input in texts: request = models.GetEmbeddingRequest() - params = { - "Input": input - } + params = {"Input": input} request.from_json_string(json.dumps(params)) response = client.GetEmbedding(request) usage = response.Usage.TotalTokens @@ -60,11 +66,7 @@ def _invoke(self, model: str, credentials: dict, result = TextEmbeddingResult( model=model, embeddings=embeddings, - usage=self._calc_response_usage( - model=model, - credentials=credentials, - tokens=token_usage - ) + usage=self._calc_response_usage(model=model, credentials=credentials, tokens=token_usage), ) return result @@ -79,22 +81,19 @@ def validate_credentials(self, model: str, credentials: dict) -> None: req = models.ChatCompletionsRequest() params = { "Model": model, - "Messages": [{ - "Role": "user", - "Content": "hello" - }], + "Messages": [{"Role": "user", "Content": "hello"}], "TopP": 1, "Temperature": 0, - "Stream": False + "Stream": False, } req.from_json_string(json.dumps(params)) client.ChatCompletions(req) except Exception as e: - raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") def _setup_hunyuan_client(self, credentials): - secret_id = credentials['secret_id'] - secret_key = credentials['secret_key'] + secret_id = credentials["secret_id"] + secret_key = credentials["secret_key"] cred = credential.Credential(secret_id, secret_key) httpProfile = HttpProfile() httpProfile.endpoint = "hunyuan.tencentcloudapi.com" @@ -102,7 +101,7 @@ def _setup_hunyuan_client(self, credentials): clientProfile.httpProfile = httpProfile client = hunyuan_client.HunyuanClient(cred, "", clientProfile) return client - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -114,10 +113,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -128,11 +124,11 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -146,7 +142,7 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] return { InvokeError: [TencentCloudSDKException], } - + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ Get number of tokens for given prompt messages @@ -170,4 +166,4 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int # response = client.GetTokenCount(request) # num_tokens += response.TokenCount - return num_tokens \ No newline at end of file + return num_tokens diff --git a/api/core/model_runtime/model_providers/jina/jina.py b/api/core/model_runtime/model_providers/jina/jina.py index cde4313495b4a8..186a0a0fa7e6dc 100644 --- a/api/core/model_runtime/model_providers/jina/jina.py +++ b/api/core/model_runtime/model_providers/jina/jina.py @@ -8,7 +8,6 @@ class JinaProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,14 +18,11 @@ def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.TEXT_EMBEDDING) - # Use `jina-embeddings-v2-base-en` model for validate, + # Use `jina-embeddings-v3` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='jina-embeddings-v2-base-en', - credentials=credentials - ) + model_instance.validate_credentials(model="jina-embeddings-v3", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/jina/jina.yaml b/api/core/model_runtime/model_providers/jina/jina.yaml index 23e18ad75f6885..970b22965b5d29 100644 --- a/api/core/model_runtime/model_providers/jina/jina.yaml +++ b/api/core/model_runtime/model_providers/jina/jina.yaml @@ -1,6 +1,6 @@ provider: jina label: - en_US: Jina + en_US: Jina AI description: en_US: Embedding and Rerank Model Supported icon_small: @@ -11,7 +11,7 @@ background: "#EFFDFD" help: title: en_US: Get your API key from Jina AI - zh_Hans: 从 Jina 获取 API Key + zh_Hans: 从 Jina AI 获取 API Key url: en_US: https://jina.ai/ supported_model_types: diff --git a/api/core/model_runtime/model_providers/jina/rerank/rerank.py b/api/core/model_runtime/model_providers/jina/rerank/rerank.py index de7e038b9f31a6..03502076517f70 100644 --- a/api/core/model_runtime/model_providers/jina/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/jina/rerank/rerank.py @@ -22,9 +22,16 @@ class JinaRerankModel(RerankModel): Model class for Jina rerank model. """ - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -40,37 +47,39 @@ def _invoke(self, model: str, credentials: dict, if len(docs) == 0: return RerankResult(model=model, docs=[]) - base_url = credentials.get('base_url', 'https://api.jina.ai/v1') - if base_url.endswith('/'): - base_url = base_url[:-1] + base_url = credentials.get("base_url", "https://api.jina.ai/v1") + base_url = base_url.removesuffix("/") try: response = httpx.post( - base_url + '/rerank', - json={ - "model": model, - "query": query, - "documents": docs, - "top_n": top_n - }, - headers={"Authorization": f"Bearer {credentials.get('api_key')}"} + base_url + "/rerank", + json={"model": model, "query": query, "documents": docs, "top_n": top_n}, + headers={"Authorization": f"Bearer {credentials.get('api_key')}"}, ) - response.raise_for_status() + response.raise_for_status() results = response.json() rerank_documents = [] - for result in results['results']: + for result in results["results"]: + index = result["index"] + if "document" in result: + text = result["document"]["text"] + else: + # llama.cpp rerank maynot return original documents + text = docs[index] + rerank_document = RerankDocument( - index=result['index'], - text=result['document']['text'], - score=result['relevance_score'], + index=index, + text=text, + score=result["relevance_score"], ) - if score_threshold is None or result['relevance_score'] >= score_threshold: + + if score_threshold is None or result["relevance_score"] >= score_threshold: rerank_documents.append(rerank_document) return RerankResult(model=model, docs=rerank_documents) except httpx.HTTPStatusError as e: - raise InvokeServerUnavailableError(str(e)) + raise InvokeServerUnavailableError(str(e)) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -81,7 +90,6 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - self._invoke( model=model, credentials=credentials, @@ -92,7 +100,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -105,23 +113,21 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] return { InvokeConnectionError: [httpx.ConnectError], InvokeServerUnavailableError: [httpx.RemoteProtocolError], - InvokeRateLimitError: [], - InvokeAuthorizationError: [httpx.HTTPStatusError], - InvokeBadRequestError: [httpx.RequestError] + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], } def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, label=I18nObject(en_US=model), model_type=ModelType.RERANK, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')) - } + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/jina-embeddings-v3.yaml b/api/core/model_runtime/model_providers/jina/text_embedding/jina-embeddings-v3.yaml new file mode 100644 index 00000000000000..4e5374dc9d733d --- /dev/null +++ b/api/core/model_runtime/model_providers/jina/text_embedding/jina-embeddings-v3.yaml @@ -0,0 +1,9 @@ +model: jina-embeddings-v3 +model_type: text-embedding +model_properties: + context_size: 8192 + max_chunks: 2048 +pricing: + input: '0.001' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py index 50f8c73ed9e929..d80cbfa83d6425 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py @@ -14,19 +14,19 @@ def _get_tokenizer(cls): with cls._lock: if cls._tokenizer is None: base_path = abspath(__file__) - gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer') + gpt2_tokenizer_path = join(dirname(base_path), "tokenizer") cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path) return cls._tokenizer @classmethod def _get_num_tokens_by_jina_base(cls, text: str) -> int: """ - use jina tokenizer to get num tokens + use jina tokenizer to get num tokens """ tokenizer = cls._get_tokenizer() tokens = tokenizer.encode(text) return len(tokens) - + @classmethod def get_num_tokens(cls, text: str) -> int: - return cls._get_num_tokens_by_jina_base(text) \ No newline at end of file + return cls._get_num_tokens_by_jina_base(text) diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py index 23203491e656fe..49c558f4a44ffa 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py @@ -4,6 +4,7 @@ from requests import post +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult @@ -24,11 +25,41 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): """ Model class for Jina text embedding model. """ - api_base: str = 'https://api.jina.ai/v1' - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + api_base: str = "https://api.jina.ai/v1" + + def _to_payload(self, model: str, texts: list[str], credentials: dict, input_type: EmbeddingInputType) -> dict: + """ + Parse model credentials + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: parsed credentials + """ + + def transform_jina_input_text(model, text): + if model == "jina-clip-v1": + return {"text": text} + return text + + data = {"model": model, "input": [transform_jina_input_text(model, text) for text in texts]} + + # model specific parameters + if model == "jina-embeddings-v3": + # set `task` type according to input type for the best performance + data["task"] = "retrieval.query" if input_type == EmbeddingInputType.QUERY else "retrieval.passage" + + return data + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -36,31 +67,20 @@ def _invoke(self, model: str, credentials: dict, :param credentials: model credentials :param texts: texts to embed :param user: unique user id + :param input_type: input type :return: embeddings result """ - api_key = credentials['api_key'] + api_key = credentials["api_key"] if not api_key: - raise CredentialsValidateFailedError('api_key is required') + raise CredentialsValidateFailedError("api_key is required") - base_url = credentials.get('base_url', self.api_base) - if base_url.endswith('/'): - base_url = base_url[:-1] + base_url = credentials.get("base_url", self.api_base) + base_url = base_url.removesuffix("/") - url = base_url + '/embeddings' - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + url = base_url + "/embeddings" + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} - def transform_jina_input_text(model, text): - if model == 'jina-clip-v1': - return {"text": text} - return text - - data = { - 'model': model, - 'input': [transform_jina_input_text(model, text) for text in texts] - } + data = self._to_payload(model=model, texts=texts, credentials=credentials, input_type=input_type) try: response = post(url, headers=headers, data=dumps(data)) @@ -70,7 +90,7 @@ def transform_jina_input_text(model, text): if response.status_code != 200: try: resp = response.json() - msg = resp['detail'] + msg = resp["detail"] if response.status_code == 401: raise InvokeAuthorizationError(msg) elif response.status_code == 429: @@ -81,25 +101,20 @@ def transform_jina_input_text(model, text): raise InvokeBadRequestError(msg) except JSONDecodeError as e: raise InvokeServerUnavailableError( - f"Failed to convert response to json: {e} with text: {response.text}") + f"Failed to convert response to json: {e} with text: {response.text}" + ) try: resp = response.json() - embeddings = resp['data'] - usage = resp['usage'] + embeddings = resp["data"] + usage = resp["usage"] except Exception as e: - raise InvokeServerUnavailableError( - f"Failed to convert response to json: {e} with text: {response.text}") + raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") - usage = self._calc_response_usage( - model=model, credentials=credentials, tokens=usage['total_tokens']) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) result = TextEmbeddingResult( - model=model, - embeddings=[[ - float(data) for data in x['embedding'] - ] for x in embeddings], - usage=usage + model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage ) return result @@ -128,30 +143,18 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except Exception as e: - raise CredentialsValidateFailedError( - f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError, - InvokeBadRequestError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError, InvokeBadRequestError], } def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: @@ -165,10 +168,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -179,24 +179,21 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, label=I18nObject(en_US=model), model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int( - credentials.get('context_size')) - } + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, ) return entity diff --git a/api/core/model_runtime/model_providers/leptonai/leptonai.py b/api/core/model_runtime/model_providers/leptonai/leptonai.py index b035c31ac51453..34a55ff1924cf8 100644 --- a/api/core/model_runtime/model_providers/leptonai/leptonai.py +++ b/api/core/model_runtime/model_providers/leptonai/leptonai.py @@ -6,8 +6,8 @@ logger = logging.getLogger(__name__) -class LeptonAIProvider(ModelProvider): +class LeptonAIProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -18,12 +18,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='llama2-7b', - credentials=credentials - ) + model_instance.validate_credentials(model="llama2-7b", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/leptonai/llm/llm.py b/api/core/model_runtime/model_providers/leptonai/llm/llm.py index 523309bac579a3..3d69417e45da72 100644 --- a/api/core/model_runtime/model_providers/leptonai/llm/llm.py +++ b/api/core/model_runtime/model_providers/leptonai/llm/llm.py @@ -8,18 +8,25 @@ class LeptonAILargeLanguageModel(OAIAPICompatLargeLanguageModel): MODEL_PREFIX_MAP = { - 'llama2-7b': 'llama2-7b', - 'gemma-7b': 'gemma-7b', - 'mistral-7b': 'mistral-7b', - 'mixtral-8x7b': 'mixtral-8x7b', - 'llama3-70b': 'llama3-70b', - 'llama2-13b': 'llama2-13b', - } - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + "llama2-7b": "llama2-7b", + "gemma-7b": "gemma-7b", + "mistral-7b": "mistral-7b", + "mixtral-8x7b": "mixtral-8x7b", + "llama3-70b": "llama3-70b", + "llama2-13b": "llama2-13b", + } + + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials, model) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -29,6 +36,5 @@ def validate_credentials(self, model: str, credentials: dict) -> None: @classmethod def _add_custom_parameters(cls, credentials: dict, model: str) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = f'https://{cls.MODEL_PREFIX_MAP[model]}.lepton.run/api/v1' - \ No newline at end of file + credentials["mode"] = "chat" + credentials["endpoint_url"] = f"https://{cls.MODEL_PREFIX_MAP[model]}.lepton.run/api/v1" diff --git a/api/core/model_runtime/model_providers/localai/llm/llm.py b/api/core/model_runtime/model_providers/localai/llm/llm.py index 1009995c5868a1..756c5571d404bf 100644 --- a/api/core/model_runtime/model_providers/localai/llm/llm.py +++ b/api/core/model_runtime/model_providers/localai/llm/llm.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import cast +from typing import Optional, cast from httpx import Timeout from openai import ( @@ -52,29 +52,48 @@ class LocalAILanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, - model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) - - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + return self._generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ) + + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: # tools is not supported yet return self._num_tokens_from_messages(prompt_messages, tools=tools) def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ - Calculate num tokens for baichuan model - LocalAI does not supports + Calculate num tokens for baichuan model + LocalAI does not supports """ def tokens(text: str): """ - We could not determine which tokenizer to use, cause the model is customized. - So we use gpt2 tokenizer to calculate the num tokens for convenience. + We could not determine which tokenizer to use, cause the model is customized. + So we use gpt2 tokenizer to calculate the num tokens for convenience. """ return self._get_num_tokens_by_gpt2(text) @@ -87,10 +106,10 @@ def tokens(text: str): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -142,30 +161,30 @@ def tokens(text: str): num_tokens = 0 for tool in tools: # calculate num tokens for function object - num_tokens += tokens('name') + num_tokens += tokens("name") num_tokens += tokens(tool.name) - num_tokens += tokens('description') + num_tokens += tokens("description") num_tokens += tokens(tool.description) parameters = tool.parameters - num_tokens += tokens('parameters') - num_tokens += tokens('type') + num_tokens += tokens("parameters") + num_tokens += tokens("type") num_tokens += tokens(parameters.get("type")) - if 'properties' in parameters: - num_tokens += tokens('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += tokens("properties") + for key, value in parameters.get("properties").items(): num_tokens += tokens(key) for field_key, field_value in value.items(): num_tokens += tokens(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += tokens(enum_field) else: num_tokens += tokens(field_key) num_tokens += tokens(str(field_value)) - if 'required' in parameters: - num_tokens += tokens('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += tokens("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += tokens(required_field) @@ -180,102 +199,104 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - self._invoke(model=model, credentials=credentials, prompt_messages=[ - UserPromptMessage(content='ping') - ], model_parameters={ - 'max_tokens': 10, - }, stop=[], stream=False) + self._invoke( + model=model, + credentials=credentials, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={ + "max_tokens": 10, + }, + stop=[], + stream=False, + ) except Exception as ex: - raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}') + raise CredentialsValidateFailedError(f"Invalid credentials {str(ex)}") - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: completion_model = None - if credentials['completion_type'] == 'chat_completion': + if credentials["completion_type"] == "chat_completion": completion_model = LLMMode.CHAT.value - elif credentials['completion_type'] == 'completion': + elif credentials["completion_type"] == "completion": completion_model = LLMMode.COMPLETION.value else: raise ValueError(f"Unknown completion type {credentials['completion_type']}") rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ) + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, max=2048, default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), ] - model_properties = { - ModelPropertyKey.MODE: completion_model, - } if completion_model else {} + model_properties = ( + { + ModelPropertyKey.MODE: completion_model, + } + if completion_model + else {} + ) - model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get('context_size', '2048')) + model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get("context_size", "2048")) entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, model_properties=model_properties, - parameter_rules=rules + parameter_rules=rules, ) return entity - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: kwargs = self._to_client_kwargs(credentials) # init model client client = OpenAI(**kwargs) model_name = model - completion_type = credentials['completion_type'] + completion_type = credentials["completion_type"] extra_model_kwargs = { "timeout": 60, } if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user if tools and len(tools) > 0: - extra_model_kwargs['functions'] = [ - helper.dump_model(tool) for tool in tools - ] + extra_model_kwargs["functions"] = [helper.dump_model(tool) for tool in tools] - if completion_type == 'chat_completion': + if completion_type == "chat_completion": result = client.chat.completions.create( messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], model=model_name, @@ -283,36 +304,32 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM **model_parameters, **extra_model_kwargs, ) - elif completion_type == 'completion': + elif completion_type == "completion": result = client.completions.create( prompt=self._convert_prompt_message_to_completion_prompts(prompt_messages), model=model, stream=stream, **model_parameters, - **extra_model_kwargs + **extra_model_kwargs, ) else: raise ValueError(f"Unknown completion type {completion_type}") if stream: - if completion_type == 'completion': + if completion_type == "completion": return self._handle_completion_generate_stream_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) return self._handle_chat_generate_stream_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) - if completion_type == 'completion': + if completion_type == "completion": return self._handle_completion_generate_response( - model=model, credentials=credentials, response=result, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, prompt_messages=prompt_messages ) return self._handle_chat_generate_response( - model=model, credentials=credentials, response=result, tools=tools, - prompt_messages=prompt_messages + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) def _to_client_kwargs(self, credentials: dict) -> dict: @@ -322,13 +339,13 @@ def _to_client_kwargs(self, credentials: dict) -> dict: :param credentials: credentials dict :return: client kwargs """ - if not credentials['server_url'].endswith('/'): - credentials['server_url'] += '/' + if not credentials["server_url"].endswith("/"): + credentials["server_url"] += "/" client_kwargs = { "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "api_key": "1", - "base_url": str(URL(credentials['server_url']) / 'v1'), + "base_url": str(URL(credentials["server_url"]) / "v1"), } return client_kwargs @@ -349,7 +366,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: if message.tool_calls and len(message.tool_calls) > 0: message_dict["function_call"] = { "name": message.tool_calls[0].function.name, - "arguments": message.tool_calls[0].function.arguments + "arguments": message.tool_calls[0].function.arguments, } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) @@ -359,11 +376,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: message = cast(ToolPromptMessage, message) message_dict = { "role": "user", - "content": [{ - "type": "tool_result", - "tool_use_id": message.tool_call_id, - "content": message.content - }] + "content": [{"type": "tool_result", "tool_use_id": message.tool_call_id, "content": message.content}], } else: raise ValueError(f"Unknown message type {type(message)}") @@ -374,27 +387,29 @@ def _convert_prompt_message_to_completion_prompts(self, messages: list[PromptMes """ Convert PromptMessage to completion prompts """ - prompts = '' + prompts = "" for message in messages: if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) - prompts += f'{message.content}\n' + prompts += f"{message.content}\n" elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) - prompts += f'{message.content}\n' + prompts += f"{message.content}\n" elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - prompts += f'{message.content}\n' + prompts += f"{message.content}\n" else: raise ValueError(f"Unknown message type {type(message)}") return prompts - def _handle_completion_generate_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Completion, - ) -> LLMResult: + def _handle_completion_generate_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Completion, + ) -> LLMResult: """ Handle llm chat response @@ -411,18 +426,16 @@ def _handle_completion_generate_response(self, model: str, assistant_message = response.choices[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message, - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message, tool_calls=[]) prompt_tokens = self._get_num_tokens_by_gpt2( self._convert_prompt_message_to_completion_prompts(prompt_messages) ) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) response = LLMResult( model=model, @@ -434,11 +447,14 @@ def _handle_completion_generate_response(self, model: str, return response - def _handle_chat_generate_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: ChatCompletion, - tools: list[PromptMessageTool]) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: ChatCompletion, + tools: list[PromptMessageTool], + ) -> LLMResult: """ Handle llm chat response @@ -459,16 +475,14 @@ def _handle_chat_generate_response(self, model: str, tool_calls = self._extract_response_tool_calls([function_calls] if function_calls else []) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) response = LLMResult( model=model, @@ -480,12 +494,15 @@ def _handle_chat_generate_response(self, model: str, return response - def _handle_completion_generate_stream_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Stream[Completion], - tools: list[PromptMessageTool]) -> Generator: - full_response = '' + def _handle_completion_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Stream[Completion], + tools: list[PromptMessageTool], + ) -> Generator: + full_response = "" for chunk in response: if len(chunk.choices) == 0: @@ -494,17 +511,11 @@ def _handle_completion_generate_stream_response(self, model: str, delta = chunk.choices[0] # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta.text if delta.text else '', - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=delta.text or "", tool_calls=[]) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage - temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=[] - ) + temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[]) prompt_tokens = self._get_num_tokens_by_gpt2( self._convert_prompt_message_to_completion_prompts(prompt_messages) @@ -512,8 +523,12 @@ def _handle_completion_generate_stream_response(self, model: str, completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) yield LLMResultChunk( model=model, @@ -523,7 +538,7 @@ def _handle_completion_generate_stream_response(self, model: str, index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage + usage=usage, ), ) else: @@ -539,12 +554,15 @@ def _handle_completion_generate_stream_response(self, model: str, full_response += delta.text - def _handle_chat_generate_stream_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Stream[ChatCompletionChunk], - tools: list[PromptMessageTool]) -> Generator: - full_response = '' + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Stream[ChatCompletionChunk], + tools: list[PromptMessageTool], + ) -> Generator: + full_response = "" for chunk in response: if len(chunk.choices) == 0: @@ -552,7 +570,7 @@ def _handle_chat_generate_stream_response(self, model: str, delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""): continue # check if there is a tool call in the response @@ -560,26 +578,28 @@ def _handle_chat_generate_stream_response(self, model: str, if delta.delta.function_call: function_calls = [delta.delta.function_call] - assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else []) + assistant_message_tool_calls = self._extract_response_tool_calls(function_calls or []) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_message_tool_calls + content=delta.delta.content or "", tool_calls=assistant_message_tool_calls ) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=assistant_message_tool_calls + content=full_response, tool_calls=assistant_message_tool_calls ) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) yield LLMResultChunk( model=model, @@ -589,7 +609,7 @@ def _handle_chat_generate_stream_response(self, model: str, index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage + usage=usage, ), ) else: @@ -605,9 +625,9 @@ def _handle_chat_generate_stream_response(self, model: str, full_response += delta.delta.content - def _extract_response_tool_calls(self, - response_function_calls: list[FunctionCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, response_function_calls: list[FunctionCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -618,15 +638,10 @@ def _extract_response_tool_calls(self, if response_function_calls: for response_tool_call in response_function_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.name, - arguments=response_tool_call.arguments + name=response_tool_call.name, arguments=response_tool_call.arguments ) - tool_call = AssistantPromptMessage.ToolCall( - id=0, - type='function', - function=function - ) + tool_call = AssistantPromptMessage.ToolCall(id=0, type="function", function=function) tool_calls.append(tool_call) return tool_calls @@ -651,15 +666,9 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] ConflictError, NotFoundError, UnprocessableEntityError, - PermissionDeniedError - ], - InvokeRateLimitError: [ - RateLimitError - ], - InvokeAuthorizationError: [ - AuthenticationError + PermissionDeniedError, ], - InvokeBadRequestError: [ - ValueError - ] + InvokeRateLimitError: [RateLimitError], + InvokeAuthorizationError: [AuthenticationError], + InvokeBadRequestError: [ValueError], } diff --git a/api/core/model_runtime/model_providers/localai/localai.py b/api/core/model_runtime/model_providers/localai/localai.py index 6d2278fd541b1f..4ff898052b380d 100644 --- a/api/core/model_runtime/model_providers/localai/localai.py +++ b/api/core/model_runtime/model_providers/localai/localai.py @@ -6,6 +6,5 @@ class LocalAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/localai/rerank/rerank.py b/api/core/model_runtime/model_providers/localai/rerank/rerank.py index c8ba9a6c7c1c8d..075b44658dcc33 100644 --- a/api/core/model_runtime/model_providers/localai/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/localai/rerank/rerank.py @@ -25,9 +25,16 @@ class LocalaiRerankModel(RerankModel): LocalAI rerank model API is compatible with Jina rerank model API. So just copy the JinaRerankModel class code here. """ - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -43,45 +50,45 @@ def _invoke(self, model: str, credentials: dict, if len(docs) == 0: return RerankResult(model=model, docs=[]) - server_url = credentials['server_url'] + server_url = credentials["server_url"] model_name = model - + if not server_url: - raise CredentialsValidateFailedError('server_url is required') + raise CredentialsValidateFailedError("server_url is required") if not model_name: - raise CredentialsValidateFailedError('model_name is required') - + raise CredentialsValidateFailedError("model_name is required") + url = server_url - headers = { - 'Authorization': f"Bearer {credentials.get('api_key')}", - 'Content-Type': 'application/json' - } + headers = {"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"} - data = { - "model": model_name, - "query": query, - "documents": docs, - "top_n": top_n - } + data = {"model": model_name, "query": query, "documents": docs, "top_n": top_n} try: - response = post(str(URL(url) / 'rerank'), headers=headers, data=dumps(data), timeout=10) - response.raise_for_status() + response = post(str(URL(url) / "rerank"), headers=headers, data=dumps(data), timeout=10) + response.raise_for_status() results = response.json() rerank_documents = [] - for result in results['results']: + for result in results["results"]: + index = result["index"] + if "document" in result: + text = result["document"]["text"] + else: + # llama.cpp rerank maynot return original documents + text = docs[index] + rerank_document = RerankDocument( - index=result['index'], - text=result['document']['text'], - score=result['relevance_score'], + index=index, + text=text, + score=result["relevance_score"], ) - if score_threshold is None or result['relevance_score'] >= score_threshold: + + if score_threshold is None or result["relevance_score"] >= score_threshold: rerank_documents.append(rerank_document) return RerankResult(model=model, docs=rerank_documents) except httpx.HTTPStatusError as e: - raise InvokeServerUnavailableError(str(e)) + raise InvokeServerUnavailableError(str(e)) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -92,7 +99,6 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - self._invoke( model=model, credentials=credentials, @@ -103,7 +109,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -116,21 +122,21 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] return { InvokeConnectionError: [httpx.ConnectError], InvokeServerUnavailableError: [httpx.RemoteProtocolError], - InvokeRateLimitError: [], - InvokeAuthorizationError: [httpx.HTTPStatusError], - InvokeBadRequestError: [httpx.RequestError] + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], } - + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, label=I18nObject(en_US=model), model_type=ModelType.RERANK, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={} + model_properties={}, ) return entity diff --git a/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py index d7403aff4ffa7f..260a35d9d4f13c 100644 --- a/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/localai/speech2text/speech2text.py @@ -32,8 +32,8 @@ def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional :param user: unique user id :return: text for given audio file """ - - url = str(URL(credentials['server_url']) / "v1/audio/transcriptions") + + url = str(URL(credentials["server_url"]) / "v1/audio/transcriptions") data = {"model": model} files = {"file": file} @@ -42,7 +42,7 @@ def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional prepared_request = session.prepare_request(request) response = session.send(prepared_request) - if 'error' in response.json(): + if "error" in response.json(): raise InvokeServerUnavailableError("Empty response") return response.json()["text"] @@ -58,7 +58,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: try: audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self._invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -66,36 +66,24 @@ def validate_credentials(self, model: str, credentials: dict) -> None: @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError - ], + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError], } - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.SPEECH2TEXT, model_properties={}, - parameter_rules=[] + parameter_rules=[], ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py index 954c9d10f2a67f..0111c3362a5929 100644 --- a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py @@ -5,6 +5,7 @@ from requests import post from yarl import URL +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult @@ -22,11 +23,17 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): """ - Model class for Jina text embedding model. + Model class for LocalAI text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -34,42 +41,37 @@ def _invoke(self, model: str, credentials: dict, :param credentials: model credentials :param texts: texts to embed :param user: unique user id + :param input_type: input type :return: embeddings result """ if len(texts) != 1: - raise InvokeBadRequestError('Only one text is supported') + raise InvokeBadRequestError("Only one text is supported") - server_url = credentials['server_url'] + server_url = credentials["server_url"] model_name = model if not server_url: - raise CredentialsValidateFailedError('server_url is required') + raise CredentialsValidateFailedError("server_url is required") if not model_name: - raise CredentialsValidateFailedError('model_name is required') - + raise CredentialsValidateFailedError("model_name is required") + url = server_url - headers = { - 'Authorization': 'Bearer 123', - 'Content-Type': 'application/json' - } + headers = {"Authorization": "Bearer 123", "Content-Type": "application/json"} - data = { - 'model': model_name, - 'input': texts[0] - } + data = {"model": model_name, "input": texts[0]} try: - response = post(str(URL(url) / 'embeddings'), headers=headers, data=dumps(data), timeout=10) + response = post(str(URL(url) / "embeddings"), headers=headers, data=dumps(data), timeout=10) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: try: resp = response.json() - code = resp['error']['code'] - msg = resp['error']['message'] + code = resp["error"]["code"] + msg = resp["error"]["message"] if code == 500: raise InvokeServerUnavailableError(msg) - + if response.status_code == 401: raise InvokeAuthorizationError(msg) elif response.status_code == 429: @@ -79,23 +81,21 @@ def _invoke(self, model: str, credentials: dict, else: raise InvokeError(msg) except JSONDecodeError as e: - raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") + raise InvokeServerUnavailableError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) try: resp = response.json() - embeddings = resp['data'] - usage = resp['usage'] + embeddings = resp["data"] + usage = resp["usage"] except Exception as e: raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens']) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) result = TextEmbeddingResult( - model=model, - embeddings=[[ - float(data) for data in x['embedding'] - ] for x in embeddings], - usage=usage + model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage ) return result @@ -114,8 +114,8 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int # use GPT2Tokenizer to get num tokens num_tokens += self._get_num_tokens_by_gpt2(text) return num_tokens - - def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + + def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ Get customizable model schema @@ -130,10 +130,10 @@ def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIMod features=[], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', '512')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "512")), ModelPropertyKey.MAX_CHUNKS: 1, }, - parameter_rules=[] + parameter_rules=[], ) def validate_credentials(self, model: str, credentials: dict) -> None: @@ -145,32 +145,22 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError: - raise CredentialsValidateFailedError('Invalid credentials') + raise CredentialsValidateFailedError("Invalid credentials") except InvokeConnectionError as e: - raise CredentialsValidateFailedError(f'Invalid credentials: {e}') + raise CredentialsValidateFailedError(f"Invalid credentials: {e}") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -182,10 +172,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -196,7 +183,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/minimax/llm/abab6.5t-chat.yaml b/api/core/model_runtime/model_providers/minimax/llm/abab6.5t-chat.yaml new file mode 100644 index 00000000000000..cc8a3aa0ecdb70 --- /dev/null +++ b/api/core/model_runtime/model_providers/minimax/llm/abab6.5t-chat.yaml @@ -0,0 +1,44 @@ +model: abab6.5t-chat +label: + en_US: Abab6.5t-Chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + min: 0.01 + max: 1 + default: 0.9 + - name: top_p + use_template: top_p + min: 0.01 + max: 1 + default: 0.95 + - name: max_tokens + use_template: max_tokens + required: true + default: 3072 + min: 1 + max: 8192 + - name: mask_sensitive_info + type: boolean + default: true + label: + zh_Hans: 隐私保护 + en_US: Moderate + help: + zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码,目前包括但不限于邮箱、域名、链接、证件号、家庭住址等,默认true,即开启打码 + en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id.. + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty +pricing: + input: '0.005' + output: '0.005' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py index 6c41e0d2a5ed6b..88cc0e8e0f32d0 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py @@ -17,42 +17,48 @@ class MinimaxChatCompletion: """ - Minimax Chat Completion API + Minimax Chat Completion API """ - def generate(self, model: str, api_key: str, group_id: str, - prompt_messages: list[MinimaxMessage], model_parameters: dict, - tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \ - -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: + + def generate( + self, + model: str, + api_key: str, + group_id: str, + prompt_messages: list[MinimaxMessage], + model_parameters: dict, + tools: list[dict[str, Any]], + stop: list[str] | None, + stream: bool, + user: str, + ) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: """ - generate chat completion + generate chat completion """ if not api_key or not group_id: - raise InvalidAPIKeyError('Invalid API key or group ID') - - url = f'https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}' + raise InvalidAPIKeyError("Invalid API key or group ID") + + url = f"https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}" extra_kwargs = {} - if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int: - extra_kwargs['tokens_to_generate'] = model_parameters['max_tokens'] + if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int: + extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"] - if 'temperature' in model_parameters and type(model_parameters['temperature']) == float: - extra_kwargs['temperature'] = model_parameters['temperature'] + if "temperature" in model_parameters and type(model_parameters["temperature"]) == float: + extra_kwargs["temperature"] = model_parameters["temperature"] - if 'top_p' in model_parameters and type(model_parameters['top_p']) == float: - extra_kwargs['top_p'] = model_parameters['top_p'] + if "top_p" in model_parameters and type(model_parameters["top_p"]) == float: + extra_kwargs["top_p"] = model_parameters["top_p"] - prompt = '你是一个什么都懂的专家' + prompt = "你是一个什么都懂的专家" - role_meta = { - 'user_name': '我', - 'bot_name': '专家' - } + role_meta = {"user_name": "我", "bot_name": "专家"} # check if there is a system message if len(prompt_messages) == 0: - raise BadRequestError('At least one message is required') - + raise BadRequestError("At least one message is required") + if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value: if prompt_messages[0].content: prompt = prompt_messages[0].content @@ -60,44 +66,43 @@ def generate(self, model: str, api_key: str, group_id: str, # check if there is a user message if len(prompt_messages) == 0: - raise BadRequestError('At least one user message is required') - - messages = [{ - 'sender_type': message.role, - 'text': message.content, - } for message in prompt_messages] - - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + raise BadRequestError("At least one user message is required") + + messages = [ + { + "sender_type": message.role, + "text": message.content, + } + for message in prompt_messages + ] + + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} body = { - 'model': model, - 'messages': messages, - 'prompt': prompt, - 'role_meta': role_meta, - 'stream': stream, - **extra_kwargs + "model": model, + "messages": messages, + "prompt": prompt, + "role_meta": role_meta, + "stream": stream, + **extra_kwargs, } try: - response = post( - url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) + response = post(url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) except Exception as e: raise InternalServerError(e) - + if response.status_code != 200: raise InternalServerError(response.text) - + if stream: return self._handle_stream_chat_generate_response(response) return self._handle_chat_generate_response(response) - + def _handle_error(self, code: int, msg: str): - if code == 1000 or code == 1001 or code == 1013 or code == 1027: + if code in {1000, 1001, 1013, 1027}: raise InternalServerError(msg) - elif code == 1002 or code == 1039: + elif code in {1002, 1039}: raise RateLimitReachedError(msg) elif code == 1004: raise InvalidAuthenticationError(msg) @@ -110,65 +115,52 @@ def _handle_error(self, code: int, msg: str): def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage: """ - handle chat generate response + handle chat generate response """ response = response.json() - if 'base_resp' in response and response['base_resp']['status_code'] != 0: - code = response['base_resp']['status_code'] - msg = response['base_resp']['status_msg'] + if "base_resp" in response and response["base_resp"]["status_code"] != 0: + code = response["base_resp"]["status_code"] + msg = response["base_resp"]["status_msg"] self._handle_error(code, msg) - - message = MinimaxMessage( - content=response['reply'], - role=MinimaxMessage.Role.ASSISTANT.value - ) + + message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value) message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': response['usage']['total_tokens'], - 'total_tokens': response['usage']['total_tokens'] + "prompt_tokens": 0, + "completion_tokens": response["usage"]["total_tokens"], + "total_tokens": response["usage"]["total_tokens"], } - message.stop_reason = response['choices'][0]['finish_reason'] + message.stop_reason = response["choices"][0]["finish_reason"] return message def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]: """ - handle stream chat generate response + handle stream chat generate response """ for line in response.iter_lines(): if not line: continue - line: str = line.decode('utf-8') - if line.startswith('data: '): + line: str = line.decode("utf-8") + if line.startswith("data: "): line = line[6:].strip() data = loads(line) - if 'base_resp' in data and data['base_resp']['status_code'] != 0: - code = data['base_resp']['status_code'] - msg = data['base_resp']['status_msg'] + if "base_resp" in data and data["base_resp"]["status_code"] != 0: + code = data["base_resp"]["status_code"] + msg = data["base_resp"]["status_msg"] self._handle_error(code, msg) - if data['reply']: - total_tokens = data['usage']['total_tokens'] - message = MinimaxMessage( - role=MinimaxMessage.Role.ASSISTANT.value, - content='' - ) - message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': total_tokens, - 'total_tokens': total_tokens - } - message.stop_reason = data['choices'][0]['finish_reason'] + if data["reply"]: + total_tokens = data["usage"]["total_tokens"] + message = MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content="") + message.usage = {"prompt_tokens": 0, "completion_tokens": total_tokens, "total_tokens": total_tokens} + message.stop_reason = data["choices"][0]["finish_reason"] yield message return - choices = data.get('choices', []) + choices = data.get("choices", []) if len(choices) == 0: continue for choice in choices: - message = choice['delta'] - yield MinimaxMessage( - content=message, - role=MinimaxMessage.Role.ASSISTANT.value - ) \ No newline at end of file + message = choice["delta"] + yield MinimaxMessage(content=message, role=MinimaxMessage.Role.ASSISTANT.value) diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py index 55747057c9ff3b..8b8fdbb6bdf558 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py @@ -17,86 +17,83 @@ class MinimaxChatCompletionPro: """ - Minimax Chat Completion Pro API, supports function calling - however, we do not have enough time and energy to implement it, but the parameters are reserved + Minimax Chat Completion Pro API, supports function calling + however, we do not have enough time and energy to implement it, but the parameters are reserved """ - def generate(self, model: str, api_key: str, group_id: str, - prompt_messages: list[MinimaxMessage], model_parameters: dict, - tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \ - -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: + + def generate( + self, + model: str, + api_key: str, + group_id: str, + prompt_messages: list[MinimaxMessage], + model_parameters: dict, + tools: list[dict[str, Any]], + stop: list[str] | None, + stream: bool, + user: str, + ) -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: """ - generate chat completion + generate chat completion """ if not api_key or not group_id: - raise InvalidAPIKeyError('Invalid API key or group ID') + raise InvalidAPIKeyError("Invalid API key or group ID") - url = f'https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}' + url = f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}" extra_kwargs = {} - if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int: - extra_kwargs['tokens_to_generate'] = model_parameters['max_tokens'] + if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int: + extra_kwargs["tokens_to_generate"] = model_parameters["max_tokens"] - if 'temperature' in model_parameters and type(model_parameters['temperature']) == float: - extra_kwargs['temperature'] = model_parameters['temperature'] + if "temperature" in model_parameters and type(model_parameters["temperature"]) == float: + extra_kwargs["temperature"] = model_parameters["temperature"] - if 'top_p' in model_parameters and type(model_parameters['top_p']) == float: - extra_kwargs['top_p'] = model_parameters['top_p'] + if "top_p" in model_parameters and type(model_parameters["top_p"]) == float: + extra_kwargs["top_p"] = model_parameters["top_p"] - if 'mask_sensitive_info' in model_parameters and type(model_parameters['mask_sensitive_info']) == bool: - extra_kwargs['mask_sensitive_info'] = model_parameters['mask_sensitive_info'] - - if model_parameters.get('plugin_web_search'): - extra_kwargs['plugins'] = [ - 'plugin_web_search' - ] + if "mask_sensitive_info" in model_parameters and type(model_parameters["mask_sensitive_info"]) == bool: + extra_kwargs["mask_sensitive_info"] = model_parameters["mask_sensitive_info"] - bot_setting = { - 'bot_name': '专家', - 'content': '你是一个什么都懂的专家' - } + if model_parameters.get("plugin_web_search"): + extra_kwargs["plugins"] = ["plugin_web_search"] - reply_constraints = { - 'sender_type': 'BOT', - 'sender_name': '专家' - } + bot_setting = {"bot_name": "专家", "content": "你是一个什么都懂的专家"} + + reply_constraints = {"sender_type": "BOT", "sender_name": "专家"} # check if there is a system message if len(prompt_messages) == 0: - raise BadRequestError('At least one message is required') + raise BadRequestError("At least one message is required") if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value: if prompt_messages[0].content: - bot_setting['content'] = prompt_messages[0].content + bot_setting["content"] = prompt_messages[0].content prompt_messages = prompt_messages[1:] # check if there is a user message if len(prompt_messages) == 0: - raise BadRequestError('At least one user message is required') + raise BadRequestError("At least one user message is required") messages = [message.to_dict() for message in prompt_messages] - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} body = { - 'model': model, - 'messages': messages, - 'bot_setting': [bot_setting], - 'reply_constraints': reply_constraints, - 'stream': stream, - **extra_kwargs + "model": model, + "messages": messages, + "bot_setting": [bot_setting], + "reply_constraints": reply_constraints, + "stream": stream, + **extra_kwargs, } if tools: - body['functions'] = tools - body['function_call'] = {'type': 'auto'} + body["functions"] = tools + body["function_call"] = {"type": "auto"} try: - response = post( - url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) + response = post(url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) except Exception as e: raise InternalServerError(e) @@ -108,9 +105,9 @@ def generate(self, model: str, api_key: str, group_id: str, return self._handle_chat_generate_response(response) def _handle_error(self, code: int, msg: str): - if code == 1000 or code == 1001 or code == 1013 or code == 1027: + if code in {1000, 1001, 1013, 1027}: raise InternalServerError(msg) - elif code == 1002 or code == 1039: + elif code in {1002, 1039}: raise RateLimitReachedError(msg) elif code == 1004: raise InvalidAuthenticationError(msg) @@ -123,78 +120,72 @@ def _handle_error(self, code: int, msg: str): def _handle_chat_generate_response(self, response: Response) -> MinimaxMessage: """ - handle chat generate response + handle chat generate response """ response = response.json() - if 'base_resp' in response and response['base_resp']['status_code'] != 0: - code = response['base_resp']['status_code'] - msg = response['base_resp']['status_msg'] + if "base_resp" in response and response["base_resp"]["status_code"] != 0: + code = response["base_resp"]["status_code"] + msg = response["base_resp"]["status_msg"] self._handle_error(code, msg) - message = MinimaxMessage( - content=response['reply'], - role=MinimaxMessage.Role.ASSISTANT.value - ) + message = MinimaxMessage(content=response["reply"], role=MinimaxMessage.Role.ASSISTANT.value) message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': response['usage']['total_tokens'], - 'total_tokens': response['usage']['total_tokens'] + "prompt_tokens": 0, + "completion_tokens": response["usage"]["total_tokens"], + "total_tokens": response["usage"]["total_tokens"], } - message.stop_reason = response['choices'][0]['finish_reason'] + message.stop_reason = response["choices"][0]["finish_reason"] return message def _handle_stream_chat_generate_response(self, response: Response) -> Generator[MinimaxMessage, None, None]: """ - handle stream chat generate response + handle stream chat generate response """ for line in response.iter_lines(): if not line: continue - line: str = line.decode('utf-8') - if line.startswith('data: '): + line: str = line.decode("utf-8") + if line.startswith("data: "): line = line[6:].strip() data = loads(line) - if 'base_resp' in data and data['base_resp']['status_code'] != 0: - code = data['base_resp']['status_code'] - msg = data['base_resp']['status_msg'] + if "base_resp" in data and data["base_resp"]["status_code"] != 0: + code = data["base_resp"]["status_code"] + msg = data["base_resp"]["status_msg"] self._handle_error(code, msg) # final chunk - if data['reply'] or data.get('usage'): - total_tokens = data['usage']['total_tokens'] - minimax_message = MinimaxMessage( - role=MinimaxMessage.Role.ASSISTANT.value, - content='' - ) + if data["reply"] or data.get("usage"): + total_tokens = data["usage"]["total_tokens"] + minimax_message = MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content="") minimax_message.usage = { - 'prompt_tokens': 0, - 'completion_tokens': total_tokens, - 'total_tokens': total_tokens + "prompt_tokens": 0, + "completion_tokens": total_tokens, + "total_tokens": total_tokens, } - minimax_message.stop_reason = data['choices'][0]['finish_reason'] + minimax_message.stop_reason = data["choices"][0]["finish_reason"] - choices = data.get('choices', []) + choices = data.get("choices", []) if len(choices) > 0: for choice in choices: - message = choice['messages'][0] + message = choice["messages"][0] # append function_call message - if 'function_call' in message: - function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value) - function_call_message.function_call = message['function_call'] + if "function_call" in message: + function_call_message = MinimaxMessage(content="", role=MinimaxMessage.Role.ASSISTANT.value) + function_call_message.function_call = message["function_call"] yield function_call_message yield minimax_message return # partial chunk - choices = data.get('choices', []) + choices = data.get("choices", []) if len(choices) == 0: continue for choice in choices: - message = choice['messages'][0] + message = choice["messages"][0] # append text message - if 'text' in message: - minimax_message = MinimaxMessage(content=message['text'], role=MinimaxMessage.Role.ASSISTANT.value) + if "text" in message: + minimax_message = MinimaxMessage(content=message["text"], role=MinimaxMessage.Role.ASSISTANT.value) yield minimax_message diff --git a/api/core/model_runtime/model_providers/minimax/llm/errors.py b/api/core/model_runtime/model_providers/minimax/llm/errors.py index d9d279e6ca0ed1..309b5cf413bd54 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/errors.py +++ b/api/core/model_runtime/model_providers/minimax/llm/errors.py @@ -1,17 +1,22 @@ class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass + class InsufficientAccountBalanceError(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/minimax/llm/llm.py b/api/core/model_runtime/model_providers/minimax/llm/llm.py index 1fab20ebbc4575..4250c40cfb94b1 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/llm.py +++ b/api/core/model_runtime/model_providers/minimax/llm/llm.py @@ -34,18 +34,25 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): model_apis = { - 'abab6.5s-chat': MinimaxChatCompletionPro, - 'abab6.5-chat': MinimaxChatCompletionPro, - 'abab6-chat': MinimaxChatCompletionPro, - 'abab5.5s-chat': MinimaxChatCompletionPro, - 'abab5.5-chat': MinimaxChatCompletionPro, - 'abab5-chat': MinimaxChatCompletion + "abab6.5s-chat": MinimaxChatCompletionPro, + "abab6.5-chat": MinimaxChatCompletionPro, + "abab6-chat": MinimaxChatCompletionPro, + "abab5.5s-chat": MinimaxChatCompletionPro, + "abab5.5-chat": MinimaxChatCompletionPro, + "abab5-chat": MinimaxChatCompletion, } - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def validate_credentials(self, model: str, credentials: dict) -> None: @@ -53,82 +60,97 @@ def validate_credentials(self, model: str, credentials: dict) -> None: Validate credentials for Baichuan model """ if model not in self.model_apis: - raise CredentialsValidateFailedError(f'Invalid model: {model}') + raise CredentialsValidateFailedError(f"Invalid model: {model}") - if not credentials.get('minimax_api_key'): - raise CredentialsValidateFailedError('Invalid API key') + if not credentials.get("minimax_api_key"): + raise CredentialsValidateFailedError("Invalid API key") + + if not credentials.get("minimax_group_id"): + raise CredentialsValidateFailedError("Invalid group ID") - if not credentials.get('minimax_group_id'): - raise CredentialsValidateFailedError('Invalid group ID') - # ping instance = MinimaxChatCompletionPro() try: instance.generate( - model=model, api_key=credentials['minimax_api_key'], group_id=credentials['minimax_group_id'], - prompt_messages=[ - MinimaxMessage(content='ping', role='USER') - ], + model=model, + api_key=credentials["minimax_api_key"], + group_id=credentials["minimax_group_id"], + prompt_messages=[MinimaxMessage(content="ping", role="USER")], model_parameters={}, - tools=[], stop=[], + tools=[], + stop=[], stream=False, - user='' + user="", ) except (InvalidAuthenticationError, InsufficientAccountBalanceError) as e: raise CredentialsValidateFailedError(f"Invalid API key: {e}") - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: return self._num_tokens_from_messages(prompt_messages, tools) def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ - Calculate num tokens for minimax model + Calculate num tokens for minimax model - not like ChatGLM, Minimax has a special prompt structure, we could not find a proper way - to caculate the num tokens, so we use str() to convert the prompt to string + not like ChatGLM, Minimax has a special prompt structure, we could not find a proper way + to calculate the num tokens, so we use str() to convert the prompt to string - Minimax does not provide their own tokenizer of adab5.5 and abab5 model - therefore, we use gpt2 tokenizer instead + Minimax does not provide their own tokenizer of adab5.5 and abab5 model + therefore, we use gpt2 tokenizer instead """ messages_dict = [self._convert_prompt_message_to_minimax_message(m).to_dict() for m in messages] return self._get_num_tokens_by_gpt2(str(messages_dict)) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface + use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface """ client: MinimaxChatCompletionPro = self.model_apis[model]() if tools: - tools = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] + tools = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} for tool in tools + ] response = client.generate( model=model, - api_key=credentials['minimax_api_key'], - group_id=credentials['minimax_group_id'], + api_key=credentials["minimax_api_key"], + group_id=credentials["minimax_group_id"], prompt_messages=[self._convert_prompt_message_to_minimax_message(message) for message in prompt_messages], model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, - user=user + user=user, ) if stream: - return self._handle_chat_generate_stream_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) - return self._handle_chat_generate_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) + return self._handle_chat_generate_stream_response( + model=model, prompt_messages=prompt_messages, credentials=credentials, response=response + ) + return self._handle_chat_generate_response( + model=model, prompt_messages=prompt_messages, credentials=credentials, response=response + ) def _convert_prompt_message_to_minimax_message(self, prompt_message: PromptMessage) -> MinimaxMessage: """ - convert PromptMessage to MinimaxMessage so that we can use MinimaxChatCompletionPro interface + convert PromptMessage to MinimaxMessage so that we can use MinimaxChatCompletionPro interface """ if isinstance(prompt_message, SystemPromptMessage): return MinimaxMessage(role=MinimaxMessage.Role.SYSTEM.value, content=prompt_message.content) @@ -136,26 +158,27 @@ def _convert_prompt_message_to_minimax_message(self, prompt_message: PromptMessa return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content) elif isinstance(prompt_message, AssistantPromptMessage): if prompt_message.tool_calls: - message = MinimaxMessage( - role=MinimaxMessage.Role.ASSISTANT.value, - content='' - ) - message.function_call={ - 'name': prompt_message.tool_calls[0].function.name, - 'arguments': prompt_message.tool_calls[0].function.arguments + message = MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content="") + message.function_call = { + "name": prompt_message.tool_calls[0].function.name, + "arguments": prompt_message.tool_calls[0].function.arguments, } return message return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content) elif isinstance(prompt_message, ToolPromptMessage): return MinimaxMessage(role=MinimaxMessage.Role.FUNCTION.value, content=prompt_message.content) else: - raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported') + raise NotImplementedError(f"Prompt message type {type(prompt_message)} is not supported") - def _handle_chat_generate_response(self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: MinimaxMessage) -> LLMResult: - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=response.usage['prompt_tokens'], - completion_tokens=response.usage['completion_tokens'] - ) + def _handle_chat_generate_response( + self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: MinimaxMessage + ) -> LLMResult: + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=response.usage["prompt_tokens"], + completion_tokens=response.usage["completion_tokens"], + ) return LLMResult( model=model, prompt_messages=prompt_messages, @@ -166,31 +189,33 @@ def _handle_chat_generate_response(self, model: str, prompt_messages: list[Promp usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, prompt_messages: list[PromptMessage], - credentials: dict, response: Generator[MinimaxMessage, None, None]) \ - -> Generator[LLMResultChunk, None, None]: + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Generator[MinimaxMessage, None, None], + ) -> Generator[LLMResultChunk, None, None]: for message in response: if message.usage: usage = self._calc_response_usage( - model=model, credentials=credentials, - prompt_tokens=message.usage['prompt_tokens'], - completion_tokens=message.usage['completion_tokens'] + model=model, + credentials=credentials, + prompt_tokens=message.usage["prompt_tokens"], + completion_tokens=message.usage["completion_tokens"], ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), usage=usage, - finish_reason=message.stop_reason if message.stop_reason else None, + finish_reason=message.stop_reason or None, ), ) elif message.function_call: - if 'name' not in message.function_call or 'arguments' not in message.function_call: + if "name" not in message.function_call or "arguments" not in message.function_call: continue yield LLMResultChunk( @@ -199,15 +224,16 @@ def _handle_chat_generate_stream_response(self, model: str, prompt_messages: lis delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage( - content='', - tool_calls=[AssistantPromptMessage.ToolCall( - id='', - type='function', - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=message.function_call['name'], - arguments=message.function_call['arguments'] + content="", + tool_calls=[ + AssistantPromptMessage.ToolCall( + id="", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=message.function_call["name"], arguments=message.function_call["arguments"] + ), ) - )] + ], ), ), ) @@ -217,11 +243,8 @@ def _handle_chat_generate_stream_response(self, model: str, prompt_messages: lis prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), - finish_reason=message.stop_reason if message.stop_reason else None, + message=AssistantPromptMessage(content=message.content, tool_calls=[]), + finish_reason=message.stop_reason or None, ), ) @@ -236,22 +259,13 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - diff --git a/api/core/model_runtime/model_providers/minimax/llm/types.py b/api/core/model_runtime/model_providers/minimax/llm/types.py index b33a7ca9ac20d0..88ebe5e2e00e7a 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/types.py +++ b/api/core/model_runtime/model_providers/minimax/llm/types.py @@ -4,32 +4,27 @@ class MinimaxMessage: class Role(Enum): - USER = 'USER' - ASSISTANT = 'BOT' - SYSTEM = 'SYSTEM' - FUNCTION = 'FUNCTION' + USER = "USER" + ASSISTANT = "BOT" + SYSTEM = "SYSTEM" + FUNCTION = "FUNCTION" role: str = Role.USER.value content: str usage: dict[str, int] = None - stop_reason: str = '' + stop_reason: str = "" function_call: dict[str, Any] = None def to_dict(self) -> dict[str, Any]: if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value: - return { - 'sender_type': 'BOT', - 'sender_name': '专家', - 'text': '', - 'function_call': self.function_call - } - + return {"sender_type": "BOT", "sender_name": "专家", "text": "", "function_call": self.function_call} + return { - 'sender_type': self.role, - 'sender_name': '我' if self.role == 'USER' else '专家', - 'text': self.content, + "sender_type": self.role, + "sender_name": "我" if self.role == "USER" else "专家", + "text": self.content, } - - def __init__(self, content: str, role: str = 'USER') -> None: + + def __init__(self, content: str, role: str = "USER") -> None: self.content = content - self.role = role \ No newline at end of file + self.role = role diff --git a/api/core/model_runtime/model_providers/minimax/minimax.py b/api/core/model_runtime/model_providers/minimax/minimax.py index 52f6c2f1d3a098..5a761903a1eb12 100644 --- a/api/core/model_runtime/model_providers/minimax/minimax.py +++ b/api/core/model_runtime/model_providers/minimax/minimax.py @@ -6,6 +6,7 @@ logger = logging.getLogger(__name__) + class MinimaxProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -19,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: model_instance = self.get_model_instance(ModelType.LLM) # Use `abab5.5-chat` model for validate, - model_instance.validate_credentials( - model='abab5.5-chat', - credentials=credentials - ) + model_instance.validate_credentials(model="abab5.5-chat", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') - raise CredentialsValidateFailedError(f'{ex}') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise CredentialsValidateFailedError(f"{ex}") diff --git a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py index 85dc6ef51d1a1d..29be5888af3fc2 100644 --- a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py @@ -4,6 +4,7 @@ from requests import post +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( @@ -30,11 +31,17 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): """ Model class for Minimax text embedding model. """ - api_base: str = 'https://api.minimax.chat/v1/embeddings' - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + api_base: str = "https://api.minimax.chat/v1/embeddings" + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -42,56 +49,47 @@ def _invoke(self, model: str, credentials: dict, :param credentials: model credentials :param texts: texts to embed :param user: unique user id + :param input_type: input type :return: embeddings result """ - api_key = credentials['minimax_api_key'] - group_id = credentials['minimax_group_id'] - if model != 'embo-01': - raise ValueError('Invalid model name') + api_key = credentials["minimax_api_key"] + group_id = credentials["minimax_group_id"] + if model != "embo-01": + raise ValueError("Invalid model name") if not api_key: - raise CredentialsValidateFailedError('api_key is required') - url = f'{self.api_base}?GroupId={group_id}' - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + raise CredentialsValidateFailedError("api_key is required") + url = f"{self.api_base}?GroupId={group_id}" + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} - data = { - 'model': 'embo-01', - 'texts': texts, - 'type': 'db' - } + embedding_type = "db" if input_type == EmbeddingInputType.DOCUMENT else "query" + data = {"model": "embo-01", "texts": texts, "type": embedding_type} try: response = post(url, headers=headers, data=dumps(data)) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: raise InvokeServerUnavailableError(response.text) - + try: resp = response.json() # check if there is an error - if resp['base_resp']['status_code'] != 0: - code = resp['base_resp']['status_code'] - msg = resp['base_resp']['status_msg'] + if resp["base_resp"]["status_code"] != 0: + code = resp["base_resp"]["status_code"] + msg = resp["base_resp"]["status_msg"] self._handle_error(code, msg) - embeddings = resp['vectors'] - total_tokens = resp['total_tokens'] + embeddings = resp["vectors"] + total_tokens = resp["total_tokens"] except InvalidAuthenticationError: - raise InvalidAPIKeyError('Invalid api key') + raise InvalidAPIKeyError("Invalid api key") except KeyError as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") usage = self._calc_response_usage(model=model, credentials=credentials, tokens=total_tokens) - result = TextEmbeddingResult( - model=model, - embeddings=embeddings, - usage=usage - ) + result = TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage) return result @@ -119,12 +117,12 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvalidAPIKeyError: - raise CredentialsValidateFailedError('Invalid api key') + raise CredentialsValidateFailedError("Invalid api key") def _handle_error(self, code: int, msg: str): - if code == 1000 or code == 1001: + if code in {1000, 1001}: raise InternalServerError(msg) elif code == 1002: raise RateLimitReachedError(msg) @@ -148,25 +146,17 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -178,10 +168,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -192,7 +179,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/mistralai/llm/_position.yaml b/api/core/model_runtime/model_providers/mistralai/llm/_position.yaml index 751003d71e8886..bdb06b7fff6376 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/mistralai/llm/_position.yaml @@ -1,3 +1,8 @@ +- pixtral-12b-2409 +- codestral-latest +- mistral-embed +- open-mistral-nemo +- open-codestral-mamba - open-mistral-7b - open-mixtral-8x7b - open-mixtral-8x22b diff --git a/api/core/model_runtime/model_providers/mistralai/llm/codestral-latest.yaml b/api/core/model_runtime/model_providers/mistralai/llm/codestral-latest.yaml new file mode 100644 index 00000000000000..5f1260233fe97b --- /dev/null +++ b/api/core/model_runtime/model_providers/mistralai/llm/codestral-latest.yaml @@ -0,0 +1,51 @@ +model: codestral-latest +label: + zh_Hans: codestral-latest + en_US: codestral-latest +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.7 + min: 0 + max: 1 + - name: top_p + use_template: top_p + default: 1 + min: 0 + max: 1 + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 4096 + - name: safe_prompt + default: false + type: boolean + help: + en_US: Whether to inject a safety prompt before all conversations. + zh_Hans: 是否开启提示词审查 + label: + en_US: SafePrompt + zh_Hans: 提示词审查 + - name: random_seed + type: int + help: + en_US: The seed to use for random sampling. If set, different calls will generate deterministic results. + zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定 + label: + en_US: RandomSeed + zh_Hans: 随机数种子 + default: 0 + min: 0 + max: 2147483647 +pricing: + input: '0.008' + output: '0.024' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/mistralai/llm/llm.py b/api/core/model_runtime/model_providers/mistralai/llm/llm.py index 01ed8010de873c..da60bd7661d597 100644 --- a/api/core/model_runtime/model_providers/mistralai/llm/llm.py +++ b/api/core/model_runtime/model_providers/mistralai/llm/llm.py @@ -7,14 +7,19 @@ class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: - + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) - + # mistral dose not support user/stop arguments stop = [] user = None @@ -27,5 +32,5 @@ def validate_credentials(self, model: str, credentials: dict) -> None: @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.mistral.ai/v1' + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.mistral.ai/v1" diff --git a/api/core/model_runtime/model_providers/mistralai/llm/mistral-embed.yaml b/api/core/model_runtime/model_providers/mistralai/llm/mistral-embed.yaml new file mode 100644 index 00000000000000..d759103d08a944 --- /dev/null +++ b/api/core/model_runtime/model_providers/mistralai/llm/mistral-embed.yaml @@ -0,0 +1,51 @@ +model: mistral-embed +label: + zh_Hans: mistral-embed + en_US: mistral-embed +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.7 + min: 0 + max: 1 + - name: top_p + use_template: top_p + default: 1 + min: 0 + max: 1 + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 1024 + - name: safe_prompt + default: false + type: boolean + help: + en_US: Whether to inject a safety prompt before all conversations. + zh_Hans: 是否开启提示词审查 + label: + en_US: SafePrompt + zh_Hans: 提示词审查 + - name: random_seed + type: int + help: + en_US: The seed to use for random sampling. If set, different calls will generate deterministic results. + zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定 + label: + en_US: RandomSeed + zh_Hans: 随机数种子 + default: 0 + min: 0 + max: 2147483647 +pricing: + input: '0.008' + output: '0.024' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/mistralai/llm/open-codestral-mamba.yaml b/api/core/model_runtime/model_providers/mistralai/llm/open-codestral-mamba.yaml new file mode 100644 index 00000000000000..d7ffb9ea020848 --- /dev/null +++ b/api/core/model_runtime/model_providers/mistralai/llm/open-codestral-mamba.yaml @@ -0,0 +1,51 @@ +model: open-codestral-mamba +label: + zh_Hans: open-codestral-mamba + en_US: open-codestral-mamba +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 256000 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.7 + min: 0 + max: 1 + - name: top_p + use_template: top_p + default: 1 + min: 0 + max: 1 + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 16384 + - name: safe_prompt + default: false + type: boolean + help: + en_US: Whether to inject a safety prompt before all conversations. + zh_Hans: 是否开启提示词审查 + label: + en_US: SafePrompt + zh_Hans: 提示词审查 + - name: random_seed + type: int + help: + en_US: The seed to use for random sampling. If set, different calls will generate deterministic results. + zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定 + label: + en_US: RandomSeed + zh_Hans: 随机数种子 + default: 0 + min: 0 + max: 2147483647 +pricing: + input: '0.008' + output: '0.024' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/mistralai/llm/open-mistral-nemo.yaml b/api/core/model_runtime/model_providers/mistralai/llm/open-mistral-nemo.yaml new file mode 100644 index 00000000000000..dcda4fbce7e82c --- /dev/null +++ b/api/core/model_runtime/model_providers/mistralai/llm/open-mistral-nemo.yaml @@ -0,0 +1,51 @@ +model: open-mistral-nemo +label: + zh_Hans: open-mistral-nemo + en_US: open-mistral-nemo +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.7 + min: 0 + max: 1 + - name: top_p + use_template: top_p + default: 1 + min: 0 + max: 1 + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 8192 + - name: safe_prompt + default: false + type: boolean + help: + en_US: Whether to inject a safety prompt before all conversations. + zh_Hans: 是否开启提示词审查 + label: + en_US: SafePrompt + zh_Hans: 提示词审查 + - name: random_seed + type: int + help: + en_US: The seed to use for random sampling. If set, different calls will generate deterministic results. + zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定 + label: + en_US: RandomSeed + zh_Hans: 随机数种子 + default: 0 + min: 0 + max: 2147483647 +pricing: + input: '0.008' + output: '0.024' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/mistralai/llm/pixtral-12b-2409.yaml b/api/core/model_runtime/model_providers/mistralai/llm/pixtral-12b-2409.yaml new file mode 100644 index 00000000000000..0b002b49cac8e0 --- /dev/null +++ b/api/core/model_runtime/model_providers/mistralai/llm/pixtral-12b-2409.yaml @@ -0,0 +1,51 @@ +model: pixtral-12b-2409 +label: + zh_Hans: pixtral-12b-2409 + en_US: pixtral-12b-2409 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.7 + min: 0 + max: 1 + - name: top_p + use_template: top_p + default: 1 + min: 0 + max: 1 + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 8192 + - name: safe_prompt + default: false + type: boolean + help: + en_US: Whether to inject a safety prompt before all conversations. + zh_Hans: 是否开启提示词审查 + label: + en_US: SafePrompt + zh_Hans: 提示词审查 + - name: random_seed + type: int + help: + en_US: The seed to use for random sampling. If set, different calls will generate deterministic results. + zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定 + label: + en_US: RandomSeed + zh_Hans: 随机数种子 + default: 0 + min: 0 + max: 2147483647 +pricing: + input: '0.008' + output: '0.024' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/mistralai/mistralai.py b/api/core/model_runtime/model_providers/mistralai/mistralai.py index f1d825f6c6f042..7f9db8da1c1ddf 100644 --- a/api/core/model_runtime/model_providers/mistralai/mistralai.py +++ b/api/core/model_runtime/model_providers/mistralai/mistralai.py @@ -8,7 +8,6 @@ class MistralAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='open-mistral-7b', - credentials=credentials - ) + model_instance.validate_credentials(model="open-mistral-7b", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/mixedbread/__init__.py b/api/core/model_runtime/model_providers/mixedbread/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/mixedbread/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/mixedbread/_assets/icon_l_en.png new file mode 100644 index 00000000000000..2027611bd5e8b4 Binary files /dev/null and b/api/core/model_runtime/model_providers/mixedbread/_assets/icon_l_en.png differ diff --git a/api/core/model_runtime/model_providers/mixedbread/_assets/icon_s_en.png b/api/core/model_runtime/model_providers/mixedbread/_assets/icon_s_en.png new file mode 100644 index 00000000000000..5c357bddbddb15 Binary files /dev/null and b/api/core/model_runtime/model_providers/mixedbread/_assets/icon_s_en.png differ diff --git a/api/core/model_runtime/model_providers/mixedbread/mixedbread.py b/api/core/model_runtime/model_providers/mixedbread/mixedbread.py new file mode 100644 index 00000000000000..3c78150e6f806e --- /dev/null +++ b/api/core/model_runtime/model_providers/mixedbread/mixedbread.py @@ -0,0 +1,27 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class MixedBreadProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.TEXT_EMBEDDING) + + # Use `mxbai-embed-large-v1` model for validate, + model_instance.validate_credentials(model="mxbai-embed-large-v1", credentials=credentials) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/mixedbread/mixedbread.yaml b/api/core/model_runtime/model_providers/mixedbread/mixedbread.yaml new file mode 100644 index 00000000000000..2f43aea6ade2c6 --- /dev/null +++ b/api/core/model_runtime/model_providers/mixedbread/mixedbread.yaml @@ -0,0 +1,31 @@ +provider: mixedbread +label: + en_US: MixedBread +description: + en_US: Embedding and Rerank Model Supported +icon_small: + en_US: icon_s_en.png +icon_large: + en_US: icon_l_en.png +background: "#EFFDFD" +help: + title: + en_US: Get your API key from MixedBread AI + zh_Hans: 从 MixedBread 获取 API Key + url: + en_US: https://www.mixedbread.ai/ +supported_model_types: + - text-embedding + - rerank +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key diff --git a/api/core/model_runtime/model_providers/mixedbread/rerank/__init__.py b/api/core/model_runtime/model_providers/mixedbread/rerank/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/mixedbread/rerank/mxbai-rerank-large-v1-en.yaml b/api/core/model_runtime/model_providers/mixedbread/rerank/mxbai-rerank-large-v1-en.yaml new file mode 100644 index 00000000000000..beda2199537450 --- /dev/null +++ b/api/core/model_runtime/model_providers/mixedbread/rerank/mxbai-rerank-large-v1-en.yaml @@ -0,0 +1,4 @@ +model: mxbai-rerank-large-v1 +model_type: rerank +model_properties: + context_size: 512 diff --git a/api/core/model_runtime/model_providers/mixedbread/rerank/rerank.py b/api/core/model_runtime/model_providers/mixedbread/rerank/rerank.py new file mode 100644 index 00000000000000..bf3c12fd86dc35 --- /dev/null +++ b/api/core/model_runtime/model_providers/mixedbread/rerank/rerank.py @@ -0,0 +1,125 @@ +from typing import Optional + +import httpx + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class MixedBreadRerankModel(RerankModel): + """ + Model class for MixedBread rerank model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n documents to return + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + base_url = credentials.get("base_url", "https://api.mixedbread.ai/v1") + base_url = base_url.removesuffix("/") + + try: + response = httpx.post( + base_url + "/reranking", + json={"model": model, "query": query, "input": docs, "top_k": top_n, "return_input": True}, + headers={"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"}, + ) + response.raise_for_status() + results = response.json() + + rerank_documents = [] + for result in results["data"]: + rerank_document = RerankDocument( + index=result["index"], + text=result["input"], + score=result["score"], + ) + if score_threshold is None or result["score"] >= score_threshold: + rerank_documents.append(rerank_document) + + return RerankResult(model=model, docs=rerank_documents) + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return { + InvokeConnectionError: [httpx.ConnectError], + InvokeServerUnavailableError: [httpx.RemoteProtocolError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.RERANK, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "512"))}, + ) + + return entity diff --git a/api/core/model_runtime/model_providers/mixedbread/text_embedding/__init__.py b/api/core/model_runtime/model_providers/mixedbread/text_embedding/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/mixedbread/text_embedding/mxbai-embed-2d-large-v1-en.yaml b/api/core/model_runtime/model_providers/mixedbread/text_embedding/mxbai-embed-2d-large-v1-en.yaml new file mode 100644 index 00000000000000..0c3c863d06b89a --- /dev/null +++ b/api/core/model_runtime/model_providers/mixedbread/text_embedding/mxbai-embed-2d-large-v1-en.yaml @@ -0,0 +1,8 @@ +model: mxbai-embed-2d-large-v1 +model_type: text-embedding +model_properties: + context_size: 512 +pricing: + input: '0.0001' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/mixedbread/text_embedding/mxbai-embed-large-v1-en.yaml b/api/core/model_runtime/model_providers/mixedbread/text_embedding/mxbai-embed-large-v1-en.yaml new file mode 100644 index 00000000000000..0c5cda2a72a99e --- /dev/null +++ b/api/core/model_runtime/model_providers/mixedbread/text_embedding/mxbai-embed-large-v1-en.yaml @@ -0,0 +1,8 @@ +model: mxbai-embed-large-v1 +model_type: text-embedding +model_properties: + context_size: 512 +pricing: + input: '0.0001' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/mixedbread/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/mixedbread/text_embedding/text_embedding.py new file mode 100644 index 00000000000000..ca949cb9532daa --- /dev/null +++ b/api/core/model_runtime/model_providers/mixedbread/text_embedding/text_embedding.py @@ -0,0 +1,170 @@ +import time +from json import JSONDecodeError, dumps +from typing import Optional + +import requests + +from core.entities.embedding_type import EmbeddingInputType +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel + + +class MixedBreadTextEmbeddingModel(TextEmbeddingModel): + """ + Model class for MixedBread text embedding model. + """ + + api_base: str = "https://api.mixedbread.ai/v1" + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :param input_type: input type + :return: embeddings result + """ + api_key = credentials["api_key"] + if not api_key: + raise CredentialsValidateFailedError("api_key is required") + + base_url = credentials.get("base_url", self.api_base) + base_url = base_url.removesuffix("/") + + url = base_url + "/embeddings" + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} + + data = {"model": model, "input": texts} + + try: + response = requests.post(url, headers=headers, data=dumps(data)) + except Exception as e: + raise InvokeConnectionError(str(e)) + + if response.status_code != 200: + try: + resp = response.json() + msg = resp["detail"] + if response.status_code == 401: + raise InvokeAuthorizationError(msg) + elif response.status_code == 429: + raise InvokeRateLimitError(msg) + elif response.status_code == 500: + raise InvokeServerUnavailableError(msg) + else: + raise InvokeBadRequestError(msg) + except JSONDecodeError as e: + raise InvokeServerUnavailableError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) + + try: + resp = response.json() + embeddings = resp["data"] + usage = resp["usage"] + except Exception as e: + raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") + + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) + + result = TextEmbeddingResult( + model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage + ) + + return result + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + return sum(self._get_num_tokens_by_gpt2(text) for text in texts) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke(model=model, credentials=credentials, texts=["ping"]) + except Exception as e: + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return { + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError, InvokeBadRequestError], + } + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at, + ) + + return usage + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.TEXT_EMBEDDING, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "512"))}, + ) + + return entity diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index b1660afafb12e4..e2d17e32575920 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict from core.helper.module_import_helper import load_single_subclass_from_source -from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map +from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from core.model_runtime.model_providers.__base.model_provider import ModelProvider @@ -234,7 +234,7 @@ def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]: ] # get _position.yaml file path - position_map = get_position_map(model_providers_path) + position_map = get_provider_position_map(model_providers_path) # traverse all model_provider_dir_paths model_providers: list[ModelProviderExtension] = [] diff --git a/api/core/model_runtime/model_providers/moonshot/llm/llm.py b/api/core/model_runtime/model_providers/moonshot/llm/llm.py index c233596637fa21..5c955c86d3a3e2 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/llm.py +++ b/api/core/model_runtime/model_providers/moonshot/llm/llm.py @@ -30,69 +30,78 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) self._add_function_call(model, credentials) user = user[:32] if user else None + # {"response_format": "json_object"} need convert to {"response_format": {"type": "json_object"}} + if "response_format" in model_parameters: + model_parameters["response_format"] = {"type": model_parameters.get("response_format")} return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def validate_credentials(self, model: str, credentials: dict) -> None: self._add_custom_parameters(credentials) super().validate_credentials(model, credentials) - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: return AIModelEntity( model=model, label=I18nObject(en_US=model, zh_Hans=model), model_type=ModelType.LLM, - features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] - if credentials.get('function_calling_type') == 'tool_call' - else [], + features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] + if credentials.get("function_calling_type") == "tool_call" + else [], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)), ModelPropertyKey.MODE: LLMMode.CHAT.value, }, parameter_rules=[ ParameterRule( - name='temperature', - use_template='temperature', - label=I18nObject(en_US='Temperature', zh_Hans='温度'), + name="temperature", + use_template="temperature", + label=I18nObject(en_US="Temperature", zh_Hans="温度"), type=ParameterType.FLOAT, ), ParameterRule( - name='max_tokens', - use_template='max_tokens', + name="max_tokens", + use_template="max_tokens", default=512, min=1, - max=int(credentials.get('max_tokens', 4096)), - label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'), + max=int(credentials.get("max_tokens", 4096)), + label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"), type=ParameterType.INT, ), ParameterRule( - name='top_p', - use_template='top_p', - label=I18nObject(en_US='Top P', zh_Hans='Top P'), + name="top_p", + use_template="top_p", + label=I18nObject(en_US="Top P", zh_Hans="Top P"), type=ParameterType.FLOAT, ), - ] + ], ) def _add_custom_parameters(self, credentials: dict) -> None: - credentials['mode'] = 'chat' - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - credentials['endpoint_url'] = 'https://api.moonshot.cn/v1' + credentials["mode"] = "chat" + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + credentials["endpoint_url"] = "https://api.moonshot.cn/v1" def _add_function_call(self, model: str, credentials: dict) -> None: model_schema = self.get_model_schema(model, credentials) - if model_schema and { - ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL - }.intersection(model_schema.features or []): - credentials['function_calling_type'] = 'tool_call' + if model_schema and {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}.intersection( + model_schema.features or [] + ): + credentials["function_calling_type"] = "tool_call" def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict: """ @@ -107,19 +116,13 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: O for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} @@ -129,14 +132,16 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: O if message.tool_calls: message_dict["tool_calls"] = [] for function_call in message.tool_calls: - message_dict["tool_calls"].append({ - "id": function_call.id, - "type": function_call.type, - "function": { - "name": function_call.function.name, - "arguments": function_call.function.arguments + message_dict["tool_calls"].append( + { + "id": function_call.id, + "type": function_call.type, + "function": { + "name": function_call.function.name, + "arguments": function_call.function.arguments, + }, } - }) + ) elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} @@ -162,21 +167,26 @@ def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[ if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "", - arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else "" + name=response_tool_call["function"]["name"] + if response_tool_call.get("function", {}).get("name") + else "", + arguments=response_tool_call["function"]["arguments"] + if response_tool_call.get("function", {}).get("arguments") + else "", ) tool_call = AssistantPromptMessage.ToolCall( id=response_tool_call["id"] if response_tool_call.get("id") else "", type=response_tool_call["type"] if response_tool_call.get("type") else "", - function=function + function=function, ) tool_calls.append(tool_call) return tool_calls - def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -186,11 +196,12 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon :param prompt_messages: prompt messages :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" chunk_index = 0 - def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ - -> LLMResultChunk: + def create_final_llm_result_chunk( + index: int, message: AssistantPromptMessage, finish_reason: str + ) -> LLMResultChunk: # calculate num tokens prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) completion_tokens = self._num_tokens_from_string(model, full_assistant_content) @@ -201,12 +212,7 @@ def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, f return LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=message, - finish_reason=finish_reason, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), ) tools_calls: list[AssistantPromptMessage.ToolCall] = [] @@ -220,9 +226,9 @@ def get_tool_call(tool_name: str): tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None) if tool_call is None: tool_call = AssistantPromptMessage.ToolCall( - id='', - type='', - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="") + id="", + type="", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""), ) tools_calls.append(tool_call) @@ -244,9 +250,9 @@ def get_tool_call(tool_name: str): for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"): if chunk: # ignore sse comments - if chunk.startswith(':'): + if chunk.startswith(":"): continue - decoded_chunk = chunk.strip().lstrip('data: ').lstrip() + decoded_chunk = chunk.strip().lstrip("data: ").lstrip() chunk_json = None try: chunk_json = json.loads(decoded_chunk) @@ -255,21 +261,21 @@ def get_tool_call(tool_name: str): yield create_final_llm_result_chunk( index=chunk_index + 1, message=AssistantPromptMessage(content=""), - finish_reason="Non-JSON encountered." + finish_reason="Non-JSON encountered.", ) break - if not chunk_json or len(chunk_json['choices']) == 0: + if not chunk_json or len(chunk_json["choices"]) == 0: continue - choice = chunk_json['choices'][0] - finish_reason = chunk_json['choices'][0].get('finish_reason') + choice = chunk_json["choices"][0] + finish_reason = chunk_json["choices"][0].get("finish_reason") chunk_index += 1 - if 'delta' in choice: - delta = choice['delta'] - delta_content = delta.get('content') + if "delta" in choice: + delta = choice["delta"] + delta_content = delta.get("content") - assistant_message_tool_calls = delta.get('tool_calls', None) + assistant_message_tool_calls = delta.get("tool_calls", None) # assistant_message_function_call = delta.delta.function_call # extract tool calls from response @@ -277,19 +283,18 @@ def get_tool_call(tool_name: str): tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) increase_tool_call(tool_calls) - if delta_content is None or delta_content == '': + if delta_content is None or delta_content == "": continue # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta_content, - tool_calls=tool_calls if assistant_message_tool_calls else [] + content=delta_content, tool_calls=tool_calls if assistant_message_tool_calls else [] ) full_assistant_content += delta_content - elif 'text' in choice: - choice_text = choice.get('text', '') - if choice_text == '': + elif "text" in choice: + choice_text = choice.get("text", "") + if choice_text == "": continue # transform assistant message to prompt message @@ -305,26 +310,21 @@ def get_tool_call(tool_name: str): delta=LLMResultChunkDelta( index=chunk_index, message=assistant_prompt_message, - ) + ), ) chunk_index += 1 - + if tools_calls: yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=chunk_index, - message=AssistantPromptMessage( - tool_calls=tools_calls, - content="" - ), - ) + message=AssistantPromptMessage(tool_calls=tools_calls, content=""), + ), ) yield create_final_llm_result_chunk( - index=chunk_index, - message=AssistantPromptMessage(content=""), - finish_reason=finish_reason - ) \ No newline at end of file + index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason + ) diff --git a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml index 0d2e51c47f9dc6..59c0915ee9b610 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml +++ b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml @@ -21,6 +21,18 @@ parameter_rules: default: 1024 min: 1 max: 128000 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.06' output: '0.06' diff --git a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml index 9ff537014a86c6..724f2aa5a29f96 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml +++ b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml @@ -21,6 +21,18 @@ parameter_rules: default: 1024 min: 1 max: 32000 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.024' output: '0.024' diff --git a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml index 0f308d36766752..5872295bfad1ef 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml +++ b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml @@ -21,6 +21,18 @@ parameter_rules: default: 512 min: 1 max: 8192 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.012' output: '0.012' diff --git a/api/core/model_runtime/model_providers/moonshot/moonshot.py b/api/core/model_runtime/model_providers/moonshot/moonshot.py index 5654ae1459cc16..4995e235f54c69 100644 --- a/api/core/model_runtime/model_providers/moonshot/moonshot.py +++ b/api/core/model_runtime/model_providers/moonshot/moonshot.py @@ -8,7 +8,6 @@ class MoonshotProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='moonshot-v1-8k', - credentials=credentials - ) + model_instance.validate_credentials(model="moonshot-v1-8k", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/nomic/__init__.py b/api/core/model_runtime/model_providers/nomic/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/nomic/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/nomic/_assets/icon_l_en.svg new file mode 100644 index 00000000000000..6c4a1058ab9c70 --- /dev/null +++ b/api/core/model_runtime/model_providers/nomic/_assets/icon_l_en.svg @@ -0,0 +1,13 @@ + + + + + + + + + + + + + diff --git a/api/core/model_runtime/model_providers/nomic/_assets/icon_s_en.png b/api/core/model_runtime/model_providers/nomic/_assets/icon_s_en.png new file mode 100644 index 00000000000000..3eba3b82bc1e3f Binary files /dev/null and b/api/core/model_runtime/model_providers/nomic/_assets/icon_s_en.png differ diff --git a/api/core/model_runtime/model_providers/nomic/_common.py b/api/core/model_runtime/model_providers/nomic/_common.py new file mode 100644 index 00000000000000..406577dcd7e701 --- /dev/null +++ b/api/core/model_runtime/model_providers/nomic/_common.py @@ -0,0 +1,28 @@ +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) + + +class _CommonNomic: + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError, InvokeBadRequestError], + } diff --git a/api/core/model_runtime/model_providers/nomic/nomic.py b/api/core/model_runtime/model_providers/nomic/nomic.py new file mode 100644 index 00000000000000..d4e5da2e98ec97 --- /dev/null +++ b/api/core/model_runtime/model_providers/nomic/nomic.py @@ -0,0 +1,26 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class NomicAtlasProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.TEXT_EMBEDDING) + model_instance.validate_credentials(model="nomic-embed-text-v1.5", credentials=credentials) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/nomic/nomic.yaml b/api/core/model_runtime/model_providers/nomic/nomic.yaml new file mode 100644 index 00000000000000..60dcf1facb475d --- /dev/null +++ b/api/core/model_runtime/model_providers/nomic/nomic.yaml @@ -0,0 +1,29 @@ +provider: nomic +label: + zh_Hans: Nomic Atlas + en_US: Nomic Atlas +icon_small: + en_US: icon_s_en.png +icon_large: + en_US: icon_l_en.svg +background: "#EFF1FE" +help: + title: + en_US: Get your API key from Nomic Atlas + zh_Hans: 从Nomic Atlas获取 API Key + url: + en_US: https://atlas.nomic.ai/data +supported_model_types: + - text-embedding +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: nomic_api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key diff --git a/api/core/model_runtime/model_providers/nomic/text_embedding/__init__.py b/api/core/model_runtime/model_providers/nomic/text_embedding/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/nomic/text_embedding/nomic-embed-text-v1.5.yaml b/api/core/model_runtime/model_providers/nomic/text_embedding/nomic-embed-text-v1.5.yaml new file mode 100644 index 00000000000000..111452df579f8f --- /dev/null +++ b/api/core/model_runtime/model_providers/nomic/text_embedding/nomic-embed-text-v1.5.yaml @@ -0,0 +1,8 @@ +model: nomic-embed-text-v1.5 +model_type: text-embedding +model_properties: + context_size: 8192 +pricing: + input: "0.1" + unit: "0.000001" + currency: USD diff --git a/api/core/model_runtime/model_providers/nomic/text_embedding/nomic-embed-text-v1.yaml b/api/core/model_runtime/model_providers/nomic/text_embedding/nomic-embed-text-v1.yaml new file mode 100644 index 00000000000000..ac59f106ed2928 --- /dev/null +++ b/api/core/model_runtime/model_providers/nomic/text_embedding/nomic-embed-text-v1.yaml @@ -0,0 +1,8 @@ +model: nomic-embed-text-v1 +model_type: text-embedding +model_properties: + context_size: 8192 +pricing: + input: "0.1" + unit: "0.000001" + currency: USD diff --git a/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py new file mode 100644 index 00000000000000..56a707333c40e9 --- /dev/null +++ b/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py @@ -0,0 +1,165 @@ +import time +from functools import wraps +from typing import Optional + +from nomic import embed +from nomic import login as nomic_login + +from core.entities.embedding_type import EmbeddingInputType +from core.model_runtime.entities.model_entities import PriceType +from core.model_runtime.entities.text_embedding_entities import ( + EmbeddingUsage, + TextEmbeddingResult, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import ( + TextEmbeddingModel, +) +from core.model_runtime.model_providers.nomic._common import _CommonNomic + + +def nomic_login_required(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + if not kwargs.get("credentials"): + raise ValueError("missing credentials parameters") + credentials = kwargs.get("credentials") + if "nomic_api_key" not in credentials: + raise ValueError("missing nomic_api_key in credentials parameters") + # nomic login + nomic_login(credentials["nomic_api_key"]) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + return func(*args, **kwargs) + + return wrapper + + +class NomicTextEmbeddingModel(_CommonNomic, TextEmbeddingModel): + """ + Model class for nomic text embedding model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :param input_type: input type + :return: embeddings result + """ + embeddings, prompt_tokens, total_tokens = self.embed_text( + model=model, + credentials=credentials, + texts=texts, + ) + + # calc usage + usage = self._calc_response_usage( + model=model, credentials=credentials, tokens=prompt_tokens, total_tokens=total_tokens + ) + return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model) + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + return sum(self._get_num_tokens_by_gpt2(text) for text in texts) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + # call embedding model + self.embed_text(model=model, credentials=credentials, texts=["ping"]) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @nomic_login_required + def embed_text(self, model: str, credentials: dict, texts: list[str]) -> tuple[list[list[float]], int, int]: + """Call out to Nomic's embedding endpoint. + + Args: + model: The model to use for embedding. + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text, and tokens usage. + """ + embeddings: list[list[float]] = [] + prompt_tokens = 0 + total_tokens = 0 + + response = embed.text( + model=model, + texts=texts, + ) + + if not (response and "embeddings" in response): + raise ValueError("Embedding data is missing in the response.") + + if not (response and "usage" in response): + raise ValueError("Response usage is missing.") + + if "prompt_tokens" not in response["usage"]: + raise ValueError("Response usage does not contain prompt tokens.") + + if "total_tokens" not in response["usage"]: + raise ValueError("Response usage does not contain total tokens.") + + embeddings = [list(map(float, e)) for e in response["embeddings"]] + total_tokens = response["usage"]["total_tokens"] + prompt_tokens = response["usage"]["prompt_tokens"] + return embeddings, prompt_tokens, total_tokens + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int, total_tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: prompt tokens + :param total_tokens: total tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=tokens, + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=total_tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at, + ) + + return usage diff --git a/api/core/model_runtime/model_providers/novita/llm/llm.py b/api/core/model_runtime/model_providers/novita/llm/llm.py index c7b223d1b7bdbe..23367ed1b4309e 100644 --- a/api/core/model_runtime/model_providers/novita/llm/llm.py +++ b/api/core/model_runtime/model_providers/novita/llm/llm.py @@ -8,19 +8,25 @@ class NovitaLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _update_endpoint_url(self, credentials: dict): - credentials['endpoint_url'] = "https://api.novita.ai/v3/openai" - credentials['extra_headers'] = { 'X-Novita-Source': 'dify.ai' } + credentials["endpoint_url"] = "https://api.novita.ai/v3/openai" + credentials["extra_headers"] = {"X-Novita-Source": "dify.ai"} return credentials - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) + def validate_credentials(self, model: str, credentials: dict) -> None: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) self._add_custom_parameters(credentials, model) @@ -28,21 +34,36 @@ def validate_credentials(self, model: str, credentials: dict) -> None: @classmethod def _add_custom_parameters(cls, credentials: dict, model: str) -> None: - credentials['mode'] = 'chat' + credentials["mode"] = "chat" - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) - return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) + return super()._generate( + model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user + ) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super().get_customizable_model_schema(model, cred_with_endpoint) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools) diff --git a/api/core/model_runtime/model_providers/novita/novita.py b/api/core/model_runtime/model_providers/novita/novita.py index f1b72246057c6d..76a75b01e27e01 100644 --- a/api/core/model_runtime/model_providers/novita/novita.py +++ b/api/core/model_runtime/model_providers/novita/novita.py @@ -20,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: # Use `meta-llama/llama-3-8b-instruct` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='meta-llama/llama-3-8b-instruct', - credentials=credentials - ) + model_instance.validate_credentials(model="meta-llama/llama-3-8b-instruct", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/nvidia/llm/llm.py b/api/core/model_runtime/model_providers/nvidia/llm/llm.py index bc42eaca658bac..1c98c6be6ca72d 100644 --- a/api/core/model_runtime/model_providers/nvidia/llm/llm.py +++ b/api/core/model_runtime/model_providers/nvidia/llm/llm.py @@ -21,31 +21,36 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel): MODEL_SUFFIX_MAP = { - 'fuyu-8b': 'vlm/adept/fuyu-8b', - 'mistralai/mistral-large': '', - 'mistralai/mixtral-8x7b-instruct-v0.1': '', - 'mistralai/mixtral-8x22b-instruct-v0.1': '', - 'google/gemma-7b': '', - 'google/codegemma-7b': '', - 'snowflake/arctic':'', - 'meta/llama2-70b': '', - 'meta/llama3-8b-instruct': '', - 'meta/llama3-70b-instruct': '', - 'meta/llama-3.1-8b-instruct': '', - 'meta/llama-3.1-70b-instruct': '', - 'meta/llama-3.1-405b-instruct': '', - 'google/recurrentgemma-2b': '', - 'nvidia/nemotron-4-340b-instruct': '', - 'microsoft/phi-3-medium-128k-instruct':'', - 'microsoft/phi-3-mini-128k-instruct':'' + "fuyu-8b": "vlm/adept/fuyu-8b", + "mistralai/mistral-large": "", + "mistralai/mixtral-8x7b-instruct-v0.1": "", + "mistralai/mixtral-8x22b-instruct-v0.1": "", + "google/gemma-7b": "", + "google/codegemma-7b": "", + "snowflake/arctic": "", + "meta/llama2-70b": "", + "meta/llama3-8b-instruct": "", + "meta/llama3-70b-instruct": "", + "meta/llama-3.1-8b-instruct": "", + "meta/llama-3.1-70b-instruct": "", + "meta/llama-3.1-405b-instruct": "", + "google/recurrentgemma-2b": "", + "nvidia/nemotron-4-340b-instruct": "", + "microsoft/phi-3-medium-128k-instruct": "", + "microsoft/phi-3-mini-128k-instruct": "", } - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: - + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials, model) prompt_messages = self._transform_prompt_messages(prompt_messages) stop = [] @@ -60,16 +65,14 @@ def _transform_prompt_messages(self, prompt_messages: list[PromptMessage]) -> li for i, p in enumerate(prompt_messages): if isinstance(p, UserPromptMessage) and isinstance(p.content, list): content = p.content - content_text = '' + content_text = "" for prompt_content in content: if prompt_content.type == PromptMessageContentType.TEXT: content_text += prompt_content.data else: content_text += f' ' - prompt_message = UserPromptMessage( - content=content_text - ) + prompt_message = UserPromptMessage(content=content_text) prompt_messages[i] = prompt_message return prompt_messages @@ -78,91 +81,87 @@ def validate_credentials(self, model: str, credentials: dict) -> None: self._validate_credentials(model, credentials) def _add_custom_parameters(self, credentials: dict, model: str) -> None: - credentials['mode'] = 'chat' - + credentials["mode"] = "chat" + if self.MODEL_SUFFIX_MAP[model]: - credentials['server_url'] = f'https://ai.api.nvidia.com/v1/{self.MODEL_SUFFIX_MAP[model]}' - credentials.pop('endpoint_url') + credentials["server_url"] = f"https://ai.api.nvidia.com/v1/{self.MODEL_SUFFIX_MAP[model]}" + credentials.pop("endpoint_url") else: - credentials['endpoint_url'] = 'https://integrate.api.nvidia.com/v1' + credentials["endpoint_url"] = "https://integrate.api.nvidia.com/v1" - credentials['stream_mode_delimiter'] = '\n' + credentials["stream_mode_delimiter"] = "\n" def _validate_credentials(self, model: str, credentials: dict) -> None: """ - Validate model credentials using requests to ensure compatibility with all providers following OpenAI's API standard. + Validate model credentials using requests to ensure compatibility with all providers following + OpenAI's API standard. :param model: model name :param credentials: model credentials :return: """ try: - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials.get('endpoint_url') - if endpoint_url and not endpoint_url.endswith('/'): - endpoint_url += '/' - server_url = credentials.get('server_url') + endpoint_url = credentials.get("endpoint_url") + if endpoint_url and not endpoint_url.endswith("/"): + endpoint_url += "/" + server_url = credentials.get("server_url") # prepare the payload for a simple ping to the model - data = { - 'model': model, - 'max_tokens': 5 - } + data = {"model": model, "max_tokens": 5} - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - data['messages'] = [ - { - "role": "user", - "content": "ping" - }, + data["messages"] = [ + {"role": "user", "content": "ping"}, ] - if 'endpoint_url' in credentials: - endpoint_url = str(URL(endpoint_url) / 'chat' / 'completions') - elif 'server_url' in credentials: + if "endpoint_url" in credentials: + endpoint_url = str(URL(endpoint_url) / "chat" / "completions") + elif "server_url" in credentials: endpoint_url = server_url elif completion_type is LLMMode.COMPLETION: - data['prompt'] = 'ping' - if 'endpoint_url' in credentials: - endpoint_url = str(URL(endpoint_url) / 'completions') - elif 'server_url' in credentials: + data["prompt"] = "ping" + if "endpoint_url" in credentials: + endpoint_url = str(URL(endpoint_url) / "completions") + elif "server_url" in credentials: endpoint_url = server_url else: raise ValueError("Unsupported completion type for model configuration.") # send a post request to validate the credentials - response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 300) - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300)) if response.status_code != 200: raise CredentialsValidateFailedError( - f'Credentials validation failed with status code {response.status_code}') + f"Credentials validation failed with status code {response.status_code}" + ) try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error") except CredentialsValidateFailedError: raise except Exception as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') - - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, \ - user: Optional[str] = None) -> Union[LLMResult, Generator]: + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}") + + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -176,57 +175,51 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM :return: full response or stream response chunk generator result """ headers = { - 'Content-Type': 'application/json', - 'Accept-Charset': 'utf-8', + "Content-Type": "application/json", + "Accept-Charset": "utf-8", } - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: - headers['Authorization'] = f'Bearer {api_key}' + headers["Authorization"] = f"Bearer {api_key}" if stream: - headers['Accept'] = 'text/event-stream' + headers["Accept"] = "text/event-stream" - endpoint_url = credentials.get('endpoint_url') - if endpoint_url and not endpoint_url.endswith('/'): - endpoint_url += '/' - server_url = credentials.get('server_url') + endpoint_url = credentials.get("endpoint_url") + if endpoint_url and not endpoint_url.endswith("/"): + endpoint_url += "/" + server_url = credentials.get("server_url") - data = { - "model": model, - "stream": stream, - **model_parameters - } + data = {"model": model, "stream": stream, **model_parameters} - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - if 'endpoint_url' in credentials: - endpoint_url = str(URL(endpoint_url) / 'chat' / 'completions') - elif 'server_url' in credentials: + if "endpoint_url" in credentials: + endpoint_url = str(URL(endpoint_url) / "chat" / "completions") + elif "server_url" in credentials: endpoint_url = server_url - data['messages'] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages] + data["messages"] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages] elif completion_type is LLMMode.COMPLETION: - data['prompt'] = 'ping' - if 'endpoint_url' in credentials: - endpoint_url = str(URL(endpoint_url) / 'completions') - elif 'server_url' in credentials: + data["prompt"] = "ping" + if "endpoint_url" in credentials: + endpoint_url = str(URL(endpoint_url) / "completions") + elif "server_url" in credentials: endpoint_url = server_url else: raise ValueError("Unsupported completion type for model configuration.") - # annotate tools with names, descriptions, etc. - function_calling_type = credentials.get('function_calling_type', 'no_call') + function_calling_type = credentials.get("function_calling_type", "no_call") formatted_tools = [] if tools: - if function_calling_type == 'function_call': - data['functions'] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] - elif function_calling_type == 'tool_call': + if function_calling_type == "function_call": + data["functions"] = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} + for tool in tools + ] + elif function_calling_type == "tool_call": data["tool_choice"] = "auto" for tool in tools: @@ -240,16 +233,10 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM if user: data["user"] = user - response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 300), - stream=stream - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream) - if response.encoding is None or response.encoding == 'ISO-8859-1': - response.encoding = 'utf-8' + if response.encoding is None or response.encoding == "ISO-8859-1": + response.encoding = "utf-8" if not response.ok: raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}") diff --git a/api/core/model_runtime/model_providers/nvidia/nvidia.py b/api/core/model_runtime/model_providers/nvidia/nvidia.py index e83f8badb57242..058fa003462585 100644 --- a/api/core/model_runtime/model_providers/nvidia/nvidia.py +++ b/api/core/model_runtime/model_providers/nvidia/nvidia.py @@ -8,7 +8,6 @@ class MistralAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='mistralai/mixtral-8x7b-instruct-v0.1', - credentials=credentials - ) + model_instance.validate_credentials(model="mistralai/mixtral-8x7b-instruct-v0.1", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py b/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py index 9d33f55bc2fb35..fabebc67ab0eeb 100644 --- a/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py @@ -22,11 +22,18 @@ class NvidiaRerankModel(RerankModel): """ def _sigmoid(self, logit: float) -> float: - return 1/(1+exp(-logit)) - - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) -> RerankResult: + return 1 / (1 + exp(-logit)) + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -54,16 +61,15 @@ def _invoke(self, model: str, credentials: dict, "query": {"text": query}, "passages": [{"text": doc} for doc in docs], } - session = requests.Session() response = session.post(invoke_url, headers=headers, json=payload) response.raise_for_status() results = response.json() rerank_documents = [] - for result in results['rankings']: - index = result['index'] - logit = result['logit'] + for result in results["rankings"]: + index = result["index"] + logit = result["logit"] rerank_document = RerankDocument( index=index, text=docs[index], @@ -71,7 +77,10 @@ def _invoke(self, model: str, credentials: dict, ) rerank_documents.append(rerank_document) - + if rerank_documents: + rerank_documents = sorted(rerank_documents, key=lambda x: x.score, reverse=True) + if top_n: + rerank_documents = rerank_documents[:top_n] return RerankResult(model=model, docs=rerank_documents) except requests.HTTPError as e: raise InvokeServerUnavailableError(str(e)) @@ -108,5 +117,5 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] InvokeServerUnavailableError: [requests.HTTPError], InvokeRateLimitError: [], InvokeAuthorizationError: [requests.HTTPError], - InvokeBadRequestError: [requests.RequestException] + InvokeBadRequestError: [requests.RequestException], } diff --git a/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py index a2adef400d404c..04363e11be8ba8 100644 --- a/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py @@ -4,6 +4,7 @@ from requests import post +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( @@ -22,12 +23,18 @@ class NvidiaTextEmbeddingModel(TextEmbeddingModel): """ Model class for Nvidia text embedding model. """ - api_base: str = 'https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings' - models: list[str] = ['NV-Embed-QA'] - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + api_base: str = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings" + models: list[str] = ["NV-Embed-QA"] + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -35,34 +42,28 @@ def _invoke(self, model: str, credentials: dict, :param credentials: model credentials :param texts: texts to embed :param user: unique user id + :param input_type: input type :return: embeddings result """ - api_key = credentials['api_key'] + api_key = credentials["api_key"] if model not in self.models: - raise InvokeBadRequestError('Invalid model name') + raise InvokeBadRequestError("Invalid model name") if not api_key: - raise CredentialsValidateFailedError('api_key is required') + raise CredentialsValidateFailedError("api_key is required") url = self.api_base - headers = { - 'Authorization': 'Bearer ' + api_key, - 'Content-Type': 'application/json' - } + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} - data = { - 'model': model, - 'input': texts[0], - 'input_type': 'query' - } + data = {"model": model, "input": texts[0], "input_type": "query"} try: response = post(url, headers=headers, data=dumps(data)) except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: try: resp = response.json() - msg = resp['detail'] + msg = resp["detail"] if response.status_code == 401: raise InvokeAuthorizationError(msg) elif response.status_code == 429: @@ -72,23 +73,21 @@ def _invoke(self, model: str, credentials: dict, else: raise InvokeError(msg) except JSONDecodeError as e: - raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") + raise InvokeServerUnavailableError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) try: resp = response.json() - embeddings = resp['data'] - usage = resp['usage'] + embeddings = resp["data"] + usage = resp["usage"] except Exception as e: raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens']) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) result = TextEmbeddingResult( - model=model, - embeddings=[[ - float(data) for data in x['embedding'] - ] for x in embeddings], - usage=usage + model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage ) return result @@ -117,30 +116,20 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError: - raise CredentialsValidateFailedError('Invalid api key') + raise CredentialsValidateFailedError("Invalid api key") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -152,10 +141,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -166,7 +152,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py b/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py index f7b849fbe23b7f..6ff380bdd99c8b 100644 --- a/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py +++ b/api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py @@ -9,4 +9,5 @@ class NVIDIANIMProvider(OAIAPICompatLargeLanguageModel): """ Model class for NVIDIA NIM large language model. """ + pass diff --git a/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py b/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py index 25ab3e8e20f021..ad890ada22abc8 100644 --- a/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py +++ b/api/core/model_runtime/model_providers/nvidia_nim/nvidia_nim.py @@ -6,6 +6,5 @@ class NVIDIANIMProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/oci/__init__.py b/api/core/model_runtime/model_providers/oci/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/oci/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/oci/_assets/icon_l_en.svg new file mode 100644 index 00000000000000..0981dfcff28c78 --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/_assets/icon_l_en.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/oci/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/oci/_assets/icon_s_en.svg new file mode 100644 index 00000000000000..0981dfcff28c78 --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/_assets/icon_s_en.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/oci/llm/cohere.command-r-16k.yaml b/api/core/model_runtime/model_providers/oci/llm/cohere.command-r-16k.yaml new file mode 100644 index 00000000000000..eb60cbcd90f5cd --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/llm/cohere.command-r-16k.yaml @@ -0,0 +1,52 @@ +model: cohere.command-r-16k +label: + en_US: cohere.command-r-16k v1.2 +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + default: 1 + max: 1.0 + - name: topP + use_template: top_p + default: 0.75 + min: 0 + max: 1 + - name: topK + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + default: 0 + min: 0 + max: 500 + - name: presencePenalty + use_template: presence_penalty + min: 0 + max: 1 + default: 0 + - name: frequencyPenalty + use_template: frequency_penalty + min: 0 + max: 1 + default: 0 + - name: maxTokens + use_template: max_tokens + default: 600 + max: 4000 +pricing: + input: '0.004' + output: '0.004' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/oci/llm/cohere.command-r-plus.yaml b/api/core/model_runtime/model_providers/oci/llm/cohere.command-r-plus.yaml new file mode 100644 index 00000000000000..df31b0d0df355d --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/llm/cohere.command-r-plus.yaml @@ -0,0 +1,52 @@ +model: cohere.command-r-plus +label: + en_US: cohere.command-r-plus v1.2 +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + default: 1 + max: 1.0 + - name: topP + use_template: top_p + default: 0.75 + min: 0 + max: 1 + - name: topK + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + default: 0 + min: 0 + max: 500 + - name: presencePenalty + use_template: presence_penalty + min: 0 + max: 1 + default: 0 + - name: frequencyPenalty + use_template: frequency_penalty + min: 0 + max: 1 + default: 0 + - name: maxTokens + use_template: max_tokens + default: 600 + max: 4000 +pricing: + input: '0.0219' + output: '0.0219' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/oci/llm/llm.py b/api/core/model_runtime/model_providers/oci/llm/llm.py new file mode 100644 index 00000000000000..1e1fc5b3ea89aa --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/llm/llm.py @@ -0,0 +1,469 @@ +import base64 +import copy +import json +import logging +from collections.abc import Generator +from typing import Optional, Union + +import oci +from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageContentType, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + +logger = logging.getLogger(__name__) + +request_template = { + "compartmentId": "", + "servingMode": {"modelId": "cohere.command-r-plus", "servingType": "ON_DEMAND"}, + "chatRequest": { + "apiFormat": "COHERE", + # "preambleOverride": "You are a helpful assistant.", + # "message": "Hello!", + # "chatHistory": [], + "maxTokens": 600, + "isStream": False, + "frequencyPenalty": 0, + "presencePenalty": 0, + "temperature": 1, + "topP": 0.75, + }, +} +oci_config_template = { + "user": "", + "fingerprint": "", + "tenancy": "", + "region": "", + "compartment_id": "", + "key_content": "", +} + + +class OCILargeLanguageModel(LargeLanguageModel): + # https://docs.oracle.com/en-us/iaas/Content/generative-ai/pretrained-models.htm + _supported_models = { + "meta.llama-3-70b-instruct": { + "system": True, + "multimodal": False, + "tool_call": False, + "stream_tool_call": False, + }, + "cohere.command-r-16k": { + "system": True, + "multimodal": False, + "tool_call": True, + "stream_tool_call": False, + }, + "cohere.command-r-plus": { + "system": True, + "multimodal": False, + "tool_call": True, + "stream_tool_call": False, + }, + } + + def _is_tool_call_supported(self, model_id: str, stream: bool = False) -> bool: + feature = self._supported_models.get(model_id) + if not feature: + return False + return feature["stream_tool_call"] if stream else feature["tool_call"] + + def _is_multimodal_supported(self, model_id: str) -> bool: + feature = self._supported_models.get(model_id) + if not feature: + return False + return feature["multimodal"] + + def _is_system_prompt_supported(self, model_id: str) -> bool: + feature = self._supported_models.get(model_id) + if not feature: + return False + return feature["system"] + + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + # print("model"+"*"*20) + # print(model) + # print("credentials"+"*"*20) + # print(credentials) + # print("model_parameters"+"*"*20) + # print(model_parameters) + # print("prompt_messages"+"*"*200) + # print(prompt_messages) + # print("tools"+"*"*20) + # print(tools) + + # invoke model + return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return:md = genai.GenerativeModel(model) + """ + prompt = self._convert_messages_to_prompt(prompt_messages) + + return self._get_num_tokens_by_gpt2(prompt) + + def get_num_characters( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return:md = genai.GenerativeModel(model) + """ + prompt = self._convert_messages_to_prompt(prompt_messages) + + return len(prompt) + + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: + """ + :param messages: List of PromptMessage to combine. + :return: Combined string with necessary human_prompt and ai_prompt tags. + """ + messages = messages.copy() # don't mutate the original list + + text = "".join(self._convert_one_message_to_text(message) for message in messages) + + return text.rstrip() + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + # Setup basic variables + # Auth Config + try: + ping_message = SystemPromptMessage(content="ping") + self._generate(model, credentials, [ping_message], {"maxTokens": 5}) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: credentials kwargs + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + # config_kwargs = model_parameters.copy() + # config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None) + # if stop: + # config_kwargs["stop_sequences"] = stop + + # initialize client + # ref: https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai-inference/20231130/ChatResult/Chat + oci_config = copy.deepcopy(oci_config_template) + if "oci_config_content" in credentials: + oci_config_content = base64.b64decode(credentials.get("oci_config_content")).decode("utf-8") + config_items = oci_config_content.split("/") + if len(config_items) != 5: + raise CredentialsValidateFailedError( + "oci_config_content should be base64.b64encode(" + "'user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))" + ) + oci_config["user"] = config_items[0] + oci_config["fingerprint"] = config_items[1] + oci_config["tenancy"] = config_items[2] + oci_config["region"] = config_items[3] + oci_config["compartment_id"] = config_items[4] + else: + raise CredentialsValidateFailedError("need to set oci_config_content in credentials ") + if "oci_key_content" in credentials: + oci_key_content = base64.b64decode(credentials.get("oci_key_content")).decode("utf-8") + oci_config["key_content"] = oci_key_content.encode(encoding="utf-8") + else: + raise CredentialsValidateFailedError("need to set oci_config_content in credentials ") + + # oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile')) + compartment_id = oci_config["compartment_id"] + client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config) + # call embedding model + request_args = copy.deepcopy(request_template) + request_args["compartmentId"] = compartment_id + request_args["servingMode"]["modelId"] = model + + chat_history = [] + system_prompts = [] + # if "meta.llama" in model: + # request_args["chatRequest"]["apiFormat"] = "GENERIC" + request_args["chatRequest"]["maxTokens"] = model_parameters.pop("maxTokens", 600) + request_args["chatRequest"].update(model_parameters) + frequency_penalty = model_parameters.get("frequencyPenalty", 0) + presence_penalty = model_parameters.get("presencePenalty", 0) + if frequency_penalty > 0 and presence_penalty > 0: + raise InvokeBadRequestError("Cannot set both frequency penalty and presence penalty") + + # for msg in prompt_messages: # makes message roles strictly alternating + # content = self._format_message_to_glm_content(msg) + # if history and history[-1]["role"] == content["role"]: + # history[-1]["parts"].extend(content["parts"]) + # else: + # history.append(content) + + # temporary not implement the tool call function + valid_value = self._is_tool_call_supported(model, stream) + if tools is not None and len(tools) > 0: + if not valid_value: + raise InvokeBadRequestError("Does not support function calling") + if model.startswith("cohere"): + # print("run cohere " * 10) + for message in prompt_messages[:-1]: + text = "" + if isinstance(message.content, str): + text = message.content + if isinstance(message, UserPromptMessage): + chat_history.append({"role": "USER", "message": text}) + else: + chat_history.append({"role": "CHATBOT", "message": text}) + if isinstance(message, SystemPromptMessage): + if isinstance(message.content, str): + system_prompts.append(message.content) + args = { + "apiFormat": "COHERE", + "preambleOverride": " ".join(system_prompts), + "message": prompt_messages[-1].content, + "chatHistory": chat_history, + } + request_args["chatRequest"].update(args) + elif model.startswith("meta"): + # print("run meta " * 10) + meta_messages = [] + for message in prompt_messages: + text = message.content + meta_messages.append({"role": message.role.name, "content": [{"type": "TEXT", "text": text}]}) + args = {"apiFormat": "GENERIC", "messages": meta_messages, "numGenerations": 1, "topK": -1} + request_args["chatRequest"].update(args) + + if stream: + request_args["chatRequest"]["isStream"] = True + # print("final request" + "|" * 20) + # print(request_args) + response = client.chat(request_args) + # print(vars(response)) + + if stream: + return self._handle_generate_stream_response(model, credentials, response, prompt_messages) + + return self._handle_generate_response(model, credentials, response, prompt_messages) + + def _handle_generate_response( + self, model: str, credentials: dict, response: BaseChatResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: + """ + Handle llm response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :return: llm response + """ + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage(content=response.data.chat_response.text) + + # calculate num tokens + prompt_tokens = self.get_num_characters(model, credentials, prompt_messages) + completion_tokens = self.get_num_characters(model, credentials, [assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + # transform response + result = LLMResult( + model=model, + prompt_messages=prompt_messages, + message=assistant_prompt_message, + usage=usage, + ) + + return result + + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: BaseChatResponse, prompt_messages: list[PromptMessage] + ) -> Generator: + """ + Handle llm stream response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :return: llm response chunk generator result + """ + index = -1 + events = response.data.events() + for stream in events: + chunk = json.loads(stream.data) + # print(chunk) + # chunk: {'apiFormat': 'COHERE', 'text': 'Hello'} + + # for chunk in response: + # for part in chunk.parts: + # if part.function_call: + # assistant_prompt_message.tool_calls = [ + # AssistantPromptMessage.ToolCall( + # id=part.function_call.name, + # type='function', + # function=AssistantPromptMessage.ToolCall.ToolCallFunction( + # name=part.function_call.name, + # arguments=json.dumps(dict(part.function_call.args.items())) + # ) + # ) + # ] + + if "finishReason" not in chunk: + assistant_prompt_message = AssistantPromptMessage(content="") + if model.startswith("cohere"): + if chunk["text"]: + assistant_prompt_message.content += chunk["text"] + elif model.startswith("meta"): + assistant_prompt_message.content += chunk["message"]["content"][0]["text"] + index += 1 + # transform assistant message to prompt message + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), + ) + else: + # calculate num tokens + prompt_tokens = self.get_num_characters(model, credentials, prompt_messages) + completion_tokens = self.get_num_characters(model, credentials, [assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message, + finish_reason=str(chunk["finishReason"]), + usage=usage, + ), + ) + + def _convert_one_message_to_text(self, message: PromptMessage) -> str: + """ + Convert a single message to a string. + + :param message: PromptMessage to convert. + :return: String representation of the message. + """ + human_prompt = "\n\nuser:" + ai_prompt = "\n\nmodel:" + + content = message.content + if isinstance(content, list): + content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE) + + if isinstance(message, UserPromptMessage): + message_text = f"{human_prompt} {content}" + elif isinstance(message, AssistantPromptMessage): + message_text = f"{ai_prompt} {content}" + elif isinstance(message, SystemPromptMessage | ToolPromptMessage): + message_text = f"{human_prompt} {content}" + else: + raise ValueError(f"Got unknown type {message}") + + return message_text + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [], + InvokeServerUnavailableError: [], + InvokeRateLimitError: [], + InvokeAuthorizationError: [], + InvokeBadRequestError: [], + } diff --git a/api/core/model_runtime/model_providers/oci/llm/meta.llama-3-70b-instruct.yaml b/api/core/model_runtime/model_providers/oci/llm/meta.llama-3-70b-instruct.yaml new file mode 100644 index 00000000000000..dd5be107c07570 --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/llm/meta.llama-3-70b-instruct.yaml @@ -0,0 +1,51 @@ +model: meta.llama-3-70b-instruct +label: + zh_Hans: meta.llama-3-70b-instruct + en_US: meta.llama-3-70b-instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + default: 1 + max: 2.0 + - name: topP + use_template: top_p + default: 0.75 + min: 0 + max: 1 + - name: topK + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + default: 0 + min: 0 + max: 500 + - name: presencePenalty + use_template: presence_penalty + min: -2 + max: 2 + default: 0 + - name: frequencyPenalty + use_template: frequency_penalty + min: -2 + max: 2 + default: 0 + - name: maxTokens + use_template: max_tokens + default: 600 + max: 8000 +pricing: + input: '0.015' + output: '0.015' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/oci/oci.py b/api/core/model_runtime/model_providers/oci/oci.py new file mode 100644 index 00000000000000..e182d2d0439d77 --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/oci.py @@ -0,0 +1,28 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class OCIGENAIProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.LLM) + + # Use `cohere.command-r-plus` model for validate, + model_instance.validate_credentials(model="cohere.command-r-plus", credentials=credentials) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/oci/oci.yaml b/api/core/model_runtime/model_providers/oci/oci.yaml new file mode 100644 index 00000000000000..f2f23e18f12073 --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/oci.yaml @@ -0,0 +1,42 @@ +provider: oci +label: + en_US: OCIGenerativeAI +description: + en_US: Models provided by OCI, such as Cohere Command R and Cohere Command R+. + zh_Hans: OCI 提供的模型,例如 Cohere Command R 和 Cohere Command R+。 +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.svg +background: "#FFFFFF" +help: + title: + en_US: Get your API Key from OCI + zh_Hans: 从 OCI 获取 API Key + url: + en_US: https://docs.cloud.oracle.com/Content/API/Concepts/sdkconfig.htm +supported_model_types: + - llm + - text-embedding + #- rerank +configurate_methods: + - predefined-model + #- customizable-model +provider_credential_schema: + credential_form_schemas: + - variable: oci_config_content + label: + en_US: oci api key config file's content + type: text-input + required: true + placeholder: + zh_Hans: 在此输入您的 oci api key config 文件的内容(base64.b64encode("user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid".encode('utf-8')) ) + en_US: Enter your oci api key config file's content(base64.b64encode("user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid".encode('utf-8')) ) + - variable: oci_key_content + label: + en_US: oci api key file's content + type: text-input + required: true + placeholder: + zh_Hans: 在此输入您的 oci api key 文件的内容(base64.b64encode("pem file content".encode('utf-8'))) + en_US: Enter your oci api key file's content(base64.b64encode("pem file content".encode('utf-8'))) diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/__init__.py b/api/core/model_runtime/model_providers/oci/text_embedding/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/_position.yaml b/api/core/model_runtime/model_providers/oci/text_embedding/_position.yaml new file mode 100644 index 00000000000000..149f1e3797850f --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/text_embedding/_position.yaml @@ -0,0 +1,5 @@ +- cohere.embed-english-light-v2.0 +- cohere.embed-english-light-v3.0 +- cohere.embed-english-v3.0 +- cohere.embed-multilingual-light-v3.0 +- cohere.embed-multilingual-v3.0 diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-light-v2.0.yaml b/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-light-v2.0.yaml new file mode 100644 index 00000000000000..259d5b45b7a2f1 --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-light-v2.0.yaml @@ -0,0 +1,9 @@ +model: cohere.embed-english-light-v2.0 +model_type: text-embedding +model_properties: + context_size: 1024 + max_chunks: 48 +pricing: + input: '0.001' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-light-v3.0.yaml b/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-light-v3.0.yaml new file mode 100644 index 00000000000000..065e7474c0bb97 --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-light-v3.0.yaml @@ -0,0 +1,9 @@ +model: cohere.embed-english-light-v3.0 +model_type: text-embedding +model_properties: + context_size: 384 + max_chunks: 48 +pricing: + input: '0.001' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-v3.0.yaml b/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-v3.0.yaml new file mode 100644 index 00000000000000..3e2deea16a1d0b --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-v3.0.yaml @@ -0,0 +1,9 @@ +model: cohere.embed-english-v3.0 +model_type: text-embedding +model_properties: + context_size: 1024 + max_chunks: 48 +pricing: + input: '0.001' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-multilingual-light-v3.0.yaml b/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-multilingual-light-v3.0.yaml new file mode 100644 index 00000000000000..0d2b892c64290e --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-multilingual-light-v3.0.yaml @@ -0,0 +1,9 @@ +model: cohere.embed-multilingual-light-v3.0 +model_type: text-embedding +model_properties: + context_size: 384 + max_chunks: 48 +pricing: + input: '0.001' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-multilingual-v3.0.yaml b/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-multilingual-v3.0.yaml new file mode 100644 index 00000000000000..9ebe260b32875b --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-multilingual-v3.0.yaml @@ -0,0 +1,9 @@ +model: cohere.embed-multilingual-v3.0 +model_type: text-embedding +model_properties: + context_size: 1024 + max_chunks: 48 +pricing: + input: '0.001' + unit: '0.0001' + currency: USD diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py new file mode 100644 index 00000000000000..50fa63768c241b --- /dev/null +++ b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py @@ -0,0 +1,224 @@ +import base64 +import copy +import time +from typing import Optional + +import numpy as np +import oci + +from core.entities.embedding_type import EmbeddingInputType +from core.model_runtime.entities.model_entities import PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel + +request_template = { + "compartmentId": "", + "servingMode": {"modelId": "cohere.embed-english-light-v3.0", "servingType": "ON_DEMAND"}, + "truncate": "NONE", + "inputs": [""], +} +oci_config_template = { + "user": "", + "fingerprint": "", + "tenancy": "", + "region": "", + "compartment_id": "", + "key_content": "", +} + + +class OCITextEmbeddingModel(TextEmbeddingModel): + """ + Model class for Cohere text embedding model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :param input_type: input type + :return: embeddings result + """ + # get model properties + context_size = self._get_context_size(model, credentials) + max_chunks = self._get_max_chunks(model, credentials) + + inputs = [] + indices = [] + used_tokens = 0 + + for i, text in enumerate(texts): + # Here token count is only an approximation based on the GPT2 tokenizer + num_tokens = self._get_num_tokens_by_gpt2(text) + + if num_tokens >= context_size: + cutoff = int(len(text) * (np.floor(context_size / num_tokens))) + # if num tokens is larger than context length, only use the start + inputs.append(text[0:cutoff]) + else: + inputs.append(text) + indices += [i] + + batched_embeddings = [] + _iter = range(0, len(inputs), max_chunks) + + for i in _iter: + # call embedding model + embeddings_batch, embedding_used_tokens = self._embedding_invoke( + model=model, credentials=credentials, texts=inputs[i : i + max_chunks] + ) + + used_tokens += embedding_used_tokens + batched_embeddings += embeddings_batch + + # calc usage + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) + + return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model) + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + return sum(self._get_num_tokens_by_gpt2(text) for text in texts) + + def get_num_characters(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + characters = 0 + for text in texts: + characters += len(text) + return characters + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + # call embedding model + self._embedding_invoke(model=model, credentials=credentials, texts=["ping"]) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> tuple[list[list[float]], int]: + """ + Invoke embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: embeddings and used tokens + """ + + # oci + # initialize client + oci_config = copy.deepcopy(oci_config_template) + if "oci_config_content" in credentials: + oci_config_content = base64.b64decode(credentials.get("oci_config_content")).decode("utf-8") + config_items = oci_config_content.split("/") + if len(config_items) != 5: + raise CredentialsValidateFailedError( + "oci_config_content should be base64.b64encode(" + "'user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))" + ) + oci_config["user"] = config_items[0] + oci_config["fingerprint"] = config_items[1] + oci_config["tenancy"] = config_items[2] + oci_config["region"] = config_items[3] + oci_config["compartment_id"] = config_items[4] + else: + raise CredentialsValidateFailedError("need to set oci_config_content in credentials ") + if "oci_key_content" in credentials: + oci_key_content = base64.b64decode(credentials.get("oci_key_content")).decode("utf-8") + oci_config["key_content"] = oci_key_content.encode(encoding="utf-8") + else: + raise CredentialsValidateFailedError("need to set oci_config_content in credentials ") + # oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile')) + compartment_id = oci_config["compartment_id"] + client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config) + # call embedding model + request_args = copy.deepcopy(request_template) + request_args["compartmentId"] = compartment_id + request_args["servingMode"]["modelId"] = model + request_args["inputs"] = texts + response = client.embed_text(request_args) + return response.data.embeddings, self.get_num_characters(model=model, credentials=credentials, texts=texts) + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at, + ) + + return usage + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], + } diff --git a/api/core/model_runtime/model_providers/ollama/llm/llm.py b/api/core/model_runtime/model_providers/ollama/llm/llm.py index 42a588e3dd5df1..a7ea53e0e99c5f 100644 --- a/api/core/model_runtime/model_providers/ollama/llm/llm.py +++ b/api/core/model_runtime/model_providers/ollama/llm/llm.py @@ -121,9 +121,7 @@ def get_num_tokens( text = "" for message_content in first_prompt_message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast( - TextPromptMessageContent, message_content - ) + message_content = cast(TextPromptMessageContent, message_content) text = message_content.data break return self._get_num_tokens_by_gpt2(text) @@ -145,13 +143,9 @@ def validate_credentials(self, model: str, credentials: dict) -> None: stream=False, ) except InvokeError as ex: - raise CredentialsValidateFailedError( - f"An error occurred during credentials validation: {ex.description}" - ) + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {ex.description}") except Exception as ex: - raise CredentialsValidateFailedError( - f"An error occurred during credentials validation: {str(ex)}" - ) + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}") def _generate( self, @@ -201,9 +195,7 @@ def _generate( if completion_type is LLMMode.CHAT: endpoint_url = urljoin(endpoint_url, "api/chat") - data["messages"] = [ - self._convert_prompt_message_to_dict(m) for m in prompt_messages - ] + data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] else: endpoint_url = urljoin(endpoint_url, "api/generate") first_prompt_message = prompt_messages[0] @@ -216,14 +208,10 @@ def _generate( images = [] for message_content in first_prompt_message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast( - TextPromptMessageContent, message_content - ) + message_content = cast(TextPromptMessageContent, message_content) text = message_content.data elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast( - ImagePromptMessageContent, message_content - ) + message_content = cast(ImagePromptMessageContent, message_content) image_data = re.sub( r"^data:image\/[a-zA-Z]+;base64,", "", @@ -235,24 +223,16 @@ def _generate( data["images"] = images # send a post request to validate the credentials - response = requests.post( - endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream) response.encoding = "utf-8" if response.status_code != 200: - raise InvokeError( - f"API request failed with status code {response.status_code}: {response.text}" - ) + raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}") if stream: - return self._handle_generate_stream_response( - model, credentials, completion_type, response, prompt_messages - ) + return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages) - return self._handle_generate_response( - model, credentials, completion_type, response, prompt_messages - ) + return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages) def _handle_generate_response( self, @@ -292,9 +272,7 @@ def _handle_generate_response( completion_tokens = self._get_num_tokens_by_gpt2(assistant_message.content) # transform usage - usage = self._calc_response_usage( - model, credentials, prompt_tokens, completion_tokens - ) + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) # transform response result = LLMResult( @@ -335,9 +313,7 @@ def create_final_llm_result_chunk( completion_tokens = self._get_num_tokens_by_gpt2(full_text) # transform usage - usage = self._calc_response_usage( - model, credentials, prompt_tokens, completion_tokens - ) + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) return LLMResultChunk( model=model, @@ -388,21 +364,24 @@ def create_final_llm_result_chunk( if chunk_json["done"]: # calculate num tokens - if "prompt_eval_count" in chunk_json and "eval_count" in chunk_json: - # transform usage + if "prompt_eval_count" in chunk_json: prompt_tokens = chunk_json["prompt_eval_count"] - completion_tokens = chunk_json["eval_count"] else: - # calculate num tokens - prompt_tokens = self._get_num_tokens_by_gpt2( - prompt_messages[0].content - ) - completion_tokens = self._get_num_tokens_by_gpt2(full_text) + prompt_message_content = prompt_messages[0].content + if isinstance(prompt_message_content, str): + prompt_tokens = self._get_num_tokens_by_gpt2(prompt_message_content) + else: + content_text = "" + for message_content in prompt_message_content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(TextPromptMessageContent, message_content) + content_text += message_content.data + prompt_tokens = self._get_num_tokens_by_gpt2(content_text) + + completion_tokens = chunk_json.get("eval_count", self._get_num_tokens_by_gpt2(full_text)) # transform usage - usage = self._calc_response_usage( - model, credentials, prompt_tokens, completion_tokens - ) + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) yield LLMResultChunk( model=chunk_json["model"], @@ -439,17 +418,11 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: images = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - message_content = cast( - TextPromptMessageContent, message_content - ) + message_content = cast(TextPromptMessageContent, message_content) text = message_content.data elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast( - ImagePromptMessageContent, message_content - ) - image_data = re.sub( - r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data - ) + message_content = cast(ImagePromptMessageContent, message_content) + image_data = re.sub(r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data) images.append(image_data) message_dict = {"role": "user", "content": text, "images": images} @@ -479,9 +452,7 @@ def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int: return num_tokens - def get_customizable_model_schema( - self, model: str, credentials: dict - ) -> AIModelEntity: + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ Get customizable model schema. @@ -502,20 +473,19 @@ def get_customizable_model_schema( fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ ModelPropertyKey.MODE: credentials.get("mode"), - ModelPropertyKey.CONTEXT_SIZE: int( - credentials.get("context_size", 4096) - ), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)), }, parameter_rules=[ ParameterRule( name=DefaultParameterName.TEMPERATURE.value, use_template=DefaultParameterName.TEMPERATURE.value, - label=I18nObject(en_US="Temperature"), + label=I18nObject(en_US="Temperature", zh_Hans="温度"), type=ParameterType.FLOAT, help=I18nObject( en_US="The temperature of the model. " "Increasing the temperature will make the model answer " - "more creatively. (Default: 0.8)" + "more creatively. (Default: 0.8)", + zh_Hans="模型的温度。增加温度将使模型的回答更具创造性。(默认值:0.8)", ), default=0.1, min=0, @@ -524,12 +494,13 @@ def get_customizable_model_schema( ParameterRule( name=DefaultParameterName.TOP_P.value, use_template=DefaultParameterName.TOP_P.value, - label=I18nObject(en_US="Top P"), + label=I18nObject(en_US="Top P", zh_Hans="Top P"), type=ParameterType.FLOAT, help=I18nObject( en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to " "more diverse text, while a lower value (e.g., 0.5) will generate more " - "focused and conservative text. (Default: 0.9)" + "focused and conservative text. (Default: 0.9)", + zh_Hans="与top-k一起工作。较高的值(例如,0.95)会导致生成更多样化的文本,而较低的值(例如,0.5)会生成更专注和保守的文本。(默认值:0.9)", ), default=0.9, min=0, @@ -537,12 +508,13 @@ def get_customizable_model_schema( ), ParameterRule( name="top_k", - label=I18nObject(en_US="Top K"), + label=I18nObject(en_US="Top K", zh_Hans="Top K"), type=ParameterType.INT, help=I18nObject( en_US="Reduces the probability of generating nonsense. " "A higher value (e.g. 100) will give more diverse answers, " - "while a lower value (e.g. 10) will be more conservative. (Default: 40)" + "while a lower value (e.g. 10) will be more conservative. (Default: 40)", + zh_Hans="减少生成无意义内容的可能性。较高的值(例如100)将提供更多样化的答案,而较低的值(例如10)将更为保守。(默认值:40)", ), min=1, max=100, @@ -554,7 +526,8 @@ def get_customizable_model_schema( help=I18nObject( en_US="Sets how strongly to penalize repetitions. " "A higher value (e.g., 1.5) will penalize repetitions more strongly, " - "while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)" + "while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)", + zh_Hans="设置对重复内容的惩罚强度。一个较高的值(例如,1.5)会更强地惩罚重复内容,而一个较低的值(例如,0.9)则会相对宽容。(默认值:1.1)", ), min=-2, max=2, @@ -562,134 +535,150 @@ def get_customizable_model_schema( ParameterRule( name="num_predict", use_template="max_tokens", - label=I18nObject(en_US="Num Predict"), + label=I18nObject(en_US="Num Predict", zh_Hans="最大令牌数预测"), type=ParameterType.INT, help=I18nObject( en_US="Maximum number of tokens to predict when generating text. " - "(Default: 128, -1 = infinite generation, -2 = fill context)" - ), - default=( - 512 if int(credentials.get("max_tokens", 4096)) >= 768 else 128 + "(Default: 128, -1 = infinite generation, -2 = fill context)", + zh_Hans="生成文本时预测的最大令牌数。(默认值:128,-1 = 无限生成,-2 = 填充上下文)", ), + default=(512 if int(credentials.get("max_tokens", 4096)) >= 768 else 128), min=-2, max=int(credentials.get("max_tokens", 4096)), ), ParameterRule( name="mirostat", - label=I18nObject(en_US="Mirostat sampling"), + label=I18nObject(en_US="Mirostat sampling", zh_Hans="Mirostat 采样"), type=ParameterType.INT, help=I18nObject( en_US="Enable Mirostat sampling for controlling perplexity. " - "(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)" + "(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", + zh_Hans="启用 Mirostat 采样以控制困惑度。" + "(默认值:0,0 = 禁用,1 = Mirostat,2 = Mirostat 2.0)", ), min=0, max=2, ), ParameterRule( name="mirostat_eta", - label=I18nObject(en_US="Mirostat Eta"), + label=I18nObject(en_US="Mirostat Eta", zh_Hans="学习率"), type=ParameterType.FLOAT, help=I18nObject( en_US="Influences how quickly the algorithm responds to feedback from " "the generated text. A lower learning rate will result in slower adjustments, " "while a higher learning rate will make the algorithm more responsive. " - "(Default: 0.1)" + "(Default: 0.1)", + zh_Hans="影响算法对生成文本反馈响应的速度。较低的学习率会导致调整速度变慢,而较高的学习率会使得算法更加灵敏。(默认值:0.1)", ), precision=1, ), ParameterRule( name="mirostat_tau", - label=I18nObject(en_US="Mirostat Tau"), + label=I18nObject(en_US="Mirostat Tau", zh_Hans="文本连贯度"), type=ParameterType.FLOAT, help=I18nObject( en_US="Controls the balance between coherence and diversity of the output. " - "A lower value will result in more focused and coherent text. (Default: 5.0)" + "A lower value will result in more focused and coherent text. (Default: 5.0)", + zh_Hans="控制输出的连贯性和多样性之间的平衡。较低的值会导致更专注和连贯的文本。(默认值:5.0)", ), precision=1, ), ParameterRule( name="num_ctx", - label=I18nObject(en_US="Size of context window"), + label=I18nObject(en_US="Size of context window", zh_Hans="上下文窗口大小"), type=ParameterType.INT, help=I18nObject( - en_US="Sets the size of the context window used to generate the next token. " - "(Default: 2048)" + en_US="Sets the size of the context window used to generate the next token. (Default: 2048)", + zh_Hans="设置用于生成下一个标记的上下文窗口大小。(默认值:2048)", ), default=2048, min=1, ), ParameterRule( - name='num_gpu', - label=I18nObject(en_US="GPU Layers"), + name="num_gpu", + label=I18nObject(en_US="GPU Layers", zh_Hans="GPU 层数"), type=ParameterType.INT, - help=I18nObject(en_US="The number of layers to offload to the GPU(s). " - "On macOS it defaults to 1 to enable metal support, 0 to disable." - "As long as a model fits into one gpu it stays in one. " - "It does not set the number of GPU(s). "), + help=I18nObject( + en_US="The number of layers to offload to the GPU(s). " + "On macOS it defaults to 1 to enable metal support, 0 to disable." + "As long as a model fits into one gpu it stays in one. " + "It does not set the number of GPU(s). ", + zh_Hans="加载到 GPU 的层数。在 macOS 上,默认为 1 以启用 Metal 支持,设置为 0 则禁用。" + "只要模型适合一个 GPU,它就保留在其中。它不设置 GPU 的数量。", + ), min=-1, - default=1 + default=1, ), ParameterRule( name="num_thread", - label=I18nObject(en_US="Num Thread"), + label=I18nObject(en_US="Num Thread", zh_Hans="线程数"), type=ParameterType.INT, help=I18nObject( en_US="Sets the number of threads to use during computation. " "By default, Ollama will detect this for optimal performance. " "It is recommended to set this value to the number of physical CPU cores " - "your system has (as opposed to the logical number of cores)." + "your system has (as opposed to the logical number of cores).", + zh_Hans="设置计算过程中使用的线程数。默认情况下,Ollama会检测以获得最佳性能。建议将此值设置为系统拥有的物理CPU核心数(而不是逻辑核心数)。", ), min=1, ), ParameterRule( name="repeat_last_n", - label=I18nObject(en_US="Repeat last N"), + label=I18nObject(en_US="Repeat last N", zh_Hans="回溯内容"), type=ParameterType.INT, help=I18nObject( en_US="Sets how far back for the model to look back to prevent repetition. " - "(Default: 64, 0 = disabled, -1 = num_ctx)" + "(Default: 64, 0 = disabled, -1 = num_ctx)", + zh_Hans="设置模型回溯多远的内容以防止重复。(默认值:64,0 = 禁用,-1 = num_ctx)", ), min=-1, ), ParameterRule( name="tfs_z", - label=I18nObject(en_US="TFS Z"), + label=I18nObject(en_US="TFS Z", zh_Hans="减少标记影响"), type=ParameterType.FLOAT, help=I18nObject( en_US="Tail free sampling is used to reduce the impact of less probable tokens " "from the output. A higher value (e.g., 2.0) will reduce the impact more, " - "while a value of 1.0 disables this setting. (default: 1)" + "while a value of 1.0 disables this setting. (default: 1)", + zh_Hans="用于减少输出中不太可能的标记的影响。较高的值(例如,2.0)会更多地减少这种影响,而1.0的值则会禁用此设置。(默认值:1)", ), precision=1, ), ParameterRule( name="seed", - label=I18nObject(en_US="Seed"), + label=I18nObject(en_US="Seed", zh_Hans="随机数种子"), type=ParameterType.INT, help=I18nObject( en_US="Sets the random number seed to use for generation. Setting this to " "a specific number will make the model generate the same text for " - "the same prompt. (Default: 0)" + "the same prompt. (Default: 0)", + zh_Hans="设置用于生成的随机数种子。将此设置为特定数字将使模型对相同的提示生成相同的文本。(默认值:0)", ), ), ParameterRule( name="keep_alive", - label=I18nObject(en_US="Keep Alive"), + label=I18nObject(en_US="Keep Alive", zh_Hans="模型存活时间"), type=ParameterType.STRING, help=I18nObject( en_US="Sets how long the model is kept in memory after generating a response. " - "This must be a duration string with a unit (e.g., '10m' for 10 minutes or '24h' for 24 hours). " - "A negative number keeps the model loaded indefinitely, and '0' unloads the model immediately after generating a response. " - "Valid time units are 's','m','h'. (Default: 5m)" + "This must be a duration string with a unit (e.g., '10m' for 10 minutes or '24h' for 24 hours)." + " A negative number keeps the model loaded indefinitely, and '0' unloads the model" + " immediately after generating a response." + " Valid time units are 's','m','h'. (Default: 5m)", + zh_Hans="设置模型在生成响应后在内存中保留的时间。" + "这必须是一个带有单位的持续时间字符串(例如,'10m' 表示10分钟,'24h' 表示24小时)。" + "负数表示无限期地保留模型,'0'表示在生成响应后立即卸载模型。" + "有效的时间单位有 's'(秒)、'm'(分钟)、'h'(小时)。(默认值:5m)", ), ), ParameterRule( name="format", - label=I18nObject(en_US="Format"), + label=I18nObject(en_US="Format", zh_Hans="返回格式"), type=ParameterType.STRING, help=I18nObject( - en_US="the format to return a response in." - " Currently the only accepted value is json." + en_US="the format to return a response in. Currently the only accepted value is json.", + zh_Hans="返回响应的格式。目前唯一接受的值是json。", ), options=["json"], ), diff --git a/api/core/model_runtime/model_providers/ollama/ollama.py b/api/core/model_runtime/model_providers/ollama/ollama.py index f8a17b98a0d677..115280193a5ed6 100644 --- a/api/core/model_runtime/model_providers/ollama/ollama.py +++ b/api/core/model_runtime/model_providers/ollama/ollama.py @@ -6,7 +6,6 @@ class OpenAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py index 9e26d35afc9437..a16c91cd7ef81e 100644 --- a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py @@ -8,6 +8,7 @@ import numpy as np import requests +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, @@ -37,9 +38,14 @@ class OllamaEmbeddingModel(TextEmbeddingModel): Model class for an Ollama text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -47,19 +53,18 @@ def _invoke(self, model: str, credentials: dict, :param credentials: model credentials :param texts: texts to embed :param user: unique user id + :param input_type: input type :return: embeddings result """ # Prepare headers and payload for the request - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - endpoint_url = credentials.get('base_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("base_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'api/embed') + endpoint_url = urljoin(endpoint_url, "api/embed") # get model properties context_size = self._get_context_size(model, credentials) @@ -67,52 +72,36 @@ def _invoke(self, model: str, credentials: dict, inputs = [] used_tokens = 0 - for i, text in enumerate(texts): + for text in texts: # Here token count is only an approximation based on the GPT2 tokenizer num_tokens = self._get_num_tokens_by_gpt2(text) if num_tokens >= context_size: cutoff = int(np.floor(len(text) * (context_size / num_tokens))) # if num tokens is larger than context length, only use the start - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) # Prepare the payload for the request - payload = { - 'input': inputs, - 'model': model, - } + payload = {"input": inputs, "model": model, "options": {"use_mmap": True}} - # Make the request to the OpenAI API - response = requests.post( - endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300) - ) + # Make the request to the Ollama API + response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) response.raise_for_status() # Raise an exception for HTTP errors response_data = response.json() # Extract embeddings and used tokens from the response - embeddings = response_data['embeddings'] + embeddings = response_data["embeddings"] embedding_used_tokens = self.get_num_tokens(model, credentials, inputs) used_tokens += embedding_used_tokens # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - return TextEmbeddingResult( - embeddings=embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -134,19 +123,15 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - self._invoke( - model=model, - credentials=credentials, - texts=['ping'] - ) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeError as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}') + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {ex.description}") except Exception as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}") def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, @@ -154,15 +139,15 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity @@ -178,10 +163,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -192,7 +174,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -220,10 +202,10 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] ], InvokeServerUnavailableError: [ requests.exceptions.ConnectionError, # Engine Overloaded - requests.exceptions.HTTPError # Server Error + requests.exceptions.HTTPError, # Server Error ], InvokeConnectionError: [ requests.exceptions.ConnectTimeout, # Timeout - requests.exceptions.ReadTimeout # Timeout - ] + requests.exceptions.ReadTimeout, # Timeout + ], } diff --git a/api/core/model_runtime/model_providers/openai/_common.py b/api/core/model_runtime/model_providers/openai/_common.py index 467a51daf2a278..2181bb4f08fd8f 100644 --- a/api/core/model_runtime/model_providers/openai/_common.py +++ b/api/core/model_runtime/model_providers/openai/_common.py @@ -22,7 +22,7 @@ def _to_credential_kwargs(self, credentials: Mapping) -> dict: :return: """ credentials_kwargs = { - "api_key": credentials['openai_api_key'], + "api_key": credentials["openai_api_key"], "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "max_retries": 1, } @@ -31,8 +31,8 @@ def _to_credential_kwargs(self, credentials: Mapping) -> dict: openai_api_base = credentials["openai_api_base"].rstrip("/") credentials_kwargs["base_url"] = openai_api_base + "/v1" - if 'openai_organization' in credentials: - credentials_kwargs['organization'] = credentials['openai_organization'] + if "openai_organization" in credentials: + credentials_kwargs["organization"] = credentials["openai_organization"] return credentials_kwargs diff --git a/api/core/model_runtime/model_providers/openai/llm/_position.yaml b/api/core/model_runtime/model_providers/openai/llm/_position.yaml index ac7313aaa1bf0b..b7c25ecb1602b6 100644 --- a/api/core/model_runtime/model_providers/openai/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/_position.yaml @@ -1,3 +1,4 @@ +- gpt-4o-audio-preview - gpt-4 - gpt-4o - gpt-4o-2024-05-13 @@ -5,6 +6,10 @@ - chatgpt-4o-latest - gpt-4o-mini - gpt-4o-mini-2024-07-18 +- o1-preview +- o1-preview-2024-09-12 +- o1-mini +- o1-mini-2024-09-12 - gpt-4-turbo - gpt-4-turbo-2024-04-09 - gpt-4-turbo-preview diff --git a/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml b/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml index 98e236650c9e73..b47449a49abc2e 100644 --- a/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/chatgpt-4o-latest.yaml @@ -28,7 +28,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0125.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0125.yaml index c1602b2efcb8d3..ffa725ec40f4f5 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0125.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0125.yaml @@ -27,7 +27,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml index 31dc53e89f188a..a1ad07d7129568 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml @@ -31,3 +31,4 @@ pricing: output: '0.002' unit: '0.001' currency: USD +deprecated: true diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml index 56ab965c39c4de..21150fc3a6df61 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml @@ -27,7 +27,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml index 4a0e2ef1911a52..4e302792842415 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml @@ -31,3 +31,4 @@ pricing: output: '0.004' unit: '0.001' currency: USD +deprecated: true diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml index 6eb15e6c0df396..d3a8ee535a9d16 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml @@ -27,7 +27,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4-0125-preview.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4-0125-preview.yaml index 007cfed0f3a6f4..ac4ec5840bd771 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4-0125-preview.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4-0125-preview.yaml @@ -40,7 +40,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4-1106-preview.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4-1106-preview.yaml index f4fa6317af8763..d7752397701f9a 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4-1106-preview.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4-1106-preview.yaml @@ -40,7 +40,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4-32k.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4-32k.yaml index f92173ccfd9968..8358425e6d2909 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4-32k.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4-32k.yaml @@ -40,7 +40,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-2024-04-09.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-2024-04-09.yaml index 6b36361efe80d3..0234499164abf4 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-2024-04-09.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-2024-04-09.yaml @@ -41,7 +41,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-preview.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-preview.yaml index c0350ae2c6d750..8d29cf0c04a1df 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-preview.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-preview.yaml @@ -40,7 +40,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo.yaml index 575acb7fa294b7..b25ff6a81269fa 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo.yaml @@ -41,7 +41,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4-vision-preview.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4-vision-preview.yaml index a63b60842396c6..07037c66438dd2 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4-vision-preview.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4-vision-preview.yaml @@ -38,7 +38,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4.yaml index a7a5bf3c864a6e..f7b5138b7df366 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4.yaml @@ -40,7 +40,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-05-13.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-05-13.yaml index f0d835cba217df..b630d6f63075c2 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-05-13.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-05-13.yaml @@ -28,7 +28,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml index 7e430c51a710fc..73b7f6970076c0 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-08-06.yaml @@ -28,7 +28,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml new file mode 100644 index 00000000000000..256e87edbe38c3 --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml @@ -0,0 +1,44 @@ +model: gpt-4o-audio-preview +label: + zh_Hans: gpt-4o-audio-preview + en_US: gpt-4o-audio-preview +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 4096 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '5.00' + output: '15.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini-2024-07-18.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini-2024-07-18.yaml index 6f23e0647d6eec..df38270f79b1c3 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini-2024-07-18.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini-2024-07-18.yaml @@ -28,7 +28,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 @@ -37,6 +37,9 @@ parameter_rules: options: - text - json_object + - json_schema + - name: json_schema + use_template: json_schema pricing: input: '0.15' output: '0.60' diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml index 23dcf85085e123..5e3c94fbe255c0 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-mini.yaml @@ -28,7 +28,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o.yaml index 4f141f772fd14c..3090a9e090c2c5 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o.yaml @@ -28,7 +28,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 06135c958463e8..922e5e131417ee 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -1,7 +1,7 @@ import json import logging from collections.abc import Generator -from typing import Optional, Union, cast +from typing import Any, Optional, Union, cast import tiktoken from openai import OpenAI, Stream @@ -11,9 +11,9 @@ from openai.types.chat.chat_completion_message import FunctionCall from core.model_runtime.callbacks.base_callback import Callback -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import ( +from core.model_runtime.entities import ( AssistantPromptMessage, + AudioPromptMessageContent, ImagePromptMessageContent, PromptMessage, PromptMessageContentType, @@ -23,6 +23,7 @@ ToolPromptMessage, UserPromptMessage, ) +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType, PriceConfig from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -37,18 +38,25 @@ {{instructions}} -""" +""" # noqa: E501 + class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): """ Model class for OpenAI large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -64,8 +72,8 @@ def _invoke(self, model: str, credentials: dict, """ # handle fine tune remote models base_model = model - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] # get model mode model_mode = self.get_model_mode(base_model, credentials) @@ -80,7 +88,7 @@ def _invoke(self, model: str, credentials: dict, tools=tools, stop=stop, stream=stream, - user=user + user=user, ) else: # text completion model @@ -91,26 +99,34 @@ def _invoke(self, model: str, credentials: dict, model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ # handle fine tune remote models base_model = model - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] # get model mode model_mode = self.get_model_mode(base_model, credentials) # transform response format - if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']: + if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}: stop = stop or [] if model_mode == LLMMode.CHAT: # chat model @@ -123,7 +139,7 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_message stop=stop, stream=stream, user=user, - response_format=model_parameters['response_format'] + response_format=model_parameters["response_format"], ) else: self._transform_completion_json_prompts( @@ -135,9 +151,9 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_message stop=stop, stream=stream, user=user, - response_format=model_parameters['response_format'] + response_format=model_parameters["response_format"], ) - model_parameters.pop('response_format') + model_parameters.pop("response_format") return self._invoke( model=model, @@ -147,14 +163,21 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_message tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - def _transform_chat_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + def _transform_chat_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts """ @@ -167,25 +190,35 @@ def _transform_chat_json_prompts(self, model: str, credentials: dict, if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): # override the system message prompt_messages[0] = SystemPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + content=OPENAI_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n")) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=OPENAI_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", f"Please output a valid {response_format} object." + ).replace("{{block}}", response_format) + ), + ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) - - def _transform_completion_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + + def _transform_completion_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts """ @@ -202,25 +235,30 @@ def _transform_completion_json_prompts(self, model: str, credentials: dict, break if user_message: - if prompt_messages[i].content[-11:] == 'Assistant: ': + if prompt_messages[i].content[-11:] == "Assistant: ": # now we are in the chat app, remove the last assistant message prompt_messages[i].content = prompt_messages[i].content[:-11] prompt_messages[i] = UserPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", user_message.content) - .replace("{{block}}", response_format) + content=OPENAI_BLOCK_MODE_PROMPT.replace("{{instructions}}", user_message.content).replace( + "{{block}}", response_format + ) ) prompt_messages[i].content += f"Assistant:\n```{response_format}\n" else: prompt_messages[i] = UserPromptMessage( - content=OPENAI_BLOCK_MODE_PROMPT - .replace("{{instructions}}", user_message.content) - .replace("{{block}}", response_format) + content=OPENAI_BLOCK_MODE_PROMPT.replace("{{instructions}}", user_message.content).replace( + "{{block}}", response_format + ) ) prompt_messages[i].content += f"\n```{response_format}\n" - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -231,8 +269,8 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr :return: """ # handle fine tune remote models - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] else: base_model = model @@ -262,14 +300,14 @@ def validate_credentials(self, model: str, credentials: dict) -> None: # handle fine tune remote models base_model = model # fine-tuned model name likes ft:gpt-3.5-turbo-0613:personal::xxxxx - if model.startswith('ft:'): - base_model = model.split(':')[1] + if model.startswith("ft:"): + base_model = model.split(":")[1] # check if model exists remote_models = self.remote_models(credentials) remote_model_map = {model.model: model for model in remote_models} if model not in remote_model_map: - raise CredentialsValidateFailedError(f'Fine-tuned model {model} not found') + raise CredentialsValidateFailedError(f"Fine-tuned model {model} not found") # get model mode model_mode = self.get_model_mode(base_model, credentials) @@ -277,7 +315,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: if model_mode == LLMMode.CHAT: # chat model client.chat.completions.create( - messages=[{"role": "user", "content": 'ping'}], + messages=[{"role": "user", "content": "ping"}], model=model, temperature=0, max_tokens=20, @@ -286,7 +324,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: else: # text completion model client.completions.create( - prompt='ping', + prompt="ping", model=model, temperature=0, max_tokens=20, @@ -313,11 +351,11 @@ def remote_models(self, credentials: dict) -> list[AIModelEntity]: # get all remote models remote_models = client.models.list() - fine_tune_models = [model for model in remote_models if model.id.startswith('ft:')] + fine_tune_models = [model for model in remote_models if model.id.startswith("ft:")] ai_model_entities = [] for model in fine_tune_models: - base_model = model.id.split(':')[1] + base_model = model.id.split(":")[1] base_model_schema = None for predefined_model_name, predefined_model in predefined_models_map.items(): @@ -329,30 +367,29 @@ def remote_models(self, credentials: dict) -> list[AIModelEntity]: ai_model_entity = AIModelEntity( model=model.id, - label=I18nObject( - zh_Hans=model.id, - en_US=model.id - ), + label=I18nObject(zh_Hans=model.id, en_US=model.id), model_type=ModelType.LLM, features=base_model_schema.features, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties=base_model_schema.model_properties, parameter_rules=base_model_schema.parameter_rules, - pricing=PriceConfig( - input=0.003, - output=0.006, - unit=0.001, - currency='USD' - ) + pricing=PriceConfig(input=0.003, output=0.006, unit=0.001, currency="USD"), ) ai_model_entities.append(ai_model_entity) return ai_model_entities - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -374,23 +411,17 @@ def _generate(self, model: str, credentials: dict, extra_model_kwargs = {} if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user if stream: - extra_model_kwargs['stream_options'] = { - "include_usage": True - } - + extra_model_kwargs["stream_options"] = {"include_usage": True} + # text completion model response = client.completions.create( - prompt=prompt_messages[0].content, - model=model, - stream=stream, - **model_parameters, - **extra_model_kwargs + prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs ) if stream: @@ -398,8 +429,9 @@ def _generate(self, model: str, credentials: dict, return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: Completion, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: Completion, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm completion response @@ -412,9 +444,7 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Com assistant_text = response.choices[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_text) # calculate num tokens if response.usage: @@ -440,8 +470,9 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Com return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: Stream[Completion], prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm completion stream response @@ -451,7 +482,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon :param prompt_messages: prompt messages :return: llm response chunk generator result """ - full_text = '' + full_text = "" prompt_tokens = 0 completion_tokens = 0 @@ -460,8 +491,8 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage(content=''), - ) + message=AssistantPromptMessage(content=""), + ), ) for chunk in response: @@ -474,14 +505,12 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon delta = chunk.choices[0] - if delta.finish_reason is None and (delta.text is None or delta.text == ''): + if delta.finish_reason is None and (delta.text is None or delta.text == ""): continue # transform assistant message to prompt message - text = delta.text if delta.text else '' - assistant_prompt_message = AssistantPromptMessage( - content=text - ) + text = delta.text or "" + assistant_prompt_message = AssistantPromptMessage(content=text) full_text += text @@ -494,7 +523,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - ) + ), ) else: yield LLMResultChunk( @@ -504,7 +533,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) if not prompt_tokens: @@ -520,10 +549,17 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon yield final_chunk - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -552,7 +588,7 @@ def _chat_generate(self, model: str, credentials: dict, try: schema = json.loads(json_schema) except: - raise ValueError(f"not currect json_schema format: {json_schema}") + raise ValueError(f"not correct json_schema format: {json_schema}") model_parameters.pop("json_schema") model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema} else: @@ -562,29 +598,39 @@ def _chat_generate(self, model: str, credentials: dict, if tools: # extra_model_kwargs['tools'] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools] - extra_model_kwargs['functions'] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] + extra_model_kwargs["functions"] = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} for tool in tools + ] if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user if stream: - extra_model_kwargs['stream_options'] = { - 'include_usage': True - } + extra_model_kwargs["stream_options"] = {"include_usage": True} # clear illegal prompt messages prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) + # o1 compatibility + block_as_stream = False + if model.startswith("o1"): + if stream: + block_as_stream = True + stream = False + + if "stream_options" in extra_model_kwargs: + del extra_model_kwargs["stream_options"] + + if "stop" in extra_model_kwargs: + del extra_model_kwargs["stop"] + # chat model + messages: Any = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] response = client.chat.completions.create( - messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], + messages=messages, model=model, stream=stream, **model_parameters, @@ -594,11 +640,56 @@ def _chat_generate(self, model: str, credentials: dict, if stream: return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) - return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) + block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) + + if block_as_stream: + return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop) + + return block_result + + def _handle_chat_block_as_stream_response( + self, + block_result: LLMResult, + prompt_messages: list[PromptMessage], + stop: Optional[list[str]] = None, + ) -> Generator[LLMResultChunk, None, None]: + """ + Handle llm chat response + + :param model: model name + :param credentials: credentials + :param response: response + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :param stop: stop words + :return: llm response chunk generator + """ + text = block_result.message.content + text = cast(str, text) + + if stop: + text = self.enforce_stop_tokens(text, stop) + + yield LLMResultChunk( + model=block_result.model, + prompt_messages=prompt_messages, + system_fingerprint=block_result.system_fingerprint, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=text), + finish_reason="stop", + usage=block_result.usage, + ), + ) - def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: ChatCompletion, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> LLMResult: """ Handle llm chat response @@ -619,10 +710,7 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, response tool_calls = [function_call] if function_call else [] # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) # calculate num tokens if response.usage: @@ -648,9 +736,14 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, response return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: Stream[ChatCompletionChunk], - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[ChatCompletionChunk], + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> Generator: """ Handle llm chat stream response @@ -660,7 +753,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r :param tools: tools for tool calling :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None prompt_tokens = 0 completion_tokens = 0 @@ -670,8 +763,8 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage(content=''), - ) + message=AssistantPromptMessage(content=""), + ), ) for chunk in response: @@ -685,8 +778,11 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r delta = chunk.choices[0] has_finish_reason = delta.finish_reason is not None - if not has_finish_reason and (delta.delta.content is None or delta.delta.content == '') and \ - delta.delta.function_call is None: + if ( + not has_finish_reason + and (delta.delta.content is None or delta.delta.content == "") + and delta.delta.function_call is None + ): continue # assistant_message_tool_calls = delta.delta.tool_calls @@ -708,7 +804,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r # start of stream function call delta_assistant_message_function_call_storage = assistant_message_function_call if delta_assistant_message_function_call_storage.arguments is None: - delta_assistant_message_function_call_storage.arguments = '' + delta_assistant_message_function_call_storage.arguments = "" if not has_finish_reason: continue @@ -719,12 +815,9 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r final_tool_calls.extend(tool_calls) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls) - full_assistant_content += delta.delta.content if delta.delta.content else '' + full_assistant_content += delta.delta.content or "" if has_finish_reason: final_chunk = LLMResultChunk( @@ -735,7 +828,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - ) + ), ) else: yield LLMResultChunk( @@ -745,7 +838,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) if not prompt_tokens: @@ -753,8 +846,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r if not completion_tokens: full_assistant_prompt_message = AssistantPromptMessage( - content=full_assistant_content, - tool_calls=final_tool_calls + content=full_assistant_content, tool_calls=final_tool_calls ) completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message]) @@ -764,9 +856,9 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r yield final_chunk - def _extract_response_tool_calls(self, - response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -777,21 +869,19 @@ def _extract_response_tool_calls(self, if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) return tool_calls - def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \ - -> AssistantPromptMessage.ToolCall: + def _extract_response_function_call( + self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall + ) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -801,14 +891,11 @@ def _extract_response_function_call(self, response_function_call: FunctionCall | tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call.name, - arguments=response_function_call.arguments + name=response_function_call.name, arguments=response_function_call.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.name, - type="function", - function=function + id=response_function_call.name, type="function", function=function ) return tool_call @@ -821,7 +908,7 @@ def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[Promp :param prompt_messages: prompt messages :return: cleaned prompt messages """ - checklist = ['gpt-4-turbo', 'gpt-4-turbo-2024-04-09'] + checklist = ["gpt-4-turbo", "gpt-4-turbo-2024-04-09"] if model in checklist: # count how many user messages are there @@ -830,11 +917,30 @@ def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[Promp for prompt_message in prompt_messages: if isinstance(prompt_message, UserPromptMessage): if isinstance(prompt_message.content, list): - prompt_message.content = '\n'.join([ - item.data if item.type == PromptMessageContentType.TEXT else - '[IMAGE]' if item.type == PromptMessageContentType.IMAGE else '' - for item in prompt_message.content - ]) + prompt_message.content = "\n".join( + [ + item.data + if item.type == PromptMessageContentType.TEXT + else "[IMAGE]" + if item.type == PromptMessageContentType.IMAGE + else "" + for item in prompt_message.content + ] + ) + + if model.startswith("o1"): + system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)]) + if system_message_count > 0: + new_prompt_messages = [] + for prompt_message in prompt_messages: + if isinstance(prompt_message, SystemPromptMessage): + prompt_message = UserPromptMessage( + content=prompt_message.content, + name=prompt_message.name, + ) + + new_prompt_messages.append(prompt_message) + prompt_messages = new_prompt_messages return prompt_messages @@ -843,27 +949,27 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: Convert PromptMessage to dict for OpenAI API """ if isinstance(message, UserPromptMessage): - message = cast(UserPromptMessage, message) if isinstance(message.content, str): message_dict = {"role": "user", "content": message.content} - else: + elif isinstance(message.content, list): sub_messages = [] for message_content in message.content: - if message_content.type == PromptMessageContentType.TEXT: - message_content = cast(TextPromptMessageContent, message_content) + if isinstance(message_content, TextPromptMessageContent): + sub_message_dict = {"type": "text", "text": message_content.data} + sub_messages.append(sub_message_dict) + elif isinstance(message_content, ImagePromptMessageContent): sub_message_dict = { - "type": "text", - "text": message_content.data + "type": "image_url", + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) - elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast(ImagePromptMessageContent, message_content) + elif isinstance(message_content, AudioPromptMessageContent): sub_message_dict = { - "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "type": "input_audio", + "input_audio": { + "data": message_content.data, + "format": message_content.format, + }, } sub_messages.append(sub_message_dict) @@ -889,11 +995,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: # "content": message.content, # "tool_call_id": message.tool_call_id # } - message_dict = { - "role": "function", - "content": message.content, - "name": message.tool_call_id - } + message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id} else: raise ValueError(f"Got unknown type {message}") @@ -902,8 +1004,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: return message_dict - def _num_tokens_from_string(self, model: str, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -924,16 +1025,17 @@ def _num_tokens_from_string(self, model: str, text: str, return num_tokens - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" - if model.startswith('ft:'): - model = model.split(':')[1] + if model.startswith("ft:"): + model = model.split(":")[1] # Currently, we can use gpt4o to calculate chatgpt-4o-latest's token. - if model == "chatgpt-4o-latest": + if model == "chatgpt-4o-latest" or model.startswith("o1"): model = "gpt-4o" try: @@ -948,7 +1050,7 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], tokens_per_message = 4 # if there's a name, the role is omitted tokens_per_name = -1 - elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"): + elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4") or model.startswith("o1"): tokens_per_message = 3 tokens_per_name = 1 else: @@ -969,10 +1071,10 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -1011,37 +1113,37 @@ def _num_tokens_for_tools(self, encoding: tiktoken.Encoding, tools: list[PromptM """ num_tokens = 0 for tool in tools: - num_tokens += len(encoding.encode('type')) - num_tokens += len(encoding.encode('function')) + num_tokens += len(encoding.encode("type")) + num_tokens += len(encoding.encode("function")) # calculate num tokens for function object - num_tokens += len(encoding.encode('name')) + num_tokens += len(encoding.encode("name")) num_tokens += len(encoding.encode(tool.name)) - num_tokens += len(encoding.encode('description')) + num_tokens += len(encoding.encode("description")) num_tokens += len(encoding.encode(tool.description)) parameters = tool.parameters - num_tokens += len(encoding.encode('parameters')) - if 'title' in parameters: - num_tokens += len(encoding.encode('title')) + num_tokens += len(encoding.encode("parameters")) + if "title" in parameters: + num_tokens += len(encoding.encode("title")) num_tokens += len(encoding.encode(parameters.get("title"))) - num_tokens += len(encoding.encode('type')) + num_tokens += len(encoding.encode("type")) num_tokens += len(encoding.encode(parameters.get("type"))) - if 'properties' in parameters: - num_tokens += len(encoding.encode('properties')) - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += len(encoding.encode("properties")) + for key, value in parameters.get("properties").items(): num_tokens += len(encoding.encode(key)) for field_key, field_value in value.items(): num_tokens += len(encoding.encode(field_key)) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += len(encoding.encode(enum_field)) else: num_tokens += len(encoding.encode(field_key)) num_tokens += len(encoding.encode(str(field_value))) - if 'required' in parameters: - num_tokens += len(encoding.encode('required')) - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += len(encoding.encode("required")) + for required_field in parameters["required"]: num_tokens += 3 num_tokens += len(encoding.encode(required_field)) @@ -1049,26 +1151,26 @@ def _num_tokens_for_tools(self, encoding: tiktoken.Encoding, tools: list[PromptM def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - OpenAI supports fine-tuning of their models. This method returns the schema of the base model - but renamed to the fine-tuned model name. + OpenAI supports fine-tuning of their models. This method returns the schema of the base model + but renamed to the fine-tuned model name. - :param model: model name - :param credentials: credentials + :param model: model name + :param credentials: credentials - :return: model schema + :return: model schema """ - if not model.startswith('ft:'): + if not model.startswith("ft:"): base_model = model else: # get base_model - base_model = model.split(':')[1] + base_model = model.split(":")[1] # get model schema models = self.predefined_models() model_map = {model.model: model for model in models} if base_model not in model_map: - raise ValueError(f'Base model {base_model} not found') - + raise ValueError(f"Base model {base_model} not found") + base_model_schema = model_map[base_model] base_model_schema_features = base_model_schema.features or [] @@ -1077,16 +1179,13 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode entity = AIModelEntity( model=model, - label=I18nObject( - zh_Hans=model, - en_US=model - ), + label=I18nObject(zh_Hans=model, en_US=model), model_type=ModelType.LLM, features=list(base_model_schema_features), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties=dict(base_model_schema_model_properties.items()), parameter_rules=list(base_model_schema_parameters_rules), - pricing=base_model_schema.pricing + pricing=base_model_schema.pricing, ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/openai/llm/o1-mini-2024-09-12.yaml b/api/core/model_runtime/model_providers/openai/llm/o1-mini-2024-09-12.yaml new file mode 100644 index 00000000000000..0ade7f8ded9d4b --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/o1-mini-2024-09-12.yaml @@ -0,0 +1,33 @@ +model: o1-mini-2024-09-12 +label: + zh_Hans: o1-mini-2024-09-12 + en_US: o1-mini-2024-09-12 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: max_tokens + use_template: max_tokens + default: 65536 + min: 1 + max: 65536 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '3.00' + output: '12.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/llm/o1-mini.yaml b/api/core/model_runtime/model_providers/openai/llm/o1-mini.yaml new file mode 100644 index 00000000000000..60816c5d1e4d93 --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/o1-mini.yaml @@ -0,0 +1,33 @@ +model: o1-mini +label: + zh_Hans: o1-mini + en_US: o1-mini +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: max_tokens + use_template: max_tokens + default: 65536 + min: 1 + max: 65536 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '3.00' + output: '12.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/llm/o1-preview-2024-09-12.yaml b/api/core/model_runtime/model_providers/openai/llm/o1-preview-2024-09-12.yaml new file mode 100644 index 00000000000000..c9da96f611bc91 --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/o1-preview-2024-09-12.yaml @@ -0,0 +1,33 @@ +model: o1-preview-2024-09-12 +label: + zh_Hans: o1-preview-2024-09-12 + en_US: o1-preview-2024-09-12 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: max_tokens + use_template: max_tokens + default: 32768 + min: 1 + max: 32768 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '15.00' + output: '60.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/llm/o1-preview.yaml b/api/core/model_runtime/model_providers/openai/llm/o1-preview.yaml new file mode 100644 index 00000000000000..c83874b76582e8 --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/o1-preview.yaml @@ -0,0 +1,33 @@ +model: o1-preview +label: + zh_Hans: o1-preview + en_US: o1-preview +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: max_tokens + use_template: max_tokens + default: 32768 + min: 1 + max: 32768 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '15.00' + output: '60.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/moderation/moderation.py b/api/core/model_runtime/model_providers/openai/moderation/moderation.py index b1d0e57ad26920..619044d808cdf6 100644 --- a/api/core/model_runtime/model_providers/openai/moderation/moderation.py +++ b/api/core/model_runtime/model_providers/openai/moderation/moderation.py @@ -14,9 +14,7 @@ class OpenAIModerationModel(_CommonOpenAI, ModerationModel): Model class for OpenAI text moderation model. """ - def _invoke(self, model: str, credentials: dict, - text: str, user: Optional[str] = None) \ - -> bool: + def _invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool: """ Invoke moderation model @@ -34,10 +32,10 @@ def _invoke(self, model: str, credentials: dict, # chars per chunk length = self._get_max_characters_per_chunk(model, credentials) - text_chunks = [text[i:i + length] for i in range(0, len(text), length)] + text_chunks = [text[i : i + length] for i in range(0, len(text), length)] max_text_chunks = self._get_max_chunks(model, credentials) - chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)] + chunks = [text_chunks[i : i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)] for text_chunk in chunks: moderation_result = self._moderation_invoke(model=model, client=client, texts=text_chunk) @@ -65,7 +63,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: self._moderation_invoke( model=model, client=client, - texts=['ping'], + texts=["ping"], ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) diff --git a/api/core/model_runtime/model_providers/openai/openai.py b/api/core/model_runtime/model_providers/openai/openai.py index 66efd4797f621a..aa6f38ce9fae5a 100644 --- a/api/core/model_runtime/model_providers/openai/openai.py +++ b/api/core/model_runtime/model_providers/openai/openai.py @@ -9,7 +9,6 @@ class OpenAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: Mapping) -> None: """ Validate provider credentials @@ -20,14 +19,11 @@ def validate_provider_credentials(self, credentials: Mapping) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - # Use `gpt-3.5-turbo` model for validate, + # Use `gpt-4o-mini` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='gpt-3.5-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="gpt-4o-mini", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py index efbdd054f9457c..ecf455c6fdb7f8 100644 --- a/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py @@ -2,6 +2,8 @@ from openai import OpenAI +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel from core.model_runtime.model_providers.openai._common import _CommonOpenAI @@ -12,9 +14,7 @@ class OpenAISpeech2TextModel(_CommonOpenAI, Speech2TextModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -37,7 +37,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: try: audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self._speech2text_invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -60,3 +60,18 @@ def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> response = client.audio.transcriptions.create(model=model, file=file) return response.text + + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.SPEECH2TEXT, + model_properties={}, + parameter_rules=[], + ) + + return entity diff --git a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py index e23a2edf87aefb..bec01fe6797f52 100644 --- a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py @@ -6,6 +6,7 @@ import tiktoken from openai import OpenAI +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -18,9 +19,14 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): Model class for OpenAI text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -28,6 +34,7 @@ def _invoke(self, model: str, credentials: dict, :param credentials: model credentials :param texts: texts to embed :param user: unique user id + :param input_type: input type :return: embeddings result """ # transform credentials to kwargs for model instance @@ -37,9 +44,9 @@ def _invoke(self, model: str, credentials: dict, extra_model_kwargs = {} if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user - extra_model_kwargs['encoding_format'] = 'base64' + extra_model_kwargs["encoding_format"] = "base64" # get model properties context_size = self._get_context_size(model, credentials) @@ -56,11 +63,9 @@ def _invoke(self, model: str, credentials: dict, enc = tiktoken.get_encoding("cl100k_base") for i, text in enumerate(texts): - token = enc.encode( - text - ) + token = enc.encode(text) for j in range(0, len(token), context_size): - tokens += [token[j: j + context_size]] + tokens += [token[j : j + context_size]] indices += [i] batched_embeddings = [] @@ -69,10 +74,7 @@ def _invoke(self, model: str, credentials: dict, for i in _iter: # call embedding model embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - client=client, - texts=tokens[i: i + max_chunks], - extra_model_kwargs=extra_model_kwargs + model=model, client=client, texts=tokens[i : i + max_chunks], extra_model_kwargs=extra_model_kwargs ) used_tokens += embedding_used_tokens @@ -88,10 +90,7 @@ def _invoke(self, model: str, credentials: dict, _result = results[i] if len(_result) == 0: embeddings_batch, embedding_used_tokens = self._embedding_invoke( - model=model, - client=client, - texts="", - extra_model_kwargs=extra_model_kwargs + model=model, client=client, texts="", extra_model_kwargs=extra_model_kwargs ) used_tokens += embedding_used_tokens @@ -101,17 +100,9 @@ def _invoke(self, model: str, credentials: dict, embeddings[i] = (average / np.linalg.norm(average)).tolist() # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - return TextEmbeddingResult( - embeddings=embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -152,17 +143,13 @@ def validate_credentials(self, model: str, credentials: dict) -> None: client = OpenAI(**credentials_kwargs) # call embedding model - self._embedding_invoke( - model=model, - client=client, - texts=['ping'], - extra_model_kwargs={} - ) + self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={}) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _embedding_invoke(self, model: str, client: OpenAI, texts: Union[list[str], str], - extra_model_kwargs: dict) -> tuple[list[list[float]], int]: + def _embedding_invoke( + self, model: str, client: OpenAI, texts: Union[list[str], str], extra_model_kwargs: dict + ) -> tuple[list[list[float]], int]: """ Invoke embedding model @@ -179,10 +166,12 @@ def _embedding_invoke(self, model: str, client: OpenAI, texts: Union[list[str], **extra_model_kwargs, ) - if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64': + if "encoding_format" in extra_model_kwargs and extra_model_kwargs["encoding_format"] == "base64": # decode base64 embedding - return ([list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data], - response.usage.total_tokens) + return ( + [list(np.frombuffer(base64.b64decode(data.embedding), dtype="float32")) for data in response.data], + response.usage.total_tokens, + ) return [data.embedding for data in response.data], response.usage.total_tokens @@ -197,10 +186,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -211,7 +197,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/openai/tts/tts.py b/api/core/model_runtime/model_providers/openai/tts/tts.py index afa5d4b88adecb..2e57b95944c6b4 100644 --- a/api/core/model_runtime/model_providers/openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/openai/tts/tts.py @@ -1,5 +1,5 @@ import concurrent.futures -from typing import Optional +from typing import Any, Optional from openai import OpenAI @@ -14,8 +14,9 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, tenant_id: str, credentials: dict, - content_text: str, voice: str, user: Optional[str] = None) -> any: + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ) -> Any: """ _invoke text2speech model @@ -28,14 +29,12 @@ def _invoke(self, model: str, tenant_id: str, credentials: dict, :return: text translated to audio file """ - if not voice or voice not in [d['value'] for d in - self.get_tts_model_voices(model=model, credentials=credentials)]: + if not voice or voice not in [ + d["value"] for d in self.get_tts_model_voices(model=model, credentials=credentials) + ]: voice = self._get_model_default_voice(model, credentials) # if streaming: - return self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - voice=voice) + return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice) def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: """ @@ -50,14 +49,13 @@ def validate_credentials(self, model: str, credentials: dict, user: Optional[str self._tts_invoke_streaming( model=model, credentials=credentials, - content_text='Hello Dify!', + content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, - voice: str) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> Any: """ _tts_invoke_streaming text2speech model @@ -71,31 +69,38 @@ def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str # doc: https://platform.openai.com/docs/guides/text-to-speech credentials_kwargs = self._to_credential_kwargs(credentials) client = OpenAI(**credentials_kwargs) - model_support_voice = [x.get("value") for x in - self.get_tts_model_voices(model=model, credentials=credentials)] + model_support_voice = [ + x.get("value") for x in self.get_tts_model_voices(model=model, credentials=credentials) + ] if not voice or voice not in model_support_voice: voice = self._get_model_default_voice(model, credentials) word_limit = self._get_model_word_limit(model, credentials) if len(content_text) > word_limit: sentences = self._split_text_into_sentences(content_text, max_length=word_limit) executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences))) - futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model, - response_format="mp3", - input=sentences[i], voice=voice) for i in range(len(sentences))] - for index, future in enumerate(futures): - yield from future.result().__enter__().iter_bytes(1024) + futures = [ + executor.submit( + client.audio.speech.with_streaming_response.create, + model=model, + response_format="mp3", + input=sentences[i], + voice=voice, + ) + for i in range(len(sentences)) + ] + for future in futures: + yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801 else: - response = client.audio.speech.with_streaming_response.create(model=model, voice=voice, - response_format="mp3", - input=content_text.strip()) + response = client.audio.speech.with_streaming_response.create( + model=model, voice=voice, response_format="mp3", input=content_text.strip() + ) - yield from response.__enter__().iter_bytes(1024) + yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801 except Exception as ex: raise InvokeBadRequestError(str(ex)) - def _process_sentence(self, sentence: str, model: str, - voice, credentials: dict): + def _process_sentence(self, sentence: str, model: str, voice, credentials: dict): """ _tts_invoke openai text2speech model api diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py index 51950ca3778424..1234e44f80c40c 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py @@ -1,4 +1,3 @@ - import requests from core.model_runtime.errors.invoke import ( @@ -11,7 +10,7 @@ ) -class _CommonOAI_API_Compat: +class _CommonOaiApiCompat: @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -35,10 +34,10 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] ], InvokeServerUnavailableError: [ requests.exceptions.ConnectionError, # Engine Overloaded - requests.exceptions.HTTPError # Server Error + requests.exceptions.HTTPError, # Server Error ], InvokeConnectionError: [ requests.exceptions.ConnectTimeout, # Timeout - requests.exceptions.ReadTimeout # Timeout - ] - } \ No newline at end of file + requests.exceptions.ReadTimeout, # Timeout + ], + } diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index e5cc884b6de315..e1342fe985ac13 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -35,22 +35,28 @@ from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat from core.model_runtime.utils import helper logger = logging.getLogger(__name__) -class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): +class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): """ Model class for OpenAI large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -77,8 +83,13 @@ def _invoke(self, model: str, credentials: dict, user=user, ) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -92,100 +103,93 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr def validate_credentials(self, model: str, credentials: dict) -> None: """ - Validate model credentials using requests to ensure compatibility with all providers following OpenAI's API standard. + Validate model credentials using requests to ensure compatibility with all providers following + OpenAI's API standard. :param model: model name :param credentials: model credentials :return: """ try: - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials['endpoint_url'] - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials["endpoint_url"] + if not endpoint_url.endswith("/"): + endpoint_url += "/" # prepare the payload for a simple ping to the model - data = { - 'model': model, - 'max_tokens': 5 - } + data = {"model": model, "max_tokens": 5} - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - data['messages'] = [ - { - "role": "user", - "content": "ping" - }, + data["messages"] = [ + {"role": "user", "content": "ping"}, ] - endpoint_url = urljoin(endpoint_url, 'chat/completions') + endpoint_url = urljoin(endpoint_url, "chat/completions") elif completion_type is LLMMode.COMPLETION: - data['prompt'] = 'ping' - endpoint_url = urljoin(endpoint_url, 'completions') + data["prompt"] = "ping" + endpoint_url = urljoin(endpoint_url, "completions") else: raise ValueError("Unsupported completion type for model configuration.") # send a post request to validate the credentials - response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 300) - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300)) if response.status_code != 200: raise CredentialsValidateFailedError( - f'Credentials validation failed with status code {response.status_code}') + f"Credentials validation failed with status code {response.status_code}" + ) try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error") - if (completion_type is LLMMode.CHAT and json_result['object'] == ''): - json_result['object'] = 'chat.completion' - elif (completion_type is LLMMode.COMPLETION and json_result['object'] == ''): - json_result['object'] = 'text_completion' + if completion_type is LLMMode.CHAT and json_result.get("object", "") == "": + json_result["object"] = "chat.completion" + elif completion_type is LLMMode.COMPLETION and json_result.get("object", "") == "": + json_result["object"] = "text_completion" - if (completion_type is LLMMode.CHAT - and ('object' not in json_result or json_result['object'] != 'chat.completion')): + if completion_type is LLMMode.CHAT and ( + "object" not in json_result or json_result["object"] != "chat.completion" + ): raise CredentialsValidateFailedError( - 'Credentials validation failed: invalid response object, must be \'chat.completion\'') - elif (completion_type is LLMMode.COMPLETION - and ('object' not in json_result or json_result['object'] != 'text_completion')): + "Credentials validation failed: invalid response object, must be 'chat.completion'" + ) + elif completion_type is LLMMode.COMPLETION and ( + "object" not in json_result or json_result["object"] != "text_completion" + ): raise CredentialsValidateFailedError( - 'Credentials validation failed: invalid response object, must be \'text_completion\'') + "Credentials validation failed: invalid response object, must be 'text_completion'" + ) except CredentialsValidateFailedError: raise except Exception as ex: - raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') + raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}") def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ features = [] - function_calling_type = credentials.get('function_calling_type', 'no_call') - if function_calling_type in ['function_call']: + function_calling_type = credentials.get("function_calling_type", "no_call") + if function_calling_type == "function_call": features.append(ModelFeature.TOOL_CALL) - elif function_calling_type in ['tool_call']: + elif function_calling_type == "tool_call": features.append(ModelFeature.MULTI_TOOL_CALL) - stream_function_calling = credentials.get('stream_function_calling', 'supported') - if stream_function_calling == 'supported': + stream_function_calling = credentials.get("stream_function_calling", "supported") + if stream_function_calling == "supported": features.append(ModelFeature.STREAM_TOOL_CALL) - vision_support = credentials.get('vision_support', 'not_support') - if vision_support == 'support': + vision_support = credentials.get("vision_support", "not_support") + if vision_support == "support": features.append(ModelFeature.VISION) entity = AIModelEntity( @@ -195,75 +199,108 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, features=features, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")), - ModelPropertyKey.MODE: credentials.get('mode'), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "4096")), + ModelPropertyKey.MODE: credentials.get("mode"), }, parameter_rules=[ ParameterRule( name=DefaultParameterName.TEMPERATURE.value, - label=I18nObject(en_US="Temperature"), + label=I18nObject(en_US="Temperature", zh_Hans="温度"), + help=I18nObject( + en_US="Kernel sampling threshold. Used to determine the randomness of the results." + "The higher the value, the stronger the randomness." + "The higher the possibility of getting different answers to the same question.", + zh_Hans="核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。", + ), type=ParameterType.FLOAT, - default=float(credentials.get('temperature', 0.7)), + default=float(credentials.get("temperature", 0.7)), min=0, max=2, - precision=2 + precision=2, ), ParameterRule( name=DefaultParameterName.TOP_P.value, - label=I18nObject(en_US="Top P"), + label=I18nObject(en_US="Top P", zh_Hans="Top P"), + help=I18nObject( + en_US="The probability threshold of the nucleus sampling method during the generation process." + "The larger the value is, the higher the randomness of generation will be." + "The smaller the value is, the higher the certainty of generation will be.", + zh_Hans="生成过程中核采样方法概率阈值。取值越大,生成的随机性越高;取值越小,生成的确定性越高。", + ), type=ParameterType.FLOAT, - default=float(credentials.get('top_p', 1)), + default=float(credentials.get("top_p", 1)), min=0, max=1, - precision=2 + precision=2, ), ParameterRule( name=DefaultParameterName.FREQUENCY_PENALTY.value, - label=I18nObject(en_US="Frequency Penalty"), + label=I18nObject(en_US="Frequency Penalty", zh_Hans="频率惩罚"), + help=I18nObject( + en_US="For controlling the repetition rate of words used by the model." + "Increasing this can reduce the repetition of the same words in the model's output.", + zh_Hans="用于控制模型已使用字词的重复率。 提高此项可以降低模型在输出中重复相同字词的重复度。", + ), type=ParameterType.FLOAT, - default=float(credentials.get('frequency_penalty', 0)), + default=float(credentials.get("frequency_penalty", 0)), min=-2, - max=2 + max=2, ), ParameterRule( name=DefaultParameterName.PRESENCE_PENALTY.value, - label=I18nObject(en_US="Presence Penalty"), + label=I18nObject(en_US="Presence Penalty", zh_Hans="存在惩罚"), + help=I18nObject( + en_US="Used to control the repetition rate when generating models." + "Increasing this can reduce the repetition rate of model generation.", + zh_Hans="用于控制模型生成时的重复度。提高此项可以降低模型生成的重复度。", + ), type=ParameterType.FLOAT, - default=float(credentials.get('presence_penalty', 0)), + default=float(credentials.get("presence_penalty", 0)), min=-2, - max=2 + max=2, ), ParameterRule( name=DefaultParameterName.MAX_TOKENS.value, - label=I18nObject(en_US="Max Tokens"), + label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"), + help=I18nObject( + en_US="Maximum length of tokens for the model response.", zh_Hans="模型回答的tokens的最大长度。" + ), type=ParameterType.INT, default=512, min=1, - max=int(credentials.get('max_tokens_to_sample', 4096)), - ) + max=int(credentials.get("max_tokens_to_sample", 4096)), + ), ], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - output=Decimal(credentials.get('output_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") + input=Decimal(credentials.get("input_price", 0)), + output=Decimal(credentials.get("output_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), ), ) - if credentials['mode'] == 'chat': + if credentials["mode"] == "chat": entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value - elif credentials['mode'] == 'completion': + elif credentials["mode"] == "completion": entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value else: raise ValueError(f"Unknown completion type {credentials['completion_type']}") return entity - # validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard. - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, \ - user: Optional[str] = None) -> Union[LLMResult, Generator]: + # validate_credentials method has been rewritten to use the requests library for compatibility with all providers + # following OpenAI's API standard. + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -277,52 +314,47 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM :return: full response or stream response chunk generator result """ headers = { - 'Content-Type': 'application/json', - 'Accept-Charset': 'utf-8', + "Content-Type": "application/json", + "Accept-Charset": "utf-8", } - extra_headers = credentials.get('extra_headers') + extra_headers = credentials.get("extra_headers") if extra_headers is not None: headers = { - **headers, - **extra_headers, + **headers, + **extra_headers, } - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" endpoint_url = credentials["endpoint_url"] - if not endpoint_url.endswith('/'): - endpoint_url += '/' + if not endpoint_url.endswith("/"): + endpoint_url += "/" - data = { - "model": model, - "stream": stream, - **model_parameters - } + data = {"model": model, "stream": stream, **model_parameters} - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) if completion_type is LLMMode.CHAT: - endpoint_url = urljoin(endpoint_url, 'chat/completions') - data['messages'] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages] + endpoint_url = urljoin(endpoint_url, "chat/completions") + data["messages"] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages] elif completion_type is LLMMode.COMPLETION: - endpoint_url = urljoin(endpoint_url, 'completions') - data['prompt'] = prompt_messages[0].content + endpoint_url = urljoin(endpoint_url, "completions") + data["prompt"] = prompt_messages[0].content else: raise ValueError("Unsupported completion type for model configuration.") # annotate tools with names, descriptions, etc. - function_calling_type = credentials.get('function_calling_type', 'no_call') + function_calling_type = credentials.get("function_calling_type", "no_call") formatted_tools = [] if tools: - if function_calling_type == 'function_call': - data['functions'] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] - elif function_calling_type == 'tool_call': + if function_calling_type == "function_call": + data["functions"] = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} + for tool in tools + ] + elif function_calling_type == "tool_call": data["tool_choice"] = "auto" for tool in tools: @@ -336,16 +368,10 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM if user: data["user"] = user - response = requests.post( - endpoint_url, - headers=headers, - json=data, - timeout=(10, 300), - stream=stream - ) + response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream) - if response.encoding is None or response.encoding == 'ISO-8859-1': - response.encoding = 'utf-8' + if response.encoding is None or response.encoding == "ISO-8859-1": + response.encoding = "utf-8" if response.status_code != 200: raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}") @@ -355,8 +381,9 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -366,31 +393,33 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon :param prompt_messages: prompt messages :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" chunk_index = 0 - def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ - -> LLMResultChunk: + def create_final_llm_result_chunk( + id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict + ) -> LLMResultChunk: # calculate num tokens - prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) - completion_tokens = self._num_tokens_from_string(model, full_assistant_content) + prompt_tokens = usage and usage.get("prompt_tokens") + if prompt_tokens is None: + prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) + completion_tokens = usage and usage.get("completion_tokens") + if completion_tokens is None: + completion_tokens = self._num_tokens_from_string(model, full_assistant_content) # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) return LLMResultChunk( + id=id, model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=message, - finish_reason=finish_reason, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), ) # delimiter for stream response, need unicode_escape import codecs + delimiter = credentials.get("stream_mode_delimiter", "\n\n") delimiter = codecs.decode(delimiter, "unicode_escape") @@ -406,10 +435,7 @@ def get_tool_call(tool_call_id: str): tool_call = AssistantPromptMessage.ToolCall( id=tool_call_id, type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name="", - arguments="" - ) + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), ) tools_calls.append(tool_call) @@ -428,47 +454,56 @@ def get_tool_call(tool_call_id: str): if new_tool_call.function.arguments: tool_call.function.arguments += new_tool_call.function.arguments - finish_reason = 'Unknown' - + finish_reason = None # The default value of finish_reason is None + message_id, usage = None, None for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): chunk = chunk.strip() if chunk: # ignore sse comments - if chunk.startswith(':'): + if chunk.startswith(":"): + continue + decoded_chunk = chunk.strip().lstrip("data: ").lstrip() + if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]" continue - decoded_chunk = chunk.strip().lstrip('data: ').lstrip() try: - chunk_json = json.loads(decoded_chunk) + chunk_json: dict = json.loads(decoded_chunk) # stream ended except json.JSONDecodeError as e: yield create_final_llm_result_chunk( + id=message_id, index=chunk_index + 1, message=AssistantPromptMessage(content=""), - finish_reason="Non-JSON encountered." + finish_reason="Non-JSON encountered.", + usage=usage, ) break - if not chunk_json or len(chunk_json['choices']) == 0: + if chunk_json: + if u := chunk_json.get("usage"): + usage = u + if not chunk_json or len(chunk_json["choices"]) == 0: continue - choice = chunk_json['choices'][0] - finish_reason = chunk_json['choices'][0].get('finish_reason') + choice = chunk_json["choices"][0] + finish_reason = chunk_json["choices"][0].get("finish_reason") + message_id = chunk_json.get("id") chunk_index += 1 - if 'delta' in choice: - delta = choice['delta'] - delta_content = delta.get('content') + if "delta" in choice: + delta = choice["delta"] + delta_content = delta.get("content") assistant_message_tool_calls = None - if 'tool_calls' in delta and credentials.get('function_calling_type', 'no_call') == 'tool_call': - assistant_message_tool_calls = delta.get('tool_calls', None) - elif 'function_call' in delta and credentials.get('function_calling_type', 'no_call') == 'function_call': - assistant_message_tool_calls = [{ - 'id': 'tool_call_id', - 'type': 'function', - 'function': delta.get('function_call', {}) - }] + if "tool_calls" in delta and credentials.get("function_calling_type", "no_call") == "tool_call": + assistant_message_tool_calls = delta.get("tool_calls", None) + elif ( + "function_call" in delta + and credentials.get("function_calling_type", "no_call") == "function_call" + ): + assistant_message_tool_calls = [ + {"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})} + ] # assistant_message_function_call = delta.delta.function_call @@ -477,7 +512,7 @@ def get_tool_call(tool_call_id: str): tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) increase_tool_call(tool_calls) - if delta_content is None or delta_content == '': + if delta_content is None or delta_content == "": continue # transform assistant message to prompt message @@ -488,9 +523,9 @@ def get_tool_call(tool_call_id: str): # reset tool calls tool_calls = [] full_assistant_content += delta_content - elif 'text' in choice: - choice_text = choice.get('text', '') - if choice_text == '': + elif "text" in choice: + choice_text = choice.get("text", "") + if choice_text == "": continue # transform assistant message to prompt message @@ -500,63 +535,65 @@ def get_tool_call(tool_call_id: str): continue yield LLMResultChunk( + id=message_id, model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=chunk_index, message=assistant_prompt_message, - ) + ), ) chunk_index += 1 if tools_calls: yield LLMResultChunk( + id=message_id, model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=chunk_index, - message=AssistantPromptMessage( - tool_calls=tools_calls, - content="" - ), - ) + message=AssistantPromptMessage(tool_calls=tools_calls, content=""), + ), ) yield create_final_llm_result_chunk( + id=message_id, index=chunk_index, message=AssistantPromptMessage(content=""), - finish_reason=finish_reason + finish_reason=finish_reason, + usage=usage, ) - def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> LLMResult: - - response_json = response.json() + def _handle_generate_response( + self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] + ) -> LLMResult: + response_json: dict = response.json() - completion_type = LLMMode.value_of(credentials['mode']) + completion_type = LLMMode.value_of(credentials["mode"]) - output = response_json['choices'][0] + output = response_json["choices"][0] + message_id = response_json.get("id") - response_content = '' + response_content = "" tool_calls = None - function_calling_type = credentials.get('function_calling_type', 'no_call') + function_calling_type = credentials.get("function_calling_type", "no_call") if completion_type is LLMMode.CHAT: - response_content = output.get('message', {})['content'] - if function_calling_type == 'tool_call': - tool_calls = output.get('message', {}).get('tool_calls') - elif function_calling_type == 'function_call': - tool_calls = output.get('message', {}).get('function_call') + response_content = output.get("message", {})["content"] + if function_calling_type == "tool_call": + tool_calls = output.get("message", {}).get("tool_calls") + elif function_calling_type == "function_call": + tool_calls = output.get("message", {}).get("function_call") elif completion_type is LLMMode.COMPLETION: - response_content = output['text'] + response_content = output["text"] assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[]) if tool_calls: - if function_calling_type == 'tool_call': + if function_calling_type == "tool_call": assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls) - elif function_calling_type == 'function_call': + elif function_calling_type == "function_call": assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)] usage = response_json.get("usage") @@ -574,6 +611,7 @@ def _handle_generate_response(self, model: str, credentials: dict, response: req # transform response result = LLMResult( + id=message_id, model=response_json["model"], prompt_messages=prompt_messages, message=assistant_message, @@ -595,19 +633,13 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: O for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) @@ -616,11 +648,10 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: O message = cast(AssistantPromptMessage, message) message_dict = {"role": "assistant", "content": message.content} if message.tool_calls: - function_calling_type = credentials.get('function_calling_type', 'no_call') - if function_calling_type == 'tool_call': - message_dict["tool_calls"] = [tool_call.dict() for tool_call in - message.tool_calls] - elif function_calling_type == 'function_call': + function_calling_type = credentials.get("function_calling_type", "no_call") + if function_calling_type == "tool_call": + message_dict["tool_calls"] = [tool_call.dict() for tool_call in message.tool_calls] + elif function_calling_type == "function_call": function_call = message.tool_calls[0] message_dict["function_call"] = { "name": function_call.function.name, @@ -631,29 +662,22 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: O message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - function_calling_type = credentials.get('function_calling_type', 'no_call') - if function_calling_type == 'tool_call': - message_dict = { - "role": "tool", - "content": message.content, - "tool_call_id": message.tool_call_id - } - elif function_calling_type == 'function_call': - message_dict = { - "role": "function", - "content": message.content, - "name": message.tool_call_id - } + function_calling_type = credentials.get("function_calling_type", "no_call") + if function_calling_type == "tool_call": + message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} + elif function_calling_type == "function_call": + message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id} else: raise ValueError(f"Got unknown type {message}") - if message.name: + if message.name and message_dict.get("role", "") != "tool": message_dict["name"] = message.name return message_dict - def _num_tokens_from_string(self, model: str, text: Union[str, list[PromptMessageContent]], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string( + self, model: str, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """ Approximate num tokens for model with gpt2 tokenizer. @@ -665,7 +689,7 @@ def _num_tokens_from_string(self, model: str, text: Union[str, list[PromptMessag if isinstance(text, str): full_text = text else: - full_text = '' + full_text = "" for message_content in text: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) @@ -678,8 +702,13 @@ def _num_tokens_from_string(self, model: str, text: Union[str, list[PromptMessag return num_tokens - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None, credentials: dict = None) -> int: + def _num_tokens_from_messages( + self, + model: str, + messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + credentials: Optional[dict] = None, + ) -> int: """ Approximate num tokens with GPT2 tokenizer. """ @@ -698,10 +727,10 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -739,46 +768,44 @@ def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int: """ num_tokens = 0 for tool in tools: - num_tokens += self._get_num_tokens_by_gpt2('type') - num_tokens += self._get_num_tokens_by_gpt2('function') - num_tokens += self._get_num_tokens_by_gpt2('function') + num_tokens += self._get_num_tokens_by_gpt2("type") + num_tokens += self._get_num_tokens_by_gpt2("function") + num_tokens += self._get_num_tokens_by_gpt2("function") # calculate num tokens for function object - num_tokens += self._get_num_tokens_by_gpt2('name') + num_tokens += self._get_num_tokens_by_gpt2("name") num_tokens += self._get_num_tokens_by_gpt2(tool.name) - num_tokens += self._get_num_tokens_by_gpt2('description') + num_tokens += self._get_num_tokens_by_gpt2("description") num_tokens += self._get_num_tokens_by_gpt2(tool.description) parameters = tool.parameters - num_tokens += self._get_num_tokens_by_gpt2('parameters') - if 'title' in parameters: - num_tokens += self._get_num_tokens_by_gpt2('title') + num_tokens += self._get_num_tokens_by_gpt2("parameters") + if "title" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("title") num_tokens += self._get_num_tokens_by_gpt2(parameters.get("title")) - num_tokens += self._get_num_tokens_by_gpt2('type') + num_tokens += self._get_num_tokens_by_gpt2("type") num_tokens += self._get_num_tokens_by_gpt2(parameters.get("type")) - if 'properties' in parameters: - num_tokens += self._get_num_tokens_by_gpt2('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("properties") + for key, value in parameters.get("properties").items(): num_tokens += self._get_num_tokens_by_gpt2(key) for field_key, field_value in value.items(): num_tokens += self._get_num_tokens_by_gpt2(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += self._get_num_tokens_by_gpt2(enum_field) else: num_tokens += self._get_num_tokens_by_gpt2(field_key) num_tokens += self._get_num_tokens_by_gpt2(str(field_value)) - if 'required' in parameters: - num_tokens += self._get_num_tokens_by_gpt2('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += self._get_num_tokens_by_gpt2("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += self._get_num_tokens_by_gpt2(required_field) return num_tokens - def _extract_response_tool_calls(self, - response_tool_calls: list[dict]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -790,20 +817,17 @@ def _extract_response_tool_calls(self, for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( name=response_tool_call.get("function", {}).get("name", ""), - arguments=response_tool_call.get("function", {}).get("arguments", "") + arguments=response_tool_call.get("function", {}).get("arguments", ""), ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.get("id", ""), - type=response_tool_call.get("type", ""), - function=function + id=response_tool_call.get("id", ""), type=response_tool_call.get("type", ""), function=function ) tool_calls.append(tool_call) return tool_calls - def _extract_response_function_call(self, response_function_call) \ - -> AssistantPromptMessage.ToolCall: + def _extract_response_function_call(self, response_function_call) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -813,14 +837,11 @@ def _extract_response_function_call(self, response_function_call) \ tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call.get('name', ''), - arguments=response_function_call.get('arguments', '') + name=response_function_call.get("name", ""), arguments=response_function_call.get("arguments", "") ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.get('id', ''), - type="function", - function=function + id=response_function_call.get("id", ""), type="function", function=function ) return tool_call diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py index 3445ebbaf752b3..ca6f1852872fed 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py @@ -6,6 +6,5 @@ class OAICompatProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml index 88c76fe16ef733..69a081f35c10a6 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml +++ b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml @@ -8,6 +8,7 @@ supported_model_types: - llm - text-embedding - speech2text + - rerank configurate_methods: - customizable-model model_credential_schema: @@ -83,6 +84,19 @@ model_credential_schema: placeholder: zh_Hans: 在此输入您的模型上下文长度 en_US: Enter your Model context size + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + show_on: + - variable: __model_type + value: rerank + type: text-input + default: '4096' + placeholder: + zh_Hans: 在此输入您的模型上下文长度 + en_US: Enter your Model context size - variable: max_tokens_to_sample label: zh_Hans: 最大 token 上限 diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/rerank/__init__.py b/api/core/model_runtime/model_providers/openai_api_compatible/rerank/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/rerank/rerank.py b/api/core/model_runtime/model_providers/openai_api_compatible/rerank/rerank.py new file mode 100644 index 00000000000000..508da4bf209210 --- /dev/null +++ b/api/core/model_runtime/model_providers/openai_api_compatible/rerank/rerank.py @@ -0,0 +1,159 @@ +from json import dumps +from typing import Optional + +import httpx +from requests import post +from yarl import URL + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class OAICompatRerankModel(RerankModel): + """ + rerank model API is compatible with Jina rerank model API. So copy the JinaRerankModel class code here. + we need enhance for llama.cpp , which return raw score, not normalize score 0~1. It seems Dify need it + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n documents to return + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + server_url = credentials["endpoint_url"] + model_name = model + + if not server_url: + raise CredentialsValidateFailedError("server_url is required") + if not model_name: + raise CredentialsValidateFailedError("model_name is required") + + url = server_url + headers = {"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"} + + # TODO: Do we need truncate docs to avoid llama.cpp return error? + + data = {"model": model_name, "query": query, "documents": docs, "top_n": top_n} + + try: + response = post(str(URL(url) / "rerank"), headers=headers, data=dumps(data), timeout=60) + response.raise_for_status() + results = response.json() + + rerank_documents = [] + scores = [result["relevance_score"] for result in results["results"]] + + # Min-Max Normalization: Normalize scores to 0 ~ 1.0 range + min_score = min(scores) + max_score = max(scores) + score_range = max_score - min_score if max_score != min_score else 1.0 # Avoid division by zero + + for result in results["results"]: + index = result["index"] + + # Retrieve document text (fallback if llama.cpp rerank doesn't return it) + text = result.get("document", {}).get("text", docs[index]) + + # Normalize the score + normalized_score = (result["relevance_score"] - min_score) / score_range + + # Create RerankDocument object with normalized score + rerank_document = RerankDocument( + index=index, + text=text, + score=normalized_score, + ) + + # Apply threshold (if defined) + if score_threshold is None or normalized_score >= score_threshold: + rerank_documents.append(rerank_document) + + # Sort rerank_documents by normalized score in descending order + rerank_documents.sort(key=lambda doc: doc.score, reverse=True) + + return RerankResult(model=model, docs=rerank_documents) + + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return { + InvokeConnectionError: [httpx.ConnectError], + InvokeServerUnavailableError: [httpx.RemoteProtocolError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.RERANK, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={}, + ) + + return entity diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py index 00702ba9367cf4..a490537e51a6ad 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py @@ -3,20 +3,20 @@ import requests +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat -class OAICompatSpeech2TextModel(_CommonOAI_API_Compat, Speech2TextModel): +class OAICompatSpeech2TextModel(_CommonOaiApiCompat, Speech2TextModel): """ Model class for OpenAI Compatible Speech to text model. """ - def _invoke( - self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None - ) -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -61,3 +61,18 @@ def validate_credentials(self, model: str, credentials: dict) -> None: self._invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) + + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.SPEECH2TEXT, + model_properties={}, + parameter_rules=[], + ) + + return entity diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index 363054b084a69c..c2b7297aac596e 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -7,6 +7,7 @@ import numpy as np import requests +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, @@ -19,17 +20,22 @@ from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat -class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): +class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel): """ Model class for an OpenAI API-compatible text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -37,29 +43,28 @@ def _invoke(self, model: str, credentials: dict, :param credentials: model credentials :param texts: texts to embed :param user: unique user id + :param input_type: input type :return: embeddings result """ - + # Prepare headers and payload for the request - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials.get('endpoint_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'embeddings') + endpoint_url = urljoin(endpoint_url, "embeddings") extra_model_kwargs = {} if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user - extra_model_kwargs['encoding_format'] = 'float' + extra_model_kwargs["encoding_format"] = "float" # get model properties context_size = self._get_context_size(model, credentials) @@ -70,7 +75,6 @@ def _invoke(self, model: str, credentials: dict, used_tokens = 0 for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer # TODO: Optimize for better token estimation and chunking num_tokens = self._get_num_tokens_by_gpt2(text) @@ -78,7 +82,7 @@ def _invoke(self, model: str, credentials: dict, if num_tokens >= context_size: cutoff = int(np.floor(len(text) * (context_size / num_tokens))) # if num tokens is larger than context length, only use the start - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) indices += [i] @@ -88,42 +92,25 @@ def _invoke(self, model: str, credentials: dict, for i in _iter: # Prepare the payload for the request - payload = { - 'input': inputs[i: i + max_chunks], - 'model': model, - **extra_model_kwargs - } + payload = {"input": inputs[i : i + max_chunks], "model": model, **extra_model_kwargs} # Make the request to the OpenAI API - response = requests.post( - endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300) - ) + response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) response.raise_for_status() # Raise an exception for HTTP errors response_data = response.json() # Extract embeddings and used tokens from the response - embeddings_batch = [data['embedding'] for data in response_data['data']] - embedding_used_tokens = response_data['usage']['total_tokens'] + embeddings_batch = [data["embedding"] for data in response_data["data"]] + embedding_used_tokens = response_data["usage"]["total_tokens"] used_tokens += embedding_used_tokens batched_embeddings += embeddings_batch # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) - - return TextEmbeddingResult( - embeddings=batched_embeddings, - usage=usage, - model=model - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) + + return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -145,45 +132,35 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials.get('endpoint_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'embeddings') + endpoint_url = urljoin(endpoint_url, "embeddings") - payload = { - 'input': 'ping', - 'model': model - } + payload = {"input": "ping", "model": model} - response = requests.post( - url=endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300) - ) + response = requests.post(url=endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) if response.status_code != 200: raise CredentialsValidateFailedError( - f'Credentials validation failed with status code {response.status_code}') + f"Credentials validation failed with status code {response.status_code}" + ) try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error") - if 'model' not in json_result: - raise CredentialsValidateFailedError( - 'Credentials validation failed: invalid response') + if "model" not in json_result: + raise CredentialsValidateFailedError("Credentials validation failed: invalid response") except CredentialsValidateFailedError: raise except Exception as ex: @@ -191,7 +168,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, @@ -199,20 +176,19 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -224,10 +200,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -238,7 +211,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/openllm/llm/llm.py b/api/core/model_runtime/model_providers/openllm/llm/llm.py index 8ea5819bde1167..e5ecc748fb1b61 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/llm.py +++ b/api/core/model_runtime/model_providers/openllm/llm/llm.py @@ -1,4 +1,5 @@ from collections.abc import Generator +from typing import Optional from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta @@ -38,88 +39,115 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate credentials for Baichuan model """ - if not credentials.get('server_url'): - raise CredentialsValidateFailedError('Invalid server URL') + if not credentials.get("server_url"): + raise CredentialsValidateFailedError("Invalid server URL") # ping instance = OpenLLMGenerate() try: instance.generate( - server_url=credentials['server_url'], - model_name=model, - prompt_messages=[ - OpenLLMGenerateMessage(content='ping\nAnswer: ', role='user') - ], + server_url=credentials["server_url"], + model_name=model, + prompt_messages=[OpenLLMGenerateMessage(content="ping\nAnswer: ", role="user")], model_parameters={ - 'max_tokens': 64, - 'temperature': 0.8, - 'top_p': 0.9, - 'top_k': 15, + "max_tokens": 64, + "temperature": 0.8, + "top_p": 0.9, + "top_k": 15, }, stream=False, - user='', + user="", stop=[], ) except InvalidAuthenticationError as e: raise CredentialsValidateFailedError(f"Invalid API key: {e}") - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: return self._num_tokens_from_messages(prompt_messages, tools) def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ - Calculate num tokens for OpenLLM model - it's a generate model, so we just join them by spe + Calculate num tokens for OpenLLM model + it's a generate model, so we just join them by spe """ - messages = ','.join([message.content for message in messages]) + messages = ",".join([message.content for message in messages]) return self._get_num_tokens_by_gpt2(messages) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: client = OpenLLMGenerate() response = client.generate( model_name=model, - server_url=credentials['server_url'], + server_url=credentials["server_url"], prompt_messages=[self._convert_prompt_message_to_openllm_message(message) for message in prompt_messages], model_parameters=model_parameters, stop=stop, stream=stream, - user=user + user=user, ) if stream: - return self._handle_chat_generate_stream_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) - return self._handle_chat_generate_response(model=model, prompt_messages=prompt_messages, credentials=credentials, response=response) + return self._handle_chat_generate_stream_response( + model=model, prompt_messages=prompt_messages, credentials=credentials, response=response + ) + return self._handle_chat_generate_response( + model=model, prompt_messages=prompt_messages, credentials=credentials, response=response + ) def _convert_prompt_message_to_openllm_message(self, prompt_message: PromptMessage) -> OpenLLMGenerateMessage: """ - convert PromptMessage to OpenLLMGenerateMessage so that we can use OpenLLMGenerateMessage interface + convert PromptMessage to OpenLLMGenerateMessage so that we can use OpenLLMGenerateMessage interface """ if isinstance(prompt_message, UserPromptMessage): return OpenLLMGenerateMessage(role=OpenLLMGenerateMessage.Role.USER.value, content=prompt_message.content) elif isinstance(prompt_message, AssistantPromptMessage): - return OpenLLMGenerateMessage(role=OpenLLMGenerateMessage.Role.ASSISTANT.value, content=prompt_message.content) + return OpenLLMGenerateMessage( + role=OpenLLMGenerateMessage.Role.ASSISTANT.value, content=prompt_message.content + ) else: - raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported') + raise NotImplementedError(f"Prompt message type {type(prompt_message)} is not supported") - def _handle_chat_generate_response(self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: OpenLLMGenerateMessage) -> LLMResult: - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=response.usage['prompt_tokens'], - completion_tokens=response.usage['completion_tokens'] - ) + def _handle_chat_generate_response( + self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: OpenLLMGenerateMessage + ) -> LLMResult: + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=response.usage["prompt_tokens"], + completion_tokens=response.usage["completion_tokens"], + ) return LLMResult( model=model, prompt_messages=prompt_messages, @@ -130,27 +158,29 @@ def _handle_chat_generate_response(self, model: str, prompt_messages: list[Promp usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, prompt_messages: list[PromptMessage], - credentials: dict, response: Generator[OpenLLMGenerateMessage, None, None]) \ - -> Generator[LLMResultChunk, None, None]: + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Generator[OpenLLMGenerateMessage, None, None], + ) -> Generator[LLMResultChunk, None, None]: for message in response: if message.usage: usage = self._calc_response_usage( - model=model, credentials=credentials, - prompt_tokens=message.usage['prompt_tokens'], - completion_tokens=message.usage['completion_tokens'] + model=model, + credentials=credentials, + prompt_tokens=message.usage["prompt_tokens"], + completion_tokens=message.usage["completion_tokens"], ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), usage=usage, - finish_reason=message.stop_reason if message.stop_reason else None, + finish_reason=message.stop_reason or None, ), ) else: @@ -159,73 +189,55 @@ def _handle_chat_generate_stream_response(self, model: str, prompt_messages: lis prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), - finish_reason=message.stop_reason if message.stop_reason else None, + message=AssistantPromptMessage(content=message.content, tool_calls=[]), + finish_reason=message.stop_reason or None, ), ) - - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ - used to define customizable model schema + used to define customizable model schema """ rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ) + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='top_k', + name="top_k", type=ParameterType.INT, - use_template='top_k', + use_template="top_k", min=1, default=1, - label=I18nObject( - zh_Hans='Top K', - en_US='Top K' - ) + label=I18nObject(zh_Hans="Top K", en_US="Top K"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), ] entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - model_properties={ + model_properties={ ModelPropertyKey.MODE: LLMMode.COMPLETION.value, }, - parameter_rules=rules + parameter_rules=rules, ) return entity @@ -241,22 +253,13 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], InvokeAuthorizationError: [ InvalidAuthenticationError, InsufficientAccountBalanceError, InvalidAPIKeyError, ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] + InvokeBadRequestError: [BadRequestError, KeyError], } - diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py index 1c3f084207ff79..351dcced153750 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py @@ -15,32 +15,38 @@ class OpenLLMGenerateMessage: class Role(Enum): - USER = 'user' - ASSISTANT = 'assistant' + USER = "user" + ASSISTANT = "assistant" role: str = Role.USER.value content: str usage: dict[str, int] = None - stop_reason: str = '' + stop_reason: str = "" def to_dict(self) -> dict[str, Any]: return { - 'role': self.role, - 'content': self.content, + "role": self.role, + "content": self.content, } - - def __init__(self, content: str, role: str = 'user') -> None: + + def __init__(self, content: str, role: str = "user") -> None: self.content = content self.role = role class OpenLLMGenerate: def generate( - self, server_url: str, model_name: str, stream: bool, model_parameters: dict[str, Any], - stop: list[str], prompt_messages: list[OpenLLMGenerateMessage], user: str, + self, + server_url: str, + model_name: str, + stream: bool, + model_parameters: dict[str, Any], + stop: list[str], + prompt_messages: list[OpenLLMGenerateMessage], + user: str, ) -> Union[Generator[OpenLLMGenerateMessage, None, None], OpenLLMGenerateMessage]: if not server_url: - raise InvalidAuthenticationError('Invalid server URL') + raise InvalidAuthenticationError("Invalid server URL") default_llm_config = { "max_new_tokens": 128, @@ -72,40 +78,37 @@ def generate( "frequency_penalty": 0, "use_beam_search": False, "ignore_eos": False, - "skip_special_tokens": True + "skip_special_tokens": True, } - if 'max_tokens' in model_parameters and type(model_parameters['max_tokens']) == int: - default_llm_config['max_new_tokens'] = model_parameters['max_tokens'] + if "max_tokens" in model_parameters and type(model_parameters["max_tokens"]) == int: + default_llm_config["max_new_tokens"] = model_parameters["max_tokens"] - if 'temperature' in model_parameters and type(model_parameters['temperature']) == float: - default_llm_config['temperature'] = model_parameters['temperature'] + if "temperature" in model_parameters and type(model_parameters["temperature"]) == float: + default_llm_config["temperature"] = model_parameters["temperature"] - if 'top_p' in model_parameters and type(model_parameters['top_p']) == float: - default_llm_config['top_p'] = model_parameters['top_p'] + if "top_p" in model_parameters and type(model_parameters["top_p"]) == float: + default_llm_config["top_p"] = model_parameters["top_p"] - if 'top_k' in model_parameters and type(model_parameters['top_k']) == int: - default_llm_config['top_k'] = model_parameters['top_k'] + if "top_k" in model_parameters and type(model_parameters["top_k"]) == int: + default_llm_config["top_k"] = model_parameters["top_k"] - if 'use_cache' in model_parameters and type(model_parameters['use_cache']) == bool: - default_llm_config['use_cache'] = model_parameters['use_cache'] + if "use_cache" in model_parameters and type(model_parameters["use_cache"]) == bool: + default_llm_config["use_cache"] = model_parameters["use_cache"] - headers = { - 'Content-Type': 'application/json', - 'accept': 'application/json' - } + headers = {"Content-Type": "application/json", "accept": "application/json"} if stream: - url = f'{server_url}/v1/generate_stream' + url = f"{server_url}/v1/generate_stream" timeout = 10 else: - url = f'{server_url}/v1/generate' + url = f"{server_url}/v1/generate" timeout = 120 data = { - 'stop': stop if stop else [], - 'prompt': '\n'.join([message.content for message in prompt_messages]), - 'llm_config': default_llm_config, + "stop": stop or [], + "prompt": "\n".join([message.content for message in prompt_messages]), + "llm_config": default_llm_config, } try: @@ -113,10 +116,10 @@ def generate( except (ConnectionError, InvalidSchema, MissingSchema) as e: # cloud not connect to the server raise InvalidAuthenticationError(f"Invalid server URL: {e}") - + if not response.ok: resp = response.json() - msg = resp['msg'] + msg = resp["msg"] if response.status_code == 400: raise BadRequestError(msg) elif response.status_code == 404: @@ -125,69 +128,71 @@ def generate( raise InternalServerError(msg) else: raise InternalServerError(msg) - + if stream: return self._handle_chat_stream_generate_response(response) return self._handle_chat_generate_response(response) - + def _handle_chat_generate_response(self, response: Response) -> OpenLLMGenerateMessage: try: data = response.json() except Exception as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}") - message = data['outputs'][0] - text = message['text'] - token_ids = message['token_ids'] - prompt_token_ids = data['prompt_token_ids'] - stop_reason = message['finish_reason'] + message = data["outputs"][0] + text = message["text"] + token_ids = message["token_ids"] + prompt_token_ids = data["prompt_token_ids"] + stop_reason = message["finish_reason"] message = OpenLLMGenerateMessage(content=text, role=OpenLLMGenerateMessage.Role.ASSISTANT.value) message.stop_reason = stop_reason message.usage = { - 'prompt_tokens': len(prompt_token_ids), - 'completion_tokens': len(token_ids), - 'total_tokens': len(prompt_token_ids) + len(token_ids), + "prompt_tokens": len(prompt_token_ids), + "completion_tokens": len(token_ids), + "total_tokens": len(prompt_token_ids) + len(token_ids), } return message - def _handle_chat_stream_generate_response(self, response: Response) -> Generator[OpenLLMGenerateMessage, None, None]: + def _handle_chat_stream_generate_response( + self, response: Response + ) -> Generator[OpenLLMGenerateMessage, None, None]: completion_usage = 0 for line in response.iter_lines(): if not line: continue - line: str = line.decode('utf-8') - if line.startswith('data: '): + line: str = line.decode("utf-8") + if line.startswith("data: "): line = line[6:].strip() - if line == '[DONE]': + if line == "[DONE]": return try: data = loads(line) except Exception as e: raise InternalServerError(f"Failed to convert response to json: {e} with text: {line}") - - output = data['outputs'] + + output = data["outputs"] for choice in output: - text = choice['text'] - token_ids = choice['token_ids'] + text = choice["text"] + token_ids = choice["token_ids"] completion_usage += len(token_ids) message = OpenLLMGenerateMessage(content=text, role=OpenLLMGenerateMessage.Role.ASSISTANT.value) - if choice.get('finish_reason'): - finish_reason = choice['finish_reason'] - prompt_token_ids = data['prompt_token_ids'] + if choice.get("finish_reason"): + finish_reason = choice["finish_reason"] + prompt_token_ids = data["prompt_token_ids"] message.stop_reason = finish_reason message.usage = { - 'prompt_tokens': len(prompt_token_ids), - 'completion_tokens': completion_usage, - 'total_tokens': completion_usage + len(prompt_token_ids), + "prompt_tokens": len(prompt_token_ids), + "completion_tokens": completion_usage, + "total_tokens": completion_usage + len(prompt_token_ids), } - - yield message \ No newline at end of file + + yield message diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py index d9d279e6ca0ed1..309b5cf413bd54 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py @@ -1,17 +1,22 @@ class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass + class InsufficientAccountBalanceError(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py index 4dbd0678e71f7f..43a2e948e2cee5 100644 --- a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py @@ -5,6 +5,7 @@ from requests import post from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( @@ -23,9 +24,15 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel): """ Model class for OpenLLM text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -33,18 +40,16 @@ def _invoke(self, model: str, credentials: dict, :param credentials: model credentials :param texts: texts to embed :param user: unique user id + :param input_type: input type :return: embeddings result """ - server_url = credentials['server_url'] + server_url = credentials["server_url"] if not server_url: - raise CredentialsValidateFailedError('server_url is required') - - headers = { - 'Content-Type': 'application/json', - 'accept': 'application/json' - } + raise CredentialsValidateFailedError("server_url is required") - url = f'{server_url}/v1/embeddings' + headers = {"Content-Type": "application/json", "accept": "application/json"} + + url = f"{server_url}/v1/embeddings" data = texts try: @@ -54,7 +59,7 @@ def _invoke(self, model: str, credentials: dict, raise InvokeAuthorizationError(f"Invalid server URL: {e}") except Exception as e: raise InvokeConnectionError(str(e)) - + if response.status_code != 200: if response.status_code == 400: raise InvokeBadRequestError(response.text) @@ -62,21 +67,17 @@ def _invoke(self, model: str, credentials: dict, raise InvokeAuthorizationError(response.text) elif response.status_code == 500: raise InvokeServerUnavailableError(response.text) - + try: resp = response.json()[0] - embeddings = resp['embeddings'] - total_tokens = resp['num_tokens'] + embeddings = resp["embeddings"] + total_tokens = resp["num_tokens"] except KeyError as e: raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") usage = self._calc_response_usage(model=model, credentials=credentials, tokens=total_tokens) - result = TextEmbeddingResult( - model=model, - embeddings=embeddings, - usage=usage - ) + result = TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage) return result @@ -104,9 +105,9 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError: - raise CredentialsValidateFailedError('Invalid server_url') + raise CredentialsValidateFailedError("Invalid server_url") @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: @@ -119,23 +120,13 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -147,10 +138,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -161,7 +149,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/openrouter/llm/_position.yaml b/api/core/model_runtime/model_providers/openrouter/llm/_position.yaml index 7e00dd3f4bc3ce..5a25c84c34c0cb 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/_position.yaml @@ -1,3 +1,5 @@ +- openai/o1-preview +- openai/o1-mini - openai/gpt-4o - openai/gpt-4o-mini - openai/gpt-4 @@ -12,6 +14,10 @@ - google/gemini-pro - cohere/command-r-plus - cohere/command-r +- meta-llama/llama-3.2-1b-instruct +- meta-llama/llama-3.2-3b-instruct +- meta-llama/llama-3.2-11b-vision-instruct +- meta-llama/llama-3.2-90b-vision-instruct - meta-llama/llama-3.1-405b-instruct - meta-llama/llama-3.1-70b-instruct - meta-llama/llama-3.1-8b-instruct @@ -20,6 +26,7 @@ - mistralai/mixtral-8x22b-instruct - mistralai/mixtral-8x7b-instruct - mistralai/mistral-7b-instruct +- qwen/qwen-2.5-72b-instruct - qwen/qwen-2-72b-instruct - deepseek/deepseek-chat - deepseek/deepseek-coder diff --git a/api/core/model_runtime/model_providers/openrouter/llm/claude-3-5-haiku.yaml b/api/core/model_runtime/model_providers/openrouter/llm/claude-3-5-haiku.yaml new file mode 100644 index 00000000000000..de45093a72a85c --- /dev/null +++ b/api/core/model_runtime/model_providers/openrouter/llm/claude-3-5-haiku.yaml @@ -0,0 +1,38 @@ +model: anthropic/claude-3-5-haiku +label: + en_US: claude-3-5-haiku +model_type: llm +features: + - agent-thought + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 + - name: response_format + use_template: response_format +pricing: + input: "1" + output: "5" + unit: "0.000001" + currency: USD diff --git a/api/core/model_runtime/model_providers/openrouter/llm/claude-3-5-sonnet.yaml b/api/core/model_runtime/model_providers/openrouter/llm/claude-3-5-sonnet.yaml index 40558854e2a7bd..e829048e55e0d8 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/claude-3-5-sonnet.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/claude-3-5-sonnet.yaml @@ -27,9 +27,9 @@ parameter_rules: - name: max_tokens use_template: max_tokens required: true - default: 4096 + default: 8192 min: 1 - max: 4096 + max: 8192 - name: response_format use_template: response_format pricing: diff --git a/api/core/model_runtime/model_providers/openrouter/llm/deepseek-chat.yaml b/api/core/model_runtime/model_providers/openrouter/llm/deepseek-chat.yaml index 7a1dea69503daa..6743bfcad662fd 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/deepseek-chat.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/deepseek-chat.yaml @@ -35,6 +35,15 @@ parameter_rules: help: zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。 en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature. + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty default: 0 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/deepseek-coder.yaml b/api/core/model_runtime/model_providers/openrouter/llm/deepseek-coder.yaml index c05f4769b83354..375a4d2d52ed2a 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/deepseek-coder.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/deepseek-coder.yaml @@ -18,6 +18,15 @@ parameter_rules: min: 0 max: 1 default: 1 + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens min: 1 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-3.5-turbo.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-3.5-turbo.yaml index 1737c50bb10869..621ecf065e93d6 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/gpt-3.5-turbo.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-3.5-turbo.yaml @@ -14,6 +14,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: presence_penalty use_template: presence_penalty - name: frequency_penalty @@ -26,7 +35,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4-32k.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4-32k.yaml index 2d55cf85656f15..887e6d60f9c129 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4-32k.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4-32k.yaml @@ -14,6 +14,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: presence_penalty use_template: presence_penalty - name: frequency_penalty @@ -41,7 +50,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4.yaml index 12015f6f6432f4..66d1f9ae67eafd 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4.yaml @@ -14,6 +14,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: presence_penalty use_template: presence_penalty - name: frequency_penalty @@ -41,7 +50,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-2024-08-06.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-2024-08-06.yaml new file mode 100644 index 00000000000000..695cc3eedf5999 --- /dev/null +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-2024-08-06.yaml @@ -0,0 +1,53 @@ +model: gpt-4o-2024-08-06 +label: + zh_Hans: gpt-4o-2024-08-06 + en_US: gpt-4o-2024-08-06 +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 16384 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: '2.50' + output: '10.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-mini.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-mini.yaml index de0bad413653eb..e1e5889085f258 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-mini.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-mini.yaml @@ -15,6 +15,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: presence_penalty use_template: presence_penalty - name: frequency_penalty @@ -27,7 +36,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o.yaml index 6945402c7207a8..560bf9d7d05dd6 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o.yaml @@ -15,6 +15,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: presence_penalty use_template: presence_penalty - name: frequency_penalty @@ -27,7 +36,7 @@ parameter_rules: - name: response_format label: zh_Hans: 回复格式 - en_US: response_format + en_US: Response Format type: string help: zh_Hans: 指定模型必须输出的格式 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llama-3-70b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/llama-3-70b-instruct.yaml index b91c39e729eda3..04a4a90c6d6e34 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llama-3-70b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/llama-3-70b-instruct.yaml @@ -10,6 +10,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens required: true diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llama-3-8b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/llama-3-8b-instruct.yaml index 84b2c7fac2c0f0..066949d431eaa3 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llama-3-8b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/llama-3-8b-instruct.yaml @@ -10,6 +10,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens required: true diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-405b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-405b-instruct.yaml index a489ce1b5ad384..0cd89dea7164b8 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-405b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-405b-instruct.yaml @@ -10,6 +10,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens required: true diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-70b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-70b-instruct.yaml index 12037411b100bd..768ab5ecbb6788 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-70b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-70b-instruct.yaml @@ -10,6 +10,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens required: true diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-8b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-8b-instruct.yaml index 6f06493f293f9d..67b6b82b5d7e9d 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-8b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-8b-instruct.yaml @@ -10,6 +10,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens required: true diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.2-11b-vision-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.2-11b-vision-instruct.yaml new file mode 100644 index 00000000000000..6ad2c26cc862d4 --- /dev/null +++ b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.2-11b-vision-instruct.yaml @@ -0,0 +1,46 @@ +model: meta-llama/llama-3.2-11b-vision-instruct +label: + zh_Hans: llama-3.2-11b-vision-instruct + en_US: llama-3.2-11b-vision-instruct +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.055' + output: '0.055' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.2-1b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.2-1b-instruct.yaml new file mode 100644 index 00000000000000..657ef16835f420 --- /dev/null +++ b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.2-1b-instruct.yaml @@ -0,0 +1,45 @@ +model: meta-llama/llama-3.2-1b-instruct +label: + zh_Hans: llama-3.2-1b-instruct + en_US: llama-3.2-1b-instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.01' + output: '0.02' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.2-3b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.2-3b-instruct.yaml new file mode 100644 index 00000000000000..7f6e24e591aacc --- /dev/null +++ b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.2-3b-instruct.yaml @@ -0,0 +1,45 @@ +model: meta-llama/llama-3.2-3b-instruct +label: + zh_Hans: llama-3.2-3b-instruct + en_US: llama-3.2-3b-instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.03' + output: '0.05' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.2-90b-vision-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.2-90b-vision-instruct.yaml new file mode 100644 index 00000000000000..c264db0f206f35 --- /dev/null +++ b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.2-90b-vision-instruct.yaml @@ -0,0 +1,46 @@ +model: meta-llama/llama-3.2-90b-vision-instruct +label: + zh_Hans: llama-3.2-90b-vision-instruct + en_US: llama-3.2-90b-vision-instruct +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.35' + output: '0.4' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llm.py b/api/core/model_runtime/model_providers/openrouter/llm/llm.py index e78ac4caf1da1e..736ab8e7a8f6ce 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llm.py +++ b/api/core/model_runtime/model_providers/openrouter/llm/llm.py @@ -1,48 +1,106 @@ from collections.abc import Generator from typing import Optional, Union -from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _update_credential(self, model: str, credentials: dict): - credentials['endpoint_url'] = "https://openrouter.ai/api/v1" - credentials['mode'] = self.get_model_mode(model).value - credentials['function_calling_type'] = 'tool_call' - return - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + credentials["endpoint_url"] = "https://openrouter.ai/api/v1" + credentials["mode"] = self.get_model_mode(model).value + credentials["function_calling_type"] = "tool_call" + + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._update_credential(model, credentials) - return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def validate_credentials(self, model: str, credentials: dict) -> None: self._update_credential(model, credentials) return super().validate_credentials(model, credentials) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._update_credential(model, credentials) - return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + block_as_stream = False + if model.startswith("openai/o1"): + block_as_stream = True + stop = None + + # invoke block as stream + if stream and block_as_stream: + return self._generate_block_as_stream( + model, credentials, prompt_messages, model_parameters, tools, stop, user + ) + else: + return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + + def _generate_block_as_stream( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + user: Optional[str] = None, + ) -> Generator: + resp: LLMResult = super()._generate( + model, credentials, prompt_messages, model_parameters, tools, stop, False, user + ) + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=resp.message, + usage=self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=resp.usage.prompt_tokens, + completion_tokens=resp.usage.completion_tokens, + ), + finish_reason="stop", + ), + ) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: self._update_credential(model, credentials) return super().get_customizable_model_schema(model, credentials) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: self._update_credential(model, credentials) return super().get_num_tokens(model, credentials, prompt_messages, tools) diff --git a/api/core/model_runtime/model_providers/openrouter/llm/mistral-7b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/mistral-7b-instruct.yaml index 012dfc55ce18a8..d08c016e95482f 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/mistral-7b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/mistral-7b-instruct.yaml @@ -18,6 +18,15 @@ parameter_rules: default: 1 min: 0 max: 1 + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens default: 1024 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x22b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x22b-instruct.yaml index f4eb4e45d95cb2..e3af0e64d8fe2c 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x22b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x22b-instruct.yaml @@ -18,6 +18,15 @@ parameter_rules: default: 1 min: 0 max: 1 + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens default: 1024 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x7b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x7b-instruct.yaml index 7871e1f7a05c17..095ea5a858417e 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x7b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x7b-instruct.yaml @@ -19,6 +19,15 @@ parameter_rules: default: 1 min: 0 max: 1 + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens default: 1024 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/o1-mini.yaml b/api/core/model_runtime/model_providers/openrouter/llm/o1-mini.yaml new file mode 100644 index 00000000000000..f4202ee814e38c --- /dev/null +++ b/api/core/model_runtime/model_providers/openrouter/llm/o1-mini.yaml @@ -0,0 +1,49 @@ +model: openai/o1-mini +label: + en_US: o1-mini +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 65536 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: "3.00" + output: "12.00" + unit: "0.000001" + currency: USD diff --git a/api/core/model_runtime/model_providers/openrouter/llm/o1-preview.yaml b/api/core/model_runtime/model_providers/openrouter/llm/o1-preview.yaml new file mode 100644 index 00000000000000..1281b842862568 --- /dev/null +++ b/api/core/model_runtime/model_providers/openrouter/llm/o1-preview.yaml @@ -0,0 +1,49 @@ +model: openai/o1-preview +label: + en_US: o1-preview +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 32768 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object +pricing: + input: "15.00" + output: "60.00" + unit: "0.000001" + currency: USD diff --git a/api/core/model_runtime/model_providers/openrouter/llm/qwen2-72b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/qwen2-72b-instruct.yaml index 7b75fcb0c986a3..b6058138d30df5 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/qwen2-72b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/qwen2-72b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/openrouter/llm/qwen2.5-72b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/qwen2.5-72b-instruct.yaml new file mode 100644 index 00000000000000..5392b111689a27 --- /dev/null +++ b/api/core/model_runtime/model_providers/openrouter/llm/qwen2.5-72b-instruct.yaml @@ -0,0 +1,39 @@ +model: qwen/qwen-2.5-72b-instruct +label: + en_US: qwen-2.5-72b-instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 512 + min: 1 + max: 8192 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: frequency_penalty + use_template: frequency_penalty +pricing: + input: "0.35" + output: "0.4" + unit: "0.000001" + currency: USD diff --git a/api/core/model_runtime/model_providers/openrouter/openrouter.py b/api/core/model_runtime/model_providers/openrouter/openrouter.py index 613f71deb1c806..2e59ab50598b8b 100644 --- a/api/core/model_runtime/model_providers/openrouter/openrouter.py +++ b/api/core/model_runtime/model_providers/openrouter/openrouter.py @@ -8,17 +8,13 @@ class OpenRouterProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='openai/gpt-3.5-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="openai/gpt-3.5-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') - raise ex \ No newline at end of file + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Llama3-Chinese_v2.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Llama3-Chinese_v2.yaml index 87712874b9f5c6..bf91468fcf56bf 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/Llama3-Chinese_v2.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Llama3-Chinese_v2.yaml @@ -59,3 +59,4 @@ pricing: output: "0.000" unit: "0.000" currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3-70B-Instruct-GPTQ-Int4.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3-70B-Instruct-GPTQ-Int4.yaml index f16f3de60b8d30..781b837e8eceea 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3-70B-Instruct-GPTQ-Int4.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3-70B-Instruct-GPTQ-Int4.yaml @@ -59,3 +59,4 @@ pricing: output: "0.000" unit: "0.000" currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3-8B-Instruct.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3-8B-Instruct.yaml index 21267c240bc0e5..67210e9020fe6a 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3-8B-Instruct.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3-8B-Instruct.yaml @@ -59,3 +59,4 @@ pricing: output: "0.000" unit: "0.000" currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3.1-405B-Instruct-AWQ-INT4.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3.1-405B-Instruct-AWQ-INT4.yaml index 80c7ec40f22cb9..482632ff0673d2 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3.1-405B-Instruct-AWQ-INT4.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Meta-Llama-3.1-405B-Instruct-AWQ-INT4.yaml @@ -59,3 +59,4 @@ pricing: output: "0.000" unit: "0.000" currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-72B-Chat-GPTQ-Int4.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-72B-Chat-GPTQ-Int4.yaml index 841dd97f353b04..ddb6fd977c5a56 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-72B-Chat-GPTQ-Int4.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-72B-Chat-GPTQ-Int4.yaml @@ -59,3 +59,4 @@ pricing: output: "0.000" unit: "0.000" currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-7B.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-7B.yaml index 33d5d12b223d2b..024c79dbcfd76c 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-7B.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen1.5-7B.yaml @@ -59,3 +59,4 @@ pricing: output: "0.000" unit: "0.000" currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-72B-Instruct-AWQ-int4.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-72B-Instruct-AWQ-int4.yaml new file mode 100644 index 00000000000000..94f661f40d5c90 --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-72B-Instruct-AWQ-int4.yaml @@ -0,0 +1,61 @@ +model: Qwen2-72B-Instruct-AWQ-int4 +label: + en_US: Qwen2-72B-Instruct-AWQ-int4 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.5 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-72B-Instruct-GPTQ-Int4.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-72B-Instruct-GPTQ-Int4.yaml index 62255cc7d2e566..a06f8d5ab18f4c 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-72B-Instruct-GPTQ-Int4.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-72B-Instruct-GPTQ-Int4.yaml @@ -61,3 +61,4 @@ pricing: output: "0.000" unit: "0.000" currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-7B-Instruct.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-7B-Instruct.yaml new file mode 100644 index 00000000000000..4369411399e72a --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-7B-Instruct.yaml @@ -0,0 +1,63 @@ +model: Qwen2-7B-Instruct +label: + en_US: Qwen2-7B-Instruct +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: completion + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 2000 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-7B.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-7B.yaml index 2f3f1f0225d712..d549ecd227dfb8 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-7B.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2-7B.yaml @@ -61,3 +61,4 @@ pricing: output: "0.000" unit: "0.000" currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2.5-72B-Instruct.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2.5-72B-Instruct.yaml new file mode 100644 index 00000000000000..15cbf01f1f66da --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2.5-72B-Instruct.yaml @@ -0,0 +1,61 @@ +model: Qwen2.5-72B-Instruct +label: + en_US: Qwen2.5-72B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 30720 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.5 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2.5-7B-Instruct.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2.5-7B-Instruct.yaml new file mode 100644 index 00000000000000..dadc8f8f3275e5 --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Qwen2.5-7B-Instruct.yaml @@ -0,0 +1,61 @@ +model: Qwen2.5-7B-Instruct +label: + en_US: Qwen2.5-7B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.5 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Reflection-Llama-3.1-70B.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Reflection-Llama-3.1-70B.yaml new file mode 100644 index 00000000000000..649be20b48abef --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Reflection-Llama-3.1-70B.yaml @@ -0,0 +1,61 @@ +model: Reflection-Llama-3.1-70B +label: + en_US: Reflection-Llama-3.1-70B +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 10240 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.5 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Yi-1_5-9B-Chat-16K.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Yi-1_5-9B-Chat-16K.yaml new file mode 100644 index 00000000000000..92eae6804f61fa --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Yi-1_5-9B-Chat-16K.yaml @@ -0,0 +1,61 @@ +model: Yi-1_5-9B-Chat-16K +label: + en_US: Yi-1_5-9B-Chat-16K +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 16384 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.5 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Yi-Coder-1.5B-Chat.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Yi-Coder-1.5B-Chat.yaml new file mode 100644 index 00000000000000..0e21ce148c39bd --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Yi-Coder-1.5B-Chat.yaml @@ -0,0 +1,61 @@ +model: Yi-Coder-1.5B-Chat +label: + en_US: Yi-Coder-1.5B-Chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 20480 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.5 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/Yi-Coder-9B-Chat.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/Yi-Coder-9B-Chat.yaml new file mode 100644 index 00000000000000..23b0841ce4ed65 --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/Yi-Coder-9B-Chat.yaml @@ -0,0 +1,61 @@ +model: Yi-Coder-9B-Chat +label: + en_US: Yi-Coder-9B-Chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 20480 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.5 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 600 + min: 1 + max: 1248 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. +pricing: + input: "0.000" + output: "0.000" + unit: "0.000" + currency: RMB diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/_position.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/_position.yaml index 2c9eac0e49a4d7..c6930e54f50aa4 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/_position.yaml @@ -1,15 +1,23 @@ -- Meta-Llama-3.1-405B-Instruct-AWQ-INT4 -- Meta-Llama-3.1-8B-Instruct -- Meta-Llama-3-70B-Instruct-GPTQ-Int4 -- Meta-Llama-3-8B-Instruct -- Qwen2-72B-Instruct-GPTQ-Int4 +- Qwen2.5-72B-Instruct +- Qwen2.5-7B-Instruct - Qwen2-72B-Instruct +- Qwen2-72B-Instruct-AWQ-int4 +- Qwen2-72B-Instruct-GPTQ-Int4 +- Qwen2-7B-Instruct - Qwen2-7B -- Qwen-14B-Chat-Int4 +- Qwen1.5-110B-Chat-GPTQ-Int4 - Qwen1.5-72B-Chat-GPTQ-Int4 - Qwen1.5-7B -- Qwen1.5-110B-Chat-GPTQ-Int4 -- deepseek-v2-chat -- deepseek-v2-lite-chat +- Qwen-14B-Chat-Int4 +- Yi-Coder-1.5B-Chat +- Yi-Coder-9B-Chat +- Yi-1_5-9B-Chat-16K +- Reflection-Llama-3.1-70B +- Meta-Llama-3.1-8B-Instruct +- Meta-Llama-3.1-405B-Instruct-AWQ-INT4 +- Meta-Llama-3-70B-Instruct-GPTQ-Int4 +- Meta-Llama-3-8B-Instruct - Llama3-Chinese_v2 +- deepseek-v2-lite-chat +- deepseek-v2-chat - chatglm3-6b diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/chatglm3-6b.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/chatglm3-6b.yaml index f9c26b7f9002c7..75d80f784a71f5 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/chatglm3-6b.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/chatglm3-6b.yaml @@ -59,3 +59,4 @@ pricing: output: "0.000" unit: "0.000" currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/deepseek-v2-chat.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/deepseek-v2-chat.yaml index 078922ef95b087..fa9a7b7175e9dc 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/deepseek-v2-chat.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/deepseek-v2-chat.yaml @@ -59,3 +59,4 @@ pricing: output: "0.000" unit: "0.000" currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/deepseek-v2-lite-chat.yaml b/api/core/model_runtime/model_providers/perfxcloud/llm/deepseek-v2-lite-chat.yaml index 4ff3af7b517761..75a26d25051aea 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/deepseek-v2-lite-chat.yaml +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/deepseek-v2-lite-chat.yaml @@ -59,3 +59,4 @@ pricing: output: "0.000" unit: "0.000" currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/perfxcloud/llm/llm.py b/api/core/model_runtime/model_providers/perfxcloud/llm/llm.py index c9116bf68538b4..89cac665aa5a08 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/llm/llm.py +++ b/api/core/model_runtime/model_providers/perfxcloud/llm/llm.py @@ -13,11 +13,17 @@ class PerfXCloudLargeLanguageModel(OpenAILargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -27,8 +33,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: super().validate_credentials(model, credentials) # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_string(self, model: str, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -46,8 +51,9 @@ def _num_tokens_from_string(self, model: str, text: str, return num_tokens # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ @@ -67,10 +73,10 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -101,10 +107,10 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['openai_api_key']=credentials['api_key'] - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - credentials['openai_api_base']='https://cloud.perfxlab.cn' + credentials["mode"] = "chat" + credentials["openai_api_key"] = credentials["api_key"] + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + credentials["openai_api_base"] = "https://cloud.perfxlab.cn" else: - parsed_url = urlparse(credentials['endpoint_url']) - credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}" + parsed_url = urlparse(credentials["endpoint_url"]) + credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}" diff --git a/api/core/model_runtime/model_providers/perfxcloud/perfxcloud.py b/api/core/model_runtime/model_providers/perfxcloud/perfxcloud.py index 0854ef5185143d..9a4ead031d018c 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/perfxcloud.py +++ b/api/core/model_runtime/model_providers/perfxcloud/perfxcloud.py @@ -1,32 +1,10 @@ import logging -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.model_provider import ModelProvider logger = logging.getLogger(__name__) class PerfXCloudProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: - """ - Validate provider credentials - if validate failed, raise exception - - :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. - """ - try: - model_instance = self.get_model_instance(ModelType.LLM) - - # Use `Qwen2_72B_Chat_GPTQ_Int4` model for validate, - # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='Qwen2-72B-Instruct-GPTQ-Int4', - credentials=credentials - ) - except CredentialsValidateFailedError as ex: - raise ex - except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') - raise ex + pass diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/_position.yaml b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/_position.yaml new file mode 100644 index 00000000000000..99163d42931b16 --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/_position.yaml @@ -0,0 +1,4 @@ +- gte-Qwen2-7B-instruct +- BAAI/bge-large-en-v1.5 +- BAAI/bge-large-zh-v1.5 +- BAAI/bge-m3 diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/gte-Qwen2-7B-instruct.yaml b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/gte-Qwen2-7B-instruct.yaml new file mode 100644 index 00000000000000..161d5ea9a2657e --- /dev/null +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/gte-Qwen2-7B-instruct.yaml @@ -0,0 +1,5 @@ +model: gte-Qwen2-7B-instruct +model_type: text-embedding +model_properties: + context_size: 2048 +deprecated: true diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py index 11d57e3749a8f1..d78bdaa75e5423 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py @@ -7,6 +7,7 @@ import numpy as np import requests +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, @@ -19,17 +20,22 @@ from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat -class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): +class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel): """ Model class for an OpenAI API-compatible text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -37,32 +43,31 @@ def _invoke(self, model: str, credentials: dict, :param credentials: model credentials :param texts: texts to embed :param user: unique user id + :param input_type: input type :return: embeddings result """ - + # Prepare headers and payload for the request - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - endpoint_url='https://cloud.perfxlab.cn/v1/' + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + endpoint_url = "https://cloud.perfxlab.cn/v1/" else: - endpoint_url = credentials.get('endpoint_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'embeddings') + endpoint_url = urljoin(endpoint_url, "embeddings") extra_model_kwargs = {} if user: - extra_model_kwargs['user'] = user + extra_model_kwargs["user"] = user - extra_model_kwargs['encoding_format'] = 'float' + extra_model_kwargs["encoding_format"] = "float" # get model properties context_size = self._get_context_size(model, credentials) @@ -73,7 +78,6 @@ def _invoke(self, model: str, credentials: dict, used_tokens = 0 for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer # TODO: Optimize for better token estimation and chunking num_tokens = self._get_num_tokens_by_gpt2(text) @@ -81,7 +85,7 @@ def _invoke(self, model: str, credentials: dict, if num_tokens >= context_size: cutoff = int(np.floor(len(text) * (context_size / num_tokens))) # if num tokens is larger than context length, only use the start - inputs.append(text[0: cutoff]) + inputs.append(text[0:cutoff]) else: inputs.append(text) indices += [i] @@ -91,42 +95,25 @@ def _invoke(self, model: str, credentials: dict, for i in _iter: # Prepare the payload for the request - payload = { - 'input': inputs[i: i + max_chunks], - 'model': model, - **extra_model_kwargs - } + payload = {"input": inputs[i : i + max_chunks], "model": model, **extra_model_kwargs} # Make the request to the OpenAI API - response = requests.post( - endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300) - ) + response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) response.raise_for_status() # Raise an exception for HTTP errors response_data = response.json() # Extract embeddings and used tokens from the response - embeddings_batch = [data['embedding'] for data in response_data['data']] - embedding_used_tokens = response_data['usage']['total_tokens'] + embeddings_batch = [data["embedding"] for data in response_data["data"]] + embedding_used_tokens = response_data["usage"]["total_tokens"] used_tokens += embedding_used_tokens batched_embeddings += embeddings_batch # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) - - return TextEmbeddingResult( - embeddings=batched_embeddings, - usage=usage, - model=model - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) + + return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -148,48 +135,38 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - headers = { - 'Content-Type': 'application/json' - } + headers = {"Content-Type": "application/json"} - api_key = credentials.get('api_key') + api_key = credentials.get("api_key") if api_key: headers["Authorization"] = f"Bearer {api_key}" - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - endpoint_url='https://cloud.perfxlab.cn/v1/' + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + endpoint_url = "https://cloud.perfxlab.cn/v1/" else: - endpoint_url = credentials.get('endpoint_url') - if not endpoint_url.endswith('/'): - endpoint_url += '/' + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" - endpoint_url = urljoin(endpoint_url, 'embeddings') + endpoint_url = urljoin(endpoint_url, "embeddings") - payload = { - 'input': 'ping', - 'model': model - } + payload = {"input": "ping", "model": model} - response = requests.post( - url=endpoint_url, - headers=headers, - data=json.dumps(payload), - timeout=(10, 300) - ) + response = requests.post(url=endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) if response.status_code != 200: raise CredentialsValidateFailedError( - f'Credentials validation failed with status code {response.status_code}') + f"Credentials validation failed with status code {response.status_code}" + ) try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') + raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error") - if 'model' not in json_result: - raise CredentialsValidateFailedError( - 'Credentials validation failed: invalid response') + if "model" not in json_result: + raise CredentialsValidateFailedError("Credentials validation failed: invalid response") except CredentialsValidateFailedError: raise except Exception as ex: @@ -197,7 +174,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, @@ -205,20 +182,19 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -230,10 +206,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -244,7 +217,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/replicate/_common.py b/api/core/model_runtime/model_providers/replicate/_common.py index 29d8427d8ef231..915f6e0eefcd08 100644 --- a/api/core/model_runtime/model_providers/replicate/_common.py +++ b/api/core/model_runtime/model_providers/replicate/_common.py @@ -4,12 +4,6 @@ class _CommonReplicate: - @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: - return { - InvokeBadRequestError: [ - ReplicateError, - ModelError - ] - } + return {InvokeBadRequestError: [ReplicateError, ModelError]} diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index 31b81a829e0882..3641b35dc02a39 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -28,16 +28,22 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: - - model_version = '' - if 'model_version' in credentials: - model_version = credentials['model_version'] - - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + model_version = "" + if "model_version" in credentials: + model_version = credentials["model_version"] + + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) model_info = client.models.get(model) if model_version: @@ -48,39 +54,43 @@ def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMes inputs = {**model_parameters} if prompt_messages[0].role == PromptMessageRole.SYSTEM: - if 'system_prompt' in model_info_version.openapi_schema['components']['schemas']['Input']['properties']: - inputs['system_prompt'] = prompt_messages[0].content - inputs['prompt'] = prompt_messages[1].content + if "system_prompt" in model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"]: + inputs["system_prompt"] = prompt_messages[0].content + inputs["prompt"] = prompt_messages[1].content else: - inputs['prompt'] = prompt_messages[0].content + inputs["prompt"] = prompt_messages[0].content - prediction = client.predictions.create( - version=model_info_version, input=inputs - ) + prediction = client.predictions.create(version=model_info_version, input=inputs) if stream: return self._handle_generate_stream_response(model, credentials, prediction, stop, prompt_messages) return self._handle_generate_response(model, credentials, prediction, stop, prompt_messages) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: prompt = self._convert_messages_to_prompt(prompt_messages) return self._get_num_tokens_by_gpt2(prompt) def validate_credentials(self, model: str, credentials: dict) -> None: - if 'replicate_api_token' not in credentials: - raise CredentialsValidateFailedError('Replicate Access Token must be provided.') + if "replicate_api_token" not in credentials: + raise CredentialsValidateFailedError("Replicate Access Token must be provided.") - model_version = '' - if 'model_version' in credentials: - model_version = credentials['model_version'] + model_version = "" + if "model_version" in credentials: + model_version = credentials["model_version"] if model.count("/") != 1: - raise CredentialsValidateFailedError('Replicate Model Name must be provided, ' - 'format: {user_name}/{model_name}') + raise CredentialsValidateFailedError( + "Replicate Model Name must be provided, format: {user_name}/{model_name}" + ) try: - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) model_info = client.models.get(model) if model_version: @@ -91,45 +101,44 @@ def validate_credentials(self, model: str, credentials: dict) -> None: self._check_text_generation_model(model_info_version, model, model_version, model_info.description) except ReplicateError as e: raise CredentialsValidateFailedError( - f"Model {model}:{model_version} not exists, cause: {e.__class__.__name__}:{str(e)}") + f"Model {model}:{model_version} not exists, cause: {e.__class__.__name__}:{str(e)}" + ) except Exception as e: raise CredentialsValidateFailedError(str(e)) @staticmethod def _check_text_generation_model(model_info_version, model_name, version, description): - if 'language model' in description.lower(): + if "language model" in description.lower(): return - if 'temperature' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \ - or 'top_p' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties'] \ - or 'top_k' not in model_info_version.openapi_schema['components']['schemas']['Input']['properties']: + if ( + "temperature" not in model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"] + or "top_p" not in model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"] + or "top_k" not in model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"] + ): raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Text Generation model.") def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: - model_type = LLMMode.CHAT if model.endswith('-chat') else LLMMode.COMPLETION + model_type = LLMMode.CHAT if model.endswith("-chat") else LLMMode.COMPLETION entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - model_properties={ - ModelPropertyKey.MODE: model_type.value - }, - parameter_rules=self._get_customizable_model_parameter_rules(model, credentials) + model_properties={ModelPropertyKey.MODE: model_type.value}, + parameter_rules=self._get_customizable_model_parameter_rules(model, credentials), ) return entity @classmethod def _get_customizable_model_parameter_rules(cls, model: str, credentials: dict) -> list[ParameterRule]: - model_version = '' - if 'model_version' in credentials: - model_version = credentials['model_version'] + model_version = "" + if "model_version" in credentials: + model_version = credentials["model_version"] - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) model_info = client.models.get(model) if model_version: @@ -140,15 +149,13 @@ def _get_customizable_model_parameter_rules(cls, model: str, credentials: dict) parameter_rules = [] input_properties = sorted( - model_info_version.openapi_schema["components"]["schemas"]["Input"][ - "properties" - ].items(), + model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"].items(), key=lambda item: item[1].get("x-order", 0), ) for key, value in input_properties: - if key not in ['system_prompt', 'prompt'] and 'stop' not in key: - value_type = value.get('type') + if key not in {"system_prompt", "prompt"} and "stop" not in key: + value_type = value.get("type") if not value_type: continue @@ -157,28 +164,28 @@ def _get_customizable_model_parameter_rules(cls, model: str, credentials: dict) rule = ParameterRule( name=key, - label={ - 'en_US': value['title'] - }, + label={"en_US": value["title"]}, type=param_type, help={ - 'en_US': value.get('description'), + "en_US": value.get("description"), }, required=False, - default=value.get('default'), - min=value.get('minimum'), - max=value.get('maximum') + default=value.get("default"), + min=value.get("minimum"), + max=value.get("maximum"), ) parameter_rules.append(rule) return parameter_rules - def _handle_generate_stream_response(self, - model: str, - credentials: dict, - prediction: Prediction, - stop: list[str], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + prediction: Prediction, + stop: list[str], + prompt_messages: list[PromptMessage], + ) -> Generator: index = -1 current_completion: str = "" stop_condition_reached = False @@ -189,7 +196,7 @@ def _handle_generate_stream_response(self, for output in prediction.output_iterator(): current_completion += output - if not is_prediction_output_finished and prediction.status == 'succeeded': + if not is_prediction_output_finished and prediction.status == "succeeded": prediction_output_length = len(prediction.output) - 1 is_prediction_output_finished = True @@ -207,18 +214,13 @@ def _handle_generate_stream_response(self, index += 1 - assistant_prompt_message = AssistantPromptMessage( - content=output if output else '' - ) + assistant_prompt_message = AssistantPromptMessage(content=output or "") if index < prediction_output_length: yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) else: prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -229,15 +231,17 @@ def _handle_generate_stream_response(self, yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message, usage=usage), ) - def _handle_generate_response(self, model: str, credentials: dict, prediction: Prediction, stop: list[str], - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, + model: str, + credentials: dict, + prediction: Prediction, + stop: list[str], + prompt_messages: list[PromptMessage], + ) -> LLMResult: current_completion: str = "" stop_condition_reached = False for output in prediction.output_iterator(): @@ -255,9 +259,7 @@ def _handle_generate_response(self, model: str, credentials: dict, prediction: P if stop_condition_reached: break - assistant_prompt_message = AssistantPromptMessage( - content=current_completion - ) + assistant_prompt_message = AssistantPromptMessage(content=current_completion) prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) @@ -275,21 +277,13 @@ def _handle_generate_response(self, model: str, credentials: dict, prediction: P @classmethod def _get_parameter_type(cls, param_type: str) -> str: - type_mapping = { - 'integer': 'int', - 'number': 'float', - 'boolean': 'boolean', - 'string': 'string' - } + type_mapping = {"integer": "int", "number": "float", "boolean": "boolean", "string": "string"} return type_mapping.get(param_type) def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) return text.rstrip() diff --git a/api/core/model_runtime/model_providers/replicate/replicate.py b/api/core/model_runtime/model_providers/replicate/replicate.py index 3a5c9b84a07b52..ca137579c96f2c 100644 --- a/api/core/model_runtime/model_providers/replicate/replicate.py +++ b/api/core/model_runtime/model_providers/replicate/replicate.py @@ -6,6 +6,5 @@ class ReplicateProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py index 0e4cdbf5bc13ca..c4e9d0b9c6ceb2 100644 --- a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py @@ -4,6 +4,7 @@ from replicate import Client as ReplicateClient +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult @@ -13,32 +14,42 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): - def _invoke(self, model: str, credentials: dict, texts: list[str], - user: Optional[str] = None) -> TextEmbeddingResult: - - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) - - if 'model_version' in credentials: - model_version = credentials['model_version'] + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :param input_type: input type + :return: embeddings result + """ + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) + + if "model_version" in credentials: + model_version = credentials["model_version"] else: model_info = client.models.get(model) model_version = model_info.latest_version.id - replicate_model_version = f'{model}:{model_version}' + replicate_model_version = f"{model}:{model_version}" text_input_key = self._get_text_input_key(model, model_version, client) - embeddings = self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key, - texts) + embeddings = self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key, texts) tokens = self.get_num_tokens(model, credentials, texts) usage = self._calc_response_usage(model, credentials, tokens) - return TextEmbeddingResult( - model=model, - embeddings=embeddings, - usage=usage - ) + return TextEmbeddingResult(model=model, embeddings=embeddings, usage=usage) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: num_tokens = 0 @@ -47,39 +58,35 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int return num_tokens def validate_credentials(self, model: str, credentials: dict) -> None: - if 'replicate_api_token' not in credentials: - raise CredentialsValidateFailedError('Replicate Access Token must be provided.') + if "replicate_api_token" not in credentials: + raise CredentialsValidateFailedError("Replicate Access Token must be provided.") try: - client = ReplicateClient(api_token=credentials['replicate_api_token'], timeout=30) + client = ReplicateClient(api_token=credentials["replicate_api_token"], timeout=30) - if 'model_version' in credentials: - model_version = credentials['model_version'] + if "model_version" in credentials: + model_version = credentials["model_version"] else: model_info = client.models.get(model) model_version = model_info.latest_version.id - replicate_model_version = f'{model}:{model_version}' + replicate_model_version = f"{model}:{model_version}" text_input_key = self._get_text_input_key(model, model_version, client) - self._generate_embeddings_by_text_input_key(client, replicate_model_version, text_input_key, - ['Hello worlds!']) + self._generate_embeddings_by_text_input_key( + client, replicate_model_version, text_input_key, ["Hello worlds!"] + ) except Exception as e: raise CredentialsValidateFailedError(str(e)) def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, - model_properties={ - 'context_size': 4096, - 'max_chunks': 1 - } + model_properties={"context_size": 4096, "max_chunks": 1}, ) return entity @@ -90,49 +97,45 @@ def _get_text_input_key(model: str, model_version: str, client: ReplicateClient) # sort through the openapi schema to get the name of text, texts or inputs input_properties = sorted( - model_info_version.openapi_schema["components"]["schemas"]["Input"][ - "properties" - ].items(), + model_info_version.openapi_schema["components"]["schemas"]["Input"]["properties"].items(), key=lambda item: item[1].get("x-order", 0), ) for input_property in input_properties: - if input_property[0] in ('text', 'texts', 'inputs'): + if input_property[0] in {"text", "texts", "inputs"}: text_input_key = input_property[0] return text_input_key - return '' + return "" @staticmethod - def _generate_embeddings_by_text_input_key(client: ReplicateClient, replicate_model_version: str, - text_input_key: str, texts: list[str]) -> list[list[float]]: - - if text_input_key in ('text', 'inputs'): + def _generate_embeddings_by_text_input_key( + client: ReplicateClient, replicate_model_version: str, text_input_key: str, texts: list[str] + ) -> list[list[float]]: + if text_input_key in {"text", "inputs"}: embeddings = [] for text in texts: - result = client.run(replicate_model_version, input={ - text_input_key: text - }) - embeddings.append(result[0].get('embedding')) + result = client.run(replicate_model_version, input={text_input_key: text}) + embeddings.append(result[0].get("embedding")) return [list(map(float, e)) for e in embeddings] - elif 'texts' == text_input_key: - result = client.run(replicate_model_version, input={ - 'texts': json.dumps(texts), - "batch_size": 4, - "convert_to_numpy": False, - "normalize_embeddings": True - }) + elif "texts" == text_input_key: + result = client.run( + replicate_model_version, + input={ + "texts": json.dumps(texts), + "batch_size": 4, + "convert_to_numpy": False, + "normalize_embeddings": True, + }, + ) return result else: - raise ValueError(f'embeddings input key is invalid: {text_input_key}') + raise ValueError(f"embeddings input key is invalid: {text_input_key}") def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -143,7 +146,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py index f8e7757a969f8e..5ff00f008eb621 100644 --- a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py +++ b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py @@ -1,17 +1,33 @@ import json import logging -from collections.abc import Generator -from typing import Any, Optional, Union +import re +from collections.abc import Generator, Iterator +from typing import Any, Optional, Union, cast import boto3 -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, + ImagePromptMessageContent, PromptMessage, + PromptMessageContent, + PromptMessageContentType, PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + I18nObject, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, ) -from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -26,17 +42,160 @@ logger = logging.getLogger(__name__) +def inference(predictor, messages: list[dict[str, Any]], params: dict[str, Any], stop: list, stream=False): + """ + params: + predictor : Sagemaker Predictor + messages (List[Dict[str,Any]]): message list。 + messages = [ + {"role": "system", "content":"please answer in Chinese"}, + {"role": "user", "content": "who are you? what are you doing?"}, + ] + params (Dict[str,Any]): model parameters for LLM。 + stream (bool): False by default。 + + response: + result of inference if stream is False + Iterator of Chunks if stream is True + """ + payload = { + "model": params.get("model_name"), + "stop": stop, + "messages": messages, + "stream": stream, + "max_tokens": params.get("max_new_tokens", params.get("max_tokens", 2048)), + "temperature": params.get("temperature", 0.1), + "top_p": params.get("top_p", 0.9), + } + + if not stream: + response = predictor.predict(payload) + return response + else: + response_stream = predictor.predict_stream(payload) + return response_stream + + class SageMakerLargeLanguageModel(LargeLanguageModel): """ Model class for Cohere large language model. """ - sagemaker_client: Any = None - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + sagemaker_session: Any = None + predictor: Any = None + sagemaker_endpoint: str = None + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: bytes, + ) -> LLMResult: + """ + handle normal chat generate response + """ + resp_obj = json.loads(resp.decode("utf-8")) + resp_str = resp_obj.get("choices")[0].get("message").get("content") + + if len(resp_str) == 0: + raise InvokeServerUnavailableError("Empty response") + + assistant_prompt_message = AssistantPromptMessage(content=resp_str, tool_calls=[]) + + prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) + completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) + + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) + + response = LLMResult( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=None, + usage=usage, + message=assistant_prompt_message, + ) + + return response + + def _handle_chat_stream_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Iterator[bytes], + ) -> Generator: + """ + handle stream chat generate response + """ + full_response = "" + buffer = "" + for chunk_bytes in resp: + buffer += chunk_bytes.decode("utf-8") + last_idx = 0 + for match in re.finditer(r"^data:\s*(.+?)(\n\n)", buffer): + try: + data = json.loads(match.group(1).strip()) + last_idx = match.span()[1] + + if "content" in data["choices"][0]["delta"]: + chunk_content = data["choices"][0]["delta"]["content"] + assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[]) + + if data["choices"][0]["finish_reason"] is not None: + temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[]) + prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) + completion_tokens = self._num_tokens_from_messages( + messages=[temp_assistant_prompt_message], tools=[] + ) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=None, + delta=LLMResultChunkDelta( + index=0, + message=assistant_prompt_message, + finish_reason=data["choices"][0]["finish_reason"], + usage=usage, + ), + ) + else: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + system_fingerprint=None, + delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message), + ) + + full_response += chunk_content + except (json.JSONDecodeError, KeyError, IndexError) as e: + logger.info("json parse exception, content: {}".format(match.group(1).strip())) + pass + + buffer = buffer[last_idx:] + + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -50,58 +209,158 @@ def _invoke(self, model: str, credentials: dict, :param user: unique user id :return: full response or stream response chunk generator result """ - # get model mode - model_mode = self.get_model_mode(model, credentials) - - if not self.sagemaker_client: - access_key = credentials.get('access_key') - secret_key = credentials.get('secret_key') - aws_region = credentials.get('aws_region') + from sagemaker import Predictor, serializers + from sagemaker.session import Session + + if not self.sagemaker_session: + access_key = credentials.get("aws_access_key_id") + secret_key = credentials.get("aws_secret_access_key") + aws_region = credentials.get("aws_region") + boto_session = None if aws_region: if access_key and secret_key: - self.sagemaker_client = boto3.client("sagemaker-runtime", - aws_access_key_id=access_key, - aws_secret_access_key=secret_key, - region_name=aws_region) + boto_session = boto3.Session( + aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=aws_region + ) else: - self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + boto_session = boto3.Session(region_name=aws_region) else: - self.sagemaker_client = boto3.client("sagemaker-runtime") - - - sagemaker_endpoint = credentials.get('sagemaker_endpoint') - response_model = self.sagemaker_client.invoke_endpoint( - EndpointName=sagemaker_endpoint, - Body=json.dumps( - { - "inputs": prompt_messages[0].content, - "parameters": { "stop" : stop}, - "history" : [] - } - ), - ContentType="application/json", - ) - - assistant_text = response_model['Body'].read().decode('utf8') - - # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_text + boto_session = boto3.Session() + + sagemaker_client = boto_session.client("sagemaker") + self.sagemaker_session = Session(boto_session=boto_session, sagemaker_client=sagemaker_client) + + if self.sagemaker_endpoint != credentials.get("sagemaker_endpoint"): + self.sagemaker_endpoint = credentials.get("sagemaker_endpoint") + self.predictor = Predictor( + endpoint_name=self.sagemaker_endpoint, + sagemaker_session=self.sagemaker_session, + serializer=serializers.JSONSerializer(), + ) + + messages: list[dict[str, Any]] = [{"role": p.role.value, "content": p.content} for p in prompt_messages] + response = inference( + predictor=self.predictor, messages=messages, params=model_parameters, stop=stop, stream=stream ) - usage = self._calc_response_usage(model, credentials, 0, 0) + if stream: + if tools and len(tools) > 0: + raise InvokeBadRequestError(f"{model}'s tool calls does not support stream mode") - response = LLMResult( - model=model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage + return self._handle_chat_stream_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=response + ) + return self._handle_chat_generate_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=response ) - return response + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: + """ + Convert PromptMessage to dict for OpenAI Compatibility API + """ + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": "user", "content": message.content} + else: + sub_messages = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + message_content = cast(PromptMessageContent, message_content) + sub_message_dict = {"type": "text", "text": message_content.data} + sub_messages.append(sub_message_dict) + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, message_content) + sub_message_dict = { + "type": "image_url", + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, + } + sub_messages.append(sub_message_dict) + message_dict = {"role": "user", "content": sub_messages} + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {"role": "assistant", "content": message.content} + if message.tool_calls and len(message.tool_calls) > 0: + message_dict["function_call"] = { + "name": message.tool_calls[0].function.name, + "arguments": message.tool_calls[0].function.arguments, + } + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, ToolPromptMessage): + message = cast(ToolPromptMessage, message) + message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content} + else: + raise ValueError(f"Unknown message type {type(message)}") + + return message_dict + + def _num_tokens_from_messages( + self, messages: list[PromptMessage], tools: list[PromptMessageTool], is_completion_model: bool = False + ) -> int: + def tokens(text: str): + return self._get_num_tokens_by_gpt2(text) + + if is_completion_model: + return sum(tokens(str(message.content)) for message in messages) + + tokens_per_message = 3 + tokens_per_name = 1 + + num_tokens = 0 + messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] + for message in messages_dict: + num_tokens += tokens_per_message + for key, value in message.items(): + if isinstance(value, list): + text = "" + for item in value: + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] + + value = text + + if key == "tool_calls": + for tool_call in value: + for t_key, t_value in tool_call.items(): + num_tokens += tokens(t_key) + if t_key == "function": + for f_key, f_value in t_value.items(): + num_tokens += tokens(f_key) + num_tokens += tokens(f_value) + else: + num_tokens += tokens(t_key) + num_tokens += tokens(t_value) + if key == "function_call": + for t_key, t_value in value.items(): + num_tokens += tokens(t_key) + if t_key == "function": + for f_key, f_value in t_value.items(): + num_tokens += tokens(f_key) + num_tokens += tokens(f_value) + else: + num_tokens += tokens(t_key) + num_tokens += tokens(t_value) + else: + num_tokens += tokens(str(value)) + + if key == "name": + num_tokens += tokens_per_name + num_tokens += 3 - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + if tools: + num_tokens += self._num_tokens_for_tools(tools) + + return num_tokens + + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -112,10 +371,8 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr :return: """ # get model mode - model_mode = self.get_model_mode(model) - try: - return 0 + return self._num_tokens_from_messages(prompt_messages, tools) except Exception as e: raise self._transform_invoke_error(e) @@ -129,7 +386,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: """ try: # get model mode - model_mode = self.get_model_mode(model) + pass except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -144,95 +401,63 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ - used to define customizable model schema + used to define customizable model schema """ rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ), + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, - max=credentials.get('context_length', 2048), + max=credentials.get("context_length", 2048), default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), ] - completion_type = LLMMode.value_of(credentials["mode"]) - - if completion_type == LLMMode.CHAT: - print(f"completion_type : {LLMMode.CHAT.value}") - - if completion_type == LLMMode.COMPLETION: - print(f"completion_type : {LLMMode.COMPLETION.value}") + completion_type = LLMMode.value_of(credentials["mode"]).value features = [] - support_function_call = credentials.get('support_function_call', False) + support_function_call = credentials.get("support_function_call", False) if support_function_call: features.append(ModelFeature.TOOL_CALL) - support_vision = credentials.get('support_vision', False) + support_vision = credentials.get("support_vision", False) if support_vision: features.append(ModelFeature.VISION) - context_length = credentials.get('context_length', 2048) + context_length = credentials.get("context_length", 2048) entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, features=features, - model_properties={ - ModelPropertyKey.MODE: completion_type, - ModelPropertyKey.CONTEXT_SIZE: context_length - }, - parameter_rules=rules + model_properties={ModelPropertyKey.MODE: completion_type, ModelPropertyKey.CONTEXT_SIZE: context_length}, + parameter_rules=rules, ) return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py index 0b06f54ef1823f..49c3fa5921b6c4 100644 --- a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py @@ -1,5 +1,6 @@ import json import logging +import operator from typing import Any, Optional import boto3 @@ -20,34 +21,36 @@ logger = logging.getLogger(__name__) + class SageMakerRerankModel(RerankModel): """ - Model class for Cohere rerank model. + Model class for SageMaker rerank model. """ + sagemaker_client: Any = None - def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str): - inputs = [query_input]*len(docs) + def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str): + inputs = [query_input] * len(docs) response_model = self.sagemaker_client.invoke_endpoint( EndpointName=rerank_endpoint, - Body=json.dumps( - { - "inputs": inputs, - "docs": docs - } - ), + Body=json.dumps({"inputs": inputs, "docs": docs}), ContentType="application/json", ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) - scores = json_obj['scores'] + scores = json_obj["scores"] return scores if isinstance(scores, list) else [scores] - - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -63,22 +66,21 @@ def _invoke(self, model: str, credentials: dict, line = 0 try: if len(docs) == 0: - return RerankResult( - model=model, - docs=docs - ) + return RerankResult(model=model, docs=docs) line = 1 if not self.sagemaker_client: - access_key = credentials.get('aws_access_key_id') - secret_key = credentials.get('aws_secret_access_key') - aws_region = credentials.get('aws_region') + access_key = credentials.get("aws_access_key_id") + secret_key = credentials.get("aws_secret_access_key") + aws_region = credentials.get("aws_region") if aws_region: if access_key and secret_key: - self.sagemaker_client = boto3.client("sagemaker-runtime", + self.sagemaker_client = boto3.client( + "sagemaker-runtime", aws_access_key_id=access_key, aws_secret_access_key=secret_key, - region_name=aws_region) + region_name=aws_region, + ) else: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) else: @@ -86,22 +88,20 @@ def _invoke(self, model: str, credentials: dict, line = 2 - sagemaker_endpoint = credentials.get('sagemaker_endpoint') + sagemaker_endpoint = credentials.get("sagemaker_endpoint") candidate_docs = [] scores = self._sagemaker_rerank(query, docs, sagemaker_endpoint) for idx in range(len(scores)): - candidate_docs.append({"content" : docs[idx], "score": scores[idx]}) + candidate_docs.append({"content": docs[idx], "score": scores[idx]}) - sorted(candidate_docs, key=lambda x: x['score'], reverse=True) + sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True) line = 3 rerank_documents = [] for idx, result in enumerate(candidate_docs): rerank_document = RerankDocument( - index=idx, - text=result.get('content'), - score=result.get('score', -100.0) + index=idx, text=result.get("content"), score=result.get("score", -100.0) ) if score_threshold is not None: @@ -110,13 +110,10 @@ def _invoke(self, model: str, credentials: dict, else: rerank_documents.append(rerank_document) - return RerankResult( - model=model, - docs=rerank_documents - ) + return RerankResult(model=model, docs=rerank_documents) except Exception as e: - logger.exception(f'Exception {e}, line : {line}') + logger.exception(f"Exception {e}, line : {line}") def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -137,7 +134,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -153,38 +150,24 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.RERANK, - model_properties={ }, - parameter_rules=[] + model_properties={}, + parameter_rules=[], ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/sagemaker.py b/api/core/model_runtime/model_providers/sagemaker/sagemaker.py index 02d05f406c50f7..042155b1522fef 100644 --- a/api/core/model_runtime/model_providers/sagemaker/sagemaker.py +++ b/api/core/model_runtime/model_providers/sagemaker/sagemaker.py @@ -1,4 +1,6 @@ import logging +import uuid +from typing import IO, Any from core.model_runtime.model_providers.__base.model_provider import ModelProvider @@ -15,3 +17,25 @@ def validate_provider_credentials(self, credentials: dict) -> None: :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. """ pass + + +def buffer_to_s3(s3_client: Any, file: IO[bytes], bucket: str, s3_prefix: str) -> str: + """ + return s3_uri of this file + """ + s3_key = f"{s3_prefix}{uuid.uuid4()}.mp3" + s3_client.put_object(Body=file.read(), Bucket=bucket, Key=s3_key, ContentType="audio/mp3") + return s3_key + + +def generate_presigned_url(s3_client: Any, file: IO[bytes], bucket_name: str, s3_prefix: str, expiration=600) -> str: + object_key = buffer_to_s3(s3_client, file, bucket_name, s3_prefix) + try: + response = s3_client.generate_presigned_url( + "get_object", Params={"Bucket": bucket_name, "Key": object_key}, ExpiresIn=expiration + ) + except Exception as e: + print(f"Error generating presigned URL: {e}") + return None + + return response diff --git a/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml b/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml index 290cb0edabee09..87cd50f50cbb85 100644 --- a/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml +++ b/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml @@ -21,6 +21,8 @@ supported_model_types: - llm - text-embedding - rerank + - speech2text + - tts configurate_methods: - customizable-model model_credential_schema: @@ -45,14 +47,10 @@ model_credential_schema: zh_Hans: 选择对话类型 en_US: Select completion mode options: - - value: completion - label: - en_US: Completion - zh_Hans: 补全 - value: chat label: en_US: Chat - zh_Hans: 对话 + zh_Hans: Chat - variable: sagemaker_endpoint label: en_US: sagemaker endpoint @@ -61,6 +59,76 @@ model_credential_schema: placeholder: zh_Hans: 请输出你的Sagemaker推理端点 en_US: Enter your Sagemaker Inference endpoint + - variable: audio_s3_cache_bucket + show_on: + - variable: __model_type + value: speech2text + label: + zh_Hans: 音频缓存桶(s3 bucket) + en_US: audio cache bucket(s3 bucket) + type: text-input + required: true + placeholder: + zh_Hans: sagemaker-us-east-1-******207838 + en_US: sagemaker-us-east-1-*******7838 + - variable: audio_model_type + show_on: + - variable: __model_type + value: tts + label: + en_US: Audio model type + type: select + required: true + placeholder: + zh_Hans: 语音模型类型 + en_US: Audio model type + options: + - value: PresetVoice + label: + en_US: preset voice + zh_Hans: 内置音色 + - value: CloneVoice + label: + en_US: clone voice + zh_Hans: 克隆音色 + - value: CloneVoice_CrossLingual + label: + en_US: crosslingual clone voice + zh_Hans: 跨语种克隆音色 + - value: InstructVoice + label: + en_US: Instruct voice + zh_Hans: 文字指令音色 + - variable: prompt_audio + show_on: + - variable: __model_type + value: tts + label: + en_US: Mock Audio Source + type: text-input + required: false + placeholder: + zh_Hans: 被模仿的音色音频 + en_US: source audio to be mocked + - variable: prompt_text + show_on: + - variable: __model_type + value: tts + label: + en_US: Prompt Audio Text + type: text-input + required: false + placeholder: + zh_Hans: 模仿音色的对应文本 + en_US: text for the mocked source audio + - variable: instruct_text + show_on: + - variable: __model_type + value: tts + label: + en_US: instruct text for speaker + type: text-input + required: false - variable: aws_access_key_id required: false label: diff --git a/api/core/model_runtime/model_providers/sagemaker/speech2text/__init__.py b/api/core/model_runtime/model_providers/sagemaker/speech2text/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py new file mode 100644 index 00000000000000..8fdf68abe189a7 --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py @@ -0,0 +1,127 @@ +import json +import logging +from typing import IO, Any, Optional + +import boto3 + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from core.model_runtime.model_providers.sagemaker.sagemaker import generate_presigned_url + +logger = logging.getLogger(__name__) + + +class SageMakerSpeech2TextModel(Speech2TextModel): + """ + Model class for Xinference speech to text model. + """ + + sagemaker_client: Any = None + s3_client: Any = None + + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: + """ + Invoke speech2text model + + :param model: model name + :param credentials: model credentials + :param file: audio file + :param user: unique user id + :return: text for given audio file + """ + asr_text = None + + try: + if not self.sagemaker_client: + access_key = credentials.get("aws_access_key_id") + secret_key = credentials.get("aws_secret_access_key") + aws_region = credentials.get("aws_region") + if aws_region: + if access_key and secret_key: + self.sagemaker_client = boto3.client( + "sagemaker-runtime", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=aws_region, + ) + self.s3_client = boto3.client( + "s3", aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=aws_region + ) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + self.s3_client = boto3.client("s3", region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + self.s3_client = boto3.client("s3") + + s3_prefix = "dify/speech2text/" + sagemaker_endpoint = credentials.get("sagemaker_endpoint") + bucket = credentials.get("audio_s3_cache_bucket") + + s3_presign_url = generate_presigned_url(self.s3_client, file, bucket, s3_prefix) + payload = {"audio_s3_presign_uri": s3_presign_url} + + response_model = self.sagemaker_client.invoke_endpoint( + EndpointName=sagemaker_endpoint, Body=json.dumps(payload), ContentType="application/json" + ) + json_str = response_model["Body"].read().decode("utf8") + json_obj = json.loads(json_str) + asr_text = json_obj["text"] + except Exception as e: + logger.exception(f"failed to invoke speech2text model, {e}") + raise CredentialsValidateFailedError(str(e)) + + return asr_text + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + pass + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.SPEECH2TEXT, + model_properties={}, + parameter_rules=[], + ) + + return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py index 4b2858b1a28228..ececfda11a55da 100644 --- a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py @@ -6,25 +6,27 @@ import boto3 +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel BATCH_SIZE = 20 -CONTEXT_SIZE=8192 +CONTEXT_SIZE = 8192 logger = logging.getLogger(__name__) + def batch_generator(generator, batch_size): while True: batch = list(itertools.islice(generator, batch_size)) @@ -32,33 +34,33 @@ def batch_generator(generator, batch_size): break yield batch + class SageMakerEmbeddingModel(TextEmbeddingModel): """ Model class for Cohere text embedding model. """ + sagemaker_client: Any = None - def _sagemaker_embedding(self, sm_client, endpoint_name, content_list:list[str]): + def _sagemaker_embedding(self, sm_client, endpoint_name, content_list: list[str]): response_model = sm_client.invoke_endpoint( EndpointName=endpoint_name, - Body=json.dumps( - { - "inputs": content_list, - "parameters": {}, - "is_query" : False, - "instruction" : '' - } - ), + Body=json.dumps({"inputs": content_list, "parameters": {}, "is_query": False, "instruction": ""}), ContentType="application/json", ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) - embeddings = json_obj['embeddings'] + embeddings = json_obj["embeddings"] return embeddings - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -66,31 +68,34 @@ def _invoke(self, model: str, credentials: dict, :param credentials: model credentials :param texts: texts to embed :param user: unique user id + :param input_type: input type :return: embeddings result """ # get model properties try: line = 1 if not self.sagemaker_client: - access_key = credentials.get('aws_access_key_id') - secret_key = credentials.get('aws_secret_access_key') - aws_region = credentials.get('aws_region') + access_key = credentials.get("aws_access_key_id") + secret_key = credentials.get("aws_secret_access_key") + aws_region = credentials.get("aws_region") if aws_region: if access_key and secret_key: - self.sagemaker_client = boto3.client("sagemaker-runtime", + self.sagemaker_client = boto3.client( + "sagemaker-runtime", aws_access_key_id=access_key, aws_secret_access_key=secret_key, - region_name=aws_region) + region_name=aws_region, + ) else: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) else: self.sagemaker_client = boto3.client("sagemaker-runtime") line = 2 - sagemaker_endpoint = credentials.get('sagemaker_endpoint') + sagemaker_endpoint = credentials.get("sagemaker_endpoint") line = 3 - truncated_texts = [ item[:CONTEXT_SIZE] for item in texts ] + truncated_texts = [item[:CONTEXT_SIZE] for item in texts] batches = batch_generator((text for text in truncated_texts), batch_size=BATCH_SIZE) all_embeddings = [] @@ -105,18 +110,14 @@ def _invoke(self, model: str, credentials: dict, usage = self._calc_response_usage( model=model, credentials=credentials, - tokens=0 # It's not SAAS API, usage is meaningless + tokens=0, # It's not SAAS API, usage is meaningless ) line = 6 - return TextEmbeddingResult( - embeddings=all_embeddings, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=all_embeddings, usage=usage, model=model) except Exception as e: - logger.exception(f'Exception {e}, line : {line}') + logger.exception(f"Exception {e}, line : {line}") def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -153,10 +154,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -167,7 +165,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -175,40 +173,28 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ - used to define customizable model schema + used to define customizable model schema """ - + entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ ModelPropertyKey.CONTEXT_SIZE: CONTEXT_SIZE, ModelPropertyKey.MAX_CHUNKS: BATCH_SIZE, }, - parameter_rules=[] + parameter_rules=[], ) return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/tts/__init__.py b/api/core/model_runtime/model_providers/sagemaker/tts/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/sagemaker/tts/tts.py b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py new file mode 100644 index 00000000000000..6a5946453be07f --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/tts/tts.py @@ -0,0 +1,275 @@ +import concurrent.futures +import copy +import json +import logging +from enum import Enum +from typing import Any, Optional + +import boto3 +import requests + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.model_providers.__base.tts_model import TTSModel + +logger = logging.getLogger(__name__) + + +class TTSModelType(Enum): + PresetVoice = "PresetVoice" + CloneVoice = "CloneVoice" + CloneVoice_CrossLingual = "CloneVoice_CrossLingual" + InstructVoice = "InstructVoice" + + +class SageMakerText2SpeechModel(TTSModel): + sagemaker_client: Any = None + s3_client: Any = None + comprehend_client: Any = None + + def __init__(self): + # preset voices, need support custom voice + self.model_voices = { + "__default": { + "all": [ + {"name": "Default", "value": "default"}, + ] + }, + "CosyVoice": { + "zh-Hans": [ + {"name": "中文男", "value": "中文男"}, + {"name": "中文女", "value": "中文女"}, + {"name": "粤语女", "value": "粤语女"}, + ], + "zh-Hant": [ + {"name": "中文男", "value": "中文男"}, + {"name": "中文女", "value": "中文女"}, + {"name": "粤语女", "value": "粤语女"}, + ], + "en-US": [ + {"name": "英文男", "value": "英文男"}, + {"name": "英文女", "value": "英文女"}, + ], + "ja-JP": [ + {"name": "日语男", "value": "日语男"}, + ], + "ko-KR": [ + {"name": "韩语女", "value": "韩语女"}, + ], + }, + } + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + pass + + def _detect_lang_code(self, content: str, map_dict: Optional[dict] = None): + map_dict = {"zh": "<|zh|>", "en": "<|en|>", "ja": "<|jp|>", "zh-TW": "<|yue|>", "ko": "<|ko|>"} + + response = self.comprehend_client.detect_dominant_language(Text=content) + language_code = response["Languages"][0]["LanguageCode"] + + return map_dict.get(language_code, "<|zh|>") + + def _build_tts_payload( + self, + model_type: str, + content_text: str, + model_role: str, + prompt_text: str, + prompt_audio: str, + instruct_text: str, + ): + if model_type == TTSModelType.PresetVoice.value and model_role: + return {"tts_text": content_text, "role": model_role} + if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio: + return {"tts_text": content_text, "prompt_text": prompt_text, "prompt_audio": prompt_audio} + if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio: + lang_tag = self._detect_lang_code(content_text) + return {"tts_text": f"{content_text}", "prompt_audio": prompt_audio, "lang_tag": lang_tag} + if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role: + return {"tts_text": content_text, "role": model_role, "instruct_text": instruct_text} + + raise RuntimeError(f"Invalid params for {model_type}") + + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ): + """ + _invoke text2speech model + + :param model: model name + :param tenant_id: user tenant id + :param credentials: model credentials + :param voice: model timbre + :param content_text: text content to be translated + :param user: unique user id + :return: text translated to audio file + """ + if not self.sagemaker_client: + access_key = credentials.get("aws_access_key_id") + secret_key = credentials.get("aws_secret_access_key") + aws_region = credentials.get("aws_region") + if aws_region: + if access_key and secret_key: + self.sagemaker_client = boto3.client( + "sagemaker-runtime", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=aws_region, + ) + self.s3_client = boto3.client( + "s3", aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=aws_region + ) + self.comprehend_client = boto3.client( + "comprehend", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=aws_region, + ) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + self.s3_client = boto3.client("s3", region_name=aws_region) + self.comprehend_client = boto3.client("comprehend", region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + self.s3_client = boto3.client("s3") + self.comprehend_client = boto3.client("comprehend") + + model_type = credentials.get("audio_model_type", "PresetVoice") + prompt_text = credentials.get("prompt_text") + prompt_audio = credentials.get("prompt_audio") + instruct_text = credentials.get("instruct_text") + sagemaker_endpoint = credentials.get("sagemaker_endpoint") + payload = self._build_tts_payload(model_type, content_text, voice, prompt_text, prompt_audio, instruct_text) + + return self._tts_invoke_streaming(model_type, payload, sagemaker_endpoint) + + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.TTS, + model_properties={}, + parameter_rules=[], + ) + + return entity + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], + } + + def _get_model_default_voice(self, model: str, credentials: dict) -> Any: + return "" + + def _get_model_word_limit(self, model: str, credentials: dict) -> int: + return 15 + + def _get_model_audio_type(self, model: str, credentials: dict) -> str: + return "mp3" + + def _get_model_workers_limit(self, model: str, credentials: dict) -> int: + return 5 + + def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: + audio_model_name = "CosyVoice" + for key, voices in self.model_voices.items(): + if key in audio_model_name: + if language and language in voices: + return voices[language] + elif "all" in voices: + return voices["all"] + + return self.model_voices["__default"]["all"] + + def _invoke_sagemaker(self, payload: dict, endpoint: str): + response_model = self.sagemaker_client.invoke_endpoint( + EndpointName=endpoint, + Body=json.dumps(payload), + ContentType="application/json", + ) + json_str = response_model["Body"].read().decode("utf8") + json_obj = json.loads(json_str) + return json_obj + + def _tts_invoke_streaming(self, model_type: str, payload: dict, sagemaker_endpoint: str) -> Any: + """ + _tts_invoke_streaming text2speech model + + :param model: model name + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model timbre + :return: text translated to audio file + """ + try: + lang_tag = "" + if model_type == TTSModelType.CloneVoice_CrossLingual.value: + lang_tag = payload.pop("lang_tag") + + word_limit = self._get_model_word_limit(model="", credentials={}) + content_text = payload.get("tts_text") + if len(content_text) > word_limit: + split_sentences = self._split_text_into_sentences(content_text, max_length=word_limit) + sentences = [f"{lang_tag}{s}" for s in split_sentences if len(s)] + len_sent = len(sentences) + executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(4, len_sent)) + payloads = [copy.deepcopy(payload) for i in range(len_sent)] + for idx in range(len_sent): + payloads[idx]["tts_text"] = sentences[idx] + + futures = [ + executor.submit( + self._invoke_sagemaker, + payload=payload, + endpoint=sagemaker_endpoint, + ) + for payload in payloads + ] + + for future in futures: + resp = future.result() + audio_bytes = requests.get(resp.get("s3_presign_url")).content + for i in range(0, len(audio_bytes), 1024): + yield audio_bytes[i : i + 1024] + else: + resp = self._invoke_sagemaker(payload, sagemaker_endpoint) + audio_bytes = requests.get(resp.get("s3_presign_url")).content + + for i in range(0, len(audio_bytes), 1024): + yield audio_bytes[i : i + 1024] + except Exception as ex: + raise InvokeBadRequestError(str(ex)) diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml index c2f0eb05360327..8d1df82140b79f 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml @@ -1,20 +1,28 @@ +- Qwen/Qwen2.5-72B-Instruct +- Qwen/Qwen2.5-32B-Instruct +- Qwen/Qwen2.5-14B-Instruct +- Qwen/Qwen2.5-7B-Instruct +- Qwen/Qwen2.5-Coder-7B-Instruct +- Qwen/Qwen2.5-Math-72B-Instruct - Qwen/Qwen2-72B-Instruct - Qwen/Qwen2-57B-A14B-Instruct - Qwen/Qwen2-7B-Instruct - Qwen/Qwen2-1.5B-Instruct -- 01-ai/Yi-1.5-34B-Chat -- 01-ai/Yi-1.5-9B-Chat-16K -- 01-ai/Yi-1.5-6B-Chat -- THUDM/glm-4-9b-chat +- deepseek-ai/DeepSeek-V2.5 - deepseek-ai/DeepSeek-V2-Chat - deepseek-ai/DeepSeek-Coder-V2-Instruct +- THUDM/glm-4-9b-chat +- 01-ai/Yi-1.5-34B-Chat-16K +- 01-ai/Yi-1.5-9B-Chat-16K +- 01-ai/Yi-1.5-6B-Chat +- internlm/internlm2_5-20b-chat - internlm/internlm2_5-7b-chat -- google/gemma-2-27b-it -- google/gemma-2-9b-it -- meta-llama/Meta-Llama-3-70B-Instruct -- meta-llama/Meta-Llama-3-8B-Instruct - meta-llama/Meta-Llama-3.1-405B-Instruct - meta-llama/Meta-Llama-3.1-70B-Instruct - meta-llama/Meta-Llama-3.1-8B-Instruct -- mistralai/Mixtral-8x7B-Instruct-v0.1 +- meta-llama/Meta-Llama-3-70B-Instruct +- meta-llama/Meta-Llama-3-8B-Instruct +- google/gemma-2-27b-it +- google/gemma-2-9b-it - mistralai/Mistral-7B-Instruct-v0.2 +- mistralai/Mixtral-8x7B-Instruct-v0.1 diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/deepdeek-coder-v2-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/deepdeek-coder-v2-instruct.yaml index d4431179e5656d..d5f23776ea2672 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/deepdeek-coder-v2-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/deepdeek-coder-v2-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2-chat.yaml index caa6508b5ed2a2..7aa684ef3813ad 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2-chat.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2.5.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2.5.yaml new file mode 100644 index 00000000000000..b30fa3e2d159c0 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2.5.yaml @@ -0,0 +1,39 @@ +model: deepseek-ai/DeepSeek-V2.5 +label: + en_US: deepseek-ai/DeepSeek-V2.5 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 512 + min: 1 + max: 4096 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: frequency_penalty + use_template: frequency_penalty +pricing: + input: '1.33' + output: '1.33' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-27b-it.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-27b-it.yaml index 2840e3dcf4b113..f2a1f64bfb524b 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-27b-it.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-27b-it.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-9b-it.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-9b-it.yaml index d7e19b46f6d6f2..b096b9b647c62e 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-9b-it.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-9b-it.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/glm4-9b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/glm4-9b-chat.yaml index 9b32a024774d06..87acc557b79dba 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/glm4-9b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/glm4-9b-chat.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-20b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-20b-chat.yaml new file mode 100644 index 00000000000000..60157c2b46ab89 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-20b-chat.yaml @@ -0,0 +1,39 @@ +model: internlm/internlm2_5-20b-chat +label: + en_US: internlm/internlm2_5-20b-chat +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 512 + min: 1 + max: 4096 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: frequency_penalty + use_template: frequency_penalty +pricing: + input: '1' + output: '1' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-7b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-7b-chat.yaml index 73ad4480aa2968..faf4af7ea3f7e4 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-7b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-7b-chat.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/llm.py b/api/core/model_runtime/model_providers/siliconflow/llm/llm.py index a9ce7b98c35a39..6015442c2b2405 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/llm.py +++ b/api/core/model_runtime/model_providers/siliconflow/llm/llm.py @@ -1,17 +1,33 @@ from collections.abc import Generator from typing import Optional, Union -from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel class SiliconflowLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -21,5 +37,55 @@ def validate_credentials(self, model: str, credentials: dict) -> None: @classmethod def _add_custom_parameters(cls, credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.siliconflow.cn/v1' + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.siliconflow.cn/v1" + + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + return AIModelEntity( + model=model, + label=I18nObject(en_US=model, zh_Hans=model), + model_type=ModelType.LLM, + features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] + if credentials.get("function_calling_type") == "tool_call" + else [], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000)), + ModelPropertyKey.MODE: LLMMode.CHAT.value, + }, + parameter_rules=[ + ParameterRule( + name="temperature", + use_template="temperature", + label=I18nObject(en_US="Temperature", zh_Hans="温度"), + type=ParameterType.FLOAT, + ), + ParameterRule( + name="max_tokens", + use_template="max_tokens", + default=512, + min=1, + max=int(credentials.get("max_tokens", 1024)), + label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"), + type=ParameterType.INT, + ), + ParameterRule( + name="top_p", + use_template="top_p", + label=I18nObject(en_US="Top P", zh_Hans="Top P"), + type=ParameterType.FLOAT, + ), + ParameterRule( + name="top_k", + use_template="top_k", + label=I18nObject(en_US="Top K", zh_Hans="Top K"), + type=ParameterType.FLOAT, + ), + ParameterRule( + name="frequency_penalty", + use_template="frequency_penalty", + label=I18nObject(en_US="Frequency Penalty", zh_Hans="重复惩罚"), + type=ParameterType.FLOAT, + ), + ], + ) diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-70b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-70b-instruct.yaml index 9993d781ac8959..d01770cb0106bb 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-70b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-70b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-8b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-8b-instruct.yaml index 60e3764789e1f5..3cd75d89e8bc84 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-8b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-8b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-405b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-405b-instruct.yaml index f992660aa2e66f..3506a70bccf9ce 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-405b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-405b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-70b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-70b-instruct.yaml index 1c69d63a400219..994a754a82b83c 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-70b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-70b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-8b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-8b-instruct.yaml index a97002a5ca3658..ebfa9aac9d9de0 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-8b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-8b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/mistral-7b-instruct-v0.2.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/mistral-7b-instruct-v0.2.yaml index 27664eab6c817a..a71d8688a87962 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/mistral-7b-instruct-v0.2.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/mistral-7b-instruct-v0.2.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: @@ -28,3 +37,4 @@ pricing: output: '0' unit: '0.000001' currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/mistral-8x7b-instruct-v0.1.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/mistral-8x7b-instruct-v0.1.yaml index fd7aada42848aa..db45a75c6de019 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/mistral-8x7b-instruct-v0.1.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/mistral-8x7b-instruct-v0.1.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: @@ -28,3 +37,4 @@ pricing: output: '1.26' unit: '0.000001' currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-1.5b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-1.5b-instruct.yaml index f6c976af8e7b6e..bec5d37c57b31b 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-1.5b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-1.5b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-57b-a14b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-57b-a14b-instruct.yaml index a996e919ea9f27..b2461335f84d6e 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-57b-a14b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-57b-a14b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-72b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-72b-instruct.yaml index a6e2c22dac87c0..e0f23bd89eb0e3 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-72b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-72b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-7b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-7b-instruct.yaml index d8bea5e12927e7..47a9da8119cf42 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-7b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-7b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-14b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-14b-instruct.yaml new file mode 100644 index 00000000000000..9cc5ac4c91d653 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-14b-instruct.yaml @@ -0,0 +1,39 @@ +model: Qwen/Qwen2.5-14B-Instruct +label: + en_US: Qwen/Qwen2.5-14B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 512 + min: 1 + max: 8192 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: frequency_penalty + use_template: frequency_penalty +pricing: + input: '0.7' + output: '0.7' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-32b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-32b-instruct.yaml new file mode 100644 index 00000000000000..c7fb21e9e10c28 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-32b-instruct.yaml @@ -0,0 +1,39 @@ +model: Qwen/Qwen2.5-32B-Instruct +label: + en_US: Qwen/Qwen2.5-32B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 512 + min: 1 + max: 8192 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: frequency_penalty + use_template: frequency_penalty +pricing: + input: '1.26' + output: '1.26' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-72b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-72b-instruct.yaml new file mode 100644 index 00000000000000..03136c88a11fb1 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-72b-instruct.yaml @@ -0,0 +1,39 @@ +model: Qwen/Qwen2.5-72B-Instruct +label: + en_US: Qwen/Qwen2.5-72B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 512 + min: 1 + max: 8192 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: frequency_penalty + use_template: frequency_penalty +pricing: + input: '4.13' + output: '4.13' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-7b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-7b-instruct.yaml new file mode 100644 index 00000000000000..99412adde7a209 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-7b-instruct.yaml @@ -0,0 +1,39 @@ +model: Qwen/Qwen2.5-7B-Instruct +label: + en_US: Qwen/Qwen2.5-7B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: max_tokens + use_template: max_tokens + type: int + default: 512 + min: 1 + max: 8192 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: frequency_penalty + use_template: frequency_penalty +pricing: + input: '0' + output: '0' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-coder-7b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-coder-7b-instruct.yaml new file mode 100644 index 00000000000000..76526200ccdccc --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-coder-7b-instruct.yaml @@ -0,0 +1,74 @@ +model: Qwen/Qwen2.5-Coder-7B-Instruct +label: + en_US: Qwen/Qwen2.5-Coder-7B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 8192 + min: 1 + max: 8192 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0' + output: '0' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-math-72b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-math-72b-instruct.yaml new file mode 100644 index 00000000000000..90afa0cfd5b96a --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-math-72b-instruct.yaml @@ -0,0 +1,74 @@ +model: Qwen/Qwen2.5-Math-72B-Instruct +label: + en_US: Qwen/Qwen2.5-Math-72B-Instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 2000 + min: 1 + max: 2000 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '4.13' + output: '4.13' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-34b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-34b-chat.yaml index 864ba46f1adfdf..3e25f82369f070 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-34b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-34b-chat.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-6b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-6b-chat.yaml index fe4c8b4b3e0350..827b2ce1e5d765 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-6b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-6b-chat.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-9b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-9b-chat.yaml index c61f0dc53fe6ec..112fcbfe97be83 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-9b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-9b-chat.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py b/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py index 683591581638e4..58b033d28aa90e 100644 --- a/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py @@ -16,39 +16,38 @@ class SiliconflowRerankModel(RerankModel): - - def _invoke(self, model: str, credentials: dict, query: str, docs: list[str], - score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: if len(docs) == 0: return RerankResult(model=model, docs=[]) - base_url = credentials.get('base_url', 'https://api.siliconflow.cn/v1') - if base_url.endswith('/'): - base_url = base_url[:-1] + base_url = credentials.get("base_url", "https://api.siliconflow.cn/v1") + base_url = base_url.removesuffix("/") try: response = httpx.post( - base_url + '/rerank', - json={ - "model": model, - "query": query, - "documents": docs, - "top_n": top_n, - "return_documents": True - }, - headers={"Authorization": f"Bearer {credentials.get('api_key')}"} + base_url + "/rerank", + json={"model": model, "query": query, "documents": docs, "top_n": top_n, "return_documents": True}, + headers={"Authorization": f"Bearer {credentials.get('api_key')}"}, ) response.raise_for_status() results = response.json() rerank_documents = [] - for result in results['results']: + for result in results["results"]: rerank_document = RerankDocument( - index=result['index'], - text=result['document']['text'], - score=result['relevance_score'], + index=result["index"], + text=result["document"]["text"], + score=result["relevance_score"], ) - if score_threshold is None or result['relevance_score'] >= score_threshold: + if score_threshold is None or result["relevance_score"] >= score_threshold: rerank_documents.append(rerank_document) return RerankResult(model=model, docs=rerank_documents) @@ -57,7 +56,6 @@ def _invoke(self, model: str, credentials: dict, query: str, docs: list[str], def validate_credentials(self, model: str, credentials: dict) -> None: try: - self._invoke( model=model, credentials=credentials, @@ -68,7 +66,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " "are a political division controlled by the United States. Its capital is Saipan.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -83,5 +81,5 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] InvokeServerUnavailableError: [httpx.RemoteProtocolError], InvokeRateLimitError: [], InvokeAuthorizationError: [httpx.HTTPStatusError], - InvokeBadRequestError: [httpx.RequestError] - } \ No newline at end of file + InvokeBadRequestError: [httpx.RequestError], + } diff --git a/api/core/model_runtime/model_providers/siliconflow/siliconflow.py b/api/core/model_runtime/model_providers/siliconflow/siliconflow.py index dd0eea362a5f83..e121ab8c7e4e2f 100644 --- a/api/core/model_runtime/model_providers/siliconflow/siliconflow.py +++ b/api/core/model_runtime/model_providers/siliconflow/siliconflow.py @@ -8,7 +8,6 @@ class SiliconflowProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='deepseek-ai/DeepSeek-V2-Chat', - credentials=credentials - ) + model_instance.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml b/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml index c46a891604c480..71f9a9238145c0 100644 --- a/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml @@ -20,6 +20,7 @@ supported_model_types: - speech2text configurate_methods: - predefined-model + - customizable-model provider_credential_schema: credential_form_schemas: - variable: api_key @@ -30,3 +31,57 @@ provider_credential_schema: placeholder: zh_Hans: 在此输入您的 API Key en_US: Enter your API Key +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + type: text-input + default: '4096' + placeholder: + zh_Hans: 在此输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: max_tokens + label: + zh_Hans: 最大 token 上限 + en_US: Upper bound for max tokens + default: '4096' + type: text-input + show_on: + - variable: __model_type + value: llm + - variable: function_calling_type + label: + en_US: Function calling + type: select + required: false + default: no_call + options: + - value: no_call + label: + en_US: Not Support + zh_Hans: 不支持 + - value: function_call + label: + en_US: Support + zh_Hans: 支持 + show_on: + - variable: __model_type + value: llm diff --git a/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py b/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py index 6ad3cab5873c69..8d1932863e09d9 100644 --- a/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py @@ -8,9 +8,7 @@ class SiliconflowSpeech2TextModel(OAICompatSpeech2TextModel): Model class for Siliconflow Speech to text model. """ - def _invoke( - self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None - ) -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model diff --git a/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py index c58765cecb9a69..5e29a4827a39f8 100644 --- a/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py @@ -1,5 +1,6 @@ from typing import Optional +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import ( OAICompatEmbeddingModel, @@ -10,20 +11,36 @@ class SiliconflowTextEmbeddingModel(OAICompatEmbeddingModel): """ Model class for Siliconflow text embedding model. """ + def validate_credentials(self, model: str, credentials: dict) -> None: self._add_custom_parameters(credentials) super().validate_credentials(model, credentials) - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :param input_type: input type + :return: embeddings result + """ self._add_custom_parameters(credentials) return super()._invoke(model, credentials, texts, user) - + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: self._add_custom_parameters(credentials) return super().get_num_tokens(model, credentials, texts) - + @classmethod def _add_custom_parameters(cls, credentials: dict) -> None: - credentials['endpoint_url'] = 'https://api.siliconflow.cn/v1' \ No newline at end of file + credentials["endpoint_url"] = "https://api.siliconflow.cn/v1" diff --git a/api/core/model_runtime/model_providers/spark/llm/_client.py b/api/core/model_runtime/model_providers/spark/llm/_client.py index 10da265701a423..48911f657a52e3 100644 --- a/api/core/model_runtime/model_providers/spark/llm/_client.py +++ b/api/core/model_runtime/model_providers/spark/llm/_client.py @@ -15,51 +15,36 @@ class SparkLLMClient: def __init__(self, model: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None): - domain = 'spark-api.xf-yun.com' - endpoint = 'chat' + domain = "spark-api.xf-yun.com" + endpoint = "chat" if api_domain: domain = api_domain - if model == 'spark-v3': - endpoint = 'multimodal' model_api_configs = { - 'spark-1.5': { - 'version': 'v1.1', - 'chat_domain': 'general' - }, - 'spark-2': { - 'version': 'v2.1', - 'chat_domain': 'generalv2' - }, - 'spark-3': { - 'version': 'v3.1', - 'chat_domain': 'generalv3' - }, - 'spark-3.5': { - 'version': 'v3.5', - 'chat_domain': 'generalv3.5' - }, - 'spark-4': { - 'version': 'v4.0', - 'chat_domain': '4.0Ultra' - } + "spark-lite": {"version": "v1.1", "chat_domain": "general"}, + "spark-pro": {"version": "v3.1", "chat_domain": "generalv3"}, + "spark-pro-128k": {"version": "pro-128k", "chat_domain": "pro-128k"}, + "spark-max": {"version": "v3.5", "chat_domain": "generalv3.5"}, + "spark-max-32k": {"version": "max-32k", "chat_domain": "max-32k"}, + "spark-4.0-ultra": {"version": "v4.0", "chat_domain": "4.0Ultra"}, } - api_version = model_api_configs[model]['version'] + api_version = model_api_configs[model]["version"] + + self.chat_domain = model_api_configs[model]["chat_domain"] + + if model in ["spark-pro-128k", "spark-max-32k"]: + self.api_base = f"wss://{domain}/{endpoint}/{api_version}" + else: + self.api_base = f"wss://{domain}/{api_version}/{endpoint}" - self.chat_domain = model_api_configs[model]['chat_domain'] - self.api_base = f"wss://{domain}/{api_version}/{endpoint}" self.app_id = app_id self.ws_url = self.create_url( - urlparse(self.api_base).netloc, - urlparse(self.api_base).path, - self.api_base, - api_key, - api_secret + urlparse(self.api_base).netloc, urlparse(self.api_base).path, self.api_base, api_key, api_secret ) self.queue = queue.Queue() - self.blocking_message = '' + self.blocking_message = "" def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str: # generate timestamp by RFC1123 @@ -71,33 +56,32 @@ def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secr signature_origin += "GET " + path + " HTTP/1.1" # encrypt using hmac-sha256 - signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'), - digestmod=hashlib.sha256).digest() + signature_sha = hmac.new( + api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256 + ).digest() - signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') + signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8") - authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' + authorization_origin = ( + f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line",' + f' signature="{signature_sha_base64}"' + ) - authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') + authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8") - v = { - "authorization": authorization, - "date": date, - "host": host - } + v = {"authorization": authorization, "date": date, "host": host} # generate url - url = api_base + '?' + urlencode(v) + url = api_base + "?" + urlencode(v) return url - def run(self, messages: list, user_id: str, - model_kwargs: Optional[dict] = None, streaming: bool = False): + def run(self, messages: list, user_id: str, model_kwargs: Optional[dict] = None, streaming: bool = False): websocket.enableTrace(False) ws = websocket.WebSocketApp( self.ws_url, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, - on_open=self.on_open + on_open=self.on_open, ) ws.messages = messages ws.user_id = user_id @@ -106,86 +90,71 @@ def run(self, messages: list, user_id: str, ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) def on_error(self, ws, error): - self.queue.put({ - 'status_code': error.status_code, - 'error': error.resp_body.decode('utf-8') - }) + self.queue.put({"status_code": error.status_code, "error": error.resp_body.decode("utf-8")}) ws.close() def on_close(self, ws, close_status_code, close_reason): - self.queue.put({'done': True}) + self.queue.put({"done": True}) def on_open(self, ws): - self.blocking_message = '' - data = json.dumps(self.gen_params( - messages=ws.messages, - user_id=ws.user_id, - model_kwargs=ws.model_kwargs - )) + self.blocking_message = "" + data = json.dumps(self.gen_params(messages=ws.messages, user_id=ws.user_id, model_kwargs=ws.model_kwargs)) ws.send(data) def on_message(self, ws, message): data = json.loads(message) - code = data['header']['code'] + code = data["header"]["code"] if code != 0: - self.queue.put({ - 'status_code': 400, - 'error': f"Code: {code}, Error: {data['header']['message']}" - }) + self.queue.put({"status_code": 400, "error": f"Code: {code}, Error: {data['header']['message']}"}) ws.close() else: choices = data["payload"]["choices"] status = choices["status"] content = choices["text"][0]["content"] if ws.streaming: - self.queue.put({'data': content}) + self.queue.put({"data": content}) else: self.blocking_message += content if status == 2: if not ws.streaming: - self.queue.put({'data': self.blocking_message}) + self.queue.put({"data": self.blocking_message}) ws.close() - def gen_params(self, messages: list, user_id: str, - model_kwargs: Optional[dict] = None) -> dict: + def gen_params(self, messages: list, user_id: str, model_kwargs: Optional[dict] = None) -> dict: data = { "header": { "app_id": self.app_id, # resolve this error message => $.header.uid' length must be less or equal than 32 - "uid": user_id[:32] if user_id else None - }, - "parameter": { - "chat": { - "domain": self.chat_domain - } + "uid": user_id[:32] if user_id else None, }, - "payload": { - "message": { - "text": messages - } - } + "parameter": {"chat": {"domain": self.chat_domain}}, + "payload": {"message": {"text": messages}}, } if model_kwargs: - data['parameter']['chat'].update(model_kwargs) + data["parameter"]["chat"].update(model_kwargs) return data def subscribe(self): while True: content = self.queue.get() - if 'error' in content: - if content['status_code'] == 401: - raise SparkError('[Spark] The credentials you provided are incorrect. ' - 'Please double-check and fill them in again.') - elif content['status_code'] == 403: - raise SparkError("[Spark] Sorry, the credentials you provided are access denied. " - "Please try again after obtaining the necessary permissions.") + if "error" in content: + if content["status_code"] == 401: + raise SparkError( + "[Spark] The credentials you provided are incorrect. " + "Please double-check and fill them in again." + ) + elif content["status_code"] == 403: + raise SparkError( + "[Spark] Sorry, the credentials you provided are access denied. " + "Please try again after obtaining the necessary permissions." + ) else: raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}") - if 'data' not in content: + if "data" not in content: break yield content diff --git a/api/core/model_runtime/model_providers/spark/llm/_position.yaml b/api/core/model_runtime/model_providers/spark/llm/_position.yaml index e49ee97db7cf56..73f39cb1197b48 100644 --- a/api/core/model_runtime/model_providers/spark/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/spark/llm/_position.yaml @@ -1,3 +1,9 @@ +- spark-max-32k +- spark-4.0-ultra +- spark-max +- spark-pro-128k +- spark-pro +- spark-lite - spark-4 - spark-3.5 - spark-3 diff --git a/api/core/model_runtime/model_providers/spark/llm/llm.py b/api/core/model_runtime/model_providers/spark/llm/llm.py index 65beae517c72e9..1181ba699af886 100644 --- a/api/core/model_runtime/model_providers/spark/llm/llm.py +++ b/api/core/model_runtime/model_providers/spark/llm/llm.py @@ -25,12 +25,17 @@ class SparkLargeLanguageModel(LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -47,8 +52,13 @@ def _invoke(self, model: str, credentials: dict, # invoke model return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -80,15 +90,21 @@ def validate_credentials(self, model: str, credentials: dict) -> None: model_parameters={ "temperature": 0.5, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -103,7 +119,7 @@ def _generate(self, model: str, credentials: dict, """ extra_model_kwargs = {} if stop: - extra_model_kwargs['stop_sequences'] = stop + extra_model_kwargs["stop_sequences"] = stop # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) @@ -113,21 +129,33 @@ def _generate(self, model: str, credentials: dict, **credentials_kwargs, ) - thread = threading.Thread(target=client.run, args=( - [{ 'role': prompt_message.role.value, 'content': prompt_message.content } for prompt_message in prompt_messages], - user, - model_parameters, - stream - )) + thread = threading.Thread( + target=client.run, + args=( + [ + {"role": prompt_message.role.value, "content": prompt_message.content} + for prompt_message in prompt_messages + ], + user, + model_parameters, + stream, + ), + ) thread.start() if stream: return self._handle_generate_stream_response(thread, model, credentials, client, prompt_messages) return self._handle_generate_response(thread, model, credentials, client, prompt_messages) - - def _handle_generate_response(self, thread: threading.Thread, model: str, credentials: dict, client: SparkLLMClient, - prompt_messages: list[PromptMessage]) -> LLMResult: + + def _handle_generate_response( + self, + thread: threading.Thread, + model: str, + credentials: dict, + client: SparkLLMClient, + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm response @@ -140,7 +168,7 @@ def _handle_generate_response(self, thread: threading.Thread, model: str, creden for content in client.subscribe(): if isinstance(content, dict): - delta = content['data'] + delta = content["data"] else: delta = content @@ -148,9 +176,7 @@ def _handle_generate_response(self, thread: threading.Thread, model: str, creden thread.join() # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=completion - ) + assistant_prompt_message = AssistantPromptMessage(content=completion) # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -168,9 +194,15 @@ def _handle_generate_response(self, thread: threading.Thread, model: str, creden ) return result - - def _handle_generate_stream_response(self, thread: threading.Thread, model: str, credentials: dict, client: SparkLLMClient, - prompt_messages: list[PromptMessage]) -> Generator: + + def _handle_generate_stream_response( + self, + thread: threading.Thread, + model: str, + credentials: dict, + client: SparkLLMClient, + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -181,29 +213,28 @@ def _handle_generate_stream_response(self, thread: threading.Thread, model: str, :param prompt_messages: prompt messages :return: llm response chunk generator result """ + completion = "" for index, content in enumerate(client.subscribe()): if isinstance(content, dict): - delta = content['data'] + delta = content["data"] else: delta = content - + completion += delta assistant_prompt_message = AssistantPromptMessage( - content=delta if delta else '', + content=delta or "", + ) + temp_assistant_prompt_message = AssistantPromptMessage( + content=completion, ) - prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + completion_tokens = self.get_num_tokens(model, credentials, [temp_assistant_prompt_message]) # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message, usage=usage), ) thread.join() @@ -216,9 +247,9 @@ def _to_credential_kwargs(self, credentials: dict) -> dict: :return: """ credentials_kwargs = { - "app_id": credentials['app_id'], - "api_secret": credentials['api_secret'], - "api_key": credentials['api_key'], + "app_id": credentials["app_id"], + "api_secret": credentials["api_secret"], + "api_key": credentials["api_key"], } return credentials_kwargs @@ -244,7 +275,7 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: raise ValueError(f"Got unknown type {message}") return message_text - + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Anthropic model @@ -254,10 +285,7 @@ def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() @@ -277,5 +305,5 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } diff --git a/api/core/model_runtime/model_providers/spark/llm/spark-1.5.yaml b/api/core/model_runtime/model_providers/spark/llm/spark-1.5.yaml index 41b8765fe6c4f1..fcd65c24e0f60c 100644 --- a/api/core/model_runtime/model_providers/spark/llm/spark-1.5.yaml +++ b/api/core/model_runtime/model_providers/spark/llm/spark-1.5.yaml @@ -1,4 +1,5 @@ model: spark-1.5 +deprecated: true label: en_US: Spark V1.5 model_type: llm diff --git a/api/core/model_runtime/model_providers/spark/llm/spark-3.5.yaml b/api/core/model_runtime/model_providers/spark/llm/spark-3.5.yaml index 6d24932ea83076..86617a53d0d4fb 100644 --- a/api/core/model_runtime/model_providers/spark/llm/spark-3.5.yaml +++ b/api/core/model_runtime/model_providers/spark/llm/spark-3.5.yaml @@ -1,4 +1,5 @@ model: spark-3.5 +deprecated: true label: en_US: Spark V3.5 model_type: llm diff --git a/api/core/model_runtime/model_providers/spark/llm/spark-3.yaml b/api/core/model_runtime/model_providers/spark/llm/spark-3.yaml index 2ef9e10f453f6b..9f296c684db78d 100644 --- a/api/core/model_runtime/model_providers/spark/llm/spark-3.yaml +++ b/api/core/model_runtime/model_providers/spark/llm/spark-3.yaml @@ -1,4 +1,5 @@ model: spark-3 +deprecated: true label: en_US: Spark V3.0 model_type: llm diff --git a/api/core/model_runtime/model_providers/spark/llm/spark-4.0-ultra.yaml b/api/core/model_runtime/model_providers/spark/llm/spark-4.0-ultra.yaml new file mode 100644 index 00000000000000..bbf85764f1c8c1 --- /dev/null +++ b/api/core/model_runtime/model_providers/spark/llm/spark-4.0-ultra.yaml @@ -0,0 +1,42 @@ +model: spark-4.0-ultra +label: + en_US: Spark 4.0 Ultra +model_type: llm +model_properties: + mode: chat +parameter_rules: + - name: temperature + use_template: temperature + default: 0.5 + help: + zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。 + en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question. + - name: max_tokens + use_template: max_tokens + default: 4096 + min: 1 + max: 8192 + help: + zh_Hans: 模型回答的tokens的最大长度。 + en_US: Maximum length of tokens for the model response. + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + default: 4 + min: 1 + max: 6 + help: + zh_Hans: 从 k 个候选中随机选择一个(非等概率)。 + en_US: Randomly select one from k candidates (non-equal probability). + required: false + - name: show_ref_label + label: + zh_Hans: 联网检索 + en_US: web search + type: boolean + default: false + help: + zh_Hans: 该参数仅4.0 Ultra版本支持,当设置为true时,如果输入内容触发联网检索插件,会先返回检索信源列表,然后再返回星火回复结果,否则仅返回星火回复结果 + en_US: The parameter is only supported in the 4.0 Ultra version. When set to true, if the input triggers the online search plugin, it will first return a list of search sources and then return the Spark response. Otherwise, it will only return the Spark response. diff --git a/api/core/model_runtime/model_providers/spark/llm/spark-4.yaml b/api/core/model_runtime/model_providers/spark/llm/spark-4.yaml index 4b0bf27029ff76..4b5529e81c0602 100644 --- a/api/core/model_runtime/model_providers/spark/llm/spark-4.yaml +++ b/api/core/model_runtime/model_providers/spark/llm/spark-4.yaml @@ -1,4 +1,5 @@ model: spark-4 +deprecated: true label: en_US: Spark V4.0 model_type: llm diff --git a/api/core/model_runtime/model_providers/spark/llm/spark-lite.yaml b/api/core/model_runtime/model_providers/spark/llm/spark-lite.yaml new file mode 100644 index 00000000000000..1f6141a816b8e1 --- /dev/null +++ b/api/core/model_runtime/model_providers/spark/llm/spark-lite.yaml @@ -0,0 +1,33 @@ +model: spark-lite +label: + en_US: Spark Lite +model_type: llm +model_properties: + mode: chat +parameter_rules: + - name: temperature + use_template: temperature + default: 0.5 + help: + zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。 + en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question. + - name: max_tokens + use_template: max_tokens + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 模型回答的tokens的最大长度。 + en_US: Maximum length of tokens for the model response. + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + default: 4 + min: 1 + max: 6 + help: + zh_Hans: 从 k 个候选中随机选择一个(非等概率)。 + en_US: Randomly select one from k candidates (non-equal probability). + required: false diff --git a/api/core/model_runtime/model_providers/spark/llm/spark-max-32k.yaml b/api/core/model_runtime/model_providers/spark/llm/spark-max-32k.yaml new file mode 100644 index 00000000000000..1a1ab6844c69c5 --- /dev/null +++ b/api/core/model_runtime/model_providers/spark/llm/spark-max-32k.yaml @@ -0,0 +1,33 @@ +model: spark-max-32k +label: + en_US: Spark Max-32K +model_type: llm +model_properties: + mode: chat +parameter_rules: + - name: temperature + use_template: temperature + default: 0.5 + help: + zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。 + en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question. + - name: max_tokens + use_template: max_tokens + default: 4096 + min: 1 + max: 8192 + help: + zh_Hans: 模型回答的tokens的最大长度。 + en_US: Maximum length of tokens for the model response. + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + default: 4 + min: 1 + max: 6 + help: + zh_Hans: 从 k 个候选中随机选择一个(非等概率)。 + en_US: Randomly select one from k candidates (non-equal probability). + required: false diff --git a/api/core/model_runtime/model_providers/spark/llm/spark-max.yaml b/api/core/model_runtime/model_providers/spark/llm/spark-max.yaml new file mode 100644 index 00000000000000..71eb2b86d36ac4 --- /dev/null +++ b/api/core/model_runtime/model_providers/spark/llm/spark-max.yaml @@ -0,0 +1,33 @@ +model: spark-max +label: + en_US: Spark Max +model_type: llm +model_properties: + mode: chat +parameter_rules: + - name: temperature + use_template: temperature + default: 0.5 + help: + zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。 + en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question. + - name: max_tokens + use_template: max_tokens + default: 4096 + min: 1 + max: 8192 + help: + zh_Hans: 模型回答的tokens的最大长度。 + en_US: Maximum length of tokens for the model response. + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + default: 4 + min: 1 + max: 6 + help: + zh_Hans: 从 k 个候选中随机选择一个(非等概率)。 + en_US: Randomly select one from k candidates (non-equal probability). + required: false diff --git a/api/core/model_runtime/model_providers/spark/llm/spark-pro-128k.yaml b/api/core/model_runtime/model_providers/spark/llm/spark-pro-128k.yaml new file mode 100644 index 00000000000000..da1fead6da940d --- /dev/null +++ b/api/core/model_runtime/model_providers/spark/llm/spark-pro-128k.yaml @@ -0,0 +1,33 @@ +model: spark-pro-128k +label: + en_US: Spark Pro-128K +model_type: llm +model_properties: + mode: chat +parameter_rules: + - name: temperature + use_template: temperature + default: 0.5 + help: + zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。 + en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question. + - name: max_tokens + use_template: max_tokens + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 模型回答的tokens的最大长度。 + en_US: Maximum length of tokens for the model response. + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + default: 4 + min: 1 + max: 6 + help: + zh_Hans: 从 k 个候选中随机选择一个(非等概率)。 + en_US: Randomly select one from k candidates (non-equal probability). + required: false diff --git a/api/core/model_runtime/model_providers/spark/llm/spark-pro.yaml b/api/core/model_runtime/model_providers/spark/llm/spark-pro.yaml new file mode 100644 index 00000000000000..9ee479f15b0504 --- /dev/null +++ b/api/core/model_runtime/model_providers/spark/llm/spark-pro.yaml @@ -0,0 +1,33 @@ +model: spark-pro +label: + en_US: Spark Pro +model_type: llm +model_properties: + mode: chat +parameter_rules: + - name: temperature + use_template: temperature + default: 0.5 + help: + zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。 + en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question. + - name: max_tokens + use_template: max_tokens + default: 4096 + min: 1 + max: 8192 + help: + zh_Hans: 模型回答的tokens的最大长度。 + en_US: Maximum length of tokens for the model response. + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + default: 4 + min: 1 + max: 6 + help: + zh_Hans: 从 k 个候选中随机选择一个(非等概率)。 + en_US: Randomly select one from k candidates (non-equal probability). + required: false diff --git a/api/core/model_runtime/model_providers/stepfun/llm/llm.py b/api/core/model_runtime/model_providers/stepfun/llm/llm.py index 6f6ffc8faa9be3..43b91a1aec9c57 100644 --- a/api/core/model_runtime/model_providers/stepfun/llm/llm.py +++ b/api/core/model_runtime/model_providers/stepfun/llm/llm.py @@ -30,11 +30,17 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) self._add_function_call(model, credentials) user = user[:32] if user else None @@ -44,56 +50,56 @@ def validate_credentials(self, model: str, credentials: dict) -> None: self._add_custom_parameters(credentials) super().validate_credentials(model, credentials) - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: return AIModelEntity( model=model, label=I18nObject(en_US=model, zh_Hans=model), model_type=ModelType.LLM, - features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] - if credentials.get('function_calling_type') == 'tool_call' - else [], + features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] + if credentials.get("function_calling_type") == "tool_call" + else [], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 8000)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000)), ModelPropertyKey.MODE: LLMMode.CHAT.value, }, parameter_rules=[ ParameterRule( - name='temperature', - use_template='temperature', - label=I18nObject(en_US='Temperature', zh_Hans='温度'), + name="temperature", + use_template="temperature", + label=I18nObject(en_US="Temperature", zh_Hans="温度"), type=ParameterType.FLOAT, ), ParameterRule( - name='max_tokens', - use_template='max_tokens', + name="max_tokens", + use_template="max_tokens", default=512, min=1, - max=int(credentials.get('max_tokens', 1024)), - label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'), + max=int(credentials.get("max_tokens", 1024)), + label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"), type=ParameterType.INT, ), ParameterRule( - name='top_p', - use_template='top_p', - label=I18nObject(en_US='Top P', zh_Hans='Top P'), + name="top_p", + use_template="top_p", + label=I18nObject(en_US="Top P", zh_Hans="Top P"), type=ParameterType.FLOAT, ), - ] + ], ) def _add_custom_parameters(self, credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.stepfun.com/v1' + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.stepfun.com/v1" def _add_function_call(self, model: str, credentials: dict) -> None: model_schema = self.get_model_schema(model, credentials) - if model_schema and { - ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL - }.intersection(model_schema.features or []): - credentials['function_calling_type'] = 'tool_call' + if model_schema and {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}.intersection( + model_schema.features or [] + ): + credentials["function_calling_type"] = "tool_call" - def _convert_prompt_message_to_dict(self, message: PromptMessage,credentials: Optional[dict] = None) -> dict: + def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict: """ Convert PromptMessage to dict for OpenAI API format """ @@ -106,10 +112,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage,credentials: Op for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -117,7 +120,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage,credentials: Op "type": "image_url", "image_url": { "url": message_content.data, - } + }, } sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} @@ -127,14 +130,16 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage,credentials: Op if message.tool_calls: message_dict["tool_calls"] = [] for function_call in message.tool_calls: - message_dict["tool_calls"].append({ - "id": function_call.id, - "type": function_call.type, - "function": { - "name": function_call.function.name, - "arguments": function_call.function.arguments + message_dict["tool_calls"].append( + { + "id": function_call.id, + "type": function_call.type, + "function": { + "name": function_call.function.name, + "arguments": function_call.function.arguments, + }, } - }) + ) elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id} @@ -160,21 +165,26 @@ def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[ if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "", - arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else "" + name=response_tool_call["function"]["name"] + if response_tool_call.get("function", {}).get("name") + else "", + arguments=response_tool_call["function"]["arguments"] + if response_tool_call.get("function", {}).get("arguments") + else "", ) tool_call = AssistantPromptMessage.ToolCall( id=response_tool_call["id"] if response_tool_call.get("id") else "", type=response_tool_call["type"] if response_tool_call.get("type") else "", - function=function + function=function, ) tool_calls.append(tool_call) return tool_calls - def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -184,11 +194,12 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon :param prompt_messages: prompt messages :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" chunk_index = 0 - def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \ - -> LLMResultChunk: + def create_final_llm_result_chunk( + index: int, message: AssistantPromptMessage, finish_reason: str + ) -> LLMResultChunk: # calculate num tokens prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) completion_tokens = self._num_tokens_from_string(model, full_assistant_content) @@ -199,12 +210,7 @@ def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, f return LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=message, - finish_reason=finish_reason, - usage=usage - ) + delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), ) tools_calls: list[AssistantPromptMessage.ToolCall] = [] @@ -218,9 +224,9 @@ def get_tool_call(tool_name: str): tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None) if tool_call is None: tool_call = AssistantPromptMessage.ToolCall( - id='', - type='', - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="") + id="", + type="", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""), ) tools_calls.append(tool_call) @@ -242,9 +248,9 @@ def get_tool_call(tool_name: str): for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"): if chunk: # ignore sse comments - if chunk.startswith(':'): + if chunk.startswith(":"): continue - decoded_chunk = chunk.strip().lstrip('data: ').lstrip() + decoded_chunk = chunk.strip().lstrip("data: ").lstrip() chunk_json = None try: chunk_json = json.loads(decoded_chunk) @@ -253,21 +259,21 @@ def get_tool_call(tool_name: str): yield create_final_llm_result_chunk( index=chunk_index + 1, message=AssistantPromptMessage(content=""), - finish_reason="Non-JSON encountered." + finish_reason="Non-JSON encountered.", ) break - if not chunk_json or len(chunk_json['choices']) == 0: + if not chunk_json or len(chunk_json["choices"]) == 0: continue - choice = chunk_json['choices'][0] - finish_reason = chunk_json['choices'][0].get('finish_reason') + choice = chunk_json["choices"][0] + finish_reason = chunk_json["choices"][0].get("finish_reason") chunk_index += 1 - if 'delta' in choice: - delta = choice['delta'] - delta_content = delta.get('content') + if "delta" in choice: + delta = choice["delta"] + delta_content = delta.get("content") - assistant_message_tool_calls = delta.get('tool_calls', None) + assistant_message_tool_calls = delta.get("tool_calls", None) # assistant_message_function_call = delta.delta.function_call # extract tool calls from response @@ -275,19 +281,18 @@ def get_tool_call(tool_name: str): tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) increase_tool_call(tool_calls) - if delta_content is None or delta_content == '': + if delta_content is None or delta_content == "": continue # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta_content, - tool_calls=tool_calls if assistant_message_tool_calls else [] + content=delta_content, tool_calls=tool_calls if assistant_message_tool_calls else [] ) full_assistant_content += delta_content - elif 'text' in choice: - choice_text = choice.get('text', '') - if choice_text == '': + elif "text" in choice: + choice_text = choice.get("text", "") + if choice_text == "": continue # transform assistant message to prompt message @@ -303,26 +308,21 @@ def get_tool_call(tool_name: str): delta=LLMResultChunkDelta( index=chunk_index, message=assistant_prompt_message, - ) + ), ) chunk_index += 1 - + if tools_calls: yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=chunk_index, - message=AssistantPromptMessage( - tool_calls=tools_calls, - content="" - ), - ) + message=AssistantPromptMessage(tool_calls=tools_calls, content=""), + ), ) yield create_final_llm_result_chunk( - index=chunk_index, - message=AssistantPromptMessage(content=""), - finish_reason=finish_reason - ) \ No newline at end of file + index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason + ) diff --git a/api/core/model_runtime/model_providers/stepfun/stepfun.py b/api/core/model_runtime/model_providers/stepfun/stepfun.py index 50b17392b54276..e1c41a91537cd1 100644 --- a/api/core/model_runtime/model_providers/stepfun/stepfun.py +++ b/api/core/model_runtime/model_providers/stepfun/stepfun.py @@ -8,7 +8,6 @@ class StepfunProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,12 +18,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='step-1-8k', - credentials=credentials - ) + model_instance.validate_credentials(model="step-1-8k", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py b/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py index c3e3b7c2580a6d..c3c21793e8eb39 100644 --- a/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py +++ b/api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py @@ -1,6 +1,7 @@ import base64 import hashlib import hmac +import operator import time import requests @@ -67,10 +68,10 @@ def set_reinforce_hotword(self, reinforce_hotword): class FlashRecognizer: """ - reponse: + response: request_id string - status Integer - message String + status Integer + message String audio_duration Integer flash_result Result Array @@ -81,16 +82,16 @@ class FlashRecognizer: Sentence: text String - start_time Integer - end_time Integer - speaker_id Integer + start_time Integer + end_time Integer + speaker_id Integer word_list Word Array Word: - word String - start_time Integer - end_time Integer - stable_flag: Integer + word String + start_time Integer + end_time Integer + stable_flag: Integer """ def __init__(self, appid, credential): @@ -100,13 +101,13 @@ def __init__(self, appid, credential): def _format_sign_string(self, param): signstr = "POSTasr.cloud.tencent.com/asr/flash/v1/" for t in param: - if 'appid' in t: + if "appid" in t: signstr += str(t[1]) break signstr += "?" for x in param: tmp = x - if 'appid' in x: + if "appid" in x: continue for t in tmp: signstr += str(t) @@ -121,31 +122,38 @@ def _build_header(self): return header def _sign(self, signstr, secret_key): - hmacstr = hmac.new(secret_key.encode('utf-8'), - signstr.encode('utf-8'), hashlib.sha1).digest() + hmacstr = hmac.new(secret_key.encode("utf-8"), signstr.encode("utf-8"), hashlib.sha1).digest() s = base64.b64encode(hmacstr) - s = s.decode('utf-8') + s = s.decode("utf-8") return s def _build_req_with_signature(self, secret_key, params, header): - query = sorted(params.items(), key=lambda d: d[0]) + query = sorted(params.items(), key=operator.itemgetter(0)) signstr = self._format_sign_string(query) signature = self._sign(signstr, secret_key) header["Authorization"] = signature - requrl = "https://" - requrl += signstr[4::] - return requrl + req_url = "https://" + req_url += signstr[4::] + return req_url def _create_query_arr(self, req): return { - 'appid': self.appid, 'secretid': self.credential.secret_id, 'timestamp': str(int(time.time())), - 'engine_type': req.engine_type, 'voice_format': req.voice_format, - 'speaker_diarization': req.speaker_diarization, 'hotword_id': req.hotword_id, - 'customization_id': req.customization_id, 'filter_dirty': req.filter_dirty, - 'filter_modal': req.filter_modal, 'filter_punc': req.filter_punc, - 'convert_num_mode': req.convert_num_mode, 'word_info': req.word_info, - 'first_channel_only': req.first_channel_only, 'reinforce_hotword': req.reinforce_hotword, - 'sentence_max_length': req.sentence_max_length + "appid": self.appid, + "secretid": self.credential.secret_id, + "timestamp": str(int(time.time())), + "engine_type": req.engine_type, + "voice_format": req.voice_format, + "speaker_diarization": req.speaker_diarization, + "hotword_id": req.hotword_id, + "customization_id": req.customization_id, + "filter_dirty": req.filter_dirty, + "filter_modal": req.filter_modal, + "filter_punc": req.filter_punc, + "convert_num_mode": req.convert_num_mode, + "word_info": req.word_info, + "first_channel_only": req.first_channel_only, + "reinforce_hotword": req.reinforce_hotword, + "sentence_max_length": req.sentence_max_length, } def recognize(self, req, data): diff --git a/api/core/model_runtime/model_providers/tencent/speech2text/speech2text.py b/api/core/model_runtime/model_providers/tencent/speech2text/speech2text.py index 00ec5aa9c8202e..5b427663ca85b0 100644 --- a/api/core/model_runtime/model_providers/tencent/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/tencent/speech2text/speech2text.py @@ -18,9 +18,7 @@ class TencentSpeech2TextModel(Speech2TextModel): - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -43,7 +41,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: try: audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self._speech2text_invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -83,10 +81,6 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - requests.exceptions.ConnectionError - ], - InvokeAuthorizationError: [ - CredentialsValidateFailedError - ] + InvokeConnectionError: [requests.exceptions.ConnectionError], + InvokeAuthorizationError: [CredentialsValidateFailedError], } diff --git a/api/core/model_runtime/model_providers/tencent/tencent.py b/api/core/model_runtime/model_providers/tencent/tencent.py index dd9f90bb474f4e..79c6f577b8d5ef 100644 --- a/api/core/model_runtime/model_providers/tencent/tencent.py +++ b/api/core/model_runtime/model_providers/tencent/tencent.py @@ -18,12 +18,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: """ try: model_instance = self.get_model_instance(ModelType.SPEECH2TEXT) - model_instance.validate_credentials( - model='tencent', - credentials=credentials - ) + model_instance.validate_credentials(model="tencent", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/togetherai/llm/llm.py b/api/core/model_runtime/model_providers/togetherai/llm/llm.py index bb802d407157bf..b96d43979ef54a 100644 --- a/api/core/model_runtime/model_providers/togetherai/llm/llm.py +++ b/api/core/model_runtime/model_providers/togetherai/llm/llm.py @@ -22,16 +22,21 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _update_endpoint_url(self, credentials: dict): - credentials['endpoint_url'] = "https://api.together.xyz/v1" + credentials["endpoint_url"] = "https://api.together.xyz/v1" return credentials - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) @@ -41,12 +46,22 @@ def validate_credentials(self, model: str, credentials: dict) -> None: return super().validate_credentials(model, cred_with_endpoint) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) - return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user) + return super()._generate( + model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user + ) def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) @@ -61,45 +76,45 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, features=features, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(cred_with_endpoint.get('context_size', "4096")), - ModelPropertyKey.MODE: cred_with_endpoint.get('mode'), + ModelPropertyKey.CONTEXT_SIZE: int(cred_with_endpoint.get("context_size", "4096")), + ModelPropertyKey.MODE: cred_with_endpoint.get("mode"), }, parameter_rules=[ ParameterRule( name=DefaultParameterName.TEMPERATURE.value, label=I18nObject(en_US="Temperature"), type=ParameterType.FLOAT, - default=float(cred_with_endpoint.get('temperature', 0.7)), + default=float(cred_with_endpoint.get("temperature", 0.7)), min=0, max=2, - precision=2 + precision=2, ), ParameterRule( name=DefaultParameterName.TOP_P.value, label=I18nObject(en_US="Top P"), type=ParameterType.FLOAT, - default=float(cred_with_endpoint.get('top_p', 1)), + default=float(cred_with_endpoint.get("top_p", 1)), min=0, max=1, - precision=2 + precision=2, ), ParameterRule( name=TOP_K, label=I18nObject(en_US="Top K"), type=ParameterType.INT, - default=int(cred_with_endpoint.get('top_k', 50)), + default=int(cred_with_endpoint.get("top_k", 50)), min=-2147483647, max=2147483647, - precision=0 + precision=0, ), ParameterRule( name=REPETITION_PENALTY, label=I18nObject(en_US="Repetition Penalty"), type=ParameterType.FLOAT, - default=float(cred_with_endpoint.get('repetition_penalty', 1)), + default=float(cred_with_endpoint.get("repetition_penalty", 1)), min=-3.4, max=3.4, - precision=1 + precision=1, ), ParameterRule( name=DefaultParameterName.MAX_TOKENS.value, @@ -107,46 +122,49 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode type=ParameterType.INT, default=512, min=1, - max=int(cred_with_endpoint.get('max_tokens_to_sample', 4096)), + max=int(cred_with_endpoint.get("max_tokens_to_sample", 4096)), ), ParameterRule( name=DefaultParameterName.FREQUENCY_PENALTY.value, label=I18nObject(en_US="Frequency Penalty"), type=ParameterType.FLOAT, - default=float(credentials.get('frequency_penalty', 0)), + default=float(credentials.get("frequency_penalty", 0)), min=-2, - max=2 + max=2, ), ParameterRule( name=DefaultParameterName.PRESENCE_PENALTY.value, label=I18nObject(en_US="Presence Penalty"), type=ParameterType.FLOAT, - default=float(credentials.get('presence_penalty', 0)), + default=float(credentials.get("presence_penalty", 0)), min=-2, - max=2 + max=2, ), ], pricing=PriceConfig( - input=Decimal(cred_with_endpoint.get('input_price', 0)), - output=Decimal(cred_with_endpoint.get('output_price', 0)), - unit=Decimal(cred_with_endpoint.get('unit', 0)), - currency=cred_with_endpoint.get('currency', "USD") + input=Decimal(cred_with_endpoint.get("input_price", 0)), + output=Decimal(cred_with_endpoint.get("output_price", 0)), + unit=Decimal(cred_with_endpoint.get("unit", 0)), + currency=cred_with_endpoint.get("currency", "USD"), ), ) - if cred_with_endpoint['mode'] == 'chat': + if cred_with_endpoint["mode"] == "chat": entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value - elif cred_with_endpoint['mode'] == 'completion': + elif cred_with_endpoint["mode"] == "completion": entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value else: raise ValueError(f"Unknown completion type {cred_with_endpoint['completion_type']}") return entity - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools) - - diff --git a/api/core/model_runtime/model_providers/togetherai/togetherai.py b/api/core/model_runtime/model_providers/togetherai/togetherai.py index ffce4794e7a8ad..aa4100a7c9b4d8 100644 --- a/api/core/model_runtime/model_providers/togetherai/togetherai.py +++ b/api/core/model_runtime/model_providers/togetherai/togetherai.py @@ -6,6 +6,5 @@ class TogetherAIProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/tongyi/_common.py b/api/core/model_runtime/model_providers/tongyi/_common.py index fab18b41fd0487..8a50c7aa05f38c 100644 --- a/api/core/model_runtime/model_providers/tongyi/_common.py +++ b/api/core/model_runtime/model_providers/tongyi/_common.py @@ -21,7 +21,7 @@ class _CommonTongyi: @staticmethod def _to_credential_kwargs(credentials: dict) -> dict: credentials_kwargs = { - "dashscope_api_key": credentials['dashscope_api_key'], + "dashscope_api_key": credentials["dashscope_api_key"], } return credentials_kwargs @@ -51,5 +51,5 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] InvalidParameter, UnsupportedModel, UnsupportedHTTPMethod, - ] + ], } diff --git a/api/core/model_runtime/model_providers/tongyi/llm/_position.yaml b/api/core/model_runtime/model_providers/tongyi/llm/_position.yaml new file mode 100644 index 00000000000000..8ce336d60cb396 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/_position.yaml @@ -0,0 +1,51 @@ +- qwen-vl-max-0809 +- qwen-vl-max-0201 +- qwen-vl-max +- qwen-max-latest +- qwen-max-1201 +- qwen-max-0919 +- qwen-max-0428 +- qwen-max-0403 +- qwen-max-0107 +- qwen-max +- qwen-max-longcontext +- qwen-plus-latest +- qwen-plus-0919 +- qwen-plus-0806 +- qwen-plus-0723 +- qwen-plus-0624 +- qwen-plus-0206 +- qwen-plus-chat +- qwen-plus +- qwen-vl-plus-0809 +- qwen-vl-plus-0201 +- qwen-vl-plus +- qwen-turbo-latest +- qwen-turbo-0919 +- qwen-turbo-0624 +- qwen-turbo-0206 +- qwen-turbo-chat +- qwen-turbo +- qwen2.5-72b-instruct +- qwen2.5-32b-instruct +- qwen2.5-14b-instruct +- qwen2.5-7b-instruct +- qwen2.5-3b-instruct +- qwen2.5-1.5b-instruct +- qwen2.5-0.5b-instruct +- qwen2.5-coder-7b-instruct +- qwen2-math-72b-instruct +- qwen2-math-7b-instruct +- qwen2-math-1.5b-instruct +- qwen-long +- qwen-math-plus-latest +- qwen-math-plus-0919 +- qwen-math-plus-0816 +- qwen-math-plus +- qwen-math-turbo-latest +- qwen-math-turbo-0919 +- qwen-math-turbo +- qwen-coder-turbo-latest +- qwen-coder-turbo-0919 +- qwen-coder-turbo +- farui-plus diff --git a/api/core/model_runtime/model_providers/tongyi/llm/farui-plus.yaml b/api/core/model_runtime/model_providers/tongyi/llm/farui-plus.yaml index aad07f56736e52..34a57d1fc0c9a5 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/farui-plus.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/farui-plus.yaml @@ -1,3 +1,4 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models model: farui-plus label: en_US: farui-plus @@ -62,16 +63,11 @@ parameter_rules: type: float default: 1.1 label: + zh_Hans: 重复惩罚 en_US: Repetition penalty help: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. - - name: enable_search - type: boolean - default: false - help: - zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 - en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. - name: response_format use_template: response_format pricing: diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index 4e1bb0a5a4fa2b..cde5d214d04d97 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -4,6 +4,7 @@ import uuid from collections.abc import Generator from http import HTTPStatus +from pathlib import Path from typing import Optional, Union, cast from dashscope import Generation, MultiModalConversation, get_tokenizer @@ -17,8 +18,7 @@ UnsupportedModel, ) -from core.model_runtime.callbacks.base_callback import Callback -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, @@ -29,8 +29,18 @@ TextPromptMessageContent, ToolPromptMessage, UserPromptMessage, + VideoPromptMessageContent, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + I18nObject, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, ) -from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -46,11 +56,17 @@ class TongyiLargeLanguageModel(LargeLanguageModel): tokenizers = {} - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -64,90 +80,16 @@ def _invoke(self, model: str, credentials: dict, :param user: unique user id :return: full response or stream response chunk generator result """ - # invoke model + # invoke model without code wrapper return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def _code_block_mode_wrapper(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, callbacks: list[Callback] = None) \ - -> LLMResult | Generator: - """ - Wrapper for code block mode - """ - block_prompts = """You should always follow the instructions and output a valid {{block}} object. -The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure -if you are not sure about the structure. - - -{{instructions}} - -You should also complete the text started with ``` but not tell ``` directly. -""" - - code_block = model_parameters.get("response_format", "") - if not code_block: - return self._invoke( - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user - ) - - model_parameters.pop("response_format") - stop = stop or [] - stop.extend(["\n```", "```\n"]) - block_prompts = block_prompts.replace("{{block}}", code_block) - - # check if there is a system message - if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): - # override the system message - prompt_messages[0] = SystemPromptMessage( - content=block_prompts - .replace("{{instructions}}", prompt_messages[0].content) - ) - else: - # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=block_prompts - .replace("{{instructions}}", f"Please output a valid {code_block} with markdown codeblocks.") - )) - - if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): - # add ```JSON\n to the last message - prompt_messages[-1].content += f"\n```{code_block}\n" - else: - # append a user message - prompt_messages.append(UserPromptMessage( - content=f"```{code_block}\n" - )) - - response = self._invoke( - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user - ) - - if isinstance(response, Generator): - return self._code_block_mode_stream_processor_with_backtick( - model=model, - prompt_messages=prompt_messages, - input_generator=response - ) - - return response - - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -157,10 +99,15 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr :param tools: tools for tool calling :return: """ - if model in ['qwen-turbo-chat', 'qwen-plus-chat']: - model = model.replace('-chat', '') - if model == 'farui-plus': - model = 'qwen-farui-plus' + # Check if the model was added via get_customizable_model_schema + if self.get_customizable_model_schema(model, credentials) is not None: + # For custom models, tokens are not calculated. + return 0 + + if model in {"qwen-turbo-chat", "qwen-plus-chat"}: + model = model.replace("-chat", "") + if model == "farui-plus": + model = "qwen-farui-plus" if model in self.tokenizers: tokenizer = self.tokenizers[model] @@ -191,16 +138,22 @@ def validate_credentials(self, model: str, credentials: dict) -> None: model_parameters={ "temperature": 0.5, }, - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -219,18 +172,18 @@ def _generate(self, model: str, credentials: dict, mode = self.get_model_mode(model, credentials) - if model in ['qwen-turbo-chat', 'qwen-plus-chat']: - model = model.replace('-chat', '') + if model in {"qwen-turbo-chat", "qwen-plus-chat"}: + model = model.replace("-chat", "") extra_model_kwargs = {} if tools: - extra_model_kwargs['tools'] = self._convert_tools(tools) + extra_model_kwargs["tools"] = self._convert_tools(tools) if stop: - extra_model_kwargs['stop'] = stop + extra_model_kwargs["stop"] = stop params = { - 'model': model, + "model": model, **model_parameters, **credentials_kwargs, **extra_model_kwargs, @@ -238,23 +191,22 @@ def _generate(self, model: str, credentials: dict, model_schema = self.get_model_schema(model, credentials) if ModelFeature.VISION in (model_schema.features or []): - params['messages'] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages, rich_content=True) + params["messages"] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages, rich_content=True) response = MultiModalConversation.call(**params, stream=stream) else: # nothing different between chat model and completion model in tongyi - params['messages'] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages) - response = Generation.call(**params, - result_format='message', - stream=stream) + params["messages"] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages) + response = Generation.call(**params, result_format="message", stream=stream) if stream: return self._handle_generate_stream_response(model, credentials, response, prompt_messages) return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: GenerationResponse, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: GenerationResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -264,10 +216,8 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Gen :param prompt_messages: prompt messages :return: llm response """ - if response.status_code != 200 and response.status_code != HTTPStatus.OK: - raise ServiceUnavailableError( - response.message - ) + if response.status_code not in {200, HTTPStatus.OK}: + raise ServiceUnavailableError(response.message) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( content=response.output.choices[0].message.content, @@ -286,9 +236,13 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Gen return result - def _handle_generate_stream_response(self, model: str, credentials: dict, - responses: Generator[GenerationResponse, None, None], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + responses: Generator[GenerationResponse, None, None], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -298,10 +252,10 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, :param prompt_messages: prompt messages :return: llm response chunk generator result """ - full_text = '' + full_text = "" tool_calls = [] for index, response in enumerate(responses): - if response.status_code != 200 and response.status_code != HTTPStatus.OK: + if response.status_code not in {200, HTTPStatus.OK}: raise ServiceUnavailableError( f"Failed to invoke model {model}, status code: {response.status_code}, " f"message: {response.message}" @@ -309,22 +263,22 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, resp_finish_reason = response.output.choices[0].finish_reason - if resp_finish_reason is not None and resp_finish_reason != 'null': + if resp_finish_reason is not None and resp_finish_reason != "null": resp_content = response.output.choices[0].message.content assistant_prompt_message = AssistantPromptMessage( - content='', + content="", ) - if 'tool_calls' in response.output.choices[0].message: - tool_calls = response.output.choices[0].message['tool_calls'] + if "tool_calls" in response.output.choices[0].message: + tool_calls = response.output.choices[0].message["tool_calls"] elif resp_content: # special for qwen-vl if isinstance(resp_content, list): - resp_content = resp_content[0]['text'] + resp_content = resp_content[0]["text"] # transform assistant message to prompt message - assistant_prompt_message.content = resp_content.replace(full_text, '', 1) + assistant_prompt_message.content = resp_content.replace(full_text, "", 1) full_text = resp_content @@ -332,12 +286,11 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, message_tool_calls = [] for tool_call_obj in tool_calls: message_tool_call = AssistantPromptMessage.ToolCall( - id=tool_call_obj['function']['name'], - type='function', + id=tool_call_obj["function"]["name"], + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=tool_call_obj['function']['name'], - arguments=tool_call_obj['function']['arguments'] - ) + name=tool_call_obj["function"]["name"], arguments=tool_call_obj["function"]["arguments"] + ), ) message_tool_calls.append(message_tool_call) @@ -351,26 +304,23 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message, - finish_reason=resp_finish_reason, - usage=usage - ) + index=index, message=assistant_prompt_message, finish_reason=resp_finish_reason, usage=usage + ), ) else: resp_content = response.output.choices[0].message.content if not resp_content: - if 'tool_calls' in response.output.choices[0].message: - tool_calls = response.output.choices[0].message['tool_calls'] + if "tool_calls" in response.output.choices[0].message: + tool_calls = response.output.choices[0].message["tool_calls"] continue # special for qwen-vl if isinstance(resp_content, list): - resp_content = resp_content[0]['text'] + resp_content = resp_content[0]["text"] # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=resp_content.replace(full_text, '', 1), + content=resp_content.replace(full_text, "", 1), ) full_text = resp_content @@ -378,10 +328,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) def _to_credential_kwargs(self, credentials: dict) -> dict: @@ -392,7 +339,7 @@ def _to_credential_kwargs(self, credentials: dict) -> dict: :return: """ credentials_kwargs = { - "api_key": credentials['dashscope_api_key'], + "api_key": credentials["dashscope_api_key"], } return credentials_kwargs @@ -419,9 +366,7 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: break elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" - elif isinstance(message, SystemPromptMessage): - message_text = content - elif isinstance(message, ToolPromptMessage): + elif isinstance(message, SystemPromptMessage | ToolPromptMessage): message_text = content else: raise ValueError(f"Got unknown type {message}") @@ -437,16 +382,14 @@ def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() - def _convert_prompt_messages_to_tongyi_messages(self, prompt_messages: list[PromptMessage], - rich_content: bool = False) -> list[dict]: + def _convert_prompt_messages_to_tongyi_messages( + self, prompt_messages: list[PromptMessage], rich_content: bool = False + ) -> list[dict]: """ Convert prompt messages to tongyi messages @@ -456,24 +399,28 @@ def _convert_prompt_messages_to_tongyi_messages(self, prompt_messages: list[Prom tongyi_messages = [] for prompt_message in prompt_messages: if isinstance(prompt_message, SystemPromptMessage): - tongyi_messages.append({ - 'role': 'system', - 'content': prompt_message.content if not rich_content else [{"text": prompt_message.content}], - }) + tongyi_messages.append( + { + "role": "system", + "content": prompt_message.content if not rich_content else [{"text": prompt_message.content}], + } + ) elif isinstance(prompt_message, UserPromptMessage): if isinstance(prompt_message.content, str): - tongyi_messages.append({ - 'role': 'user', - 'content': prompt_message.content if not rich_content else [{"text": prompt_message.content}], - }) + tongyi_messages.append( + { + "role": "user", + "content": prompt_message.content + if not rich_content + else [{"text": prompt_message.content}], + } + ) else: sub_messages = [] for message_content in prompt_message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "text": message_content.data - } + sub_message_dict = {"text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -483,35 +430,33 @@ def _convert_prompt_messages_to_tongyi_messages(self, prompt_messages: list[Prom # convert image base64 data to file in /tmp image_url = self._save_base64_image_to_file(message_content.data) - sub_message_dict = { - "image": image_url - } + sub_message_dict = {"image": image_url} + sub_messages.append(sub_message_dict) + elif message_content.type == PromptMessageContentType.VIDEO: + message_content = cast(VideoPromptMessageContent, message_content) + video_url = message_content.data + if message_content.data.startswith("data:"): + raise InvokeError("not support base64, please set MULTIMODAL_SEND_VIDEO_FORMAT to url") + + sub_message_dict = {"video": video_url} sub_messages.append(sub_message_dict) # resort sub_messages to ensure text is always at last - sub_messages = sorted(sub_messages, key=lambda x: 'text' in x) + sub_messages = sorted(sub_messages, key=lambda x: "text" in x) - tongyi_messages.append({ - 'role': 'user', - 'content': sub_messages - }) + tongyi_messages.append({"role": "user", "content": sub_messages}) elif isinstance(prompt_message, AssistantPromptMessage): content = prompt_message.content if not content: - content = ' ' - message = { - 'role': 'assistant', - 'content': content if not rich_content else [{"text": content}] - } + content = " " + message = {"role": "assistant", "content": content if not rich_content else [{"text": content}]} if prompt_message.tool_calls: - message['tool_calls'] = [tool_call.model_dump() for tool_call in prompt_message.tool_calls] + message["tool_calls"] = [tool_call.model_dump() for tool_call in prompt_message.tool_calls] tongyi_messages.append(message) elif isinstance(prompt_message, ToolPromptMessage): - tongyi_messages.append({ - "role": "tool", - "content": prompt_message.content, - "name": prompt_message.tool_call_id - }) + tongyi_messages.append( + {"role": "tool", "content": prompt_message.content, "name": prompt_message.tool_call_id} + ) else: raise ValueError(f"Got unknown type {prompt_message}") @@ -526,15 +471,14 @@ def _save_base64_image_to_file(self, base64_image: str) -> str: :return: image file path """ # get mime type and encoded string - mime_type, encoded_string = base64_image.split(',')[0].split(';')[0].split(':')[1], base64_image.split(',')[1] + mime_type, encoded_string = base64_image.split(",")[0].split(";")[0].split(":")[1], base64_image.split(",")[1] # save image to file temp_dir = tempfile.gettempdir() file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.{mime_type.split('/')[1]}") - with open(file_path, "wb") as image_file: - image_file.write(base64.b64decode(encoded_string)) + Path(file_path).write_bytes(base64.b64decode(encoded_string)) return f"file://{file_path}" @@ -544,19 +488,18 @@ def _convert_tools(self, tools: list[PromptMessageTool]) -> list[dict]: """ tool_definitions = [] for tool in tools: - properties = tool.parameters['properties'] - required_properties = tool.parameters['required'] + properties = tool.parameters["properties"] + required_properties = tool.parameters["required"] properties_definitions = {} for p_key, p_val in properties.items(): - desc = p_val['description'] - if 'enum' in p_val: - desc += (f"; Only accepts one of the following predefined options: " - f"[{', '.join(p_val['enum'])}]") + desc = p_val["description"] + if "enum" in p_val: + desc += f"; Only accepts one of the following predefined options: [{', '.join(p_val['enum'])}]" properties_definitions[p_key] = { - 'description': desc, - 'type': p_val['type'], + "description": desc, + "type": p_val["type"], } tool_definition = { @@ -565,8 +508,8 @@ def _convert_tools(self, tools: list[PromptMessageTool]) -> list[dict]: "name": tool.name, "description": tool.description, "parameters": properties_definitions, - "required": required_properties - } + "required": required_properties, + }, } tool_definitions.append(tool_definition) @@ -598,5 +541,62 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] InvalidParameter, UnsupportedModel, UnsupportedHTTPMethod, - ] + ], } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + """ + Architecture for defining customizable models + + :param model: model name + :param credentials: model credentials + :return: AIModelEntity or None + """ + return AIModelEntity( + model=model, + label=I18nObject(en_US=model, zh_Hans=model), + model_type=ModelType.LLM, + features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] + if credentials.get("function_calling_type") == "tool_call" + else [], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000)), + ModelPropertyKey.MODE: LLMMode.CHAT.value, + }, + parameter_rules=[ + ParameterRule( + name="temperature", + use_template="temperature", + label=I18nObject(en_US="Temperature", zh_Hans="温度"), + type=ParameterType.FLOAT, + ), + ParameterRule( + name="max_tokens", + use_template="max_tokens", + default=512, + min=1, + max=int(credentials.get("max_tokens", 1024)), + label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"), + type=ParameterType.INT, + ), + ParameterRule( + name="top_p", + use_template="top_p", + label=I18nObject(en_US="Top P", zh_Hans="Top P"), + type=ParameterType.FLOAT, + ), + ParameterRule( + name="top_k", + use_template="top_k", + label=I18nObject(en_US="Top K", zh_Hans="Top K"), + type=ParameterType.FLOAT, + ), + ParameterRule( + name="frequency_penalty", + use_template="frequency_penalty", + label=I18nObject(en_US="Frequency Penalty", zh_Hans="重复惩罚"), + type=ParameterType.FLOAT, + ), + ], + ) diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-coder-turbo-0919.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-coder-turbo-0919.yaml new file mode 100644 index 00000000000000..64a3f331336bc0 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-coder-turbo-0919.yaml @@ -0,0 +1,75 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen-coder-turbo-0919 +label: + en_US: qwen-coder-turbo-0919 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 8192 + min: 1 + max: 8192 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.002' + output: '0.006' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-coder-turbo-latest.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-coder-turbo-latest.yaml new file mode 100644 index 00000000000000..a4c93f7047ff58 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-coder-turbo-latest.yaml @@ -0,0 +1,75 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen-coder-turbo-latest +label: + en_US: qwen-coder-turbo-latest +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 8192 + min: 1 + max: 8192 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.002' + output: '0.006' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-coder-turbo.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-coder-turbo.yaml new file mode 100644 index 00000000000000..ff68faed80810b --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-coder-turbo.yaml @@ -0,0 +1,75 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen-coder-turbo +label: + en_US: qwen-coder-turbo +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 8192 + min: 1 + max: 8192 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.002' + output: '0.006' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-long.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-long.yaml index b2cf3dd486f4fe..c3dbb3616fb961 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-long.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-long.yaml @@ -1,3 +1,4 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models model: qwen-long label: en_US: qwen-long @@ -24,7 +25,7 @@ parameter_rules: type: int default: 2000 min: 1 - max: 2000 + max: 6000 help: zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. @@ -62,16 +63,11 @@ parameter_rules: type: float default: 1.1 label: + zh_Hans: 重复惩罚 en_US: Repetition penalty help: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. - - name: enable_search - type: boolean - default: false - help: - zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 - en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. - name: response_format use_template: response_format pricing: diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-math-plus-0816.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-math-plus-0816.yaml new file mode 100644 index 00000000000000..42fe1f68623bc4 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-math-plus-0816.yaml @@ -0,0 +1,75 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen-math-plus-0816 +label: + en_US: qwen-math-plus-0816 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 3072 + min: 1 + max: 3072 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.004' + output: '0.012' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-math-plus-0919.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-math-plus-0919.yaml new file mode 100644 index 00000000000000..9b6567b8cda4d7 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-math-plus-0919.yaml @@ -0,0 +1,75 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen-math-plus-0919 +label: + en_US: qwen-math-plus-0919 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 3072 + min: 1 + max: 3072 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.004' + output: '0.012' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-math-plus-latest.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-math-plus-latest.yaml new file mode 100644 index 00000000000000..b2a2393b365fcb --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-math-plus-latest.yaml @@ -0,0 +1,75 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen-math-plus-latest +label: + en_US: qwen-math-plus-latest +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 3072 + min: 1 + max: 3072 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.004' + output: '0.012' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-math-plus.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-math-plus.yaml new file mode 100644 index 00000000000000..63f4b7ff0a0879 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-math-plus.yaml @@ -0,0 +1,75 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen-math-plus +label: + en_US: qwen-math-plus +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 3072 + min: 1 + max: 3072 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.004' + output: '0.012' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-math-turbo-0919.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-math-turbo-0919.yaml new file mode 100644 index 00000000000000..4da90eec3eddfd --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-math-turbo-0919.yaml @@ -0,0 +1,75 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen-math-turbo-0919 +label: + en_US: qwen-math-turbo-0919 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 3072 + min: 1 + max: 3072 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.002' + output: '0.006' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-math-turbo-latest.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-math-turbo-latest.yaml new file mode 100644 index 00000000000000..d29f8851dd3909 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-math-turbo-latest.yaml @@ -0,0 +1,75 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen-math-turbo-latest +label: + en_US: qwen-math-turbo-latest +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 3072 + min: 1 + max: 3072 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.002' + output: '0.006' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-math-turbo.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-math-turbo.yaml new file mode 100644 index 00000000000000..2a8f7f725e9366 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-math-turbo.yaml @@ -0,0 +1,75 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen-math-turbo +label: + en_US: qwen-math-turbo +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 3072 + min: 1 + max: 3072 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.002' + output: '0.006' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-0107.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-0107.yaml new file mode 100644 index 00000000000000..ef1841b5173bc5 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-0107.yaml @@ -0,0 +1,78 @@ +# this model corresponds to qwen-max, for more details +# please refer to (https://help.aliyun.com/zh/model-studio/getting-started/models#cf6cc4aa2aokf) +model: qwen-max-0107 +label: + en_US: qwen-max-0107 +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat + context_size: 8000 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 2000 + min: 1 + max: 2000 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.04' + output: '0.12' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-0403.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-0403.yaml index 935a16ebcb1166..a2ea5df130f379 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-0403.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-0403.yaml @@ -1,3 +1,5 @@ +# this model corresponds to qwen-max-0403, for more details +# please refer to (https://help.aliyun.com/zh/model-studio/getting-started/models#cf6cc4aa2aokf) model: qwen-max-0403 label: en_US: qwen-max-0403 @@ -8,7 +10,7 @@ features: - stream-tool-call model_properties: mode: chat - context_size: 8192 + context_size: 8000 parameter_rules: - name: temperature use_template: temperature @@ -62,16 +64,11 @@ parameter_rules: type: float default: 1.1 label: + zh_Hans: 重复惩罚 en_US: Repetition penalty help: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. - - name: enable_search - type: boolean - default: false - help: - zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 - en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. - name: response_format use_template: response_format pricing: diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-0428.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-0428.yaml index c39799a71fdcdc..a467665f118a68 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-0428.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-0428.yaml @@ -1,3 +1,5 @@ +# this model corresponds to qwen-max-0428, for more details +# please refer to (https://help.aliyun.com/zh/model-studio/getting-started/models#cf6cc4aa2aokf) model: qwen-max-0428 label: en_US: qwen-max-0428 @@ -8,7 +10,7 @@ features: - stream-tool-call model_properties: mode: chat - context_size: 8192 + context_size: 8000 parameter_rules: - name: temperature use_template: temperature @@ -62,16 +64,11 @@ parameter_rules: type: float default: 1.1 label: + zh_Hans: 重复惩罚 en_US: Repetition penalty help: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. - - name: enable_search - type: boolean - default: false - help: - zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 - en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. - name: response_format use_template: response_format pricing: diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-0919.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-0919.yaml new file mode 100644 index 00000000000000..78661eaea065f2 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-0919.yaml @@ -0,0 +1,78 @@ +# this model corresponds to qwen-max-0919, for more details +# please refer to (https://help.aliyun.com/zh/model-studio/getting-started/models#cf6cc4aa2aokf) +model: qwen-max-0919 +label: + en_US: qwen-max-0919 +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 8192 + min: 1 + max: 8192 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.02' + output: '0.06' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-1201.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-1201.yaml index 0368a4a01e4c6f..6f4674576b4426 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-1201.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-1201.yaml @@ -1,3 +1,5 @@ +# this model corresponds to qwen-max, for more details +# please refer to (https://help.aliyun.com/zh/model-studio/getting-started/models#cf6cc4aa2aokf) model: qwen-max-1201 label: en_US: qwen-max-1201 @@ -66,12 +68,6 @@ parameter_rules: help: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. - - name: enable_search - type: boolean - default: false - help: - zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 - en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. - name: response_format use_template: response_format pricing: @@ -79,3 +75,4 @@ pricing: output: '0.12' unit: '0.001' currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-latest.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-latest.yaml new file mode 100644 index 00000000000000..8b5f0054733455 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-latest.yaml @@ -0,0 +1,78 @@ +# this model corresponds to qwen-max, for more details +# please refer to (https://help.aliyun.com/zh/model-studio/getting-started/models#cf6cc4aa2aokf) +model: qwen-max-latest +label: + en_US: qwen-max-latest +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 8192 + min: 1 + max: 8192 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.02' + output: '0.06' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-longcontext.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-longcontext.yaml index 1c705670ca6a2f..cc0bb7a117318b 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-longcontext.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-longcontext.yaml @@ -1,3 +1,5 @@ +# this model corresponds to qwen-max, for more details +# please refer to (https://help.aliyun.com/zh/model-studio/getting-started/models#cf6cc4aa2aokf) model: qwen-max-longcontext label: en_US: qwen-max-longcontext @@ -8,7 +10,7 @@ features: - stream-tool-call model_properties: mode: chat - context_size: 32768 + context_size: 32000 parameter_rules: - name: temperature use_template: temperature @@ -22,9 +24,9 @@ parameter_rules: - name: max_tokens use_template: max_tokens type: int - default: 2000 + default: 8000 min: 1 - max: 2000 + max: 8000 help: zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. @@ -62,16 +64,11 @@ parameter_rules: type: float default: 1.1 label: + zh_Hans: 重复惩罚 en_US: Repetition penalty help: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. - - name: enable_search - type: boolean - default: false - help: - zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 - en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. - name: response_format use_template: response_format pricing: @@ -79,3 +76,4 @@ pricing: output: '0.12' unit: '0.001' currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max.yaml index 64094effbbec82..4af4822e86051d 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max.yaml @@ -1,3 +1,5 @@ +# this model corresponds to qwen-max, for more details +# please refer to (https://help.aliyun.com/zh/model-studio/getting-started/models#cf6cc4aa2aokf) model: qwen-max label: en_US: qwen-max @@ -8,7 +10,7 @@ features: - stream-tool-call model_properties: mode: chat - context_size: 8192 + context_size: 32000 parameter_rules: - name: temperature use_template: temperature @@ -24,7 +26,7 @@ parameter_rules: type: int default: 2000 min: 1 - max: 2000 + max: 8192 help: zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. @@ -62,6 +64,7 @@ parameter_rules: type: float default: 1.1 label: + zh_Hans: 重复惩罚 en_US: Repetition penalty help: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 @@ -69,13 +72,16 @@ parameter_rules: - name: enable_search type: boolean default: false + label: + zh_Hans: 联网搜索 + en_US: Web Search help: zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. - name: response_format use_template: response_format pricing: - input: '0.04' - output: '0.12' + input: '0.02' + output: '0.06' unit: '0.001' currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-0206.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-0206.yaml new file mode 100644 index 00000000000000..0b1a6f81df80c0 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-0206.yaml @@ -0,0 +1,76 @@ +# this model corresponds to qwen-plus-0206, for more details +# please refer to (https://help.aliyun.com/zh/model-studio/getting-started/models#bb0ffee88bwnk) +model: qwen-plus-0206 +label: + en_US: qwen-plus-0206 +model_type: llm +features: + - agent-thought +model_properties: + mode: completion + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 8000 + min: 1 + max: 8000 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.004' + output: '0.012' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-0624.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-0624.yaml new file mode 100644 index 00000000000000..7706005bb535cd --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-0624.yaml @@ -0,0 +1,76 @@ +# this model corresponds to qwen-plus-0624, for more details +# please refer to (https://help.aliyun.com/zh/model-studio/getting-started/models#bb0ffee88bwnk) +model: qwen-plus-0624 +label: + en_US: qwen-plus-0624 +model_type: llm +features: + - agent-thought +model_properties: + mode: completion + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 8000 + min: 1 + max: 8000 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.004' + output: '0.012' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-0723.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-0723.yaml new file mode 100644 index 00000000000000..348276fc08f98c --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-0723.yaml @@ -0,0 +1,76 @@ +# this model corresponds to qwen-plus-0723, for more details +# please refer to (https://help.aliyun.com/zh/model-studio/getting-started/models#bb0ffee88bwnk) +model: qwen-plus-0723 +label: + en_US: qwen-plus-0723 +model_type: llm +features: + - agent-thought +model_properties: + mode: completion + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 8000 + min: 1 + max: 8000 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.004' + output: '0.012' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-0806.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-0806.yaml new file mode 100644 index 00000000000000..29f125135eaa3f --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-0806.yaml @@ -0,0 +1,76 @@ +# this model corresponds to qwen-plus-0806, for more details +# please refer to (https://help.aliyun.com/zh/model-studio/getting-started/models#bb0ffee88bwnk) +model: qwen-plus-0806 +label: + en_US: qwen-plus-0806 +model_type: llm +features: + - agent-thought +model_properties: + mode: completion + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 8192 + min: 1 + max: 8192 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.004' + output: '0.012' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-0919.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-0919.yaml new file mode 100644 index 00000000000000..905fa1e1028bbf --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-0919.yaml @@ -0,0 +1,76 @@ +# this model corresponds to qwen-plus-0919, for more details +# please refer to (https://help.aliyun.com/zh/model-studio/getting-started/models#bb0ffee88bwnk) +model: qwen-plus-0919 +label: + en_US: qwen-plus-0919 +model_type: llm +features: + - agent-thought +model_properties: + mode: completion + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 8192 + min: 1 + max: 8192 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.0008' + output: '0.002' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-chat.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-chat.yaml index bc848072edd7fa..c7a3549727ce8e 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-chat.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-chat.yaml @@ -1,3 +1,5 @@ +# this model corresponds to qwen-plus, for more details +# please refer to (https://help.aliyun.com/zh/model-studio/getting-started/models#bb0ffee88bwnk) model: qwen-plus-chat label: en_US: qwen-plus-chat @@ -62,16 +64,11 @@ parameter_rules: type: float default: 1.1 label: + zh_Hans: 重复惩罚 en_US: Repetition penalty help: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. - - name: enable_search - type: boolean - default: false - help: - zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 - en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. - name: response_format use_template: response_format pricing: @@ -79,3 +76,4 @@ pricing: output: '0.012' unit: '0.001' currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-latest.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-latest.yaml new file mode 100644 index 00000000000000..608f52c2964ea3 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus-latest.yaml @@ -0,0 +1,76 @@ +# this model corresponds to qwen-plus-latest, for more details +# please refer to (https://help.aliyun.com/zh/model-studio/getting-started/models#bb0ffee88bwnk) +model: qwen-plus-latest +label: + en_US: qwen-plus-latest +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 8192 + min: 1 + max: 8192 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.0008' + output: '0.002' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus.yaml index 4be78627f0495a..529a29b1b5bfb8 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus.yaml @@ -1,12 +1,16 @@ +# this model corresponds to qwen-plus, for more details +# please refer to (https://help.aliyun.com/zh/model-studio/getting-started/models#bb0ffee88bwnk) model: qwen-plus label: en_US: qwen-plus model_type: llm features: + - multi-tool-call - agent-thought + - stream-tool-call model_properties: - mode: completion - context_size: 32768 + mode: chat + context_size: 128000 parameter_rules: - name: temperature use_template: temperature @@ -20,9 +24,9 @@ parameter_rules: - name: max_tokens use_template: max_tokens type: int - default: 2000 + default: 8192 min: 1 - max: 2000 + max: 8192 help: zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. @@ -60,6 +64,7 @@ parameter_rules: type: float default: 1.1 label: + zh_Hans: 重复惩罚 en_US: Repetition penalty help: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 @@ -67,13 +72,16 @@ parameter_rules: - name: enable_search type: boolean default: false + label: + zh_Hans: 联网搜索 + en_US: Web Search help: zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. - name: response_format use_template: response_format pricing: - input: '0.004' - output: '0.012' + input: '0.0008' + output: '0.002' unit: '0.001' currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo-0206.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo-0206.yaml new file mode 100644 index 00000000000000..7ee0d44f2f2834 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo-0206.yaml @@ -0,0 +1,77 @@ +# this model corresponds to qwen-turbo-0206, for more details +# please refer to (https://help.aliyun.com/zh/model-studio/getting-started/models#ff492e2c10lub) + +model: qwen-turbo-0206 +label: + en_US: qwen-turbo-0206 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8000 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 2000 + min: 1 + max: 2000 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.002' + output: '0.006' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo-0624.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo-0624.yaml new file mode 100644 index 00000000000000..20a3f7eb6460f3 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo-0624.yaml @@ -0,0 +1,76 @@ +# this model corresponds to qwen-turbo-0624, for more details +# please refer to (https://help.aliyun.com/zh/model-studio/getting-started/models#ff492e2c10lub) +model: qwen-turbo-0624 +label: + en_US: qwen-turbo-0624 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8000 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 2000 + min: 1 + max: 2000 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.002' + output: '0.006' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo-0919.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo-0919.yaml new file mode 100644 index 00000000000000..ba73dec3631fb5 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo-0919.yaml @@ -0,0 +1,76 @@ +# this model corresponds to qwen-turbo-0919, for more details +# please refer to (https://help.aliyun.com/zh/model-studio/getting-started/models#ff492e2c10lub) +model: qwen-turbo-0919 +label: + en_US: qwen-turbo-0919 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 8192 + min: 1 + max: 8192 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.0003' + output: '0.0006' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo-chat.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo-chat.yaml index f1950577ec03ad..d785b7fe857878 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo-chat.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo-chat.yaml @@ -1,3 +1,5 @@ +# this model corresponds to qwen-turbo, for more details +# please refer to (https://help.aliyun.com/zh/model-studio/getting-started/models#ff492e2c10lub) model: qwen-turbo-chat label: en_US: qwen-turbo-chat @@ -62,16 +64,11 @@ parameter_rules: type: float default: 1.1 label: + zh_Hans: 重复惩罚 en_US: Repetition penalty help: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. - - name: enable_search - type: boolean - default: false - help: - zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 - en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. - name: response_format use_template: response_format pricing: @@ -79,3 +76,4 @@ pricing: output: '0.006' unit: '0.001' currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo-latest.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo-latest.yaml new file mode 100644 index 00000000000000..fe38a4283c2d1e --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo-latest.yaml @@ -0,0 +1,76 @@ +# this model corresponds to qwen-turbo-latest, for more details +# please refer to (https://help.aliyun.com/zh/model-studio/getting-started/models#ff492e2c10lub) +model: qwen-turbo-latest +label: + en_US: qwen-turbo-latest +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 8192 + min: 1 + max: 8192 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.0006' + output: '0.0003' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo.yaml index d4c03100ecbee8..a0c4ba682023ef 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo.yaml @@ -1,12 +1,16 @@ +# this model corresponds to qwen-turbo, for more details +# please refer to (https://help.aliyun.com/zh/model-studio/getting-started/models#ff492e2c10lub) model: qwen-turbo label: en_US: qwen-turbo model_type: llm features: + - multi-tool-call - agent-thought + - stream-tool-call model_properties: - mode: completion - context_size: 8192 + mode: chat + context_size: 128000 parameter_rules: - name: temperature use_template: temperature @@ -20,9 +24,9 @@ parameter_rules: - name: max_tokens use_template: max_tokens type: int - default: 1500 + default: 2000 min: 1 - max: 1500 + max: 8192 help: zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. @@ -60,6 +64,7 @@ parameter_rules: type: float default: 1.1 label: + zh_Hans: 重复惩罚 en_US: Repetition penalty help: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 @@ -67,13 +72,16 @@ parameter_rules: - name: enable_search type: boolean default: false + label: + zh_Hans: 联网搜索 + en_US: Web Search help: zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. - name: response_format use_template: response_format pricing: - input: '0.002' - output: '0.006' + input: '0.0006' + output: '0.0003' unit: '0.001' currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max-0201.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max-0201.yaml new file mode 100644 index 00000000000000..d80168ffc3fb55 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max-0201.yaml @@ -0,0 +1,49 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen-vl-max-0201 +label: + en_US: qwen-vl-max-0201 +model_type: llm +features: + - vision + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: response_format + use_template: response_format +pricing: + input: '0.02' + output: '0.02' + unit: '0.001' + currency: RMB +deprecated: true diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max-0809.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max-0809.yaml new file mode 100644 index 00000000000000..50e10226a5f5c4 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max-0809.yaml @@ -0,0 +1,79 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen-vl-max-0809 +label: + en_US: qwen-vl-max-0809 +model_type: llm +features: + - vision + - agent-thought +model_properties: + mode: chat + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: max_tokens + required: false + use_template: max_tokens + type: int + default: 2000 + min: 1 + max: 2000 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: response_format + use_template: response_format + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.02' + output: '0.02' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max.yaml index f917ccaa5d8577..21b127f56c47d9 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max.yaml @@ -1,3 +1,4 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models model: qwen-vl-max label: en_US: qwen-vl-max @@ -7,8 +8,17 @@ features: - agent-thought model_properties: mode: chat - context_size: 8192 + context_size: 32000 parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. - name: top_p use_template: top_p type: float @@ -28,6 +38,16 @@ parameter_rules: help: zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: max_tokens + required: false + use_template: max_tokens + type: int + default: 2000 + min: 1 + max: 2000 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. - name: seed required: false type: int @@ -40,6 +60,18 @@ parameter_rules: en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. - name: response_format use_template: response_format + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format pricing: input: '0.02' output: '0.02' diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus-0201.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus-0201.yaml new file mode 100644 index 00000000000000..03cb039d15a7dd --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus-0201.yaml @@ -0,0 +1,79 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen-vl-plus-0201 +label: + en_US: qwen-vl-plus-0201 +model_type: llm +features: + - vision + - agent-thought +model_properties: + mode: chat + context_size: 8000 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: max_tokens + required: false + use_template: max_tokens + type: int + default: 2000 + min: 1 + max: 2000 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: response_format + use_template: response_format + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.02' + output: '0.02' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus-0809.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus-0809.yaml new file mode 100644 index 00000000000000..67b2b2ebddc616 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus-0809.yaml @@ -0,0 +1,79 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen-vl-plus-0809 +label: + en_US: qwen-vl-plus-0809 +model_type: llm +features: + - vision + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: max_tokens + required: false + use_template: max_tokens + type: int + default: 2000 + min: 1 + max: 2000 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: response_format + use_template: response_format + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.008' + output: '0.008' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus.yaml index e2dd8c4e576a2c..f55764c6c05500 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus.yaml @@ -1,3 +1,4 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models model: qwen-vl-plus label: en_US: qwen-vl-plus @@ -7,8 +8,17 @@ features: - agent-thought model_properties: mode: chat - context_size: 32768 + context_size: 8000 parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. - name: top_p use_template: top_p type: float @@ -28,6 +38,16 @@ parameter_rules: help: zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: max_tokens + required: false + use_template: max_tokens + type: int + default: 2000 + min: 1 + max: 2000 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. - name: seed required: false type: int @@ -40,6 +60,18 @@ parameter_rules: en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. - name: response_format use_template: response_format + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format pricing: input: '0.008' output: '0.008' diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen2-math-1.5b-instruct.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen2-math-1.5b-instruct.yaml new file mode 100644 index 00000000000000..ea157f42ded914 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen2-math-1.5b-instruct.yaml @@ -0,0 +1,75 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen2-math-1.5b-instruct +label: + en_US: qwen2-math-1.5b-instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 2000 + min: 1 + max: 2000 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.004' + output: '0.012' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen2-math-72b-instruct.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen2-math-72b-instruct.yaml new file mode 100644 index 00000000000000..37052a923317d9 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen2-math-72b-instruct.yaml @@ -0,0 +1,75 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen2-math-72b-instruct +label: + en_US: qwen2-math-72b-instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 2000 + min: 1 + max: 2000 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.004' + output: '0.012' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen2-math-7b-instruct.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen2-math-7b-instruct.yaml new file mode 100644 index 00000000000000..e182f1c27f7ea9 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen2-math-7b-instruct.yaml @@ -0,0 +1,75 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen2-math-7b-instruct +label: + en_US: qwen2-math-7b-instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 2000 + min: 1 + max: 2000 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.004' + output: '0.012' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-0.5b-instruct.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-0.5b-instruct.yaml new file mode 100644 index 00000000000000..9e75ccc1f210d9 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-0.5b-instruct.yaml @@ -0,0 +1,75 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen2.5-0.5b-instruct +label: + en_US: qwen2.5-0.5b-instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 8192 + min: 1 + max: 8192 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.000' + output: '0.000' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-1.5b-instruct.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-1.5b-instruct.yaml new file mode 100644 index 00000000000000..67c9d312432af7 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-1.5b-instruct.yaml @@ -0,0 +1,75 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen2.5-1.5b-instruct +label: + en_US: qwen2.5-1.5b-instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 8192 + min: 1 + max: 8192 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.000' + output: '0.000' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-14b-instruct.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-14b-instruct.yaml new file mode 100644 index 00000000000000..2a38be921cf3fd --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-14b-instruct.yaml @@ -0,0 +1,75 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen2.5-14b-instruct +label: + en_US: qwen2.5-14b-instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 8192 + min: 1 + max: 8192 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.002' + output: '0.006' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-32b-instruct.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-32b-instruct.yaml new file mode 100644 index 00000000000000..e6e4fbf97808be --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-32b-instruct.yaml @@ -0,0 +1,75 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen2.5-32b-instruct +label: + en_US: qwen2.5-32b-instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 8192 + min: 1 + max: 8192 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.0035' + output: '0.007' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-3b-instruct.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-3b-instruct.yaml new file mode 100644 index 00000000000000..8f250379a788ab --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-3b-instruct.yaml @@ -0,0 +1,75 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen2.5-3b-instruct +label: + en_US: qwen2.5-3b-instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 8192 + min: 1 + max: 8192 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.000' + output: '0.000' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-72b-instruct.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-72b-instruct.yaml new file mode 100644 index 00000000000000..bb3cdd6141f1ea --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-72b-instruct.yaml @@ -0,0 +1,75 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen2.5-72b-instruct +label: + en_US: qwen2.5-72b-instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 8192 + min: 1 + max: 8192 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.004' + output: '0.012' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-7b-instruct.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-7b-instruct.yaml new file mode 100644 index 00000000000000..fdcd3d42757edb --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-7b-instruct.yaml @@ -0,0 +1,75 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen2.5-7b-instruct +label: + en_US: qwen2.5-7b-instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 8192 + min: 1 + max: 8192 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.001' + output: '0.002' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-coder-7b-instruct.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-coder-7b-instruct.yaml new file mode 100644 index 00000000000000..7ebeec395393c7 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen2.5-coder-7b-instruct.yaml @@ -0,0 +1,75 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models +model: qwen2.5-coder-7b-instruct +label: + en_US: qwen2.5-coder-7b-instruct +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 用于控制随机性和多样性的程度。具体来说,temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度。较高的temperature值会降低概率分布的峰值,使得更多的低概率词被选择,生成结果更加多样化;而较低的temperature值则会增强概率分布的峰值,使得高概率词更容易被选择,生成结果更加确定。 + en_US: Used to control the degree of randomness and diversity. Specifically, the temperature value controls the degree to which the probability distribution of each candidate word is smoothed when generating text. A higher temperature value will reduce the peak value of the probability distribution, allowing more low-probability words to be selected, and the generated results will be more diverse; while a lower temperature value will enhance the peak value of the probability distribution, making it easier for high-probability words to be selected. , the generated results are more certain. + - name: max_tokens + use_template: max_tokens + type: int + default: 8192 + min: 1 + max: 8192 + help: + zh_Hans: 用于指定模型在生成内容时token的最大数量,它定义了生成的上限,但不保证每次都会生成到这个数量。 + en_US: It is used to specify the maximum number of tokens when the model generates content. It defines the upper limit of generation, but does not guarantee that this number will be generated every time. + - name: top_p + use_template: top_p + type: float + default: 0.8 + min: 0.1 + max: 0.9 + help: + zh_Hans: 生成过程中核采样方法概率阈值,例如,取值为0.8时,仅保留概率加起来大于等于0.8的最可能token的最小集合作为候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的确定性越高。 + en_US: The probability threshold of the kernel sampling method during the generation process. For example, when the value is 0.8, only the smallest set of the most likely tokens with a sum of probabilities greater than or equal to 0.8 is retained as the candidate set. The value range is (0,1.0). The larger the value, the higher the randomness generated; the lower the value, the higher the certainty generated. + - name: top_k + type: int + min: 0 + max: 99 + label: + zh_Hans: 取样数量 + en_US: Top k + help: + zh_Hans: 生成时,采样候选集的大小。例如,取值为50时,仅将单次生成中得分最高的50个token组成随机采样的候选集。取值越大,生成的随机性越高;取值越小,生成的确定性越高。 + en_US: The size of the sample candidate set when generated. For example, when the value is 50, only the 50 highest-scoring tokens in a single generation form a randomly sampled candidate set. The larger the value, the higher the randomness generated; the smaller the value, the higher the certainty generated. + - name: seed + required: false + type: int + default: 1234 + label: + zh_Hans: 随机种子 + en_US: Random seed + help: + zh_Hans: 生成时使用的随机数种子,用户控制模型生成内容的随机性。支持无符号64位整数,默认值为 1234。在使用seed时,模型将尽可能生成相同或相似的结果,但目前不保证每次生成的结果完全相同。 + en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time. + - name: repetition_penalty + required: false + type: float + default: 1.1 + label: + zh_Hans: 重复惩罚 + en_US: Repetition penalty + help: + zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 + en_US: Used to control the repeatability when generating models. Increasing repetition_penalty can reduce the duplication of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format +pricing: + input: '0.001' + output: '0.002' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/rerank/__init__.py b/api/core/model_runtime/model_providers/tongyi/rerank/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/tongyi/rerank/_position.yaml b/api/core/model_runtime/model_providers/tongyi/rerank/_position.yaml new file mode 100644 index 00000000000000..439afda99263ad --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/rerank/_position.yaml @@ -0,0 +1 @@ +- gte-rerank diff --git a/api/core/model_runtime/model_providers/tongyi/rerank/gte-rerank.yaml b/api/core/model_runtime/model_providers/tongyi/rerank/gte-rerank.yaml new file mode 100644 index 00000000000000..44d51b9b0d9cdd --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/rerank/gte-rerank.yaml @@ -0,0 +1,4 @@ +model: gte-rerank +model_type: rerank +model_properties: + context_size: 4000 diff --git a/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py b/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py new file mode 100644 index 00000000000000..c9245bd82ddb08 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py @@ -0,0 +1,136 @@ +from typing import Optional + +import dashscope +from dashscope.common.error import ( + AuthenticationError, + InvalidParameter, + RequestFailure, + ServiceUnavailableError, + UnsupportedHTTPMethod, + UnsupportedModel, +) + +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class GTERerankModel(RerankModel): + """ + Model class for GTE rerank model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=docs) + + # initialize client + dashscope.api_key = credentials["dashscope_api_key"] + + response = dashscope.TextReRank.call( + query=query, + documents=docs, + model=model, + top_n=top_n, + return_documents=True, + ) + + rerank_documents = [] + for _, result in enumerate(response.output.results): + # format document + rerank_document = RerankDocument( + index=result.index, + score=result.relevance_score, + text=result["document"]["text"], + ) + + # score threshold check + if score_threshold is not None: + if result.relevance_score >= score_threshold: + rerank_documents.append(rerank_document) + else: + rerank_documents.append(rerank_document) + + return RerankResult(model=model, docs=rerank_documents) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self.invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + except Exception as ex: + print(ex) + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + RequestFailure, + ], + InvokeServerUnavailableError: [ + ServiceUnavailableError, + ], + InvokeRateLimitError: [], + InvokeAuthorizationError: [ + AuthenticationError, + ], + InvokeBadRequestError: [ + InvalidParameter, + UnsupportedModel, + UnsupportedHTTPMethod, + ], + } diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v1.yaml b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v1.yaml index f4303c53d38b80..52e35d8b50afd8 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v1.yaml +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v1.yaml @@ -1,3 +1,4 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models#3383780daf8hw model: text-embedding-v1 model_type: text-embedding model_properties: diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v2.yaml b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v2.yaml index f6be3544ed8f65..5bb6a8f4243d53 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v2.yaml +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v2.yaml @@ -1,3 +1,4 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models#3383780daf8hw model: text-embedding-v2 model_type: text-embedding model_properties: diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v3.yaml b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v3.yaml new file mode 100644 index 00000000000000..d8af0e2b63565d --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text-embedding-v3.yaml @@ -0,0 +1,10 @@ +# for more details, please refer to https://help.aliyun.com/zh/model-studio/getting-started/models#3383780daf8hw +model: text-embedding-v3 +model_type: text-embedding +model_properties: + context_size: 8192 + max_chunks: 25 +pricing: + input: "0.0007" + unit: "0.001" + currency: RMB diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py index e7e1b5c764c093..2ef7f3f5774481 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py @@ -4,6 +4,7 @@ import dashscope import numpy as np +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import ( EmbeddingUsage, @@ -27,6 +28,7 @@ def _invoke( credentials: dict, texts: list[str], user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -35,6 +37,7 @@ def _invoke( :param credentials: model credentials :param texts: texts to embed :param user: unique user id + :param input_type: input type :return: embeddings result """ credentials_kwargs = self._to_credential_kwargs(credentials) @@ -46,7 +49,6 @@ def _invoke( used_tokens = 0 for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer num_tokens = self._get_num_tokens_by_gpt2(text) @@ -71,12 +73,8 @@ def _invoke( batched_embeddings += embeddings_batch # calc usage - usage = self._calc_response_usage( - model=model, credentials=credentials, tokens=used_tokens - ) - return TextEmbeddingResult( - embeddings=batched_embeddings, usage=usage, model=model - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) + return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -108,16 +106,12 @@ def validate_credentials(self, model: str, credentials: dict) -> None: credentials_kwargs = self._to_credential_kwargs(credentials) # call embedding model - self.embed_documents( - credentials_kwargs=credentials_kwargs, model=model, texts=["ping"] - ) + self.embed_documents(credentials_kwargs=credentials_kwargs, model=model, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @staticmethod - def embed_documents( - credentials_kwargs: dict, model: str, texts: list[str] - ) -> tuple[list[list[float]], int]: + def embed_documents(credentials_kwargs: dict, model: str, texts: list[str]) -> tuple[list[list[float]], int]: """Call out to Tongyi's embedding endpoint. Args: @@ -137,15 +131,23 @@ def embed_documents( input=text, text_type="document", ) - data = response.output["embeddings"][0] - embeddings.append(data["embedding"]) - embedding_used_tokens += response.usage["total_tokens"] + if response.output and "embeddings" in response.output and response.output["embeddings"]: + data = response.output["embeddings"][0] + if "embedding" in data: + embeddings.append(data["embedding"]) + else: + raise ValueError("Embedding data is missing in the response.") + else: + raise ValueError("Response output is missing or does not contain embeddings.") + + if response.usage and "total_tokens" in response.usage: + embedding_used_tokens += response.usage["total_tokens"] + else: + raise ValueError("Response usage is missing or does not contain total tokens.") return [list(map(float, e)) for e in embeddings], embedding_used_tokens - def _calc_response_usage( - self, model: str, credentials: dict, tokens: int - ) -> EmbeddingUsage: + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage diff --git a/api/core/model_runtime/model_providers/tongyi/tongyi.py b/api/core/model_runtime/model_providers/tongyi/tongyi.py index d5e25e6ecf87ae..a084512de9a885 100644 --- a/api/core/model_runtime/model_providers/tongyi/tongyi.py +++ b/api/core/model_runtime/model_providers/tongyi/tongyi.py @@ -20,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: model_instance = self.get_model_instance(ModelType.LLM) # Use `qwen-turbo` model for validate, - model_instance.validate_credentials( - model='qwen-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="qwen-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/tongyi/tongyi.yaml b/api/core/model_runtime/model_providers/tongyi/tongyi.yaml index b251391e3410dc..6349c227142656 100644 --- a/api/core/model_runtime/model_providers/tongyi/tongyi.yaml +++ b/api/core/model_runtime/model_providers/tongyi/tongyi.yaml @@ -11,15 +11,17 @@ background: "#EFF1FE" help: title: en_US: Get your API key from AliCloud - zh_Hans: 从阿里云获取 API Key + zh_Hans: 从阿里云百炼获取 API Key url: - en_US: https://dashscope.console.aliyun.com/api-key_management + en_US: https://bailian.console.aliyun.com/?apiKey=1#/api-key supported_model_types: - llm - tts - text-embedding + - rerank configurate_methods: - predefined-model + - customizable-model provider_credential_schema: credential_form_schemas: - variable: dashscope_api_key @@ -30,3 +32,57 @@ provider_credential_schema: placeholder: zh_Hans: 在此输入您的 API Key en_US: Enter your API Key +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: dashscope_api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + type: text-input + default: '4096' + placeholder: + zh_Hans: 在此输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: max_tokens + label: + zh_Hans: 最大 token 上限 + en_US: Upper bound for max tokens + default: '4096' + type: text-input + show_on: + - variable: __model_type + value: llm + - variable: function_calling_type + label: + en_US: Function calling + type: select + required: false + default: no_call + options: + - value: no_call + label: + en_US: Not Support + zh_Hans: 不支持 + - value: function_call + label: + en_US: Support + zh_Hans: 支持 + show_on: + - variable: __model_type + value: llm diff --git a/api/core/model_runtime/model_providers/tongyi/tts/tts.py b/api/core/model_runtime/model_providers/tongyi/tts/tts.py index 664b02cd92fc09..ca3b9fbc1c3c00 100644 --- a/api/core/model_runtime/model_providers/tongyi/tts/tts.py +++ b/api/core/model_runtime/model_providers/tongyi/tts/tts.py @@ -1,6 +1,6 @@ import threading from queue import Queue -from typing import Optional +from typing import Any, Optional import dashscope from dashscope import SpeechSynthesizer @@ -18,8 +18,9 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): Model class for Tongyi Speech to text model. """ - def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, - user: Optional[str] = None) -> any: + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ) -> Any: """ _invoke text2speech model @@ -31,14 +32,12 @@ def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: s :param user: unique user id :return: text translated to audio file """ - if not voice or voice not in [d['value'] for d in - self.get_tts_model_voices(model=model, credentials=credentials)]: + if not voice or voice not in [ + d["value"] for d in self.get_tts_model_voices(model=model, credentials=credentials) + ]: voice = self._get_model_default_voice(model, credentials) - return self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - voice=voice) + return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice) def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: """ @@ -53,14 +52,13 @@ def validate_credentials(self, model: str, credentials: dict, user: Optional[str self._tts_invoke_streaming( model=model, credentials=credentials, - content_text='Hello Dify!', + content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, - voice: str) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> Any: """ _tts_invoke_streaming text2speech model @@ -82,15 +80,21 @@ def invoke_remote(content, v, api_key, cb, at, wl): else: sentences = list(self._split_text_into_sentences(org_text=content, max_length=wl)) for sentence in sentences: - SpeechSynthesizer.call(model=v, sample_rate=16000, - api_key=api_key, - text=sentence.strip(), - callback=cb, - format=at, word_timestamp_enabled=True, - phoneme_timestamp_enabled=True) - - threading.Thread(target=invoke_remote, args=( - content_text, voice, credentials.get('dashscope_api_key'), callback, audio_type, word_limit)).start() + SpeechSynthesizer.call( + model=v, + sample_rate=16000, + api_key=api_key, + text=sentence.strip(), + callback=cb, + format=at, + word_timestamp_enabled=True, + phoneme_timestamp_enabled=True, + ) + + threading.Thread( + target=invoke_remote, + args=(content_text, voice, credentials.get("dashscope_api_key"), callback, audio_type, word_limit), + ).start() while True: audio = audio_queue.get() @@ -112,16 +116,18 @@ def _process_sentence(sentence: str, credentials: dict, voice: str, audio_type: :param audio_type: audio file type :return: text translated to audio file """ - response = dashscope.audio.tts.SpeechSynthesizer.call(model=voice, sample_rate=48000, - api_key=credentials.get('dashscope_api_key'), - text=sentence.strip(), - format=audio_type) + response = dashscope.audio.tts.SpeechSynthesizer.call( + model=voice, + sample_rate=48000, + api_key=credentials.get("dashscope_api_key"), + text=sentence.strip(), + format=audio_type, + ) if isinstance(response.get_audio_data(), bytes): return response.get_audio_data() class Callback(ResultCallback): - def __init__(self, queue: Queue): self._queue = queue diff --git a/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py b/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py index 95272a41c2e1a8..47a4b992146405 100644 --- a/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py +++ b/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py @@ -1,4 +1,5 @@ from collections.abc import Generator +from typing import Optional from httpx import Response, post from yarl import URL @@ -33,198 +34,223 @@ class TritonInferenceAILargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - invoke LLM + invoke LLM - see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` + see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` """ return self._generate( - model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, - tools=tools, stop=stop, stream=stream, user=user, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, ) def validate_credentials(self, model: str, credentials: dict) -> None: """ - validate credentials + validate credentials """ - if 'server_url' not in credentials: - raise CredentialsValidateFailedError('server_url is required in credentials') - + if "server_url" not in credentials: + raise CredentialsValidateFailedError("server_url is required in credentials") + try: - self._invoke(model=model, credentials=credentials, prompt_messages=[ - UserPromptMessage(content='ping') - ], model_parameters={}, stream=False) + self._invoke( + model=model, + credentials=credentials, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={}, + stream=False, + ) except InvokeError as ex: - raise CredentialsValidateFailedError(f'An error occurred during connection: {str(ex)}') + raise CredentialsValidateFailedError(f"An error occurred during connection: {str(ex)}") - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: """ - get number of tokens + get number of tokens - cause TritonInference LLM is a customized model, we could net detect which tokenizer to use - so we just take the GPT2 tokenizer as default + cause TritonInference LLM is a customized model, we could net detect which tokenizer to use + so we just take the GPT2 tokenizer as default """ return self._get_num_tokens_by_gpt2(self._convert_prompt_message_to_text(prompt_messages)) - + def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str: """ - convert prompt message to text + convert prompt message to text """ - text = '' + text = "" for item in message: if isinstance(item, UserPromptMessage): - text += f'User: {item.content}' + text += f"User: {item.content}" elif isinstance(item, SystemPromptMessage): - text += f'System: {item.content}' + text += f"System: {item.content}" elif isinstance(item, AssistantPromptMessage): - text += f'Assistant: {item.content}' + text += f"Assistant: {item.content}" else: - raise NotImplementedError(f'PromptMessage type {type(item)} is not supported') + raise NotImplementedError(f"PromptMessage type {type(item)} is not supported") return text - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ - used to define customizable model schema + used to define customizable model schema """ rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ), + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, - max=int(credentials.get('context_length', 2048)), - default=min(512, int(credentials.get('context_length', 2048))), - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) - ) + max=int(credentials.get("context_length", 2048)), + default=min(512, int(credentials.get("context_length", 2048))), + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), + ), ] completion_type = None - if 'completion_type' in credentials: - if credentials['completion_type'] == 'chat': + if "completion_type" in credentials: + if credentials["completion_type"] == "chat": completion_type = LLMMode.CHAT.value - elif credentials['completion_type'] == 'completion': + elif credentials["completion_type"] == "completion": completion_type = LLMMode.COMPLETION.value else: raise ValueError(f'completion_type {credentials["completion_type"]} is not supported') - + entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), parameter_rules=rules, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, model_properties={ ModelPropertyKey.MODE: completion_type, - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_length', 2048)), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_length", 2048)), }, ) return entity - - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - generate text from LLM + generate text from LLM """ - if 'server_url' not in credentials: - raise CredentialsValidateFailedError('server_url is required in credentials') - - if 'stream' in credentials and not bool(credentials['stream']) and stream: - raise ValueError(f'stream is not supported by model {model}') + if "server_url" not in credentials: + raise CredentialsValidateFailedError("server_url is required in credentials") + + if "stream" in credentials and not bool(credentials["stream"]) and stream: + raise ValueError(f"stream is not supported by model {model}") try: parameters = {} - if 'temperature' in model_parameters: - parameters['temperature'] = model_parameters['temperature'] - if 'top_p' in model_parameters: - parameters['top_p'] = model_parameters['top_p'] - if 'top_k' in model_parameters: - parameters['top_k'] = model_parameters['top_k'] - if 'presence_penalty' in model_parameters: - parameters['presence_penalty'] = model_parameters['presence_penalty'] - if 'frequency_penalty' in model_parameters: - parameters['frequency_penalty'] = model_parameters['frequency_penalty'] + if "temperature" in model_parameters: + parameters["temperature"] = model_parameters["temperature"] + if "top_p" in model_parameters: + parameters["top_p"] = model_parameters["top_p"] + if "top_k" in model_parameters: + parameters["top_k"] = model_parameters["top_k"] + if "presence_penalty" in model_parameters: + parameters["presence_penalty"] = model_parameters["presence_penalty"] + if "frequency_penalty" in model_parameters: + parameters["frequency_penalty"] = model_parameters["frequency_penalty"] - response = post(str(URL(credentials['server_url']) / 'v2' / 'models' / model / 'generate'), json={ - 'text_input': self._convert_prompt_message_to_text(prompt_messages), - 'max_tokens': model_parameters.get('max_tokens', 512), - 'parameters': { - 'stream': False, - **parameters + response = post( + str(URL(credentials["server_url"]) / "v2" / "models" / model / "generate"), + json={ + "text_input": self._convert_prompt_message_to_text(prompt_messages), + "max_tokens": model_parameters.get("max_tokens", 512), + "parameters": {"stream": False, **parameters}, }, - }, timeout=(10, 120)) + timeout=(10, 120), + ) response.raise_for_status() if response.status_code != 200: - raise InvokeBadRequestError(f'Invoke failed with status code {response.status_code}, {response.text}') - + raise InvokeBadRequestError(f"Invoke failed with status code {response.status_code}, {response.text}") + if stream: - return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages, - tools=tools, resp=response) - return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages, - tools=tools, resp=response) + return self._handle_chat_stream_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=response + ) + return self._handle_chat_generate_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=response + ) except Exception as ex: - raise InvokeConnectionError(f'An error occurred during connection: {str(ex)}') - - def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Response) -> LLMResult: + raise InvokeConnectionError(f"An error occurred during connection: {str(ex)}") + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Response, + ) -> LLMResult: """ - handle normal chat generate response + handle normal chat generate response """ - text = resp.json()['text_output'] + text = resp.json()["text_output"] usage = LLMUsage.empty_usage() usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) usage.completion_tokens = self._get_num_tokens_by_gpt2(text) return LLMResult( - model=model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=text - ), - usage=usage + model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage ) - def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Response) -> Generator: + def _handle_chat_stream_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Response, + ) -> Generator: """ - handle normal chat generate response + handle normal chat generate response """ - text = resp.json()['text_output'] + text = resp.json()["text_output"] usage = LLMUsage.empty_usage() usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -233,13 +259,7 @@ def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_mes yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage( - content=text - ), - usage=usage - ) + delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=text), usage=usage), ) @property @@ -253,15 +273,9 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - ], - InvokeRateLimitError: [ - ], - InvokeAuthorizationError: [ - ], - InvokeBadRequestError: [ - ValueError - ] - } \ No newline at end of file + InvokeConnectionError: [], + InvokeServerUnavailableError: [], + InvokeRateLimitError: [], + InvokeAuthorizationError: [], + InvokeBadRequestError: [ValueError], + } diff --git a/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py b/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py index 06846825ab6e35..d85f7c82e7db71 100644 --- a/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py +++ b/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py @@ -4,6 +4,7 @@ logger = logging.getLogger(__name__) + class XinferenceAIProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/upstage/_common.py b/api/core/model_runtime/model_providers/upstage/_common.py index 13b73181e95ffb..47ebaccd84ab8a 100644 --- a/api/core/model_runtime/model_providers/upstage/_common.py +++ b/api/core/model_runtime/model_providers/upstage/_common.py @@ -1,4 +1,3 @@ - from collections.abc import Mapping import openai @@ -20,13 +19,13 @@ def _to_credential_kwargs(self, credentials: Mapping) -> dict: Transform credentials to kwargs for model instance :param credentials: - :return: + :return: """ credentials_kwargs = { - "api_key": credentials['upstage_api_key'], + "api_key": credentials["upstage_api_key"], "base_url": "https://api.upstage.ai/v1/solar", "timeout": Timeout(315.0, read=300.0, write=20.0, connect=10.0), - "max_retries": 1 + "max_retries": 1, } return credentials_kwargs @@ -53,5 +52,3 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] openai.APIError, ], } - - diff --git a/api/core/model_runtime/model_providers/upstage/llm/llm.py b/api/core/model_runtime/model_providers/upstage/llm/llm.py index d1ed4619d6bbbf..a18ee906248a49 100644 --- a/api/core/model_runtime/model_providers/upstage/llm/llm.py +++ b/api/core/model_runtime/model_providers/upstage/llm/llm.py @@ -34,17 +34,25 @@ {{instructions}} -""" +""" # noqa: E501 + class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): """ - Model class for Upstage large language model. + Model class for Upstage large language model. """ - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -67,15 +75,25 @@ def _invoke(self, model: str, credentials: dict, tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - def _code_block_mode_wrapper(self, - model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> Union[LLMResult, Generator]: + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ - if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']: + if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}: stop = stop or [] self._transform_chat_json_prompts( model=model, @@ -86,9 +104,9 @@ def _code_block_mode_wrapper(self, stop=stop, stream=stream, user=user, - response_format=model_parameters['response_format'] + response_format=model_parameters["response_format"], ) - model_parameters.pop('response_format') + model_parameters.pop("response_format") return self._invoke( model=model, @@ -98,15 +116,23 @@ def _code_block_mode_wrapper(self, tools=tools, stop=stop, stream=stream, - user=user + user=user, ) - def _transform_chat_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') -> None: + def _transform_chat_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ - Transform json prompts + Transform json prompts """ if stop is None: stop = [] @@ -117,20 +143,29 @@ def _transform_chat_json_prompts(self, model: str, credentials: dict, if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): prompt_messages[0] = SystemPromptMessage( - content=UPSTAGE_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + content=UPSTAGE_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n")) else: - prompt_messages.insert(0, SystemPromptMessage( - content=UPSTAGE_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=UPSTAGE_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", f"Please output a valid {response_format} object." + ).replace("{{block}}", response_format) + ), + ) prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -155,30 +190,31 @@ def validate_credentials(self, model: str, credentials: dict) -> None: client = OpenAI(**credentials_kwargs) client.chat.completions.create( - messages=[{"role": "user", "content": "ping"}], - model=model, - temperature=0, - max_tokens=10, - stream=False + messages=[{"role": "user", "content": "ping"}], model=model, temperature=0, max_tokens=10, stream=False ) except Exception as e: raise CredentialsValidateFailedError(str(e)) - def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _chat_generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: credentials_kwargs = self._to_credential_kwargs(credentials) client = OpenAI(**credentials_kwargs) extra_model_kwargs = {} if tools: - extra_model_kwargs["functions"] = [{ - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters - } for tool in tools] + extra_model_kwargs["functions"] = [ + {"name": tool.name, "description": tool.description, "parameters": tool.parameters} for tool in tools + ] if stop: extra_model_kwargs["stop"] = stop @@ -198,10 +234,15 @@ def _chat_generate(self, model: str, credentials: dict, if stream: return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) - - def _handle_chat_generate_response(self, model: str, credentials: dict, response: ChatCompletion, - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> LLMResult: + + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + response: ChatCompletion, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> LLMResult: """ Handle llm chat response @@ -222,10 +263,7 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, response tool_calls = [function_call] if function_call else [] # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls) # calculate num tokens if response.usage: @@ -251,9 +289,14 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, response return response - def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: Stream[ChatCompletionChunk], - prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + credentials: dict, + response: Stream[ChatCompletionChunk], + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> Generator: """ Handle llm chat stream response @@ -263,7 +306,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r :param tools: tools for tool calling :return: llm response chunk generator """ - full_assistant_content = '' + full_assistant_content = "" delta_assistant_message_function_call_storage: Optional[ChoiceDeltaFunctionCall] = None prompt_tokens = 0 completion_tokens = 0 @@ -273,8 +316,8 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage(content=''), - ) + message=AssistantPromptMessage(content=""), + ), ) for chunk in response: @@ -288,8 +331,11 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r delta = chunk.choices[0] has_finish_reason = delta.finish_reason is not None - if not has_finish_reason and (delta.delta.content is None or delta.delta.content == '') and \ - delta.delta.function_call is None: + if ( + not has_finish_reason + and (delta.delta.content is None or delta.delta.content == "") + and delta.delta.function_call is None + ): continue # assistant_message_tool_calls = delta.delta.tool_calls @@ -311,7 +357,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r # start of stream function call delta_assistant_message_function_call_storage = assistant_message_function_call if delta_assistant_message_function_call_storage.arguments is None: - delta_assistant_message_function_call_storage.arguments = '' + delta_assistant_message_function_call_storage.arguments = "" if not has_finish_reason: continue @@ -322,12 +368,9 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r final_tool_calls.extend(tool_calls) # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=tool_calls - ) + assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls) - full_assistant_content += delta.delta.content if delta.delta.content else '' + full_assistant_content += delta.delta.content or "" if has_finish_reason: final_chunk = LLMResultChunk( @@ -338,7 +381,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - ) + ), ) else: yield LLMResultChunk( @@ -348,7 +391,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, - ) + ), ) if not prompt_tokens: @@ -356,8 +399,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r if not completion_tokens: full_assistant_prompt_message = AssistantPromptMessage( - content=full_assistant_content, - tool_calls=final_tool_calls + content=full_assistant_content, tool_calls=final_tool_calls ) completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message]) @@ -367,9 +409,9 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r yield final_chunk - def _extract_response_tool_calls(self, - response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -380,21 +422,19 @@ def _extract_response_tool_calls(self, if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) return tool_calls - def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \ - -> AssistantPromptMessage.ToolCall: + def _extract_response_function_call( + self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall + ) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -404,14 +444,11 @@ def _extract_response_function_call(self, response_function_call: FunctionCall | tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call.name, - arguments=response_function_call.arguments + name=response_function_call.name, arguments=response_function_call.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.name, - type="function", - function=function + id=response_function_call.name, type="function", function=function ) return tool_call @@ -429,19 +466,13 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) @@ -467,11 +498,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: # "content": message.content, # "tool_call_id": message.tool_call_id # } - message_dict = { - "role": "function", - "content": message.content, - "name": message.tool_call_id - } + message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id} else: raise ValueError(f"Got unknown type {message}") @@ -483,16 +510,17 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: def _get_tokenizer(self) -> Tokenizer: return Tokenizer.from_pretrained("upstage/solar-1-mini-tokenizer") - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """ Calculate num tokens for solar with Huggingface Solar tokenizer. - Solar tokenizer is opened in huggingface https://huggingface.co/upstage/solar-1-mini-tokenizer + Solar tokenizer is opened in huggingface https://huggingface.co/upstage/solar-1-mini-tokenizer """ tokenizer = self._get_tokenizer() - tokens_per_message = 5 # <|im_start|>{role}\n{message}<|im_end|> - tokens_prefix = 1 # <|startoftext|> - tokens_suffix = 3 # <|im_start|>assistant\n + tokens_per_message = 5 # <|im_start|>{role}\n{message}<|im_end|> + tokens_prefix = 1 # <|startoftext|> + tokens_suffix = 3 # <|im_start|>assistant\n num_tokens = 0 num_tokens += tokens_prefix @@ -502,10 +530,10 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text if key == "tool_calls": @@ -538,37 +566,37 @@ def _num_tokens_for_tools(self, tokenizer: Tokenizer, tools: list[PromptMessageT """ num_tokens = 0 for tool in tools: - num_tokens += len(tokenizer.encode('type')) - num_tokens += len(tokenizer.encode('function')) + num_tokens += len(tokenizer.encode("type")) + num_tokens += len(tokenizer.encode("function")) # calculate num tokens for function object - num_tokens += len(tokenizer.encode('name')) + num_tokens += len(tokenizer.encode("name")) num_tokens += len(tokenizer.encode(tool.name)) - num_tokens += len(tokenizer.encode('description')) + num_tokens += len(tokenizer.encode("description")) num_tokens += len(tokenizer.encode(tool.description)) parameters = tool.parameters - num_tokens += len(tokenizer.encode('parameters')) - if 'title' in parameters: - num_tokens += len(tokenizer.encode('title')) + num_tokens += len(tokenizer.encode("parameters")) + if "title" in parameters: + num_tokens += len(tokenizer.encode("title")) num_tokens += len(tokenizer.encode(parameters.get("title"))) - num_tokens += len(tokenizer.encode('type')) + num_tokens += len(tokenizer.encode("type")) num_tokens += len(tokenizer.encode(parameters.get("type"))) - if 'properties' in parameters: - num_tokens += len(tokenizer.encode('properties')) - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += len(tokenizer.encode("properties")) + for key, value in parameters.get("properties").items(): num_tokens += len(tokenizer.encode(key)) for field_key, field_value in value.items(): num_tokens += len(tokenizer.encode(field_key)) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += len(tokenizer.encode(enum_field)) else: num_tokens += len(tokenizer.encode(field_key)) num_tokens += len(tokenizer.encode(str(field_value))) - if 'required' in parameters: - num_tokens += len(tokenizer.encode('required')) - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += len(tokenizer.encode("required")) + for required_field in parameters["required"]: num_tokens += 3 num_tokens += len(tokenizer.encode(required_field)) diff --git a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py index 05ae8665d65bdd..7dd495b55ef4e6 100644 --- a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py @@ -7,6 +7,7 @@ from openai import OpenAI from tokenizers import Tokenizer +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -18,10 +19,18 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel): """ Model class for Upstage text embedding model. """ + def _get_tokenizer(self) -> Tokenizer: return Tokenizer.from_pretrained("upstage/solar-1-mini-tokenizer") - def _invoke(self, model: str, credentials: dict, texts: list[str], user: str | None = None) -> TextEmbeddingResult: + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: str | None = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -29,6 +38,7 @@ def _invoke(self, model: str, credentials: dict, texts: list[str], user: str | N :param credentials: model credentials :param texts: texts to embed :param user: unique user id + :param input_type: input type :return: embeddings result """ @@ -53,9 +63,9 @@ def _invoke(self, model: str, credentials: dict, texts: list[str], user: str | N for i, text in enumerate(texts): token = tokenizer.encode(text, add_special_tokens=False).tokens for j in range(0, len(token), context_size): - tokens += [token[j:j+context_size]] + tokens += [token[j : j + context_size]] indices += [i] - + batched_embeddings = [] _iter = range(0, len(tokens), max_chunks) @@ -63,20 +73,20 @@ def _invoke(self, model: str, credentials: dict, texts: list[str], user: str | N embeddings_batch, embedding_used_tokens = self._embedding_invoke( model=model, client=client, - texts=tokens[i:i+max_chunks], + texts=tokens[i : i + max_chunks], extra_model_kwargs=extra_model_kwargs, ) used_tokens += embedding_used_tokens batched_embeddings += embeddings_batch - + results: list[list[list[float]]] = [[] for _ in range(len(texts))] num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))] for i in range(len(indices)): results[indices[i]].append(batched_embeddings[i]) num_tokens_in_batch[indices[i]].append(len(tokens[i])) - + for i in range(len(texts)): _result = results[i] if len(_result) == 0: @@ -91,15 +101,11 @@ def _invoke(self, model: str, credentials: dict, texts: list[str], user: str | N else: average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) embeddings[i] = (average / np.linalg.norm(average)).tolist() - - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=used_tokens - ) + + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model) - + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: tokenizer = self._get_tokenizer() """ @@ -122,7 +128,7 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int total_num_tokens += len(tokenized_text) return total_num_tokens - + def validate_credentials(self, model: str, credentials: Mapping) -> None: """ Validate model credentials @@ -137,16 +143,13 @@ def validate_credentials(self, model: str, credentials: Mapping) -> None: client = OpenAI(**credentials_kwargs) # call embedding model - self._embedding_invoke( - model=model, - client=client, - texts=['ping'], - extra_model_kwargs={} - ) + self._embedding_invoke(model=model, client=client, texts=["ping"], extra_model_kwargs={}) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - - def _embedding_invoke(self, model: str, client: OpenAI, texts: Union[list[str], str], extra_model_kwargs: dict) -> tuple[list[list[float]], int]: + + def _embedding_invoke( + self, model: str, client: OpenAI, texts: Union[list[str], str], extra_model_kwargs: dict + ) -> tuple[list[list[float]], int]: """ Invoke embedding model :param model: model name @@ -155,17 +158,19 @@ def _embedding_invoke(self, model: str, client: OpenAI, texts: Union[list[str], :param extra_model_kwargs: extra model kwargs :return: embeddings and used tokens """ - response = client.embeddings.create( - model=model, - input=texts, - **extra_model_kwargs - ) + response = client.embeddings.create(model=model, input=texts, **extra_model_kwargs) + + if "encoding_format" in extra_model_kwargs and extra_model_kwargs["encoding_format"] == "base64": + return ( + [ + list(np.frombuffer(base64.b64decode(embedding.embedding), dtype=np.float32)) + for embedding in response.data + ], + response.usage.total_tokens, + ) - if 'encoding_format' in extra_model_kwargs and extra_model_kwargs['encoding_format'] == 'base64': - return ([list(np.frombuffer(base64.b64decode(embedding.embedding), dtype=np.float32)) for embedding in response.data], response.usage.total_tokens) - return [data.embedding for data in response.data], response.usage.total_tokens - + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -176,10 +181,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em :return: usage """ input_price_info = self.get_price( - model=model, - credentials=credentials, - tokens=tokens, - price_type=PriceType.INPUT + model=model, credentials=credentials, tokens=tokens, price_type=PriceType.INPUT ) usage = EmbeddingUsage( @@ -189,7 +191,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/upstage/upstage.py b/api/core/model_runtime/model_providers/upstage/upstage.py index 56c91c00618922..e45d4aae19eb6c 100644 --- a/api/core/model_runtime/model_providers/upstage/upstage.py +++ b/api/core/model_runtime/model_providers/upstage/upstage.py @@ -8,7 +8,6 @@ class UpstageProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -19,14 +18,10 @@ def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model="solar-1-mini-chat", - credentials=credentials - ) + model_instance.validate_credentials(model="solar-1-mini-chat", credentials=credentials) except CredentialsValidateFailedError as e: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise e except Exception as e: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise e - diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3.5-sonnet-v2.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3.5-sonnet-v2.yaml new file mode 100644 index 00000000000000..0be3e26e7ad851 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/anthropic.claude-3.5-sonnet-v2.yaml @@ -0,0 +1,55 @@ +model: claude-3-5-sonnet-v2@20241022 +label: + en_US: Claude 3.5 Sonnet v2 +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 200000 +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. +pricing: + input: '0.003' + output: '0.015' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash-001.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash-001.yaml new file mode 100644 index 00000000000000..f5386be06da6be --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash-001.yaml @@ -0,0 +1,37 @@ +model: gemini-1.5-flash-001 +label: + en_US: Gemini 1.5 Flash 001 +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 1048576 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + en_US: Top k + type: int + help: + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_output_tokens + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash-002.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash-002.yaml new file mode 100644 index 00000000000000..97bd44f06b5145 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash-002.yaml @@ -0,0 +1,37 @@ +model: gemini-1.5-flash-002 +label: + en_US: Gemini 1.5 Flash 002 +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 1048576 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + en_US: Top k + type: int + help: + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_output_tokens + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash.yaml deleted file mode 100644 index c308f0a322fddd..00000000000000 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-flash.yaml +++ /dev/null @@ -1,37 +0,0 @@ -model: gemini-1.5-flash-001 -label: - en_US: Gemini 1.5 Flash -model_type: llm -features: - - agent-thought - - vision -model_properties: - mode: chat - context_size: 1048576 -parameter_rules: - - name: temperature - use_template: temperature - - name: top_p - use_template: top_p - - name: top_k - label: - en_US: Top k - type: int - help: - en_US: Only sample from the top K options for each subsequent token. - required: false - - name: presence_penalty - use_template: presence_penalty - - name: frequency_penalty - use_template: frequency_penalty - - name: max_output_tokens - use_template: max_tokens - required: true - default: 8192 - min: 1 - max: 8192 -pricing: - input: '0.00' - output: '0.00' - unit: '0.000001' - currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro-001.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro-001.yaml new file mode 100644 index 00000000000000..5e08f2294e2ebf --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro-001.yaml @@ -0,0 +1,37 @@ +model: gemini-1.5-pro-001 +label: + en_US: Gemini 1.5 Pro 001 +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 1048576 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + en_US: Top k + type: int + help: + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_output_tokens + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro-002.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro-002.yaml new file mode 100644 index 00000000000000..8f327ea2f3d37e --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro-002.yaml @@ -0,0 +1,37 @@ +model: gemini-1.5-pro-002 +label: + en_US: Gemini 1.5 Pro 002 +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 1048576 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + en_US: Top k + type: int + help: + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_output_tokens + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro.yaml deleted file mode 100644 index 744863e7731e15..00000000000000 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-1.5-pro.yaml +++ /dev/null @@ -1,37 +0,0 @@ -model: gemini-1.5-pro-001 -label: - en_US: Gemini 1.5 Pro -model_type: llm -features: - - agent-thought - - vision -model_properties: - mode: chat - context_size: 1048576 -parameter_rules: - - name: temperature - use_template: temperature - - name: top_p - use_template: top_p - - name: top_k - label: - en_US: Top k - type: int - help: - en_US: Only sample from the top K options for each subsequent token. - required: false - - name: presence_penalty - use_template: presence_penalty - - name: frequency_penalty - use_template: frequency_penalty - - name: max_output_tokens - use_template: max_tokens - required: true - default: 8192 - min: 1 - max: 8192 -pricing: - input: '0.00' - output: '0.00' - unit: '0.000001' - currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-flash-experimental.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-flash-experimental.yaml new file mode 100644 index 00000000000000..0f5eb34c0cdf03 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-flash-experimental.yaml @@ -0,0 +1,37 @@ +model: gemini-flash-experimental +label: + en_US: Gemini Flash Experimental +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 1048576 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + en_US: Top k + type: int + help: + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_output_tokens + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-pro-experimental.yaml b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-pro-experimental.yaml new file mode 100644 index 00000000000000..fa31cabb85abb0 --- /dev/null +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/gemini-pro-experimental.yaml @@ -0,0 +1,37 @@ +model: gemini-pro-experimental +label: + en_US: Gemini Pro Experimental +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 1048576 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + en_US: Top k + type: int + help: + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_output_tokens + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py index 8901549110ee07..1469de605525ef 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -1,11 +1,13 @@ import base64 +import io import json import logging +import time from collections.abc import Generator from typing import Optional, Union, cast -import google.api_core.exceptions as exceptions import google.auth.transport.requests +import requests import vertexai.generative_models as glm from anthropic import AnthropicVertex, Stream from anthropic.types import ( @@ -16,9 +18,10 @@ MessageStopEvent, MessageStreamEvent, ) +from google.api_core import exceptions from google.cloud import aiplatform from google.oauth2 import service_account -from vertexai.generative_models import HarmBlockThreshold, HarmCategory +from PIL import Image from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( @@ -32,6 +35,7 @@ ToolPromptMessage, UserPromptMessage, ) +from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -47,12 +51,17 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -72,8 +81,16 @@ def _invoke(self, model: str, credentials: dict, # invoke Gemini model return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate_anthropic( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke Anthropic large language model @@ -90,7 +107,7 @@ def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: li service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) project_id = credentials["vertex_project_id"] SCOPES = ["https://www.googleapis.com/auth/cloud-platform"] - token = '' + token = "" # get access token from service account credential if service_account_info: @@ -99,41 +116,35 @@ def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: li credentials.refresh(request) token = credentials.token - # Vertex AI Anthropic Claude3 Opus model available in us-east5 region, Sonnet and Haiku available in us-central1 region - if 'opus' or 'claude-3-5-sonnet' in model: - location = 'us-east5' + # Vertex AI Anthropic Claude3 Opus model available in us-east5 region, Sonnet and Haiku available + # in us-central1 region + if "opus" in model or "claude-3-5-sonnet" in model: + location = "us-east5" else: - location = 'us-central1' - + location = "us-central1" + # use access token to authenticate if token: - client = AnthropicVertex( - region=location, - project_id=project_id, - access_token=token - ) - # When access token is empty, try to use the Google Cloud VM's built-in service account or the GOOGLE_APPLICATION_CREDENTIALS environment variable + client = AnthropicVertex(region=location, project_id=project_id, access_token=token) + # When access token is empty, try to use the Google Cloud VM's built-in service account + # or the GOOGLE_APPLICATION_CREDENTIALS environment variable else: client = AnthropicVertex( - region=location, + region=location, project_id=project_id, ) extra_model_kwargs = {} if stop: - extra_model_kwargs['stop_sequences'] = stop + extra_model_kwargs["stop_sequences"] = stop system, prompt_message_dicts = self._convert_claude_prompt_messages(prompt_messages) if system: - extra_model_kwargs['system'] = system + extra_model_kwargs["system"] = system response = client.messages.create( - model=model, - messages=prompt_message_dicts, - stream=stream, - **model_parameters, - **extra_model_kwargs + model=model, messages=prompt_message_dicts, stream=stream, **model_parameters, **extra_model_kwargs ) if stream: @@ -141,8 +152,9 @@ def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: li return self._handle_claude_response(model, credentials, response, prompt_messages) - def _handle_claude_response(self, model: str, credentials: dict, response: Message, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_claude_response( + self, model: str, credentials: dict, response: Message, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm chat response @@ -154,9 +166,7 @@ def _handle_claude_response(self, model: str, credentials: dict, response: Messa """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=response.content[0].text - ) + assistant_prompt_message = AssistantPromptMessage(content=response.content[0].text) # calculate num tokens if response.usage: @@ -173,16 +183,18 @@ def _handle_claude_response(self, model: str, credentials: dict, response: Messa # transform response response = LLMResult( - model=response.model, - prompt_messages=prompt_messages, - message=assistant_prompt_message, - usage=usage + model=response.model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage ) return response - def _handle_claude_stream_response(self, model: str, credentials: dict, response: Stream[MessageStreamEvent], - prompt_messages: list[PromptMessage], ) -> Generator: + def _handle_claude_stream_response( + self, + model: str, + credentials: dict, + response: Stream[MessageStreamEvent], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm chat stream response @@ -194,7 +206,7 @@ def _handle_claude_stream_response(self, model: str, credentials: dict, response """ try: - full_assistant_content = '' + full_assistant_content = "" return_model = None input_tokens = 0 output_tokens = 0 @@ -215,18 +227,16 @@ def _handle_claude_stream_response(self, model: str, credentials: dict, response prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index + 1, - message=AssistantPromptMessage( - content='' - ), + message=AssistantPromptMessage(content=""), finish_reason=finish_reason, - usage=usage - ) + usage=usage, + ), ) elif isinstance(chunk, ContentBlockDeltaEvent): - chunk_text = chunk.delta.text if chunk.delta.text else '' + chunk_text = chunk.delta.text or "" full_assistant_content += chunk_text assistant_prompt_message = AssistantPromptMessage( - content=chunk_text if chunk_text else '', + content=chunk_text or "", ) index = chunk.index yield LLMResultChunk( @@ -235,12 +245,14 @@ def _handle_claude_stream_response(self, model: str, credentials: dict, response delta=LLMResultChunkDelta( index=index, message=assistant_prompt_message, - ) + ), ) except Exception as ex: raise InvokeError(str(ex)) - def _calc_claude_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage: + def _calc_claude_response_usage( + self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int + ) -> LLMUsage: """ Calculate response usage @@ -260,10 +272,7 @@ def _calc_claude_response_usage(self, model: str, credentials: dict, prompt_toke # get completion price info completion_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.OUTPUT, - tokens=completion_tokens + model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens ) # transform usage @@ -279,7 +288,7 @@ def _calc_claude_response_usage(self, model: str, credentials: dict, prompt_toke total_tokens=prompt_tokens + completion_tokens, total_price=prompt_price_info.total_amount + completion_price_info.total_amount, currency=prompt_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage @@ -293,13 +302,13 @@ def _convert_claude_prompt_messages(self, prompt_messages: list[PromptMessage]) first_loop = True for message in prompt_messages: if isinstance(message, SystemPromptMessage): - message.content=message.content.strip() + message.content = message.content.strip() if first_loop: - system=message.content - first_loop=False + system = message.content + first_loop = False else: - system+="\n" - system+=message.content + system += "\n" + system += message.content prompt_message_dicts = [] for message in prompt_messages: @@ -321,10 +330,7 @@ def _convert_claude_prompt_message_to_dict(self, message: PromptMessage) -> dict for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(TextPromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) @@ -332,8 +338,9 @@ def _convert_claude_prompt_message_to_dict(self, message: PromptMessage) -> dict # fetch image data from url try: image_content = requests.get(message_content.data).content - mime_type, _ = mimetypes.guess_type(message_content.data) - base64_data = base64.b64encode(image_content).decode('utf-8') + with Image.open(io.BytesIO(image_content)) as img: + mime_type = f"image/{img.format.lower()}" + base64_data = base64.b64encode(image_content).decode("utf-8") except Exception as ex: raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") else: @@ -341,17 +348,15 @@ def _convert_claude_prompt_message_to_dict(self, message: PromptMessage) -> dict mime_type = data_split[0].replace("data:", "") base64_data = data_split[1] - if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: - raise ValueError(f"Unsupported image type {mime_type}, " - f"only support image/jpeg, image/png, image/gif, and image/webp") + if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}: + raise ValueError( + f"Unsupported image type {mime_type}, " + f"only support image/jpeg, image/png, image/gif, and image/webp" + ) sub_message_dict = { "type": "image", - "source": { - "type": "base64", - "media_type": mime_type, - "data": base64_data - } + "source": {"type": "base64", "media_type": mime_type, "data": base64_data}, } sub_messages.append(sub_message_dict) @@ -367,8 +372,13 @@ def _convert_claude_prompt_message_to_dict(self, message: PromptMessage) -> dict return message_dict - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -381,7 +391,7 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr prompt = self._convert_messages_to_prompt(prompt_messages) return self._get_num_tokens_by_gpt2(prompt) - + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Google model @@ -391,13 +401,10 @@ def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) return text.rstrip() - + def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool: """ Convert tool messages to glm tools @@ -413,14 +420,16 @@ def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool type=glm.Type.OBJECT, properties={ key: { - 'type_': value.get('type', 'string').upper(), - 'description': value.get('description', ''), - 'enum': value.get('enum', []) - } for key, value in tool.parameters.get('properties', {}).items() + "type_": value.get("type", "string").upper(), + "description": value.get("description", ""), + "enum": value.get("enum", []), + } + for key, value in tool.parameters.get("properties", {}).items() }, - required=tool.parameters.get('required', []) + required=tool.parameters.get("required", []), ), - ) for tool in tools + ) + for tool in tools ] ) @@ -432,20 +441,25 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :param credentials: model credentials :return: """ - + try: ping_message = SystemPromptMessage(content="ping") self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5}) - + except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None - ) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -459,7 +473,7 @@ def _generate(self, model: str, credentials: dict, :return: full response or stream response chunk generator result """ config_kwargs = model_parameters.copy() - config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None) + config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None) if stop: config_kwargs["stop_sequences"] = stop @@ -491,26 +505,13 @@ def _generate(self, model: str, credentials: dict, else: history.append(content) - safety_settings={ - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, - } - - google_model = glm.GenerativeModel( - model_name=model, - system_instruction=system_instruction - ) + google_model = glm.GenerativeModel(model_name=model, system_instruction=system_instruction) response = google_model.generate_content( contents=history, - generation_config=glm.GenerationConfig( - **config_kwargs - ), + generation_config=glm.GenerationConfig(**config_kwargs), stream=stream, - safety_settings=safety_settings, - tools=self._convert_tools_to_glm_tool(tools) if tools else None + tools=self._convert_tools_to_glm_tool(tools) if tools else None, ) if stream: @@ -518,8 +519,9 @@ def _generate(self, model: str, credentials: dict, return self._handle_generate_response(model, credentials, response, prompt_messages) - def _handle_generate_response(self, model: str, credentials: dict, response: glm.GenerationResponse, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _handle_generate_response( + self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage] + ) -> LLMResult: """ Handle llm response @@ -530,9 +532,7 @@ def _handle_generate_response(self, model: str, credentials: dict, response: glm :return: llm response """ # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=response.candidates[0].content.parts[0].text - ) + assistant_prompt_message = AssistantPromptMessage(content=response.candidates[0].content.parts[0].text) # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) @@ -551,8 +551,9 @@ def _handle_generate_response(self, model: str, credentials: dict, response: glm return result - def _handle_generate_stream_response(self, model: str, credentials: dict, response: glm.GenerationResponse, - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage] + ) -> Generator: """ Handle llm stream response @@ -565,9 +566,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon index = -1 for chunk in response: for part in chunk.candidates[0].content.parts: - assistant_prompt_message = AssistantPromptMessage( - content='' - ) + assistant_prompt_message = AssistantPromptMessage(content="") if part.text: assistant_prompt_message.content += part.text @@ -576,35 +575,31 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon assistant_prompt_message.tool_calls = [ AssistantPromptMessage.ToolCall( id=part.function_call.name, - type='function', + type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=part.function_call.name, - arguments=json.dumps(dict(part.function_call.args.items())) - ) + arguments=json.dumps(dict(part.function_call.args.items())), + ), ) ] index += 1 - - if not hasattr(chunk, 'finish_reason') or not chunk.finish_reason: + + if not hasattr(chunk, "finish_reason") or not chunk.finish_reason: # transform assistant message to prompt message yield LLMResultChunk( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=assistant_prompt_message - ) + delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message), ) else: - # calculate num tokens prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - + yield LLMResultChunk( model=model, prompt_messages=prompt_messages, @@ -612,8 +607,8 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon index=index, message=assistant_prompt_message, finish_reason=chunk.candidates[0].finish_reason, - usage=usage - ) + usage=usage, + ), ) def _convert_one_message_to_text(self, message: PromptMessage) -> str: @@ -628,17 +623,13 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: content = message.content if isinstance(content, list): - content = "".join( - c.data for c in content if c.type != PromptMessageContentType.IMAGE - ) + content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE) if isinstance(message, UserPromptMessage): message_text = f"{human_prompt} {content}" elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" - elif isinstance(message, SystemPromptMessage): - message_text = f"{human_prompt} {content}" - elif isinstance(message, ToolPromptMessage): + elif isinstance(message, SystemPromptMessage | ToolPromptMessage): message_text = f"{human_prompt} {content}" else: raise ValueError(f"Got unknown type {message}") @@ -655,7 +646,7 @@ def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content: if isinstance(message, UserPromptMessage): glm_content = glm.Content(role="user", parts=[]) - if (isinstance(message.content, str)): + if isinstance(message.content, str): glm_content = glm.Content(role="user", parts=[glm.Part.from_text(message.content)]) else: parts = [] @@ -663,61 +654,73 @@ def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content: if c.type == PromptMessageContentType.TEXT: parts.append(glm.Part.from_text(c.data)) else: - metadata, data = c.data.split(',', 1) - mime_type = metadata.split(';', 1)[0].split(':')[1] - parts.append(glm.Part.from_data(mime_type=mime_type, data=data)) + message_content = cast(ImagePromptMessageContent, c) + if not message_content.data.startswith("data:"): + url_arr = message_content.data.split(".") + mime_type = f"image/{url_arr[-1]}" + parts.append(glm.Part.from_uri(mime_type=mime_type, uri=message_content.data)) + else: + metadata, data = c.data.split(",", 1) + mime_type = metadata.split(";", 1)[0].split(":")[1] + parts.append(glm.Part.from_data(mime_type=mime_type, data=data)) glm_content = glm.Content(role="user", parts=parts) return glm_content elif isinstance(message, AssistantPromptMessage): if message.content: glm_content = glm.Content(role="model", parts=[glm.Part.from_text(message.content)]) if message.tool_calls: - glm_content = glm.Content(role="model", parts=[glm.Part.from_function_response(glm.FunctionCall( - name=message.tool_calls[0].function.name, - args=json.loads(message.tool_calls[0].function.arguments), - ))]) + glm_content = glm.Content( + role="model", + parts=[ + glm.Part.from_function_response( + glm.FunctionCall( + name=message.tool_calls[0].function.name, + args=json.loads(message.tool_calls[0].function.arguments), + ) + ) + ], + ) return glm_content elif isinstance(message, ToolPromptMessage): - glm_content = glm.Content(role="function", parts=[glm.Part(function_response=glm.FunctionResponse( - name=message.name, - response={ - "response": message.content - } - ))]) + glm_content = glm.Content( + role="function", + parts=[ + glm.Part( + function_response=glm.FunctionResponse( + name=message.name, response={"response": message.content} + ) + ) + ], + ) return glm_content else: raise ValueError(f"Got unknown type {message}") - + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ Map model invoke error to unified error - The key is the ermd = gml.GenerativeModel(model)ror type thrown to the caller - The value is the md = gml.GenerativeModel(model)error type thrown by the model, + The key is the ermd = gml.GenerativeModel(model) error type thrown to the caller + The value is the md = gml.GenerativeModel(model) error type thrown by the model, which needs to be converted into a unified error type for the caller. - :return: Invoke emd = gml.GenerativeModel(model)rror mapping + :return: Invoke emd = gml.GenerativeModel(model) error mapping """ return { - InvokeConnectionError: [ - exceptions.RetryError - ], + InvokeConnectionError: [exceptions.RetryError], InvokeServerUnavailableError: [ exceptions.ServiceUnavailable, exceptions.InternalServerError, exceptions.BadGateway, exceptions.GatewayTimeout, - exceptions.DeadlineExceeded - ], - InvokeRateLimitError: [ - exceptions.ResourceExhausted, - exceptions.TooManyRequests + exceptions.DeadlineExceeded, ], + InvokeRateLimitError: [exceptions.ResourceExhausted, exceptions.TooManyRequests], InvokeAuthorizationError: [ exceptions.Unauthenticated, exceptions.PermissionDenied, exceptions.Unauthenticated, - exceptions.Forbidden + exceptions.Forbidden, ], InvokeBadRequestError: [ exceptions.BadRequest, @@ -733,5 +736,5 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] exceptions.PreconditionFailed, exceptions.RequestRangeNotSatisfiable, exceptions.Cancelled, - ] + ], } diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py index 2404ba589431a7..43233e61262264 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py @@ -9,6 +9,7 @@ from google.oauth2 import service_account from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, @@ -29,15 +30,22 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): Model class for Vertex AI text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: """ Invoke text embedding model :param model: model name :param credentials: model credentials :param texts: texts to embed + :param user: unique user id + :param input_type: input type :return: embeddings result """ service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) @@ -51,23 +59,12 @@ def _invoke(self, model: str, credentials: dict, client = VertexTextEmbeddingModel.from_pretrained(model) - embeddings_batch, embedding_used_tokens = self._embedding_invoke( - client=client, - texts=texts - ) + embeddings_batch, embedding_used_tokens = self._embedding_invoke(client=client, texts=texts) # calc usage - usage = self._calc_response_usage( - model=model, - credentials=credentials, - tokens=embedding_used_tokens - ) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=embedding_used_tokens) - return TextEmbeddingResult( - embeddings=embeddings_batch, - usage=usage, - model=model - ) + return TextEmbeddingResult(embeddings=embeddings_batch, usage=usage, model=model) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ @@ -115,15 +112,11 @@ def validate_credentials(self, model: str, credentials: dict) -> None: client = VertexTextEmbeddingModel.from_pretrained(model) # call embedding model - self._embedding_invoke( - model=model, - client=client, - texts=['ping'] - ) + self._embedding_invoke(model=model, client=client, texts=["ping"]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _embedding_invoke(self, client: VertexTextEmbeddingModel, texts: list[str]) -> [list[float], int]: # type: ignore + def _embedding_invoke(self, client: VertexTextEmbeddingModel, texts: list[str]) -> [list[float], int]: # type: ignore """ Invoke embedding model @@ -154,10 +147,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -168,14 +158,14 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage - + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ entity = AIModelEntity( model=model, @@ -183,15 +173,15 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity diff --git a/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py index 3cbfb088d12536..466a86fd36a181 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py +++ b/api/core/model_runtime/model_providers/vertex_ai/vertex_ai.py @@ -20,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: model_instance = self.get_model_instance(ModelType.LLM) # Use `gemini-1.0-pro-002` model for validate, - model_instance.validate_credentials( - model='gemini-1.0-pro-002', - credentials=credentials - ) + model_instance.validate_credentials(model="gemini-1.0-pro-002", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/vessl_ai/__init__.py b/api/core/model_runtime/model_providers/vessl_ai/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_l_en.png new file mode 100644 index 00000000000000..18ba350fa0c98f Binary files /dev/null and b/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_l_en.png differ diff --git a/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_s_en.svg new file mode 100644 index 00000000000000..242f4e82b278b2 --- /dev/null +++ b/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_s_en.svg @@ -0,0 +1,3 @@ + + + diff --git a/api/core/model_runtime/model_providers/vessl_ai/llm/__init__.py b/api/core/model_runtime/model_providers/vessl_ai/llm/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py b/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py new file mode 100644 index 00000000000000..034c066ab5f071 --- /dev/null +++ b/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py @@ -0,0 +1,83 @@ +from decimal import Decimal + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + DefaultParameterName, + FetchFrom, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, + PriceConfig, +) +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel + + +class VesslAILargeLanguageModel(OAIAPICompatLargeLanguageModel): + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + features = [] + + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + features=features, + model_properties={ + ModelPropertyKey.MODE: credentials.get("mode"), + }, + parameter_rules=[ + ParameterRule( + name=DefaultParameterName.TEMPERATURE.value, + label=I18nObject(en_US="Temperature"), + type=ParameterType.FLOAT, + default=float(credentials.get("temperature", 0.7)), + min=0, + max=2, + precision=2, + ), + ParameterRule( + name=DefaultParameterName.TOP_P.value, + label=I18nObject(en_US="Top P"), + type=ParameterType.FLOAT, + default=float(credentials.get("top_p", 1)), + min=0, + max=1, + precision=2, + ), + ParameterRule( + name=DefaultParameterName.TOP_K.value, + label=I18nObject(en_US="Top K"), + type=ParameterType.INT, + default=int(credentials.get("top_k", 50)), + min=-2147483647, + max=2147483647, + precision=0, + ), + ParameterRule( + name=DefaultParameterName.MAX_TOKENS.value, + label=I18nObject(en_US="Max Tokens"), + type=ParameterType.INT, + default=512, + min=1, + max=int(credentials.get("max_tokens_to_sample", 4096)), + ), + ], + pricing=PriceConfig( + input=Decimal(credentials.get("input_price", 0)), + output=Decimal(credentials.get("output_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), + ) + + if credentials["mode"] == "chat": + entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value + elif credentials["mode"] == "completion": + entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value + else: + raise ValueError(f"Unknown completion type {credentials['completion_type']}") + + return entity diff --git a/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.py b/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.py new file mode 100644 index 00000000000000..7a987c67107994 --- /dev/null +++ b/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.py @@ -0,0 +1,10 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class VesslAIProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + pass diff --git a/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.yaml b/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.yaml new file mode 100644 index 00000000000000..6052756cae4887 --- /dev/null +++ b/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.yaml @@ -0,0 +1,56 @@ +provider: vessl_ai +label: + en_US: vessl_ai +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.png +background: "#F1EFED" +help: + title: + en_US: How to deploy VESSL AI LLM Model Endpoint + url: + en_US: https://docs.vessl.ai/guides/get-started/llama3-deployment +supported_model_types: + - llm +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + placeholder: + en_US: Enter your model name + credential_form_schemas: + - variable: endpoint_url + label: + en_US: endpoint url + type: text-input + required: true + placeholder: + en_US: Enter the url of your endpoint url + - variable: api_key + required: true + label: + en_US: API Key + type: secret-input + placeholder: + en_US: Enter your VESSL AI secret key + - variable: mode + show_on: + - variable: __model_type + value: llm + label: + en_US: Completion mode + type: select + required: false + default: chat + placeholder: + en_US: Select completion mode + options: + - value: completion + label: + en_US: Completion + - value: chat + label: + en_US: Chat diff --git a/api/core/model_runtime/model_providers/volcengine_maas/client.py b/api/core/model_runtime/model_providers/volcengine_maas/client.py index 471cb3c94e01f8..cfe21e4b9f4617 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/client.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/client.py @@ -1,6 +1,25 @@ import re -from collections.abc import Callable, Generator -from typing import cast +from collections.abc import Generator +from typing import Optional, cast + +from volcenginesdkarkruntime import Ark +from volcenginesdkarkruntime.types.chat import ( + ChatCompletion, + ChatCompletionAssistantMessageParam, + ChatCompletionChunk, + ChatCompletionContentPartImageParam, + ChatCompletionContentPartTextParam, + ChatCompletionMessageParam, + ChatCompletionMessageToolCallParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolMessageParam, + ChatCompletionToolParam, + ChatCompletionUserMessageParam, +) +from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL +from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function +from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse +from volcenginesdkarkruntime.types.shared_params import FunctionDefinition from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -12,123 +31,186 @@ ToolPromptMessage, UserPromptMessage, ) -from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error -from core.model_runtime.model_providers.volcengine_maas.volc_sdk import ChatRole, MaasException, MaasService +DEFAULT_V2_ENDPOINT = "maas-api.ml-platform-cn-beijing.volces.com" +DEFAULT_V3_ENDPOINT = "https://ark.cn-beijing.volces.com/api/v3" + + +class ArkClientV3: + endpoint_id: Optional[str] = None + ark: Optional[Ark] = None -class MaaSClient(MaasService): - def __init__(self, host: str, region: str): + def __init__(self, *args, **kwargs): + self.ark = Ark(*args, **kwargs) self.endpoint_id = None - super().__init__(host, region) - def set_endpoint_id(self, endpoint_id: str): - self.endpoint_id = endpoint_id + @staticmethod + def is_legacy(credentials: dict) -> bool: + # match default v2 endpoint + if ArkClientV3.is_compatible_with_legacy(credentials): + return False + # match default v3 endpoint + if credentials.get("api_endpoint_host") == DEFAULT_V3_ENDPOINT: + return False + # only v3 support api_key + if credentials.get("auth_method") == "api_key": + return False + # these cases are considered as sdk v2 + # - modified default v2 endpoint + # - modified default v3 endpoint and auth without api_key + return True - @classmethod - def from_credential(cls, credentials: dict) -> 'MaaSClient': - host = credentials['api_endpoint_host'] - region = credentials['volc_region'] - ak = credentials['volc_access_key_id'] - sk = credentials['volc_secret_access_key'] - endpoint_id = credentials['endpoint_id'] - - client = cls(host, region) - client.set_endpoint_id(endpoint_id) - client.set_ak(ak) - client.set_sk(sk) - return client + @staticmethod + def is_compatible_with_legacy(credentials: dict) -> bool: + endpoint = credentials.get("api_endpoint_host") + return endpoint == DEFAULT_V2_ENDPOINT - def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict: - req = { - 'parameters': params, - 'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages], - **extra_model_kwargs, + @classmethod + def from_credentials(cls, credentials): + """Initialize the client using the credentials provided.""" + args = { + "base_url": credentials["api_endpoint_host"], + "region": credentials["volc_region"], } - if not stream: - return super().chat( - self.endpoint_id, - req, - ) - return super().stream_chat( - self.endpoint_id, - req, - ) + if credentials.get("auth_method") == "api_key": + args = { + **args, + "api_key": credentials["volc_api_key"], + } + else: + args = { + **args, + "ak": credentials["volc_access_key_id"], + "sk": credentials["volc_secret_access_key"], + } - def embeddings(self, texts: list[str]) -> dict: - req = { - 'input': texts - } - return super().embeddings(self.endpoint_id, req) + if cls.is_compatible_with_legacy(credentials): + args = {**args, "base_url": DEFAULT_V3_ENDPOINT} + + client = ArkClientV3(**args) + client.endpoint_id = credentials["endpoint_id"] + return client @staticmethod - def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict: + def convert_prompt_message(message: PromptMessage) -> ChatCompletionMessageParam: + """Converts a PromptMessage to a ChatCompletionMessageParam""" if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) if isinstance(message.content, str): - message_dict = {"role": ChatRole.USER, - "content": message.content} + content = message.content else: content = [] for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: - raise ValueError( - 'Content object type only support image_url') + content.append( + ChatCompletionContentPartTextParam( + text=message_content.text, + type="text", + ) + ) elif message_content.type == PromptMessageContentType.IMAGE: - message_content = cast( - ImagePromptMessageContent, message_content) - image_data = re.sub( - r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data) - content.append({ - 'type': 'image_url', - 'image_url': { - 'url': '', - 'image_bytes': image_data, - 'detail': message_content.detail, - } - }) - - message_dict = {'role': ChatRole.USER, 'content': content} + message_content = cast(ImagePromptMessageContent, message_content) + image_data = re.sub(r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data) + content.append( + ChatCompletionContentPartImageParam( + image_url=ImageURL( + url=image_data, + detail=message_content.detail.value, + ), + type="image_url", + ) + ) + message_dict = ChatCompletionUserMessageParam(role="user", content=content) elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) - message_dict = {'role': ChatRole.ASSISTANT, - 'content': message.content} - if message.tool_calls: - message_dict['tool_calls'] = [ - { - 'name': call.function.name, - 'arguments': call.function.arguments - } for call in message.tool_calls - ] + message_dict = ChatCompletionAssistantMessageParam( + content=message.content, + role="assistant", + tool_calls=None + if not message.tool_calls + else [ + ChatCompletionMessageToolCallParam( + id=call.id, + function=Function(name=call.function.name, arguments=call.function.arguments), + type="function", + ) + for call in message.tool_calls + ], + ) elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) - message_dict = {'role': ChatRole.SYSTEM, - 'content': message.content} + message_dict = ChatCompletionSystemMessageParam(content=message.content, role="system") elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - message_dict = {'role': ChatRole.FUNCTION, - 'content': message.content, - 'name': message.tool_call_id} + message_dict = ChatCompletionToolMessageParam( + content=message.content, role="tool", tool_call_id=message.tool_call_id + ) else: raise ValueError(f"Got unknown PromptMessage type {message}") return message_dict @staticmethod - def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator: - try: - resp = fn() - except MaasException as e: - raise wrap_error(e) + def _convert_tool_prompt(message: PromptMessageTool) -> ChatCompletionToolParam: + return ChatCompletionToolParam( + type="function", + function=FunctionDefinition( + name=message.name, + description=message.description, + parameters=message.parameters, + ), + ) - return resp + def chat( + self, + messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + ) -> ChatCompletion: + """Block chat""" + return self.ark.chat.completions.create( + model=self.endpoint_id, + messages=[self.convert_prompt_message(message) for message in messages], + tools=[self._convert_tool_prompt(tool) for tool in tools] if tools else None, + stop=stop, + frequency_penalty=frequency_penalty, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + top_p=top_p, + temperature=temperature, + ) - @staticmethod - def transform_tool_prompt_to_maas_config(tool: PromptMessageTool): - return { - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters, - } - } + def stream_chat( + self, + messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + frequency_penalty: Optional[float] = None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + ) -> Generator[ChatCompletionChunk]: + """Stream chat""" + chunks = self.ark.chat.completions.create( + stream=True, + model=self.endpoint_id, + messages=[self.convert_prompt_message(message) for message in messages], + tools=[self._convert_tool_prompt(tool) for tool in tools] if tools else None, + stop=stop, + frequency_penalty=frequency_penalty, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + top_p=top_p, + temperature=temperature, + stream_options={"include_usage": True}, + ) + yield from chunks + + def embeddings(self, texts: list[str]) -> CreateEmbeddingResponse: + return self.ark.embeddings.create(model=self.endpoint_id, input=texts) diff --git a/api/core/model_runtime/model_providers/volcengine_maas/errors.py b/api/core/model_runtime/model_providers/volcengine_maas/errors.py deleted file mode 100644 index 63397a456ece96..00000000000000 --- a/api/core/model_runtime/model_providers/volcengine_maas/errors.py +++ /dev/null @@ -1,156 +0,0 @@ -from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException - - -class ClientSDKRequestError(MaasException): - pass - - -class SignatureDoesNotMatch(MaasException): - pass - - -class RequestTimeout(MaasException): - pass - - -class ServiceConnectionTimeout(MaasException): - pass - - -class MissingAuthenticationHeader(MaasException): - pass - - -class AuthenticationHeaderIsInvalid(MaasException): - pass - - -class InternalServiceError(MaasException): - pass - - -class MissingParameter(MaasException): - pass - - -class InvalidParameter(MaasException): - pass - - -class AuthenticationExpire(MaasException): - pass - - -class EndpointIsInvalid(MaasException): - pass - - -class EndpointIsNotEnable(MaasException): - pass - - -class ModelNotSupportStreamMode(MaasException): - pass - - -class ReqTextExistRisk(MaasException): - pass - - -class RespTextExistRisk(MaasException): - pass - - -class EndpointRateLimitExceeded(MaasException): - pass - - -class ServiceConnectionRefused(MaasException): - pass - - -class ServiceConnectionClosed(MaasException): - pass - - -class UnauthorizedUserForEndpoint(MaasException): - pass - - -class InvalidEndpointWithNoURL(MaasException): - pass - - -class EndpointAccountRpmRateLimitExceeded(MaasException): - pass - - -class EndpointAccountTpmRateLimitExceeded(MaasException): - pass - - -class ServiceResourceWaitQueueFull(MaasException): - pass - - -class EndpointIsPending(MaasException): - pass - - -class ServiceNotOpen(MaasException): - pass - - -AuthErrors = { - 'SignatureDoesNotMatch': SignatureDoesNotMatch, - 'MissingAuthenticationHeader': MissingAuthenticationHeader, - 'AuthenticationHeaderIsInvalid': AuthenticationHeaderIsInvalid, - 'AuthenticationExpire': AuthenticationExpire, - 'UnauthorizedUserForEndpoint': UnauthorizedUserForEndpoint, -} - -BadRequestErrors = { - 'MissingParameter': MissingParameter, - 'InvalidParameter': InvalidParameter, - 'EndpointIsInvalid': EndpointIsInvalid, - 'EndpointIsNotEnable': EndpointIsNotEnable, - 'ModelNotSupportStreamMode': ModelNotSupportStreamMode, - 'ReqTextExistRisk': ReqTextExistRisk, - 'RespTextExistRisk': RespTextExistRisk, - 'InvalidEndpointWithNoURL': InvalidEndpointWithNoURL, - 'ServiceNotOpen': ServiceNotOpen, -} - -RateLimitErrors = { - 'EndpointRateLimitExceeded': EndpointRateLimitExceeded, - 'EndpointAccountRpmRateLimitExceeded': EndpointAccountRpmRateLimitExceeded, - 'EndpointAccountTpmRateLimitExceeded': EndpointAccountTpmRateLimitExceeded, -} - -ServerUnavailableErrors = { - 'InternalServiceError': InternalServiceError, - 'EndpointIsPending': EndpointIsPending, - 'ServiceResourceWaitQueueFull': ServiceResourceWaitQueueFull, -} - -ConnectionErrors = { - 'ClientSDKRequestError': ClientSDKRequestError, - 'RequestTimeout': RequestTimeout, - 'ServiceConnectionTimeout': ServiceConnectionTimeout, - 'ServiceConnectionRefused': ServiceConnectionRefused, - 'ServiceConnectionClosed': ServiceConnectionClosed, -} - -ErrorCodeMap = { - **AuthErrors, - **BadRequestErrors, - **RateLimitErrors, - **ServerUnavailableErrors, - **ConnectionErrors, -} - - -def wrap_error(e: MaasException) -> Exception: - if ErrorCodeMap.get(e.code): - return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id) - return e diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py new file mode 100644 index 00000000000000..266f1216f82b29 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py @@ -0,0 +1,123 @@ +import re +from collections.abc import Callable, Generator +from typing import cast + +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error +from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasError, MaasService + + +class MaaSClient(MaasService): + def __init__(self, host: str, region: str): + self.endpoint_id = None + super().__init__(host, region) + + def set_endpoint_id(self, endpoint_id: str): + self.endpoint_id = endpoint_id + + @classmethod + def from_credential(cls, credentials: dict) -> "MaaSClient": + host = credentials["api_endpoint_host"] + region = credentials["volc_region"] + ak = credentials["volc_access_key_id"] + sk = credentials["volc_secret_access_key"] + endpoint_id = credentials["endpoint_id"] + + client = cls(host, region) + client.set_endpoint_id(endpoint_id) + client.set_ak(ak) + client.set_sk(sk) + return client + + def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict: + req = { + "parameters": params, + "messages": [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages], + **extra_model_kwargs, + } + if not stream: + return super().chat( + self.endpoint_id, + req, + ) + return super().stream_chat( + self.endpoint_id, + req, + ) + + def embeddings(self, texts: list[str]) -> dict: + req = {"input": texts} + return super().embeddings(self.endpoint_id, req) + + @staticmethod + def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict: + if isinstance(message, UserPromptMessage): + message = cast(UserPromptMessage, message) + if isinstance(message.content, str): + message_dict = {"role": ChatRole.USER, "content": message.content} + else: + content = [] + for message_content in message.content: + if message_content.type == PromptMessageContentType.TEXT: + raise ValueError("Content object type only support image_url") + elif message_content.type == PromptMessageContentType.IMAGE: + message_content = cast(ImagePromptMessageContent, message_content) + image_data = re.sub(r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data) + content.append( + { + "type": "image_url", + "image_url": { + "url": "", + "image_bytes": image_data, + "detail": message_content.detail, + }, + } + ) + + message_dict = {"role": ChatRole.USER, "content": content} + elif isinstance(message, AssistantPromptMessage): + message = cast(AssistantPromptMessage, message) + message_dict = {"role": ChatRole.ASSISTANT, "content": message.content} + if message.tool_calls: + message_dict["tool_calls"] = [ + {"name": call.function.name, "arguments": call.function.arguments} for call in message.tool_calls + ] + elif isinstance(message, SystemPromptMessage): + message = cast(SystemPromptMessage, message) + message_dict = {"role": ChatRole.SYSTEM, "content": message.content} + elif isinstance(message, ToolPromptMessage): + message = cast(ToolPromptMessage, message) + message_dict = {"role": ChatRole.FUNCTION, "content": message.content, "name": message.tool_call_id} + else: + raise ValueError(f"Got unknown PromptMessage type {message}") + + return message_dict + + @staticmethod + def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator: + try: + resp = fn() + except MaasError as e: + raise wrap_error(e) + + return resp + + @staticmethod + def transform_tool_prompt_to_maas_config(tool: PromptMessageTool): + return { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + }, + } diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py new file mode 100644 index 00000000000000..91dbe21a616195 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/errors.py @@ -0,0 +1,156 @@ +from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasError + + +class ClientSDKRequestError(MaasError): + pass + + +class SignatureDoesNotMatchError(MaasError): + pass + + +class RequestTimeoutError(MaasError): + pass + + +class ServiceConnectionTimeoutError(MaasError): + pass + + +class MissingAuthenticationHeaderError(MaasError): + pass + + +class AuthenticationHeaderIsInvalidError(MaasError): + pass + + +class InternalServiceError(MaasError): + pass + + +class MissingParameterError(MaasError): + pass + + +class InvalidParameterError(MaasError): + pass + + +class AuthenticationExpireError(MaasError): + pass + + +class EndpointIsInvalidError(MaasError): + pass + + +class EndpointIsNotEnableError(MaasError): + pass + + +class ModelNotSupportStreamModeError(MaasError): + pass + + +class ReqTextExistRiskError(MaasError): + pass + + +class RespTextExistRiskError(MaasError): + pass + + +class EndpointRateLimitExceededError(MaasError): + pass + + +class ServiceConnectionRefusedError(MaasError): + pass + + +class ServiceConnectionClosedError(MaasError): + pass + + +class UnauthorizedUserForEndpointError(MaasError): + pass + + +class InvalidEndpointWithNoURLError(MaasError): + pass + + +class EndpointAccountRpmRateLimitExceededError(MaasError): + pass + + +class EndpointAccountTpmRateLimitExceededError(MaasError): + pass + + +class ServiceResourceWaitQueueFullError(MaasError): + pass + + +class EndpointIsPendingError(MaasError): + pass + + +class ServiceNotOpenError(MaasError): + pass + + +AuthErrors = { + "SignatureDoesNotMatch": SignatureDoesNotMatchError, + "MissingAuthenticationHeader": MissingAuthenticationHeaderError, + "AuthenticationHeaderIsInvalid": AuthenticationHeaderIsInvalidError, + "AuthenticationExpire": AuthenticationExpireError, + "UnauthorizedUserForEndpoint": UnauthorizedUserForEndpointError, +} + +BadRequestErrors = { + "MissingParameter": MissingParameterError, + "InvalidParameter": InvalidParameterError, + "EndpointIsInvalid": EndpointIsInvalidError, + "EndpointIsNotEnable": EndpointIsNotEnableError, + "ModelNotSupportStreamMode": ModelNotSupportStreamModeError, + "ReqTextExistRisk": ReqTextExistRiskError, + "RespTextExistRisk": RespTextExistRiskError, + "InvalidEndpointWithNoURL": InvalidEndpointWithNoURLError, + "ServiceNotOpen": ServiceNotOpenError, +} + +RateLimitErrors = { + "EndpointRateLimitExceeded": EndpointRateLimitExceededError, + "EndpointAccountRpmRateLimitExceeded": EndpointAccountRpmRateLimitExceededError, + "EndpointAccountTpmRateLimitExceeded": EndpointAccountTpmRateLimitExceededError, +} + +ServerUnavailableErrors = { + "InternalServiceError": InternalServiceError, + "EndpointIsPending": EndpointIsPendingError, + "ServiceResourceWaitQueueFull": ServiceResourceWaitQueueFullError, +} + +ConnectionErrors = { + "ClientSDKRequestError": ClientSDKRequestError, + "RequestTimeout": RequestTimeoutError, + "ServiceConnectionTimeout": ServiceConnectionTimeoutError, + "ServiceConnectionRefused": ServiceConnectionRefusedError, + "ServiceConnectionClosed": ServiceConnectionClosedError, +} + +ErrorCodeMap = { + **AuthErrors, + **BadRequestErrors, + **RateLimitErrors, + **ServerUnavailableErrors, + **ConnectionErrors, +} + + +def wrap_error(e: MaasError) -> Exception: + if ErrorCodeMap.get(e.code): + return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id) + return e diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py new file mode 100644 index 00000000000000..8b3eb157be5cfe --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py @@ -0,0 +1,4 @@ +from .common import ChatRole +from .maas import MaasError, MaasService + +__all__ = ["MaasService", "ChatRole", "MaasError"] diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/__init__.py similarity index 100% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/__init__.py diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py new file mode 100644 index 00000000000000..c22bf8e76de36a --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py @@ -0,0 +1,159 @@ +# coding : utf-8 +import datetime +from itertools import starmap + +import pytz + +from .util import Util + + +class MetaData: + def __init__(self): + self.algorithm = "" + self.credential_scope = "" + self.signed_headers = "" + self.date = "" + self.region = "" + self.service = "" + + def set_date(self, date): + self.date = date + + def set_service(self, service): + self.service = service + + def set_region(self, region): + self.region = region + + def set_algorithm(self, algorithm): + self.algorithm = algorithm + + def set_credential_scope(self, credential_scope): + self.credential_scope = credential_scope + + def set_signed_headers(self, signed_headers): + self.signed_headers = signed_headers + + +class SignResult: + def __init__(self): + self.xdate = "" + self.xCredential = "" + self.xAlgorithm = "" + self.xSignedHeaders = "" + self.xSignedQueries = "" + self.xSignature = "" + self.xContextSha256 = "" + self.xSecurityToken = "" + + self.authorization = "" + + def __str__(self): + return "\n".join(list(starmap("{}:{}".format, self.__dict__.items()))) + + +class Credentials: + def __init__(self, ak, sk, service, region, session_token=""): + self.ak = ak + self.sk = sk + self.service = service + self.region = region + self.session_token = session_token + + def set_ak(self, ak): + self.ak = ak + + def set_sk(self, sk): + self.sk = sk + + def set_session_token(self, session_token): + self.session_token = session_token + + +class Signer: + @staticmethod + def sign(request, credentials): + if request.path == "": + request.path = "/" + if request.method != "GET" and "Content-Type" not in request.headers: + request.headers["Content-Type"] = "application/x-www-form-urlencoded; charset=utf-8" + + format_date = Signer.get_current_format_date() + request.headers["X-Date"] = format_date + if credentials.session_token != "": + request.headers["X-Security-Token"] = credentials.session_token + + md = MetaData() + md.set_algorithm("HMAC-SHA256") + md.set_service(credentials.service) + md.set_region(credentials.region) + md.set_date(format_date[:8]) + + hashed_canon_req = Signer.hashed_canonical_request_v4(request, md) + md.set_credential_scope("/".join([md.date, md.region, md.service, "request"])) + + signing_str = "\n".join([md.algorithm, format_date, md.credential_scope, hashed_canon_req]) + signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service) + sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str)) + request.headers["Authorization"] = Signer.build_auth_header_v4(sign, md, credentials) + + @staticmethod + def hashed_canonical_request_v4(request, meta): + body_hash = Util.sha256(request.body) + request.headers["X-Content-Sha256"] = body_hash + + signed_headers = {} + for key in request.headers: + if key in {"Content-Type", "Content-Md5", "Host"} or key.startswith("X-"): + signed_headers[key.lower()] = request.headers[key] + + if "host" in signed_headers: + v = signed_headers["host"] + if v.find(":") != -1: + split = v.split(":") + port = split[1] + if str(port) == "80" or str(port) == "443": + signed_headers["host"] = split[0] + + signed_str = "" + for key in sorted(signed_headers.keys()): + signed_str += key + ":" + signed_headers[key] + "\n" + + meta.set_signed_headers(";".join(sorted(signed_headers.keys()))) + + canonical_request = "\n".join( + [ + request.method, + Util.norm_uri(request.path), + Util.norm_query(request.query), + signed_str, + meta.signed_headers, + body_hash, + ] + ) + + return Util.sha256(canonical_request) + + @staticmethod + def get_signing_secret_key_v4(sk, date, region, service): + date = Util.hmac_sha256(bytes(sk, encoding="utf-8"), date) + region = Util.hmac_sha256(date, region) + service = Util.hmac_sha256(region, service) + return Util.hmac_sha256(service, "request") + + @staticmethod + def build_auth_header_v4(signature, meta, credentials): + credential = credentials.ak + "/" + meta.credential_scope + return ( + meta.algorithm + + " Credential=" + + credential + + ", SignedHeaders=" + + meta.signed_headers + + ", Signature=" + + signature + ) + + @staticmethod + def get_current_format_date(): + return datetime.datetime.now(tz=pytz.timezone("UTC")).strftime("%Y%m%dT%H%M%SZ") diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py new file mode 100644 index 00000000000000..33c41f3eb331a3 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py @@ -0,0 +1,216 @@ +import json +from collections import OrderedDict +from urllib.parse import urlencode + +import requests + +from .auth import Signer + +VERSION = "v1.0.137" + + +class Service: + def __init__(self, service_info, api_info): + self.service_info = service_info + self.api_info = api_info + self.session = requests.session() + + def set_ak(self, ak): + self.service_info.credentials.set_ak(ak) + + def set_sk(self, sk): + self.service_info.credentials.set_sk(sk) + + def set_session_token(self, session_token): + self.service_info.credentials.set_session_token(session_token) + + def set_host(self, host): + self.service_info.host = host + + def set_scheme(self, scheme): + self.service_info.scheme = scheme + + def get(self, api, params, doseq=0): + if api not in self.api_info: + raise Exception("no such api") + api_info = self.api_info[api] + + r = self.prepare_request(api_info, params, doseq) + + Signer.sign(r, self.service_info.credentials) + + url = r.build(doseq) + resp = self.session.get( + url, headers=r.headers, timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout) + ) + if resp.status_code == 200: + return resp.text + else: + raise Exception(resp.text) + + def post(self, api, params, form): + if api not in self.api_info: + raise Exception("no such api") + api_info = self.api_info[api] + r = self.prepare_request(api_info, params) + r.headers["Content-Type"] = "application/x-www-form-urlencoded" + r.form = self.merge(api_info.form, form) + r.body = urlencode(r.form, True) + Signer.sign(r, self.service_info.credentials) + + url = r.build() + + resp = self.session.post( + url, + headers=r.headers, + data=r.form, + timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout), + ) + if resp.status_code == 200: + return resp.text + else: + raise Exception(resp.text) + + def json(self, api, params, body): + if api not in self.api_info: + raise Exception("no such api") + api_info = self.api_info[api] + r = self.prepare_request(api_info, params) + r.headers["Content-Type"] = "application/json" + r.body = body + + Signer.sign(r, self.service_info.credentials) + + url = r.build() + resp = self.session.post( + url, + headers=r.headers, + data=r.body, + timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout), + ) + if resp.status_code == 200: + return json.dumps(resp.json()) + else: + raise Exception(resp.text.encode("utf-8")) + + def put(self, url, file_path, headers): + with open(file_path, "rb") as f: + resp = self.session.put(url, headers=headers, data=f) + if resp.status_code == 200: + return True, resp.text.encode("utf-8") + else: + return False, resp.text.encode("utf-8") + + def put_data(self, url, data, headers): + resp = self.session.put(url, headers=headers, data=data) + if resp.status_code == 200: + return True, resp.text.encode("utf-8") + else: + return False, resp.text.encode("utf-8") + + def prepare_request(self, api_info, params, doseq=0): + for key in params: + if type(params[key]) == int or type(params[key]) == float or type(params[key]) == bool: + params[key] = str(params[key]) + elif type(params[key]) == list: + if not doseq: + params[key] = ",".join(params[key]) + + connection_timeout = self.service_info.connection_timeout + socket_timeout = self.service_info.socket_timeout + + r = Request() + r.set_schema(self.service_info.scheme) + r.set_method(api_info.method) + r.set_connection_timeout(connection_timeout) + r.set_socket_timeout(socket_timeout) + + headers = self.merge(api_info.header, self.service_info.header) + headers["Host"] = self.service_info.host + headers["User-Agent"] = "volc-sdk-python/" + VERSION + r.set_headers(headers) + + query = self.merge(api_info.query, params) + r.set_query(query) + + r.set_host(self.service_info.host) + r.set_path(api_info.path) + + return r + + @staticmethod + def merge(param1, param2): + od = OrderedDict() + for key in param1: + od[key] = param1[key] + + for key in param2: + od[key] = param2[key] + + return od + + +class Request: + def __init__(self): + self.schema = "" + self.method = "" + self.host = "" + self.path = "" + self.headers = OrderedDict() + self.query = OrderedDict() + self.body = "" + self.form = {} + self.connection_timeout = 0 + self.socket_timeout = 0 + + def set_schema(self, schema): + self.schema = schema + + def set_method(self, method): + self.method = method + + def set_host(self, host): + self.host = host + + def set_path(self, path): + self.path = path + + def set_headers(self, headers): + self.headers = headers + + def set_query(self, query): + self.query = query + + def set_body(self, body): + self.body = body + + def set_connection_timeout(self, connection_timeout): + self.connection_timeout = connection_timeout + + def set_socket_timeout(self, socket_timeout): + self.socket_timeout = socket_timeout + + def build(self, doseq=0): + return self.schema + "://" + self.host + self.path + "?" + urlencode(self.query, doseq) + + +class ServiceInfo: + def __init__(self, host, header, credentials, connection_timeout, socket_timeout, scheme="http"): + self.host = host + self.header = header + self.credentials = credentials + self.connection_timeout = connection_timeout + self.socket_timeout = socket_timeout + self.scheme = scheme + + +class ApiInfo: + def __init__(self, method, path, query, form, header): + self.method = method + self.path = path + self.query = query + self.form = form + self.header = header + + def __str__(self): + return "method: " + self.method + ", path: " + self.path diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py new file mode 100644 index 00000000000000..178d63714e9cf1 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py @@ -0,0 +1,44 @@ +import hashlib +import hmac +import operator +from functools import reduce +from urllib.parse import quote + + +class Util: + @staticmethod + def norm_uri(path): + return quote(path).replace("%2F", "/").replace("+", "%20") + + @staticmethod + def norm_query(params): + query = "" + for key in sorted(params.keys()): + if type(params[key]) == list: + for k in params[key]: + query = query + quote(key, safe="-_.~") + "=" + quote(k, safe="-_.~") + "&" + else: + query = query + quote(key, safe="-_.~") + "=" + quote(params[key], safe="-_.~") + "&" + query = query[:-1] + return query.replace("+", "%20") + + @staticmethod + def hmac_sha256(key, content): + return hmac.new(key, bytes(content, encoding="utf-8"), hashlib.sha256).digest() + + @staticmethod + def sha256(content): + if isinstance(content, str) is True: + return hashlib.sha256(content.encode("utf-8")).hexdigest() + else: + return hashlib.sha256(content).hexdigest() + + @staticmethod + def to_hex(content): + lst = [] + for ch in content: + hv = hex(ch).replace("0x", "") + if len(hv) == 1: + hv = "0" + hv + lst.append(hv) + return reduce(operator.add, lst) diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py similarity index 75% rename from api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py rename to api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py index 8b14d026d96795..3825fd65741ef5 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py @@ -43,9 +43,7 @@ def json_to_object(json_str, req_id=None): def gen_req_id(): - return datetime.now().strftime("%Y%m%d%H%M%S") + format( - random.randint(0, 2 ** 64 - 1), "020X" - ) + return datetime.now().strftime("%Y%m%d%H%M%S") + format(random.randint(0, 2**64 - 1), "020X") class SSEDecoder: @@ -53,13 +51,13 @@ def __init__(self, source): self.source = source def _read(self): - data = b'' + data = b"" for chunk in self.source: for line in chunk.splitlines(True): data += line - if data.endswith((b'\r\r', b'\n\n', b'\r\n\r\n')): + if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): yield data - data = b'' + data = b"" if data: yield data @@ -67,13 +65,13 @@ def next(self): for chunk in self._read(): for line in chunk.splitlines(): # skip comment - if line.startswith(b':'): + if line.startswith(b":"): continue - if b':' in line: - field, value = line.split(b':', 1) + if b":" in line: + field, value = line.split(b":", 1) else: - field, value = line, b'' + field, value = line, b"" - if field == b'data' and len(value) > 0: + if field == b"data" and len(value) > 0: yield value diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py new file mode 100644 index 00000000000000..a3836685f1fbf4 --- /dev/null +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py @@ -0,0 +1,198 @@ +import copy +import json +from collections.abc import Iterator + +from .base.auth import Credentials, Signer +from .base.service import ApiInfo, Service, ServiceInfo +from .common import SSEDecoder, dict_to_object, gen_req_id, json_to_object + + +class MaasService(Service): + def __init__(self, host, region, connection_timeout=60, socket_timeout=60): + service_info = self.get_service_info(host, region, connection_timeout, socket_timeout) + self._apikey = None + api_info = self.get_api_info() + super().__init__(service_info, api_info) + + def set_apikey(self, apikey): + self._apikey = apikey + + @staticmethod + def get_service_info(host, region, connection_timeout, socket_timeout): + service_info = ServiceInfo( + host, + {"Accept": "application/json"}, + Credentials("", "", "ml_maas", region), + connection_timeout, + socket_timeout, + "https", + ) + return service_info + + @staticmethod + def get_api_info(): + api_info = { + "chat": ApiInfo("POST", "/api/v2/endpoint/{endpoint_id}/chat", {}, {}, {}), + "embeddings": ApiInfo("POST", "/api/v2/endpoint/{endpoint_id}/embeddings", {}, {}, {}), + } + return api_info + + def chat(self, endpoint_id, req): + req["stream"] = False + return self._request(endpoint_id, "chat", req) + + def stream_chat(self, endpoint_id, req): + req_id = gen_req_id() + self._validate("chat", req_id) + apikey = self._apikey + + try: + req["stream"] = True + res = self._call(endpoint_id, "chat", req_id, {}, json.dumps(req).encode("utf-8"), apikey, stream=True) + + decoder = SSEDecoder(res) + + def iter_fn(): + for data in decoder.next(): + if data == b"[DONE]": + return + + try: + res = json_to_object(str(data, encoding="utf-8"), req_id=req_id) + except Exception: + raise + + if res.error is not None and res.error.code_n != 0: + raise MaasError( + res.error.code_n, + res.error.code, + res.error.message, + req_id, + ) + yield res + + return iter_fn() + except MaasError: + raise + except Exception as e: + raise new_client_sdk_request_error(str(e)) + + def embeddings(self, endpoint_id, req): + return self._request(endpoint_id, "embeddings", req) + + def _request(self, endpoint_id, api, req, params={}): + req_id = gen_req_id() + + self._validate(api, req_id) + + apikey = self._apikey + + try: + res = self._call(endpoint_id, api, req_id, params, json.dumps(req).encode("utf-8"), apikey) + resp = dict_to_object(res.json()) + if resp and isinstance(resp, dict): + resp["req_id"] = req_id + return resp + + except MaasError as e: + raise e + except Exception as e: + raise new_client_sdk_request_error(str(e), req_id) + + def _validate(self, api, req_id): + credentials_exist = ( + self.service_info.credentials is not None + and self.service_info.credentials.sk is not None + and self.service_info.credentials.ak is not None + ) + + if not self._apikey and not credentials_exist: + raise new_client_sdk_request_error("no valid credential", req_id) + + if api not in self.api_info: + raise new_client_sdk_request_error("no such api", req_id) + + def _call(self, endpoint_id, api, req_id, params, body, apikey=None, stream=False): + api_info = copy.deepcopy(self.api_info[api]) + api_info.path = api_info.path.format(endpoint_id=endpoint_id) + + r = self.prepare_request(api_info, params) + r.headers["x-tt-logid"] = req_id + r.headers["Content-Type"] = "application/json" + r.body = body + + if apikey is None: + Signer.sign(r, self.service_info.credentials) + elif apikey is not None: + r.headers["Authorization"] = "Bearer " + apikey + + url = r.build() + res = self.session.post( + url, + headers=r.headers, + data=r.body, + timeout=( + self.service_info.connection_timeout, + self.service_info.socket_timeout, + ), + stream=stream, + ) + + if res.status_code != 200: + raw = res.text.encode() + res.close() + try: + resp = json_to_object(str(raw, encoding="utf-8"), req_id=req_id) + except Exception: + raise new_client_sdk_request_error(raw, req_id) + + if resp.error: + raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, req_id) + else: + raise new_client_sdk_request_error(resp, req_id) + + return res + + +class MaasError(Exception): + def __init__(self, code_n, code, message, req_id): + self.code_n = code_n + self.code = code + self.message = message + self.req_id = req_id + + def __str__(self): + return ( + "Detailed exception information is listed below.\n" + + "req_id: {}\n" + + "code_n: {}\n" + + "code: {}\n" + + "message: {}" + ).format(self.req_id, self.code_n, self.code, self.message) + + +def new_client_sdk_request_error(raw, req_id=""): + return MaasError(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id) + + +class BinaryResponseContent: + def __init__(self, response, request_id) -> None: + self.response = response + self.request_id = request_id + + def stream_to_file(self, file: str) -> None: + is_first = True + error_bytes = b"" + with open(file, mode="wb") as f: + for data in self.response: + if len(error_bytes) > 0 or (is_first and '"error":' in str(data)): + error_bytes += data + else: + f.write(data) + + if len(error_bytes) > 0: + resp = json_to_object(str(error_bytes, encoding="utf-8"), req_id=self.request_id) + raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, self.request_id) + + def iter_bytes(self) -> Iterator[bytes]: + yield from self.response diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py index 8bea30324b5ad8..1c776cec7e3096 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py @@ -1,8 +1,11 @@ import logging from collections.abc import Generator +from typing import Optional + +from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -27,61 +30,94 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient -from core.model_runtime.model_providers.volcengine_maas.errors import ( +from core.model_runtime.model_providers.volcengine_maas.client import ArkClientV3 +from core.model_runtime.model_providers.volcengine_maas.legacy.client import MaaSClient +from core.model_runtime.model_providers.volcengine_maas.legacy.errors import ( AuthErrors, BadRequestErrors, ConnectionErrors, + MaasError, RateLimitErrors, ServerUnavailableErrors, ) -from core.model_runtime.model_providers.volcengine_maas.llm.models import ModelConfigs -from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException +from core.model_runtime.model_providers.volcengine_maas.llm.models import ( + get_model_config, + get_v2_req_params, + get_v3_req_params, +) logger = logging.getLogger(__name__) class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + if ArkClientV3.is_legacy(credentials): + return self._generate_v2(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + return self._generate_v3(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate credentials """ - # ping + if ArkClientV3.is_legacy(credentials): + return self._validate_credentials_v2(credentials) + return self._validate_credentials_v3(credentials) + + @staticmethod + def _validate_credentials_v2(credentials: dict) -> None: client = MaaSClient.from_credential(credentials) try: client.chat( { - 'max_new_tokens': 16, - 'temperature': 0.7, - 'top_p': 0.9, - 'top_k': 15, + "max_new_tokens": 16, + "temperature": 0.7, + "top_p": 0.9, + "top_k": 15, }, - [UserPromptMessage(content='ping\nAnswer: ')], + [UserPromptMessage(content="ping\nAnswer: ")], ) - except MaasException as e: + except MaasError as e: raise CredentialsValidateFailedError(e.message) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: - if len(prompt_messages) == 0: - return 0 - return self._num_tokens_from_messages(prompt_messages) + @staticmethod + def _validate_credentials_v3(credentials: dict) -> None: + client = ArkClientV3.from_credentials(credentials) + try: + client.chat( + max_tokens=16, + temperature=0.7, + top_p=0.9, + messages=[UserPromptMessage(content="ping\nAnswer: ")], + ) + except Exception as e: + raise CredentialsValidateFailedError(e) - def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int: - """ - Calculate num tokens. + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: + if ArkClientV3.is_legacy(credentials): + return self._get_num_tokens_v2(prompt_messages) + return self._get_num_tokens_v3(prompt_messages) - :param messages: messages - """ + def _get_num_tokens_v2(self, messages: list[PromptMessage]) -> int: + if len(messages) == 0: + return 0 num_tokens = 0 - messages_dict = [ - MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages] + messages_dict = [MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages] for message in messages_dict: for key, value in message.items(): num_tokens += self._get_num_tokens_by_gpt2(str(key)) @@ -89,204 +125,247 @@ def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int: return num_tokens - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - - client = MaaSClient.from_credential(credentials) + def _get_num_tokens_v3(self, messages: list[PromptMessage]) -> int: + if len(messages) == 0: + return 0 + num_tokens = 0 + messages_dict = [ArkClientV3.convert_prompt_message(m) for m in messages] + for message in messages_dict: + for key, value in message.items(): + num_tokens += self._get_num_tokens_by_gpt2(str(key)) + num_tokens += self._get_num_tokens_by_gpt2(str(value)) - req_params = ModelConfigs.get( - credentials['base_model_name'], {}).get('req_params', {}).copy() - if credentials.get('context_size'): - req_params['max_prompt_tokens'] = credentials.get('context_size') - if credentials.get('max_tokens'): - req_params['max_new_tokens'] = credentials.get('max_tokens') - if model_parameters.get('max_tokens'): - req_params['max_new_tokens'] = model_parameters.get('max_tokens') - if model_parameters.get('temperature'): - req_params['temperature'] = model_parameters.get('temperature') - if model_parameters.get('top_p'): - req_params['top_p'] = model_parameters.get('top_p') - if model_parameters.get('top_k'): - req_params['top_k'] = model_parameters.get('top_k') - if model_parameters.get('presence_penalty'): - req_params['presence_penalty'] = model_parameters.get( - 'presence_penalty') - if model_parameters.get('frequency_penalty'): - req_params['frequency_penalty'] = model_parameters.get( - 'frequency_penalty') - if stop: - req_params['stop'] = stop + return num_tokens + def _generate_v2( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + client = MaaSClient.from_credential(credentials) + req_params = get_v2_req_params(credentials, model_parameters, stop) extra_model_kwargs = {} - if tools: - extra_model_kwargs['tools'] = [ - MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools - ] + extra_model_kwargs["tools"] = [MaaSClient.transform_tool_prompt_to_maas_config(tool) for tool in tools] + resp = MaaSClient.wrap_exception(lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs)) - resp = MaaSClient.wrap_exception( - lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs)) - if not stream: - return self._handle_chat_response(model, credentials, prompt_messages, resp) - return self._handle_stream_chat_response(model, credentials, prompt_messages, resp) + def _handle_stream_chat_response() -> Generator: + for index, r in enumerate(resp): + choices = r["choices"] + if not choices: + continue + choice = choices[0] + message = choice["message"] + usage = None + if r.get("usage"): + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=r["usage"]["prompt_tokens"], + completion_tokens=r["usage"]["completion_tokens"], + ) + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=index, + message=AssistantPromptMessage(content=message["content"] or "", tool_calls=[]), + usage=usage, + finish_reason=choice.get("finish_reason"), + ), + ) - def _handle_stream_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: Generator) -> Generator: - for index, r in enumerate(resp): - choices = r['choices'] + def _handle_chat_response() -> LLMResult: + choices = resp["choices"] if not choices: - continue + raise ValueError("No choices found") + choice = choices[0] - message = choice['message'] - usage = None - if r.get('usage'): - usage = self._calc_usage(model, credentials, r['usage']) - yield LLMResultChunk( + message = choice["message"] + + # parse tool calls + tool_calls = [] + if message["tool_calls"]: + for call in message["tool_calls"]: + tool_call = AssistantPromptMessage.ToolCall( + id=call["function"]["name"], + type=call["type"], + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=call["function"]["name"], arguments=call["function"]["arguments"] + ), + ) + tool_calls.append(tool_call) + + usage = resp["usage"] + return LLMResult( model=model, prompt_messages=prompt_messages, - delta=LLMResultChunkDelta( - index=index, - message=AssistantPromptMessage( - content=message['content'] if message['content'] else '', - tool_calls=[] - ), - usage=usage, - finish_reason=choice.get('finish_reason'), + message=AssistantPromptMessage( + content=message["content"] or "", + tool_calls=tool_calls, + ), + usage=self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=usage["prompt_tokens"], + completion_tokens=usage["completion_tokens"], ), ) - def _handle_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: dict) -> LLMResult: - choices = resp['choices'] - if not choices: - return - choice = choices[0] - message = choice['message'] - - # parse tool calls - tool_calls = [] - if message['tool_calls']: - for call in message['tool_calls']: - tool_call = AssistantPromptMessage.ToolCall( - id=call['function']['name'], - type=call['type'], - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=call['function']['name'], - arguments=call['function']['arguments'] - ) + if not stream: + return _handle_chat_response() + return _handle_stream_chat_response() + + def _generate_v3( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + client = ArkClientV3.from_credentials(credentials) + req_params = get_v3_req_params(credentials, model_parameters, stop) + if tools: + req_params["tools"] = tools + + def _handle_stream_chat_response(chunks: Generator[ChatCompletionChunk]) -> Generator: + for chunk in chunks: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=chunk.choices[0].delta.content if chunk.choices else "", tool_calls=[] + ), + usage=self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=chunk.usage.prompt_tokens, + completion_tokens=chunk.usage.completion_tokens, + ) + if chunk.usage + else None, + finish_reason=chunk.choices[0].finish_reason if chunk.choices else None, + ), ) - tool_calls.append(tool_call) - return LLMResult( - model=model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=message['content'] if message['content'] else '', - tool_calls=tool_calls, - ), - usage=self._calc_usage(model, credentials, resp['usage']), - ) + def _handle_chat_response(resp: ChatCompletion) -> LLMResult: + choice = resp.choices[0] + message = choice.message + # parse tool calls + tool_calls = [] + if message.tool_calls: + for call in message.tool_calls: + tool_call = AssistantPromptMessage.ToolCall( + id=call.id, + type=call.type, + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=call.function.name, arguments=call.function.arguments + ), + ) + tool_calls.append(tool_call) - def _calc_usage(self, model: str, credentials: dict, usage: dict) -> LLMUsage: - return self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=usage['prompt_tokens'], - completion_tokens=usage['completion_tokens'] - ) + usage = resp.usage + return LLMResult( + model=model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage( + content=message.content or "", + tool_calls=tool_calls, + ), + usage=self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens, + ), + ) + + if not stream: + resp = client.chat(prompt_messages, **req_params) + return _handle_chat_response(resp) - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + chunks = client.stream_chat(prompt_messages, **req_params) + return _handle_stream_chat_response(chunks) + + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ - used to define customizable model schema + used to define customizable model schema """ - max_tokens = ModelConfigs.get( - credentials['base_model_name'], {}).get('req_params', {}).get('max_new_tokens') - if credentials.get('max_tokens'): - max_tokens = int(credentials.get('max_tokens')) + model_config = get_model_config(credentials) + rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ) + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='top_k', - type=ParameterType.INT, - min=1, - default=1, - label=I18nObject( - zh_Hans='Top K', - en_US='Top K' - ) + name="top_k", type=ParameterType.INT, min=1, default=1, label=I18nObject(zh_Hans="Top K", en_US="Top K") ), ParameterRule( - name='presence_penalty', + name="presence_penalty", type=ParameterType.FLOAT, - use_template='presence_penalty', - label={ - 'en_US': 'Presence Penalty', - 'zh_Hans': '存在惩罚', - }, + use_template="presence_penalty", + label=I18nObject( + en_US="Presence Penalty", + zh_Hans="存在惩罚", + ), min=-2.0, max=2.0, ), ParameterRule( - name='frequency_penalty', + name="frequency_penalty", type=ParameterType.FLOAT, - use_template='frequency_penalty', - label={ - 'en_US': 'Frequency Penalty', - 'zh_Hans': '频率惩罚', - }, + use_template="frequency_penalty", + label=I18nObject( + en_US="Frequency Penalty", + zh_Hans="频率惩罚", + ), min=-2.0, max=2.0, ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, - max=max_tokens, + max=model_config.properties.max_tokens, default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), ), ] - model_properties = ModelConfigs.get( - credentials['base_model_name'], {}).get('model_properties', {}).copy() - if credentials.get('mode'): - model_properties[ModelPropertyKey.MODE] = credentials.get('mode') - if credentials.get('context_size'): - model_properties[ModelPropertyKey.CONTEXT_SIZE] = int( - credentials.get('context_size', 4096)) - - model_features = ModelConfigs.get( - credentials['base_model_name'], {}).get('features', []) + model_properties = {} + model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size + model_properties[ModelPropertyKey.MODE] = model_config.properties.mode.value entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, model_properties=model_properties, parameter_rules=rules, - features=model_features, + features=model_config.features, ) return entity diff --git a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py index 3e5938f3b494af..d8be14b0247698 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/llm/models.py @@ -1,181 +1,142 @@ +from pydantic import BaseModel + +from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.model_entities import ModelFeature -ModelConfigs = { - 'Doubao-pro-4k': { - 'req_params': { - 'max_prompt_tokens': 4096, - 'max_new_tokens': 4096, - }, - 'model_properties': { - 'context_size': 4096, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Doubao-lite-4k': { - 'req_params': { - 'max_prompt_tokens': 4096, - 'max_new_tokens': 4096, - }, - 'model_properties': { - 'context_size': 4096, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Doubao-pro-32k': { - 'req_params': { - 'max_prompt_tokens': 32768, - 'max_new_tokens': 32768, - }, - 'model_properties': { - 'context_size': 32768, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Doubao-lite-32k': { - 'req_params': { - 'max_prompt_tokens': 32768, - 'max_new_tokens': 32768, - }, - 'model_properties': { - 'context_size': 32768, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Doubao-pro-128k': { - 'req_params': { - 'max_prompt_tokens': 131072, - 'max_new_tokens': 131072, - }, - 'model_properties': { - 'context_size': 131072, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Doubao-lite-128k': { - 'req_params': { - 'max_prompt_tokens': 131072, - 'max_new_tokens': 131072, - }, - 'model_properties': { - 'context_size': 131072, - 'mode': 'chat', - }, - 'features': [ - ModelFeature.TOOL_CALL - ], - }, - 'Skylark2-pro-4k': { - 'req_params': { - 'max_prompt_tokens': 4096, - 'max_new_tokens': 4000, - }, - 'model_properties': { - 'context_size': 4096, - 'mode': 'chat', - }, - 'features': [], - }, - 'Llama3-8B': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 8192, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - }, - 'Llama3-70B': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 8192, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - }, - 'Moonshot-v1-8k': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 4096, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - }, - 'Moonshot-v1-32k': { - 'req_params': { - 'max_prompt_tokens': 32768, - 'max_new_tokens': 16384, - }, - 'model_properties': { - 'context_size': 32768, - 'mode': 'chat', - }, - 'features': [], - }, - 'Moonshot-v1-128k': { - 'req_params': { - 'max_prompt_tokens': 131072, - 'max_new_tokens': 65536, - }, - 'model_properties': { - 'context_size': 131072, - 'mode': 'chat', - }, - 'features': [], - }, - 'GLM3-130B': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 4096, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - }, - 'GLM3-130B-Fin': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 4096, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - }, - 'Mistral-7B': { - 'req_params': { - 'max_prompt_tokens': 8192, - 'max_new_tokens': 2048, - }, - 'model_properties': { - 'context_size': 8192, - 'mode': 'chat', - }, - 'features': [], - } + +class ModelProperties(BaseModel): + context_size: int + max_tokens: int + mode: LLMMode + + +class ModelConfig(BaseModel): + properties: ModelProperties + features: list[ModelFeature] + + +configs: dict[str, ModelConfig] = { + "Doubao-pro-4k": ModelConfig( + properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL], + ), + "Doubao-lite-4k": ModelConfig( + properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL], + ), + "Doubao-pro-32k": ModelConfig( + properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL], + ), + "Doubao-lite-32k": ModelConfig( + properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL], + ), + "Doubao-pro-128k": ModelConfig( + properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL], + ), + "Doubao-lite-128k": ModelConfig( + properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT), features=[] + ), + "Skylark2-pro-4k": ModelConfig( + properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT), features=[] + ), + "Llama3-8B": ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), features=[] + ), + "Llama3-70B": ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=8192, mode=LLMMode.CHAT), features=[] + ), + "Moonshot-v1-8k": ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL], + ), + "Moonshot-v1-32k": ModelConfig( + properties=ModelProperties(context_size=32768, max_tokens=16384, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL], + ), + "Moonshot-v1-128k": ModelConfig( + properties=ModelProperties(context_size=131072, max_tokens=65536, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL], + ), + "GLM3-130B": ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL], + ), + "GLM3-130B-Fin": ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=4096, mode=LLMMode.CHAT), + features=[ModelFeature.TOOL_CALL], + ), + "Mistral-7B": ModelConfig( + properties=ModelProperties(context_size=8192, max_tokens=2048, mode=LLMMode.CHAT), features=[] + ), } + + +def get_model_config(credentials: dict) -> ModelConfig: + base_model = credentials.get("base_model_name", "") + model_configs = configs.get(base_model) + if not model_configs: + return ModelConfig( + properties=ModelProperties( + context_size=int(credentials.get("context_size", 0)), + max_tokens=int(credentials.get("max_tokens", 0)), + mode=LLMMode.value_of(credentials.get("mode", "chat")), + ), + features=[], + ) + return model_configs + + +def get_v2_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None): + req_params = {} + # predefined properties + model_configs = get_model_config(credentials) + if model_configs: + req_params["max_prompt_tokens"] = model_configs.properties.context_size + req_params["max_new_tokens"] = model_configs.properties.max_tokens + + # model parameters + if model_parameters.get("max_tokens"): + req_params["max_new_tokens"] = model_parameters.get("max_tokens") + if model_parameters.get("temperature"): + req_params["temperature"] = model_parameters.get("temperature") + if model_parameters.get("top_p"): + req_params["top_p"] = model_parameters.get("top_p") + if model_parameters.get("top_k"): + req_params["top_k"] = model_parameters.get("top_k") + if model_parameters.get("presence_penalty"): + req_params["presence_penalty"] = model_parameters.get("presence_penalty") + if model_parameters.get("frequency_penalty"): + req_params["frequency_penalty"] = model_parameters.get("frequency_penalty") + + if stop: + req_params["stop"] = stop + + return req_params + + +def get_v3_req_params(credentials: dict, model_parameters: dict, stop: list[str] | None = None): + req_params = {} + # predefined properties + model_configs = get_model_config(credentials) + if model_configs: + req_params["max_tokens"] = model_configs.properties.max_tokens + + # model parameters + if model_parameters.get("max_tokens"): + req_params["max_tokens"] = model_parameters.get("max_tokens") + if model_parameters.get("temperature"): + req_params["temperature"] = model_parameters.get("temperature") + if model_parameters.get("top_p"): + req_params["top_p"] = model_parameters.get("top_p") + if model_parameters.get("presence_penalty"): + req_params["presence_penalty"] = model_parameters.get("presence_penalty") + if model_parameters.get("frequency_penalty"): + req_params["frequency_penalty"] = model_parameters.get("frequency_penalty") + + if stop: + req_params["stop"] = stop + + return req_params diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py index 569f89e9754454..ce4f0c3ab1960e 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py @@ -1,9 +1,28 @@ +from pydantic import BaseModel + + +class ModelProperties(BaseModel): + context_size: int + max_chunks: int + + +class ModelConfig(BaseModel): + properties: ModelProperties + + ModelConfigs = { - 'Doubao-embedding': { - 'req_params': {}, - 'model_properties': { - 'context_size': 4096, - 'max_chunks': 1, - } - }, + "Doubao-embedding": ModelConfig(properties=ModelProperties(context_size=4096, max_chunks=32)), } + + +def get_model_config(credentials: dict) -> ModelConfig: + base_model = credentials.get("base_model_name", "") + model_configs = ModelConfigs.get(base_model) + if not model_configs: + return ModelConfig( + properties=ModelProperties( + context_size=int(credentials.get("context_size", 0)), + max_chunks=int(credentials.get("max_chunks", 0)), + ) + ) + return model_configs diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py index 10b01c0d0d6401..4d13e4708b6004 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py @@ -2,6 +2,7 @@ from decimal import Decimal from typing import Optional +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, @@ -22,16 +23,17 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient -from core.model_runtime.model_providers.volcengine_maas.errors import ( +from core.model_runtime.model_providers.volcengine_maas.client import ArkClientV3 +from core.model_runtime.model_providers.volcengine_maas.legacy.client import MaaSClient +from core.model_runtime.model_providers.volcengine_maas.legacy.errors import ( AuthErrors, BadRequestErrors, ConnectionErrors, + MaasError, RateLimitErrors, ServerUnavailableErrors, ) -from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import ModelConfigs -from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException +from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import get_model_config class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): @@ -39,9 +41,14 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel): Model class for VolcengineMaaS text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -49,19 +56,35 @@ def _invoke(self, model: str, credentials: dict, :param credentials: model credentials :param texts: texts to embed :param user: unique user id + :param input_type: input type :return: embeddings result """ + if ArkClientV3.is_legacy(credentials): + return self._generate_v2(model, credentials, texts, user) + + return self._generate_v3(model, credentials, texts, user) + + def _generate_v2( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: client = MaaSClient.from_credential(credentials) resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts)) - usage = self._calc_response_usage( - model=model, credentials=credentials, tokens=resp['usage']['total_tokens']) + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=resp["usage"]["total_tokens"]) - result = TextEmbeddingResult( - model=model, - embeddings=[v['embedding'] for v in resp['data']], - usage=usage - ) + result = TextEmbeddingResult(model=model, embeddings=[v["embedding"] for v in resp["data"]], usage=usage) + + return result + + def _generate_v3( + self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None + ) -> TextEmbeddingResult: + client = ArkClientV3.from_credentials(credentials) + resp = client.embeddings(texts) + + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=resp.usage.total_tokens) + + result = TextEmbeddingResult(model=model, embeddings=[v.embedding for v in resp.data], usage=usage) return result @@ -88,11 +111,22 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :param credentials: model credentials :return: """ + if ArkClientV3.is_legacy(credentials): + return self._validate_credentials_v2(model, credentials) + return self._validate_credentials_v3(model, credentials) + + def _validate_credentials_v2(self, model: str, credentials: dict) -> None: try: - self._invoke(model=model, credentials=credentials, texts=['ping']) - except MaasException as e: + self._invoke(model=model, credentials=credentials, texts=["ping"]) + except MaasError as e: raise CredentialsValidateFailedError(e.message) + def _validate_credentials_v3(self, model: str, credentials: dict) -> None: + try: + self._invoke(model=model, credentials=credentials, texts=["ping"]) + except Exception as e: + raise CredentialsValidateFailedError(e) + @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ @@ -113,16 +147,13 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ - generate custom model entities from credentials + generate custom model entities from credentials """ - model_properties = ModelConfigs.get( - credentials['base_model_name'], {}).get('model_properties', {}).copy() - if credentials.get('context_size'): - model_properties[ModelPropertyKey.CONTEXT_SIZE] = int( - credentials.get('context_size', 4096)) - if credentials.get('max_chunks'): - model_properties[ModelPropertyKey.MAX_CHUNKS] = int( - credentials.get('max_chunks', 4096)) + model_config = get_model_config(credentials) + model_properties = { + ModelPropertyKey.CONTEXT_SIZE: model_config.properties.context_size, + ModelPropertyKey.MAX_CHUNKS: model_config.properties.max_chunks, + } entity = AIModelEntity( model=model, label=I18nObject(en_US=model), @@ -131,10 +162,10 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode model_properties=model_properties, parameter_rules=[], pricing=PriceConfig( - input=Decimal(credentials.get('input_price', 0)), - unit=Decimal(credentials.get('unit', 0)), - currency=credentials.get('currency', "USD") - ) + input=Decimal(credentials.get("input_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), ) return entity @@ -150,10 +181,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -164,7 +192,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py deleted file mode 100644 index 64f342f16e936b..00000000000000 --- a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .common import ChatRole -from .maas import MaasException, MaasService - -__all__ = ['MaasService', 'ChatRole', 'MaasException'] diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py deleted file mode 100644 index 053432a089ee46..00000000000000 --- a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py +++ /dev/null @@ -1,144 +0,0 @@ -# coding : utf-8 -import datetime - -import pytz - -from .util import Util - - -class MetaData: - def __init__(self): - self.algorithm = '' - self.credential_scope = '' - self.signed_headers = '' - self.date = '' - self.region = '' - self.service = '' - - def set_date(self, date): - self.date = date - - def set_service(self, service): - self.service = service - - def set_region(self, region): - self.region = region - - def set_algorithm(self, algorithm): - self.algorithm = algorithm - - def set_credential_scope(self, credential_scope): - self.credential_scope = credential_scope - - def set_signed_headers(self, signed_headers): - self.signed_headers = signed_headers - - -class SignResult: - def __init__(self): - self.xdate = '' - self.xCredential = '' - self.xAlgorithm = '' - self.xSignedHeaders = '' - self.xSignedQueries = '' - self.xSignature = '' - self.xContextSha256 = '' - self.xSecurityToken = '' - - self.authorization = '' - - def __str__(self): - return '\n'.join(['{}:{}'.format(*item) for item in self.__dict__.items()]) - - -class Credentials: - def __init__(self, ak, sk, service, region, session_token=''): - self.ak = ak - self.sk = sk - self.service = service - self.region = region - self.session_token = session_token - - def set_ak(self, ak): - self.ak = ak - - def set_sk(self, sk): - self.sk = sk - - def set_session_token(self, session_token): - self.session_token = session_token - - -class Signer: - @staticmethod - def sign(request, credentials): - if request.path == '': - request.path = '/' - if request.method != 'GET' and not ('Content-Type' in request.headers): - request.headers['Content-Type'] = 'application/x-www-form-urlencoded; charset=utf-8' - - format_date = Signer.get_current_format_date() - request.headers['X-Date'] = format_date - if credentials.session_token != '': - request.headers['X-Security-Token'] = credentials.session_token - - md = MetaData() - md.set_algorithm('HMAC-SHA256') - md.set_service(credentials.service) - md.set_region(credentials.region) - md.set_date(format_date[:8]) - - hashed_canon_req = Signer.hashed_canonical_request_v4(request, md) - md.set_credential_scope('/'.join([md.date, md.region, md.service, 'request'])) - - signing_str = '\n'.join([md.algorithm, format_date, md.credential_scope, hashed_canon_req]) - signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service) - sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str)) - request.headers['Authorization'] = Signer.build_auth_header_v4(sign, md, credentials) - return - - @staticmethod - def hashed_canonical_request_v4(request, meta): - body_hash = Util.sha256(request.body) - request.headers['X-Content-Sha256'] = body_hash - - signed_headers = {} - for key in request.headers: - if key in ['Content-Type', 'Content-Md5', 'Host'] or key.startswith('X-'): - signed_headers[key.lower()] = request.headers[key] - - if 'host' in signed_headers: - v = signed_headers['host'] - if v.find(':') != -1: - split = v.split(':') - port = split[1] - if str(port) == '80' or str(port) == '443': - signed_headers['host'] = split[0] - - signed_str = '' - for key in sorted(signed_headers.keys()): - signed_str += key + ':' + signed_headers[key] + '\n' - - meta.set_signed_headers(';'.join(sorted(signed_headers.keys()))) - - canonical_request = '\n'.join( - [request.method, Util.norm_uri(request.path), Util.norm_query(request.query), signed_str, - meta.signed_headers, body_hash]) - - return Util.sha256(canonical_request) - - @staticmethod - def get_signing_secret_key_v4(sk, date, region, service): - date = Util.hmac_sha256(bytes(sk, encoding='utf-8'), date) - region = Util.hmac_sha256(date, region) - service = Util.hmac_sha256(region, service) - return Util.hmac_sha256(service, 'request') - - @staticmethod - def build_auth_header_v4(signature, meta, credentials): - credential = credentials.ak + '/' + meta.credential_scope - return meta.algorithm + ' Credential=' + credential + ', SignedHeaders=' + meta.signed_headers + ', Signature=' + signature - - @staticmethod - def get_current_format_date(): - return datetime.datetime.now(tz=pytz.timezone('UTC')).strftime("%Y%m%dT%H%M%SZ") diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py deleted file mode 100644 index 7271ae63fd7309..00000000000000 --- a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py +++ /dev/null @@ -1,207 +0,0 @@ -import json -from collections import OrderedDict -from urllib.parse import urlencode - -import requests - -from .auth import Signer - -VERSION = 'v1.0.137' - - -class Service: - def __init__(self, service_info, api_info): - self.service_info = service_info - self.api_info = api_info - self.session = requests.session() - - def set_ak(self, ak): - self.service_info.credentials.set_ak(ak) - - def set_sk(self, sk): - self.service_info.credentials.set_sk(sk) - - def set_session_token(self, session_token): - self.service_info.credentials.set_session_token(session_token) - - def set_host(self, host): - self.service_info.host = host - - def set_scheme(self, scheme): - self.service_info.scheme = scheme - - def get(self, api, params, doseq=0): - if not (api in self.api_info): - raise Exception("no such api") - api_info = self.api_info[api] - - r = self.prepare_request(api_info, params, doseq) - - Signer.sign(r, self.service_info.credentials) - - url = r.build(doseq) - resp = self.session.get(url, headers=r.headers, - timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout)) - if resp.status_code == 200: - return resp.text - else: - raise Exception(resp.text) - - def post(self, api, params, form): - if not (api in self.api_info): - raise Exception("no such api") - api_info = self.api_info[api] - r = self.prepare_request(api_info, params) - r.headers['Content-Type'] = 'application/x-www-form-urlencoded' - r.form = self.merge(api_info.form, form) - r.body = urlencode(r.form, True) - Signer.sign(r, self.service_info.credentials) - - url = r.build() - - resp = self.session.post(url, headers=r.headers, data=r.form, - timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout)) - if resp.status_code == 200: - return resp.text - else: - raise Exception(resp.text) - - def json(self, api, params, body): - if not (api in self.api_info): - raise Exception("no such api") - api_info = self.api_info[api] - r = self.prepare_request(api_info, params) - r.headers['Content-Type'] = 'application/json' - r.body = body - - Signer.sign(r, self.service_info.credentials) - - url = r.build() - resp = self.session.post(url, headers=r.headers, data=r.body, - timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout)) - if resp.status_code == 200: - return json.dumps(resp.json()) - else: - raise Exception(resp.text.encode("utf-8")) - - def put(self, url, file_path, headers): - with open(file_path, 'rb') as f: - resp = self.session.put(url, headers=headers, data=f) - if resp.status_code == 200: - return True, resp.text.encode("utf-8") - else: - return False, resp.text.encode("utf-8") - - def put_data(self, url, data, headers): - resp = self.session.put(url, headers=headers, data=data) - if resp.status_code == 200: - return True, resp.text.encode("utf-8") - else: - return False, resp.text.encode("utf-8") - - def prepare_request(self, api_info, params, doseq=0): - for key in params: - if type(params[key]) == int or type(params[key]) == float or type(params[key]) == bool: - params[key] = str(params[key]) - elif type(params[key]) == list: - if not doseq: - params[key] = ','.join(params[key]) - - connection_timeout = self.service_info.connection_timeout - socket_timeout = self.service_info.socket_timeout - - r = Request() - r.set_schema(self.service_info.scheme) - r.set_method(api_info.method) - r.set_connection_timeout(connection_timeout) - r.set_socket_timeout(socket_timeout) - - headers = self.merge(api_info.header, self.service_info.header) - headers['Host'] = self.service_info.host - headers['User-Agent'] = 'volc-sdk-python/' + VERSION - r.set_headers(headers) - - query = self.merge(api_info.query, params) - r.set_query(query) - - r.set_host(self.service_info.host) - r.set_path(api_info.path) - - return r - - @staticmethod - def merge(param1, param2): - od = OrderedDict() - for key in param1: - od[key] = param1[key] - - for key in param2: - od[key] = param2[key] - - return od - - -class Request: - def __init__(self): - self.schema = '' - self.method = '' - self.host = '' - self.path = '' - self.headers = OrderedDict() - self.query = OrderedDict() - self.body = '' - self.form = {} - self.connection_timeout = 0 - self.socket_timeout = 0 - - def set_schema(self, schema): - self.schema = schema - - def set_method(self, method): - self.method = method - - def set_host(self, host): - self.host = host - - def set_path(self, path): - self.path = path - - def set_headers(self, headers): - self.headers = headers - - def set_query(self, query): - self.query = query - - def set_body(self, body): - self.body = body - - def set_connection_timeout(self, connection_timeout): - self.connection_timeout = connection_timeout - - def set_socket_timeout(self, socket_timeout): - self.socket_timeout = socket_timeout - - def build(self, doseq=0): - return self.schema + '://' + self.host + self.path + '?' + urlencode(self.query, doseq) - - -class ServiceInfo: - def __init__(self, host, header, credentials, connection_timeout, socket_timeout, scheme='http'): - self.host = host - self.header = header - self.credentials = credentials - self.connection_timeout = connection_timeout - self.socket_timeout = socket_timeout - self.scheme = scheme - - -class ApiInfo: - def __init__(self, method, path, query, form, header): - self.method = method - self.path = path - self.query = query - self.form = form - self.header = header - - def __str__(self): - return 'method: ' + self.method + ', path: ' + self.path diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py deleted file mode 100644 index 7eb5fdfa9122a0..00000000000000 --- a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py +++ /dev/null @@ -1,43 +0,0 @@ -import hashlib -import hmac -from functools import reduce -from urllib.parse import quote - - -class Util: - @staticmethod - def norm_uri(path): - return quote(path).replace('%2F', '/').replace('+', '%20') - - @staticmethod - def norm_query(params): - query = '' - for key in sorted(params.keys()): - if type(params[key]) == list: - for k in params[key]: - query = query + quote(key, safe='-_.~') + '=' + quote(k, safe='-_.~') + '&' - else: - query = query + quote(key, safe='-_.~') + '=' + quote(params[key], safe='-_.~') + '&' - query = query[:-1] - return query.replace('+', '%20') - - @staticmethod - def hmac_sha256(key, content): - return hmac.new(key, bytes(content, encoding='utf-8'), hashlib.sha256).digest() - - @staticmethod - def sha256(content): - if isinstance(content, str) is True: - return hashlib.sha256(content.encode('utf-8')).hexdigest() - else: - return hashlib.sha256(content).hexdigest() - - @staticmethod - def to_hex(content): - lst = [] - for ch in content: - hv = hex(ch).replace('0x', '') - if len(hv) == 1: - hv = '0' + hv - lst.append(hv) - return reduce(lambda x, y: x + y, lst) diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py b/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py deleted file mode 100644 index 3cbe9d9f099e83..00000000000000 --- a/api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py +++ /dev/null @@ -1,213 +0,0 @@ -import copy -import json -from collections.abc import Iterator - -from .base.auth import Credentials, Signer -from .base.service import ApiInfo, Service, ServiceInfo -from .common import SSEDecoder, dict_to_object, gen_req_id, json_to_object - - -class MaasService(Service): - def __init__(self, host, region, connection_timeout=60, socket_timeout=60): - service_info = self.get_service_info( - host, region, connection_timeout, socket_timeout - ) - self._apikey = None - api_info = self.get_api_info() - super().__init__(service_info, api_info) - - def set_apikey(self, apikey): - self._apikey = apikey - - @staticmethod - def get_service_info(host, region, connection_timeout, socket_timeout): - service_info = ServiceInfo( - host, - {"Accept": "application/json"}, - Credentials("", "", "ml_maas", region), - connection_timeout, - socket_timeout, - "https", - ) - return service_info - - @staticmethod - def get_api_info(): - api_info = { - "chat": ApiInfo("POST", "/api/v2/endpoint/{endpoint_id}/chat", {}, {}, {}), - "embeddings": ApiInfo( - "POST", "/api/v2/endpoint/{endpoint_id}/embeddings", {}, {}, {} - ), - } - return api_info - - def chat(self, endpoint_id, req): - req["stream"] = False - return self._request(endpoint_id, "chat", req) - - def stream_chat(self, endpoint_id, req): - req_id = gen_req_id() - self._validate("chat", req_id) - apikey = self._apikey - - try: - req["stream"] = True - res = self._call( - endpoint_id, "chat", req_id, {}, json.dumps(req).encode("utf-8"), apikey, stream=True - ) - - decoder = SSEDecoder(res) - - def iter_fn(): - for data in decoder.next(): - if data == b"[DONE]": - return - - try: - res = json_to_object( - str(data, encoding="utf-8"), req_id=req_id) - except Exception: - raise - - if res.error is not None and res.error.code_n != 0: - raise MaasException( - res.error.code_n, - res.error.code, - res.error.message, - req_id, - ) - yield res - - return iter_fn() - except MaasException: - raise - except Exception as e: - raise new_client_sdk_request_error(str(e)) - - def embeddings(self, endpoint_id, req): - return self._request(endpoint_id, "embeddings", req) - - def _request(self, endpoint_id, api, req, params={}): - req_id = gen_req_id() - - self._validate(api, req_id) - - apikey = self._apikey - - try: - res = self._call(endpoint_id, api, req_id, params, - json.dumps(req).encode("utf-8"), apikey) - resp = dict_to_object(res.json()) - if resp and isinstance(resp, dict): - resp["req_id"] = req_id - return resp - - except MaasException as e: - raise e - except Exception as e: - raise new_client_sdk_request_error(str(e), req_id) - - def _validate(self, api, req_id): - credentials_exist = ( - self.service_info.credentials is not None and - self.service_info.credentials.sk is not None and - self.service_info.credentials.ak is not None - ) - - if not self._apikey and not credentials_exist: - raise new_client_sdk_request_error("no valid credential", req_id) - - if not (api in self.api_info): - raise new_client_sdk_request_error("no such api", req_id) - - def _call(self, endpoint_id, api, req_id, params, body, apikey=None, stream=False): - api_info = copy.deepcopy(self.api_info[api]) - api_info.path = api_info.path.format(endpoint_id=endpoint_id) - - r = self.prepare_request(api_info, params) - r.headers["x-tt-logid"] = req_id - r.headers["Content-Type"] = "application/json" - r.body = body - - if apikey is None: - Signer.sign(r, self.service_info.credentials) - elif apikey is not None: - r.headers["Authorization"] = "Bearer " + apikey - - url = r.build() - res = self.session.post( - url, - headers=r.headers, - data=r.body, - timeout=( - self.service_info.connection_timeout, - self.service_info.socket_timeout, - ), - stream=stream, - ) - - if res.status_code != 200: - raw = res.text.encode() - res.close() - try: - resp = json_to_object( - str(raw, encoding="utf-8"), req_id=req_id) - except Exception: - raise new_client_sdk_request_error(raw, req_id) - - if resp.error: - raise MaasException( - resp.error.code_n, resp.error.code, resp.error.message, req_id - ) - else: - raise new_client_sdk_request_error(resp, req_id) - - return res - - -class MaasException(Exception): - def __init__(self, code_n, code, message, req_id): - self.code_n = code_n - self.code = code - self.message = message - self.req_id = req_id - - def __str__(self): - return ("Detailed exception information is listed below.\n" + - "req_id: {}\n" + - "code_n: {}\n" + - "code: {}\n" + - "message: {}").format(self.req_id, self.code_n, self.code, self.message) - - -def new_client_sdk_request_error(raw, req_id=""): - return MaasException(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id) - - -class BinaryResponseContent: - def __init__(self, response, request_id) -> None: - self.response = response - self.request_id = request_id - - def stream_to_file( - self, - file: str - ) -> None: - is_first = True - error_bytes = b'' - with open(file, mode="wb") as f: - for data in self.response: - if len(error_bytes) > 0 or (is_first and "\"error\":" in str(data)): - error_bytes += data - else: - f.write(data) - - if len(error_bytes) > 0: - resp = json_to_object( - str(error_bytes, encoding="utf-8"), req_id=self.request_id) - raise MaasException( - resp.error.code_n, resp.error.code, resp.error.message, self.request_id - ) - - def iter_bytes(self) -> Iterator[bytes]: - yield from self.response diff --git a/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml index a00c1b79944b5a..13e00da76fb149 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml +++ b/api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml @@ -30,8 +30,28 @@ model_credential_schema: en_US: Enter your Model Name zh_Hans: 输入模型名称 credential_form_schemas: + - variable: auth_method + required: true + label: + en_US: Authentication Method + zh_Hans: 鉴权方式 + type: select + default: aksk + options: + - label: + en_US: API Key + value: api_key + - label: + en_US: Access Key / Secret Access Key + value: aksk + placeholder: + en_US: Enter your Authentication Method + zh_Hans: 选择鉴权方式 - variable: volc_access_key_id required: true + show_on: + - variable: auth_method + value: aksk label: en_US: Access Key zh_Hans: Access Key @@ -41,6 +61,9 @@ model_credential_schema: zh_Hans: 输入您的 Access Key - variable: volc_secret_access_key required: true + show_on: + - variable: auth_method + value: aksk label: en_US: Secret Access Key zh_Hans: Secret Access Key @@ -48,6 +71,17 @@ model_credential_schema: placeholder: en_US: Enter your Secret Access Key zh_Hans: 输入您的 Secret Access Key + - variable: volc_api_key + required: true + show_on: + - variable: auth_method + value: api_key + label: + en_US: API Key + type: secret-input + placeholder: + en_US: Enter your API Key + zh_Hans: 输入您的 API Key - variable: volc_region required: true label: @@ -64,7 +98,7 @@ model_credential_schema: en_US: API Endpoint Host zh_Hans: API Endpoint Host type: text-input - default: maas-api.ml-platform-cn-beijing.volces.com + default: https://ark.cn-beijing.volces.com/api/v3 placeholder: en_US: Enter your API Endpoint Host zh_Hans: 输入 API Endpoint Host diff --git a/api/core/model_runtime/model_providers/voyage/__init__.py b/api/core/model_runtime/model_providers/voyage/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/voyage/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/voyage/_assets/icon_l_en.svg new file mode 100644 index 00000000000000..a961f5e4355eea --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/_assets/icon_l_en.svg @@ -0,0 +1,21 @@ + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/voyage/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/voyage/_assets/icon_s_en.svg new file mode 100644 index 00000000000000..2c4e121dd71f0b --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/_assets/icon_s_en.svg @@ -0,0 +1,8 @@ + + + voyage + + + + + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/voyage/rerank/__init__.py b/api/core/model_runtime/model_providers/voyage/rerank/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/voyage/rerank/_position.yaml b/api/core/model_runtime/model_providers/voyage/rerank/_position.yaml new file mode 100644 index 00000000000000..32afefbe047806 --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/rerank/_position.yaml @@ -0,0 +1,4 @@ +- rerank-2 +- rerank-lite-2 +- rerank-1 +- rerank-lite-1 diff --git a/api/core/model_runtime/model_providers/voyage/rerank/rerank-1.yaml b/api/core/model_runtime/model_providers/voyage/rerank/rerank-1.yaml new file mode 100644 index 00000000000000..9c894eda85203b --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/rerank/rerank-1.yaml @@ -0,0 +1,4 @@ +model: rerank-1 +model_type: rerank +model_properties: + context_size: 8000 diff --git a/api/core/model_runtime/model_providers/voyage/rerank/rerank-2.yaml b/api/core/model_runtime/model_providers/voyage/rerank/rerank-2.yaml new file mode 100644 index 00000000000000..b760d3c41894a9 --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/rerank/rerank-2.yaml @@ -0,0 +1,4 @@ +model: rerank-2 +model_type: rerank +model_properties: + context_size: 16000 diff --git a/api/core/model_runtime/model_providers/voyage/rerank/rerank-lite-1.yaml b/api/core/model_runtime/model_providers/voyage/rerank/rerank-lite-1.yaml new file mode 100644 index 00000000000000..b052d6f00028cb --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/rerank/rerank-lite-1.yaml @@ -0,0 +1,4 @@ +model: rerank-lite-1 +model_type: rerank +model_properties: + context_size: 4000 diff --git a/api/core/model_runtime/model_providers/voyage/rerank/rerank-lite-2.yaml b/api/core/model_runtime/model_providers/voyage/rerank/rerank-lite-2.yaml new file mode 100644 index 00000000000000..b6fa37a25bf18a --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/rerank/rerank-lite-2.yaml @@ -0,0 +1,4 @@ +model: rerank-lite-2 +model_type: rerank +model_properties: + context_size: 8000 diff --git a/api/core/model_runtime/model_providers/voyage/rerank/rerank.py b/api/core/model_runtime/model_providers/voyage/rerank/rerank.py new file mode 100644 index 00000000000000..33fdebbb45ef36 --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/rerank/rerank.py @@ -0,0 +1,123 @@ +from typing import Optional + +import httpx + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class VoyageRerankModel(RerankModel): + """ + Model class for Voyage rerank model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n documents to return + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + base_url = credentials.get("base_url", "https://api.voyageai.com/v1") + base_url = base_url.removesuffix("/") + + try: + response = httpx.post( + base_url + "/rerank", + json={"model": model, "query": query, "documents": docs, "top_k": top_n, "return_documents": True}, + headers={"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"}, + ) + response.raise_for_status() + results = response.json() + + rerank_documents = [] + for result in results["data"]: + rerank_document = RerankDocument( + index=result["index"], + text=result["document"], + score=result["relevance_score"], + ) + if score_threshold is None or result["relevance_score"] >= score_threshold: + rerank_documents.append(rerank_document) + + return RerankResult(model=model, docs=rerank_documents) + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return { + InvokeConnectionError: [httpx.ConnectError], + InvokeServerUnavailableError: [httpx.RemoteProtocolError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.RERANK, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "8000"))}, + ) + + return entity diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/__init__.py b/api/core/model_runtime/model_providers/voyage/text_embedding/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/_position.yaml b/api/core/model_runtime/model_providers/voyage/text_embedding/_position.yaml new file mode 100644 index 00000000000000..595663990f5945 --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/text_embedding/_position.yaml @@ -0,0 +1,6 @@ +- voyage-3 +- voyage-3-lite +- voyage-finance-2 +- voyage-multilingual-2 +- voyage-law-2 +- voyage-code-2 diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py new file mode 100644 index 00000000000000..e69c9fccba97ed --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py @@ -0,0 +1,172 @@ +import time +from json import JSONDecodeError, dumps +from typing import Optional + +import requests + +from core.entities.embedding_type import EmbeddingInputType +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel + + +class VoyageTextEmbeddingModel(TextEmbeddingModel): + """ + Model class for Voyage text embedding model. + """ + + api_base: str = "https://api.voyageai.com/v1" + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :param input_type: input type + :return: embeddings result + """ + api_key = credentials["api_key"] + if not api_key: + raise CredentialsValidateFailedError("api_key is required") + + base_url = credentials.get("base_url", self.api_base) + base_url = base_url.removesuffix("/") + + url = base_url + "/embeddings" + headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} + voyage_input_type = "null" + if input_type is not None: + voyage_input_type = input_type.value + data = {"model": model, "input": texts, "input_type": voyage_input_type} + + try: + response = requests.post(url, headers=headers, data=dumps(data)) + except Exception as e: + raise InvokeConnectionError(str(e)) + + if response.status_code != 200: + try: + resp = response.json() + msg = resp["detail"] + if response.status_code == 401: + raise InvokeAuthorizationError(msg) + elif response.status_code == 429: + raise InvokeRateLimitError(msg) + elif response.status_code == 500: + raise InvokeServerUnavailableError(msg) + else: + raise InvokeBadRequestError(msg) + except JSONDecodeError as e: + raise InvokeServerUnavailableError( + f"Failed to convert response to json: {e} with text: {response.text}" + ) + + try: + resp = response.json() + embeddings = resp["data"] + usage = resp["usage"] + except Exception as e: + raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}") + + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) + + result = TextEmbeddingResult( + model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage + ) + + return result + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + return sum(self._get_num_tokens_by_gpt2(text) for text in texts) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke(model=model, credentials=credentials, texts=["ping"]) + except Exception as e: + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return { + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError, InvokeBadRequestError], + } + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at, + ) + + return usage + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.TEXT_EMBEDDING, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, + ) + + return entity diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3-lite.yaml b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3-lite.yaml new file mode 100644 index 00000000000000..a06bb7639feacd --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3-lite.yaml @@ -0,0 +1,8 @@ +model: voyage-3-lite +model_type: text-embedding +model_properties: + context_size: 32000 +pricing: + input: '0.00002' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3.yaml b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3.yaml new file mode 100644 index 00000000000000..117afbcaf3c808 --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-3.yaml @@ -0,0 +1,8 @@ +model: voyage-3 +model_type: text-embedding +model_properties: + context_size: 32000 +pricing: + input: '0.00006' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-code-2.yaml b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-code-2.yaml new file mode 100644 index 00000000000000..693669c82c7ba3 --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-code-2.yaml @@ -0,0 +1,8 @@ +model: voyage-code-2 +model_type: text-embedding +model_properties: + context_size: 16000 +pricing: + input: '0.00012' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-finance-2.yaml b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-finance-2.yaml new file mode 100644 index 00000000000000..555e11002ade3b --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-finance-2.yaml @@ -0,0 +1,8 @@ +model: voyage-finance-2 +model_type: text-embedding +model_properties: + context_size: 32000 +pricing: + input: '0.00012' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-law-2.yaml b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-law-2.yaml new file mode 100644 index 00000000000000..032693286f6e88 --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-law-2.yaml @@ -0,0 +1,8 @@ +model: voyage-law-2 +model_type: text-embedding +model_properties: + context_size: 16000 +pricing: + input: '0.00012' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-multilingual-2.yaml b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-multilingual-2.yaml new file mode 100644 index 00000000000000..9ecf4d50098338 --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/text_embedding/voyage-multilingual-2.yaml @@ -0,0 +1,8 @@ +model: voyage-multilingual-2 +model_type: text-embedding +model_properties: + context_size: 32000 +pricing: + input: '0.00012' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/voyage/voyage.py b/api/core/model_runtime/model_providers/voyage/voyage.py new file mode 100644 index 00000000000000..3e33b45e110d56 --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/voyage.py @@ -0,0 +1,28 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class VoyageProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.TEXT_EMBEDDING) + + # Use `voyage-3` model for validate, + # no matter what model you pass in, text completion model or chat model + model_instance.validate_credentials(model="voyage-3", credentials=credentials) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/voyage/voyage.yaml b/api/core/model_runtime/model_providers/voyage/voyage.yaml new file mode 100644 index 00000000000000..c64707800eebe0 --- /dev/null +++ b/api/core/model_runtime/model_providers/voyage/voyage.yaml @@ -0,0 +1,31 @@ +provider: voyage +label: + en_US: Voyage +description: + en_US: Embedding and Rerank Model Supported +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.svg +background: "#EFFDFD" +help: + title: + en_US: Get your API key from Voyage AI + zh_Hans: 从 Voyage 获取 API Key + url: + en_US: https://dash.voyageai.com/ +supported_model_types: + - text-embedding + - rerank +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key diff --git a/api/core/model_runtime/model_providers/wenxin/_common.py b/api/core/model_runtime/model_providers/wenxin/_common.py new file mode 100644 index 00000000000000..c77a499982e98b --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/_common.py @@ -0,0 +1,196 @@ +from datetime import datetime, timedelta +from threading import Lock + +from requests import post + +from core.model_runtime.model_providers.wenxin.wenxin_errors import ( + BadRequestError, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) + +baidu_access_tokens: dict[str, "BaiduAccessToken"] = {} +baidu_access_tokens_lock = Lock() + + +class BaiduAccessToken: + api_key: str + access_token: str + expires: datetime + + def __init__(self, api_key: str) -> None: + self.api_key = api_key + self.access_token = "" + self.expires = datetime.now() + timedelta(days=3) + + @staticmethod + def _get_access_token(api_key: str, secret_key: str) -> str: + """ + request access token from Baidu + """ + try: + response = post( + url=f"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}", + headers={"Content-Type": "application/json", "Accept": "application/json"}, + ) + except Exception as e: + raise InvalidAuthenticationError(f"Failed to get access token from Baidu: {e}") + + resp = response.json() + if "error" in resp: + if resp["error"] == "invalid_client": + raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}') + elif resp["error"] == "unknown_error": + raise InternalServerError(f'Internal server error: {resp["error_description"]}') + elif resp["error"] == "invalid_request": + raise BadRequestError(f'Bad request: {resp["error_description"]}') + elif resp["error"] == "rate_limit_exceeded": + raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}') + else: + raise Exception(f'Unknown error: {resp["error_description"]}') + + return resp["access_token"] + + @staticmethod + def get_access_token(api_key: str, secret_key: str) -> "BaiduAccessToken": + """ + LLM from Baidu requires access token to invoke the API. + however, we have api_key and secret_key, and access token is valid for 30 days. + so we can cache the access token for 3 days. (avoid memory leak) + + it may be more efficient to use a ticker to refresh access token, but it will cause + more complexity, so we just refresh access tokens when get_access_token is called. + """ + + # loop up cache, remove expired access token + baidu_access_tokens_lock.acquire() + now = datetime.now() + for key in list(baidu_access_tokens.keys()): + token = baidu_access_tokens[key] + if token.expires < now: + baidu_access_tokens.pop(key) + + if api_key not in baidu_access_tokens: + # if access token not in cache, request it + token = BaiduAccessToken(api_key) + baidu_access_tokens[api_key] = token + try: + # try to get access token + token_str = BaiduAccessToken._get_access_token(api_key, secret_key) + finally: + # release it to enhance performance + # btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock + baidu_access_tokens_lock.release() + token.access_token = token_str + token.expires = now + timedelta(days=3) + return token + else: + # if access token in cache, return it + token = baidu_access_tokens[api_key] + baidu_access_tokens_lock.release() + return token + + +class _CommonWenxin: + api_bases = { + "ernie-bot": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205", + "ernie-bot-4": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro", + "ernie-bot-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions", + "ernie-bot-turbo": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant", + "ernie-3.5-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions", + "ernie-3.5-8k-0205": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205", + "ernie-3.5-8k-1222": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222", + "ernie-3.5-4k-0205": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205", + "ernie-3.5-128k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k", + "ernie-4.0-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro", + "ernie-4.0-8k-latest": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro", + "ernie-speed-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed", + "ernie-speed-128k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k", + "ernie-speed-appbuilder": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas", + "ernie-lite-8k-0922": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant", + "ernie-lite-8k-0308": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k", + "ernie-character-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k", + "ernie-character-8k-0321": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k", + "ernie-4.0-turbo-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k", + "ernie-4.0-turbo-8k-preview": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview", + "ernie-4.0-turbo-128k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-128k", + "yi_34b_chat": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat", + "embedding-v1": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1", + "bge-large-en": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en", + "bge-large-zh": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh", + "tao-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k", + "bce-reranker-base_v1": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/reranker/bce_reranker_base", + } + + function_calling_supports = [ + "ernie-bot", + "ernie-bot-8k", + "ernie-3.5-8k", + "ernie-3.5-8k-0205", + "ernie-3.5-8k-1222", + "ernie-3.5-4k-0205", + "ernie-3.5-128k", + "ernie-4.0-8k", + "ernie-4.0-turbo-8k", + "ernie-4.0-turbo-8k-preview", + "yi_34b_chat", + ] + + api_key: str = "" + secret_key: str = "" + + def __init__(self, api_key: str, secret_key: str): + self.api_key = api_key + self.secret_key = secret_key + + @staticmethod + def _to_credential_kwargs(credentials: dict) -> dict: + credentials_kwargs = {"api_key": credentials["api_key"], "secret_key": credentials["secret_key"]} + return credentials_kwargs + + def _handle_error(self, code: int, msg: str): + error_map = { + 1: InternalServerError, + 2: InternalServerError, + 3: BadRequestError, + 4: RateLimitReachedError, + 6: InvalidAuthenticationError, + 13: InvalidAPIKeyError, + 14: InvalidAPIKeyError, + 15: InvalidAPIKeyError, + 17: RateLimitReachedError, + 18: RateLimitReachedError, + 19: RateLimitReachedError, + 100: InvalidAPIKeyError, + 111: InvalidAPIKeyError, + 200: InternalServerError, + 336000: InternalServerError, + 336001: BadRequestError, + 336002: BadRequestError, + 336003: BadRequestError, + 336004: InvalidAuthenticationError, + 336005: InvalidAPIKeyError, + 336006: BadRequestError, + 336007: BadRequestError, + 336008: BadRequestError, + 336100: InternalServerError, + 336101: BadRequestError, + 336102: BadRequestError, + 336103: BadRequestError, + 336104: BadRequestError, + 336105: BadRequestError, + 336200: InternalServerError, + 336303: BadRequestError, + 337006: BadRequestError, + } + + if code in error_map: + raise error_map[code](msg) + else: + raise InternalServerError(f"Unknown error: {msg}") + + def _get_access_token(self) -> str: + token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key) + return token.access_token diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-4.0-turbo-128k.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-4.0-turbo-128k.yaml new file mode 100644 index 00000000000000..f8d56406d91687 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-4.0-turbo-128k.yaml @@ -0,0 +1,40 @@ +model: ernie-4.0-turbo-128k +label: + en_US: Ernie-4.0-turbo-128K +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + min: 0.1 + max: 1.0 + default: 0.8 + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 2 + max: 4096 + - name: presence_penalty + use_template: presence_penalty + default: 1.0 + min: 1.0 + max: 2.0 + - name: frequency_penalty + use_template: frequency_penalty + - name: response_format + use_template: response_format + - name: disable_search + label: + zh_Hans: 禁用搜索 + en_US: Disable Search + type: boolean + help: + zh_Hans: 禁用模型自行进行外部搜索。 + en_US: Disable the model to perform external search. + required: false diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py index e345663d36efcb..07b970f8104c8f 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py @@ -1,176 +1,53 @@ from collections.abc import Generator -from datetime import datetime, timedelta from enum import Enum from json import dumps, loads -from threading import Lock from typing import Any, Union from requests import Response, post from core.model_runtime.entities.message_entities import PromptMessageTool -from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( +from core.model_runtime.model_providers.wenxin._common import _CommonWenxin +from core.model_runtime.model_providers.wenxin.wenxin_errors import ( BadRequestError, InternalServerError, - InvalidAPIKeyError, - InvalidAuthenticationError, - RateLimitReachedError, ) -# map api_key to access_token -baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {} -baidu_access_tokens_lock = Lock() - -class BaiduAccessToken: - api_key: str - access_token: str - expires: datetime - - def __init__(self, api_key: str) -> None: - self.api_key = api_key - self.access_token = '' - self.expires = datetime.now() + timedelta(days=3) - - def _get_access_token(api_key: str, secret_key: str) -> str: - """ - request access token from Baidu - """ - try: - response = post( - url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}', - headers={ - 'Content-Type': 'application/json', - 'Accept': 'application/json' - }, - ) - except Exception as e: - raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}') - - resp = response.json() - if 'error' in resp: - if resp['error'] == 'invalid_client': - raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}') - elif resp['error'] == 'unknown_error': - raise InternalServerError(f'Internal server error: {resp["error_description"]}') - elif resp['error'] == 'invalid_request': - raise BadRequestError(f'Bad request: {resp["error_description"]}') - elif resp['error'] == 'rate_limit_exceeded': - raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}') - else: - raise Exception(f'Unknown error: {resp["error_description"]}') - - return resp['access_token'] - - @staticmethod - def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken': - """ - LLM from Baidu requires access token to invoke the API. - however, we have api_key and secret_key, and access token is valid for 30 days. - so we can cache the access token for 3 days. (avoid memory leak) - - it may be more efficient to use a ticker to refresh access token, but it will cause - more complexity, so we just refresh access tokens when get_access_token is called. - """ - - # loop up cache, remove expired access token - baidu_access_tokens_lock.acquire() - now = datetime.now() - for key in list(baidu_access_tokens.keys()): - token = baidu_access_tokens[key] - if token.expires < now: - baidu_access_tokens.pop(key) - - if api_key not in baidu_access_tokens: - # if access token not in cache, request it - token = BaiduAccessToken(api_key) - baidu_access_tokens[api_key] = token - # release it to enhance performance - # btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock - baidu_access_tokens_lock.release() - # try to get access token - token_str = BaiduAccessToken._get_access_token(api_key, secret_key) - token.access_token = token_str - token.expires = now + timedelta(days=3) - return token - else: - # if access token in cache, return it - token = baidu_access_tokens[api_key] - baidu_access_tokens_lock.release() - return token - class ErnieMessage: class Role(Enum): - USER = 'user' - ASSISTANT = 'assistant' - FUNCTION = 'function' - SYSTEM = 'system' + USER = "user" + ASSISTANT = "assistant" + FUNCTION = "function" + SYSTEM = "system" role: str = Role.USER.value content: str usage: dict[str, int] = None - stop_reason: str = '' + stop_reason: str = "" def to_dict(self) -> dict[str, Any]: return { - 'role': self.role, - 'content': self.content, + "role": self.role, + "content": self.content, } - def __init__(self, content: str, role: str = 'user') -> None: + def __init__(self, content: str, role: str = "user") -> None: self.content = content self.role = role -class ErnieBotModel: - api_bases = { - 'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', - 'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', - 'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', - 'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', - 'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205', - 'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222', - 'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205', - 'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k', - 'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', - 'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed', - 'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k', - 'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas', - 'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant', - 'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k', - 'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', - 'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k', - 'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k', - 'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview', - 'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat', - } - - function_calling_supports = [ - 'ernie-bot', - 'ernie-bot-8k', - 'ernie-3.5-8k', - 'ernie-3.5-8k-0205', - 'ernie-3.5-8k-1222', - 'ernie-3.5-4k-0205', - 'ernie-3.5-128k', - 'ernie-4.0-8k', - 'ernie-4.0-turbo-8k', - 'ernie-4.0-turbo-8k-preview', - 'yi_34b_chat' - ] - - api_key: str = '' - secret_key: str = '' - - def __init__(self, api_key: str, secret_key: str): - self.api_key = api_key - self.secret_key = secret_key - - def generate(self, model: str, stream: bool, messages: list[ErnieMessage], - parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \ - stop: list[str], user: str) \ - -> Union[Generator[ErnieMessage, None, None], ErnieMessage]: +class ErnieBotModel(_CommonWenxin): + def generate( + self, + model: str, + stream: bool, + messages: list[ErnieMessage], + parameters: dict[str, Any], + timeout: int, + tools: list[PromptMessageTool], + stop: list[str], + user: str, + ) -> Union[Generator[ErnieMessage, None, None], ErnieMessage]: # check parameters self._check_parameters(model, parameters, tools, stop) @@ -178,79 +55,36 @@ def generate(self, model: str, stream: bool, messages: list[ErnieMessage], access_token = self._get_access_token() # generate request body - url = f'{self.api_bases[model]}?access_token={access_token}' + url = f"{self.api_bases[model]}?access_token={access_token}" # clone messages messages_cloned = self._copy_messages(messages=messages) # build body - body = self._build_request_body(model, messages=messages_cloned, stream=stream, - parameters=parameters, tools=tools, stop=stop, user=user) + body = self._build_request_body( + model, messages=messages_cloned, stream=stream, parameters=parameters, tools=tools, stop=stop, user=user + ) headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } resp = post(url=url, data=dumps(body), headers=headers, stream=stream) if resp.status_code != 200: - raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}') + raise InternalServerError(f"Failed to invoke ernie bot: {resp.text}") if stream: return self._handle_chat_stream_generate_response(resp) return self._handle_chat_generate_response(resp) - def _handle_error(self, code: int, msg: str): - error_map = { - 1: InternalServerError, - 2: InternalServerError, - 3: BadRequestError, - 4: RateLimitReachedError, - 6: InvalidAuthenticationError, - 13: InvalidAPIKeyError, - 14: InvalidAPIKeyError, - 15: InvalidAPIKeyError, - 17: RateLimitReachedError, - 18: RateLimitReachedError, - 19: RateLimitReachedError, - 100: InvalidAPIKeyError, - 111: InvalidAPIKeyError, - 200: InternalServerError, - 336000: InternalServerError, - 336001: BadRequestError, - 336002: BadRequestError, - 336003: BadRequestError, - 336004: InvalidAuthenticationError, - 336005: InvalidAPIKeyError, - 336006: BadRequestError, - 336007: BadRequestError, - 336008: BadRequestError, - 336100: InternalServerError, - 336101: BadRequestError, - 336102: BadRequestError, - 336103: BadRequestError, - 336104: BadRequestError, - 336105: BadRequestError, - 336200: InternalServerError, - 336303: BadRequestError, - 337006: BadRequestError - } - - if code in error_map: - raise error_map[code](msg) - else: - raise InternalServerError(f'Unknown error: {msg}') - - def _get_access_token(self) -> str: - token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key) - return token.access_token - def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]: return [ErnieMessage(message.content, message.role) for message in messages] - def _check_parameters(self, model: str, parameters: dict[str, Any], - tools: list[PromptMessageTool], stop: list[str]) -> None: + def _check_parameters( + self, model: str, parameters: dict[str, Any], tools: list[PromptMessageTool], stop: list[str] + ) -> None: if model not in self.api_bases: - raise BadRequestError(f'Invalid model: {model}') + raise BadRequestError(f"Invalid model: {model}") # if model not in self.function_calling_supports and tools is not None and len(tools) > 0: # raise BadRequestError(f'Model {model} does not support calling function.') @@ -259,86 +93,106 @@ def _check_parameters(self, model: str, parameters: dict[str, Any], # so, we just disable function calling for now. if tools is not None and len(tools) > 0: - raise BadRequestError('function calling is not supported yet.') + raise BadRequestError("function calling is not supported yet.") if stop is not None: if len(stop) > 4: - raise BadRequestError('stop list should not exceed 4 items.') + raise BadRequestError("stop list should not exceed 4 items.") for s in stop: if len(s) > 20: - raise BadRequestError('stop item should not exceed 20 characters.') - - def _build_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, parameters: dict[str, Any], - tools: list[PromptMessageTool], stop: list[str], user: str) -> dict[str, Any]: + raise BadRequestError("stop item should not exceed 20 characters.") + + def _build_request_body( + self, + model: str, + messages: list[ErnieMessage], + stream: bool, + parameters: dict[str, Any], + tools: list[PromptMessageTool], + stop: list[str], + user: str, + ) -> dict[str, Any]: # if model in self.function_calling_supports: # return self._build_function_calling_request_body(model, messages, parameters, tools, stop, user) return self._build_chat_request_body(model, messages, stream, parameters, stop, user) - def _build_function_calling_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, - parameters: dict[str, Any], tools: list[PromptMessageTool], - stop: list[str], user: str) \ - -> dict[str, Any]: + def _build_function_calling_request_body( + self, + model: str, + messages: list[ErnieMessage], + stream: bool, + parameters: dict[str, Any], + tools: list[PromptMessageTool], + stop: list[str], + user: str, + ) -> dict[str, Any]: if len(messages) % 2 == 0: - raise BadRequestError('The number of messages should be odd.') - if messages[0].role == 'function': - raise BadRequestError('The first message should be user message.') + raise BadRequestError("The number of messages should be odd.") + if messages[0].role == "function": + raise BadRequestError("The first message should be user message.") """ TODO: implement function calling """ - def _build_chat_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, - parameters: dict[str, Any], stop: list[str], user: str) \ - -> dict[str, Any]: + def _build_chat_request_body( + self, + model: str, + messages: list[ErnieMessage], + stream: bool, + parameters: dict[str, Any], + stop: list[str], + user: str, + ) -> dict[str, Any]: if len(messages) == 0: - raise BadRequestError('The number of messages should not be zero.') + raise BadRequestError("The number of messages should not be zero.") # check if the first element is system, shift it - system_message = '' - if messages[0].role == 'system': + system_message = "" + if messages[0].role == "system": message = messages.pop(0) system_message = message.content if len(messages) % 2 == 0: - raise BadRequestError('The number of messages should be odd.') - if messages[0].role != 'user': - raise BadRequestError('The first message should be user message.') + raise BadRequestError("The number of messages should be odd.") + if messages[0].role != "user": + raise BadRequestError("The first message should be user message.") body = { - 'messages': [message.to_dict() for message in messages], - 'stream': stream, - 'stop': stop, - 'user_id': user, - **parameters + "messages": [message.to_dict() for message in messages], + "stream": stream, + "stop": stop, + "user_id": user, + **parameters, } - if 'max_tokens' in parameters and type(parameters['max_tokens']) == int: - body['max_output_tokens'] = parameters['max_tokens'] + if "max_tokens" in parameters and type(parameters["max_tokens"]) == int: + body["max_output_tokens"] = parameters["max_tokens"] - if 'presence_penalty' in parameters and type(parameters['presence_penalty']) == float: - body['penalty_score'] = parameters['presence_penalty'] + if "presence_penalty" in parameters and type(parameters["presence_penalty"]) == float: + body["penalty_score"] = parameters["presence_penalty"] if system_message: - body['system'] = system_message + body["system"] = system_message return body def _handle_chat_generate_response(self, response: Response) -> ErnieMessage: data = response.json() - if 'error_code' in data: - code = data['error_code'] - msg = data['error_msg'] + if "error_code" in data: + code = data["error_code"] + msg = data["error_msg"] # raise error self._handle_error(code, msg) - result = data['result'] - usage = data['usage'] + result = data["result"] + usage = data["usage"] - message = ErnieMessage(content=result, role='assistant') + message = ErnieMessage(content=result, role="assistant") message.usage = { - 'prompt_tokens': usage['prompt_tokens'], - 'completion_tokens': usage['completion_tokens'], - 'total_tokens': usage['total_tokens'] + "prompt_tokens": usage["prompt_tokens"], + "completion_tokens": usage["completion_tokens"], + "total_tokens": usage["total_tokens"], } return message @@ -347,19 +201,19 @@ def _handle_chat_stream_generate_response(self, response: Response) -> Generator for line in response.iter_lines(): if len(line) == 0: continue - line = line.decode('utf-8') - if line[0] == '{': + line = line.decode("utf-8") + if line[0] == "{": try: data = loads(line) - if 'error_code' in data: - code = data['error_code'] - msg = data['error_msg'] + if "error_code" in data: + code = data["error_code"] + msg = data["error_msg"] # raise error self._handle_error(code, msg) except Exception as e: - raise InternalServerError(f'Failed to parse response: {e}') + raise InternalServerError(f"Failed to parse response: {e}") - if line.startswith('data:'): + if line.startswith("data:"): line = line[5:].strip() else: continue @@ -369,23 +223,23 @@ def _handle_chat_stream_generate_response(self, response: Response) -> Generator try: data = loads(line) except Exception as e: - raise InternalServerError(f'Failed to parse response: {e}') + raise InternalServerError(f"Failed to parse response: {e}") - result = data['result'] - is_end = data['is_end'] + result = data["result"] + is_end = data["is_end"] if is_end: - usage = data['usage'] - finish_reason = data.get('finish_reason', None) - message = ErnieMessage(content=result, role='assistant') + usage = data["usage"] + finish_reason = data.get("finish_reason", None) + message = ErnieMessage(content=result, role="assistant") message.usage = { - 'prompt_tokens': usage['prompt_tokens'], - 'completion_tokens': usage['completion_tokens'], - 'total_tokens': usage['total_tokens'] + "prompt_tokens": usage["prompt_tokens"], + "completion_tokens": usage["completion_tokens"], + "total_tokens": usage["total_tokens"], } message.stop_reason = finish_reason yield message else: - message = ErnieMessage(content=result, role='assistant') + message = ErnieMessage(content=result, role="assistant") yield message diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot_errors.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot_errors.py deleted file mode 100644 index 67d76b4a291c06..00000000000000 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot_errors.py +++ /dev/null @@ -1,17 +0,0 @@ -class InvalidAuthenticationError(Exception): - pass - -class InvalidAPIKeyError(Exception): - pass - -class RateLimitReachedError(Exception): - pass - -class InsufficientAccountBalance(Exception): - pass - -class InternalServerError(Exception): - pass - -class BadRequestError(Exception): - pass \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/wenxin/llm/llm.py b/api/core/model_runtime/model_providers/wenxin/llm/llm.py index d39d63deeeae7d..952cbb33f4b02c 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/llm.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/llm.py @@ -11,24 +11,13 @@ UserPromptMessage, ) from core.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.wenxin.llm.ernie_bot import BaiduAccessToken, ErnieBotModel, ErnieMessage -from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( - BadRequestError, - InsufficientAccountBalance, - InternalServerError, - InvalidAPIKeyError, - InvalidAuthenticationError, - RateLimitReachedError, -) +from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken +from core.model_runtime.model_providers.wenxin.llm.ernie_bot import ErnieBotModel, ErnieMessage +from core.model_runtime.model_providers.wenxin.wenxin_errors import invoke_error_mapping ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure @@ -39,44 +28,84 @@ You should also complete the text started with ``` but not tell ``` directly. -""" +""" # noqa: E501 + class ErnieBotLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: - return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, - model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + return self._generate( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ) - def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, - callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + def _code_block_mode_wrapper( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + callbacks: Optional[list[Callback]] = None, + ) -> Union[LLMResult, Generator]: """ Code block mode wrapper for invoking large language model """ - if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']: - response_format = model_parameters['response_format'] + if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}: + response_format = model_parameters["response_format"] stop = stop or [] - self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, response_format) - model_parameters.pop('response_format') + self._transform_json_prompts( + model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, response_format + ) + model_parameters.pop("response_format") if stream: return self._code_block_mode_stream_processor( model=model, prompt_messages=prompt_messages, - input_generator=self._invoke(model=model, credentials=credentials, prompt_messages=prompt_messages, - model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) + input_generator=self._invoke( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + ), ) - + return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def _transform_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ - -> None: + def _transform_json_prompts( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + response_format: str = "JSON", + ) -> None: """ Transform json prompts to model prompts """ @@ -85,34 +114,44 @@ def _transform_json_prompts(self, model: str, credentials: dict, if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): # override the system message prompt_messages[0] = SystemPromptMessage( - content=ERNIE_BOT_BLOCK_MODE_PROMPT - .replace("{{instructions}}", prompt_messages[0].content) - .replace("{{block}}", response_format) + content=ERNIE_BOT_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace( + "{{block}}", response_format + ) ) else: # insert the system message - prompt_messages.insert(0, SystemPromptMessage( - content=ERNIE_BOT_BLOCK_MODE_PROMPT - .replace("{{instructions}}", f"Please output a valid {response_format} object.") - .replace("{{block}}", response_format) - )) + prompt_messages.insert( + 0, + SystemPromptMessage( + content=ERNIE_BOT_BLOCK_MODE_PROMPT.replace( + "{{instructions}}", f"Please output a valid {response_format} object." + ).replace("{{block}}", response_format) + ), + ) if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): # add ```JSON\n to the last message prompt_messages[-1].content += "\n```JSON\n{\n" else: # append a user message - prompt_messages.append(UserPromptMessage( - content="```JSON\n{\n" - )) + prompt_messages.append(UserPromptMessage(content="```JSON\n{\n")) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: # tools is not supported yet return self._num_tokens_from_messages(prompt_messages) - def _num_tokens_from_messages(self, messages: list[PromptMessage],) -> int: + def _num_tokens_from_messages( + self, + messages: list[PromptMessage], + ) -> int: """Calculate num tokens for baichuan model""" + def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -124,10 +163,10 @@ def tokens(text: str): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -137,36 +176,53 @@ def tokens(text: str): return num_tokens def validate_credentials(self, model: str, credentials: dict) -> None: - api_key = credentials['api_key'] - secret_key = credentials['secret_key'] + api_key = credentials["api_key"] + secret_key = credentials["secret_key"] try: - BaiduAccessToken._get_access_token(api_key, secret_key) + BaiduAccessToken.get_access_token(api_key, secret_key) except Exception as e: - raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: instance = ErnieBotModel( - api_key=credentials['api_key'], - secret_key=credentials['secret_key'], + api_key=credentials["api_key"], + secret_key=credentials["secret_key"], ) - user = user if user else 'ErnieBotDefault' + user = user or "ErnieBotDefault" # convert prompt messages to baichuan messages messages = [ ErnieMessage( - content=message.content if isinstance(message.content, str) else ''.join([ - content.data for content in message.content - ]), - role=message.role.value - ) for message in prompt_messages + content=message.content + if isinstance(message.content, str) + else "".join([content.data for content in message.content]), + role=message.role.value, + ) + for message in prompt_messages ] # invoke model - response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters, timeout=60, tools=tools, stop=stop, user=user) + response = instance.generate( + model=model, + stream=stream, + messages=messages, + parameters=model_parameters, + timeout=60, + tools=tools, + stop=stop, + user=user, + ) if stream: return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response) @@ -191,43 +247,49 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: message_dict = {"role": "system", "content": message.content} else: raise ValueError(f"Unknown message type {type(message)}") - + return message_dict - def _handle_chat_generate_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: ErnieMessage) -> LLMResult: + def _handle_chat_generate_response( + self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: ErnieMessage + ) -> LLMResult: # convert baichuan message to llm result - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=response.usage['prompt_tokens'], completion_tokens=response.usage['completion_tokens']) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=response.usage["prompt_tokens"], + completion_tokens=response.usage["completion_tokens"], + ) return LLMResult( model=model, prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=response.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=response.content, tool_calls=[]), usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, - prompt_messages: list[PromptMessage], - credentials: dict, - response: Generator[ErnieMessage, None, None]) -> Generator: + def _handle_chat_generate_stream_response( + self, + model: str, + prompt_messages: list[PromptMessage], + credentials: dict, + response: Generator[ErnieMessage, None, None], + ) -> Generator: for message in response: if message.usage: - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=message.usage['prompt_tokens'], completion_tokens=message.usage['completion_tokens']) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=message.usage["prompt_tokens"], + completion_tokens=message.usage["completion_tokens"], + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), + message=AssistantPromptMessage(content=message.content, tool_calls=[]), usage=usage, - finish_reason=message.stop_reason if message.stop_reason else None, + finish_reason=message.stop_reason or None, ), ) else: @@ -236,11 +298,8 @@ def _handle_chat_generate_stream_response(self, model: str, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, - message=AssistantPromptMessage( - content=message.content, - tool_calls=[] - ), - finish_reason=message.stop_reason if message.stop_reason else None, + message=AssistantPromptMessage(content=message.content, tool_calls=[]), + finish_reason=message.stop_reason or None, ), ) @@ -254,22 +313,4 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ - return { - InvokeConnectionError: [ - ], - InvokeServerUnavailableError: [ - InternalServerError - ], - InvokeRateLimitError: [ - RateLimitReachedError - ], - InvokeAuthorizationError: [ - InvalidAuthenticationError, - InsufficientAccountBalance, - InvalidAPIKeyError, - ], - InvokeBadRequestError: [ - BadRequestError, - KeyError - ] - } + return invoke_error_mapping() diff --git a/api/core/model_runtime/model_providers/wenxin/rerank/__init__.py b/api/core/model_runtime/model_providers/wenxin/rerank/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/wenxin/rerank/bce-reranker-base_v1.yaml b/api/core/model_runtime/model_providers/wenxin/rerank/bce-reranker-base_v1.yaml new file mode 100644 index 00000000000000..ef4b07d7678702 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/rerank/bce-reranker-base_v1.yaml @@ -0,0 +1,8 @@ +model: bce-reranker-base_v1 +model_type: rerank +model_properties: + context_size: 4096 +pricing: + input: '0.0005' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/wenxin/rerank/rerank.py b/api/core/model_runtime/model_providers/wenxin/rerank/rerank.py new file mode 100644 index 00000000000000..9e6a7dd99e10a5 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/rerank/rerank.py @@ -0,0 +1,122 @@ +from typing import Optional + +import httpx + +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import InvokeError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel +from core.model_runtime.model_providers.wenxin._common import _CommonWenxin +from core.model_runtime.model_providers.wenxin.wenxin_errors import ( + InternalServerError, + invoke_error_mapping, +) + + +class WenxinRerank(_CommonWenxin): + def rerank(self, model: str, query: str, docs: list[str], top_n: Optional[int] = None): + access_token = self._get_access_token() + url = f"{self.api_bases[model]}?access_token={access_token}" + + try: + response = httpx.post( + url, + json={"model": model, "query": query, "documents": docs, "top_n": top_n}, + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + raise InternalServerError(str(e)) + + +class WenxinRerankModel(RerankModel): + """ + Model class for wenxin rerank model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n documents to return + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + api_key = credentials["api_key"] + secret_key = credentials["secret_key"] + + wenxin_rerank: WenxinRerank = WenxinRerank(api_key, secret_key) + + try: + results = wenxin_rerank.rerank(model, query, docs, top_n) + + rerank_documents = [] + for result in results["results"]: + index = result["index"] + if "document" in result: + text = result["document"] + else: + # llama.cpp rerank maynot return original documents + text = docs[index] + + rerank_document = RerankDocument( + index=index, + text=text, + score=result["relevance_score"], + ) + + if score_threshold is None or result["relevance_score"] >= score_threshold: + rerank_documents.append(rerank_document) + + return RerankResult(model=model, docs=rerank_documents) + except httpx.HTTPStatusError as e: + raise InternalServerError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return invoke_error_mapping() diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/__init__.py b/api/core/model_runtime/model_providers/wenxin/text_embedding/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-en.yaml b/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-en.yaml new file mode 100644 index 00000000000000..74fadb7f9de60f --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-en.yaml @@ -0,0 +1,9 @@ +model: bge-large-en +model_type: text-embedding +model_properties: + context_size: 512 + max_chunks: 16 +pricing: + input: '0.0005' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-zh.yaml b/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-zh.yaml new file mode 100644 index 00000000000000..d4af27ec389a66 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-zh.yaml @@ -0,0 +1,9 @@ +model: bge-large-zh +model_type: text-embedding +model_properties: + context_size: 512 + max_chunks: 16 +pricing: + input: '0.0005' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/embedding-v1.yaml b/api/core/model_runtime/model_providers/wenxin/text_embedding/embedding-v1.yaml new file mode 100644 index 00000000000000..eda48d965533e5 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/embedding-v1.yaml @@ -0,0 +1,9 @@ +model: embedding-v1 +model_type: text-embedding +model_properties: + context_size: 384 + max_chunks: 16 +pricing: + input: '0.0005' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/tao-8k.yaml b/api/core/model_runtime/model_providers/wenxin/text_embedding/tao-8k.yaml new file mode 100644 index 00000000000000..e28f253eb6b861 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/tao-8k.yaml @@ -0,0 +1,9 @@ +model: tao-8k +model_type: text-embedding +model_properties: + context_size: 8192 + max_chunks: 1 +pricing: + input: '0.0005' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py new file mode 100644 index 00000000000000..19135deb27380d --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py @@ -0,0 +1,187 @@ +import time +from abc import abstractmethod +from collections.abc import Mapping +from json import dumps +from typing import Any, Optional + +import numpy as np +from requests import Response, post + +from core.entities.embedding_type import EmbeddingInputType +from core.model_runtime.entities.model_entities import PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import InvokeError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken, _CommonWenxin +from core.model_runtime.model_providers.wenxin.wenxin_errors import ( + BadRequestError, + InternalServerError, + invoke_error_mapping, +) + + +class TextEmbedding: + @abstractmethod + def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + raise NotImplementedError + + +class WenxinTextEmbedding(_CommonWenxin, TextEmbedding): + def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int): + access_token = self._get_access_token() + url = f"{self.api_bases[model]}?access_token={access_token}" + body = self._build_embed_request_body(model, texts, user) + headers = { + "Content-Type": "application/json", + } + + resp = post(url, data=dumps(body), headers=headers) + if resp.status_code != 200: + raise InternalServerError(f"Failed to invoke ernie bot: {resp.text}") + return self._handle_embed_response(model, resp) + + def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> dict[str, Any]: + if len(texts) == 0: + raise BadRequestError("The number of texts should not be zero.") + body = { + "input": texts, + "user_id": user, + } + return body + + def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int): + data = response.json() + if "error_code" in data: + code = data["error_code"] + msg = data["error_msg"] + # raise error + self._handle_error(code, msg) + + embeddings = [v["embedding"] for v in data["data"]] + _usage = data["usage"] + tokens = _usage["prompt_tokens"] + total_tokens = _usage["total_tokens"] + + return embeddings, tokens, total_tokens + + +class WenxinTextEmbeddingModel(TextEmbeddingModel): + def _create_text_embedding(self, api_key: str, secret_key: str) -> TextEmbedding: + return WenxinTextEmbedding(api_key, secret_key) + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :param input_type: input type + :return: embeddings result + """ + + api_key = credentials["api_key"] + secret_key = credentials["secret_key"] + embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key) + user = user or "ErnieBotDefault" + + context_size = self._get_context_size(model, credentials) + max_chunks = self._get_max_chunks(model, credentials) + inputs = [] + indices = [] + used_tokens = 0 + used_total_tokens = 0 + + for i, text in enumerate(texts): + # Here token count is only an approximation based on the GPT2 tokenizer + num_tokens = self._get_num_tokens_by_gpt2(text) + + if num_tokens >= context_size: + cutoff = int(np.floor(len(text) * (context_size / num_tokens))) + # if num tokens is larger than context length, only use the start + inputs.append(text[0:cutoff]) + else: + inputs.append(text) + indices += [i] + + batched_embeddings = [] + _iter = range(0, len(inputs), max_chunks) + for i in _iter: + embeddings_batch, _used_tokens, _total_used_tokens = embedding.embed_documents( + model, inputs[i : i + max_chunks], user + ) + used_tokens += _used_tokens + used_total_tokens += _total_used_tokens + batched_embeddings += embeddings_batch + + usage = self._calc_response_usage(model, credentials, used_tokens, used_total_tokens) + return TextEmbeddingResult( + model=model, + embeddings=batched_embeddings, + usage=usage, + ) + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + if len(texts) == 0: + return 0 + total_num_tokens = 0 + for text in texts: + total_num_tokens += self._get_num_tokens_by_gpt2(text) + + return total_num_tokens + + def validate_credentials(self, model: str, credentials: Mapping) -> None: + api_key = credentials["api_key"] + secret_key = credentials["secret_key"] + try: + BaiduAccessToken.get_access_token(api_key, secret_key) + except Exception as e: + raise CredentialsValidateFailedError(f"Credentials validation failed: {e}") + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return invoke_error_mapping() + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int, total_tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=total_tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at, + ) + + return usage diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin.py b/api/core/model_runtime/model_providers/wenxin/wenxin.py index 04845d06bcf1bc..895af20bc8541d 100644 --- a/api/core/model_runtime/model_providers/wenxin/wenxin.py +++ b/api/core/model_runtime/model_providers/wenxin/wenxin.py @@ -6,6 +6,7 @@ logger = logging.getLogger(__name__) + class WenxinProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ @@ -19,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: model_instance = self.get_model_instance(ModelType.LLM) # Use `ernie-bot` model for validate, - model_instance.validate_credentials( - model='ernie-bot', - credentials=credentials - ) + model_instance.validate_credentials(model="ernie-bot", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin.yaml b/api/core/model_runtime/model_providers/wenxin/wenxin.yaml index b3a1f608249bbf..d8acfd8120a954 100644 --- a/api/core/model_runtime/model_providers/wenxin/wenxin.yaml +++ b/api/core/model_runtime/model_providers/wenxin/wenxin.yaml @@ -17,6 +17,8 @@ help: en_US: https://cloud.baidu.com/wenxin.html supported_model_types: - llm + - text-embedding + - rerank configurate_methods: - predefined-model provider_credential_schema: diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py b/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py new file mode 100644 index 00000000000000..bd074e047717ac --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/wenxin_errors.py @@ -0,0 +1,54 @@ +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) + + +def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [], + InvokeServerUnavailableError: [InternalServerError], + InvokeRateLimitError: [RateLimitReachedError], + InvokeAuthorizationError: [ + InvalidAuthenticationError, + InsufficientAccountBalanceError, + InvalidAPIKeyError, + ], + InvokeBadRequestError: [BadRequestError, KeyError], + } + + +class InvalidAuthenticationError(Exception): + pass + + +class InvalidAPIKeyError(Exception): + pass + + +class RateLimitReachedError(Exception): + pass + + +class InsufficientAccountBalanceError(Exception): + pass + + +class InternalServerError(Exception): + pass + + +class BadRequestError(Exception): + pass diff --git a/api/core/model_runtime/model_providers/x/__init__.py b/api/core/model_runtime/model_providers/x/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg b/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg new file mode 100644 index 00000000000000..f8b745cb13defc --- /dev/null +++ b/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/x/llm/__init__.py b/api/core/model_runtime/model_providers/x/llm/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml b/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml new file mode 100644 index 00000000000000..7c305735b99e33 --- /dev/null +++ b/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml @@ -0,0 +1,63 @@ +model: grok-beta +label: + en_US: Grok beta +model_type: llm +features: + - multi-tool-call +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 2.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: 0 + max: 2.0 + precision: 1 + required: false + help: + en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim." + zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/x/llm/llm.py b/api/core/model_runtime/model_providers/x/llm/llm.py new file mode 100644 index 00000000000000..3f5325a857dc92 --- /dev/null +++ b/api/core/model_runtime/model_providers/x/llm/llm.py @@ -0,0 +1,37 @@ +from collections.abc import Generator +from typing import Optional, Union + +from yarl import URL + +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult +from core.model_runtime.entities.message_entities import ( + PromptMessage, + PromptMessageTool, +) +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel + + +class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + self._add_custom_parameters(credentials) + return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + super().validate_credentials(model, credentials) + + @staticmethod + def _add_custom_parameters(credentials) -> None: + credentials["endpoint_url"] = str(URL(credentials["endpoint_url"])) or "https://api.x.ai/v1" + credentials["mode"] = LLMMode.CHAT.value + credentials["function_calling_type"] = "tool_call" diff --git a/api/core/model_runtime/model_providers/x/x.py b/api/core/model_runtime/model_providers/x/x.py new file mode 100644 index 00000000000000..e3f2b8eeba3ead --- /dev/null +++ b/api/core/model_runtime/model_providers/x/x.py @@ -0,0 +1,25 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class XAIProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.LLM) + model_instance.validate_credentials(model="grok-beta", credentials=credentials) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/x/x.yaml b/api/core/model_runtime/model_providers/x/x.yaml new file mode 100644 index 00000000000000..90d1cbfe7e6983 --- /dev/null +++ b/api/core/model_runtime/model_providers/x/x.yaml @@ -0,0 +1,38 @@ +provider: x +label: + en_US: xAI +description: + en_US: xAI is a company working on building artificial intelligence to accelerate human scientific discovery. We are guided by our mission to advance our collective understanding of the universe. +icon_small: + en_US: x-ai-logo.svg +icon_large: + en_US: x-ai-logo.svg +help: + title: + en_US: Get your token from xAI + zh_Hans: 从 xAI 获取 token + url: + en_US: https://x.ai/api +supported_model_types: + - llm +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: endpoint_url + label: + en_US: API Base + type: text-input + required: false + default: https://api.x.ai/v1 + placeholder: + zh_Hans: 在此输入您的 API Base + en_US: Enter your API Base diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 988bb0ce4432df..b82f0430c5f782 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Iterator -from typing import cast +from typing import Optional, cast from openai import ( APIConnectionError, @@ -19,7 +19,6 @@ from openai.types.completion import Completion from xinference_client.client.restful.restful_client import ( Client, - RESTfulChatglmCppChatModelHandle, RESTfulChatModelHandle, RESTfulGenerateModelHandle, ) @@ -60,91 +59,115 @@ from core.model_runtime.model_providers.xinference.xinference_helper import ( XinferenceHelper, XinferenceModelExtraParameter, + validate_model_uid, ) from core.model_runtime.utils import helper class XinferenceAILargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - invoke LLM + invoke LLM - see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` + see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` """ - if 'temperature' in model_parameters: - if model_parameters['temperature'] < 0.01: - model_parameters['temperature'] = 0.01 - elif model_parameters['temperature'] > 1.0: - model_parameters['temperature'] = 0.99 + if "temperature" in model_parameters: + if model_parameters["temperature"] < 0.01: + model_parameters["temperature"] = 0.01 + elif model_parameters["temperature"] > 1.0: + model_parameters["temperature"] = 0.99 return self._generate( - model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, - tools=tools, stop=stop, stream=stream, user=user, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'] - ) + server_url=credentials["server_url"], + model_uid=credentials["model_uid"], + api_key=credentials.get("api_key"), + ), ) def validate_credentials(self, model: str, credentials: dict) -> None: """ - validate credentials - - credentials should be like: - { - 'model_type': 'text-generation', - 'server_url': 'server url', - 'model_uid': 'model uid', - } + validate credentials + + credentials should be like: + { + 'model_type': 'text-generation', + 'server_url': 'server url', + 'model_uid': 'model uid', + } """ try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + if not validate_model_uid(credentials): raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") extra_param = XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'] + server_url=credentials["server_url"], + model_uid=credentials["model_uid"], + api_key=credentials.get("api_key"), ) - if 'completion_type' not in credentials: - if 'chat' in extra_param.model_ability: - credentials['completion_type'] = 'chat' - elif 'generate' in extra_param.model_ability: - credentials['completion_type'] = 'completion' + if "completion_type" not in credentials: + if "chat" in extra_param.model_ability: + credentials["completion_type"] = "chat" + elif "generate" in extra_param.model_ability: + credentials["completion_type"] = "completion" else: raise ValueError( - f'xinference model ability {extra_param.model_ability} is not supported, check if you have the right model type') + f"xinference model ability {extra_param.model_ability} is not supported," + f" check if you have the right model type" + ) if extra_param.support_function_call: - credentials['support_function_call'] = True + credentials["support_function_call"] = True if extra_param.support_vision: - credentials['support_vision'] = True + credentials["support_vision"] = True if extra_param.context_length: - credentials['context_length'] = extra_param.context_length + credentials["context_length"] = extra_param.context_length except RuntimeError as e: - raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') + raise CredentialsValidateFailedError(f"Xinference credentials validate failed: {e}") except KeyError as e: - raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') + raise CredentialsValidateFailedError(f"Xinference credentials validate failed: {e}") except Exception as e: raise e - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool] | None = None, + ) -> int: """ - get number of tokens + get number of tokens - cause XinferenceAI LLM is a customized model, we could net detect which tokenizer to use - so we just take the GPT2 tokenizer as default + cause XinferenceAI LLM is a customized model, we could net detect which tokenizer to use + so we just take the GPT2 tokenizer as default """ return self._num_tokens_from_messages(prompt_messages, tools) - def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], - is_completion_model: bool = False) -> int: + def _num_tokens_from_messages( + self, messages: list[PromptMessage], tools: list[PromptMessageTool], is_completion_model: bool = False + ) -> int: def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -160,10 +183,10 @@ def tokens(text: str): num_tokens += tokens_per_message for key, value in message.items(): if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -215,30 +238,30 @@ def tokens(text: str): num_tokens = 0 for tool in tools: # calculate num tokens for function object - num_tokens += tokens('name') + num_tokens += tokens("name") num_tokens += tokens(tool.name) - num_tokens += tokens('description') + num_tokens += tokens("description") num_tokens += tokens(tool.description) parameters = tool.parameters - num_tokens += tokens('parameters') - num_tokens += tokens('type') + num_tokens += tokens("parameters") + num_tokens += tokens("type") num_tokens += tokens(parameters.get("type")) - if 'properties' in parameters: - num_tokens += tokens('properties') - for key, value in parameters.get('properties').items(): + if "properties" in parameters: + num_tokens += tokens("properties") + for key, value in parameters.get("properties").items(): num_tokens += tokens(key) for field_key, field_value in value.items(): num_tokens += tokens(field_key) - if field_key == 'enum': + if field_key == "enum": for enum_field in field_value: num_tokens += 3 num_tokens += tokens(enum_field) else: num_tokens += tokens(field_key) num_tokens += tokens(str(field_value)) - if 'required' in parameters: - num_tokens += tokens('required') - for required_field in parameters['required']: + if "required" in parameters: + num_tokens += tokens("required") + for required_field in parameters["required"]: num_tokens += 3 num_tokens += tokens(required_field) @@ -246,18 +269,14 @@ def tokens(text: str): def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str: """ - convert prompt message to text + convert prompt message to text """ - text = '' + text = "" for item in message: - if isinstance(item, UserPromptMessage): - text += item.content - elif isinstance(item, SystemPromptMessage): - text += item.content - elif isinstance(item, AssistantPromptMessage): + if isinstance(item, UserPromptMessage | SystemPromptMessage | AssistantPromptMessage): text += item.content else: - raise NotImplementedError(f'PromptMessage type {type(item)} is not supported') + raise NotImplementedError(f"PromptMessage type {type(item)} is not supported") return text def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: @@ -273,19 +292,13 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: for message_content in message.content: if message_content.type == PromptMessageContentType.TEXT: message_content = cast(PromptMessageContent, message_content) - sub_message_dict = { - "type": "text", - "text": message_content.data - } + sub_message_dict = {"type": "text", "text": message_content.data} sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) sub_message_dict = { "type": "image_url", - "image_url": { - "url": message_content.data, - "detail": message_content.detail.value - } + "image_url": {"url": message_content.data, "detail": message_content.detail.value}, } sub_messages.append(sub_message_dict) message_dict = {"role": "user", "content": sub_messages} @@ -295,7 +308,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: if message.tool_calls and len(message.tool_calls) > 0: message_dict["function_call"] = { "name": message.tool_calls[0].function.name, - "arguments": message.tool_calls[0].function.arguments + "arguments": message.tool_calls[0].function.arguments, } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) @@ -308,152 +321,147 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: return message_dict - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ - used to define customizable model schema + used to define customizable model schema """ rules = [ ParameterRule( - name='temperature', + name="temperature", type=ParameterType.FLOAT, - use_template='temperature', - label=I18nObject( - zh_Hans='温度', - en_US='Temperature' - ), + use_template="temperature", + label=I18nObject(zh_Hans="温度", en_US="Temperature"), ), ParameterRule( - name='top_p', + name="top_p", type=ParameterType.FLOAT, - use_template='top_p', - label=I18nObject( - zh_Hans='Top P', - en_US='Top P' - ) + use_template="top_p", + label=I18nObject(zh_Hans="Top P", en_US="Top P"), ), ParameterRule( - name='max_tokens', + name="max_tokens", type=ParameterType.INT, - use_template='max_tokens', + use_template="max_tokens", min=1, - max=credentials.get('context_length', 2048), + max=credentials.get("context_length", 2048), default=512, - label=I18nObject( - zh_Hans='最大生成长度', - en_US='Max Tokens' - ) + label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"), ), ParameterRule( name=DefaultParameterName.PRESENCE_PENALTY, use_template=DefaultParameterName.PRESENCE_PENALTY, type=ParameterType.FLOAT, label=I18nObject( - en_US='Presence Penalty', - zh_Hans='存在惩罚', + en_US="Presence Penalty", + zh_Hans="存在惩罚", ), required=False, help=I18nObject( - en_US='Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they ' - 'appear in the text so far, increasing the model\'s likelihood to talk about new topics.', - zh_Hans='介于 -2.0 和 2.0 之间的数字。正值会根据新词是否已出现在文本中对其进行惩罚,从而增加模型谈论新话题的可能性。' + en_US="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they " + "appear in the text so far, increasing the model's likelihood to talk about new topics.", + zh_Hans="介于 -2.0 和 2.0 之间的数字。正值会根据新词是否已出现在文本中对其进行惩罚," + "从而增加模型谈论新话题的可能性。", ), default=0.0, min=-2.0, max=2.0, - precision=2 + precision=2, ), ParameterRule( name=DefaultParameterName.FREQUENCY_PENALTY, use_template=DefaultParameterName.FREQUENCY_PENALTY, type=ParameterType.FLOAT, label=I18nObject( - en_US='Frequency Penalty', - zh_Hans='频率惩罚', + en_US="Frequency Penalty", + zh_Hans="频率惩罚", ), required=False, help=I18nObject( - en_US='Number between -2.0 and 2.0. Positive values penalize new tokens based on their ' - 'existing frequency in the text so far, decreasing the model\'s likelihood to repeat the ' - 'same line verbatim.', - zh_Hans='介于 -2.0 和 2.0 之间的数字。正值会根据新词在文本中的现有频率对其进行惩罚,从而降低模型逐字重复相同内容的可能性。' + en_US="Number between -2.0 and 2.0. Positive values penalize new tokens based on their " + "existing frequency in the text so far, decreasing the model's likelihood to repeat the " + "same line verbatim.", + zh_Hans="介于 -2.0 和 2.0 之间的数字。正值会根据新词在文本中的现有频率对其进行惩罚," + "从而降低模型逐字重复相同内容的可能性。", ), default=0.0, min=-2.0, max=2.0, - precision=2 - ) + precision=2, + ), ] completion_type = None - if 'completion_type' in credentials: - if credentials['completion_type'] == 'chat': + if "completion_type" in credentials: + if credentials["completion_type"] == "chat": completion_type = LLMMode.CHAT.value - elif credentials['completion_type'] == 'completion': + elif credentials["completion_type"] == "completion": completion_type = LLMMode.COMPLETION.value else: raise ValueError(f'completion_type {credentials["completion_type"]} is not supported') else: extra_args = XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'] + server_url=credentials["server_url"], + model_uid=credentials["model_uid"], + api_key=credentials.get("api_key"), ) - if 'chat' in extra_args.model_ability: + if "chat" in extra_args.model_ability: completion_type = LLMMode.CHAT.value - elif 'generate' in extra_args.model_ability: + elif "generate" in extra_args.model_ability: completion_type = LLMMode.COMPLETION.value else: - raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported') + raise ValueError(f"xinference model ability {extra_args.model_ability} is not supported") features = [] - support_function_call = credentials.get('support_function_call', False) + support_function_call = credentials.get("support_function_call", False) if support_function_call: features.append(ModelFeature.TOOL_CALL) - support_vision = credentials.get('support_vision', False) + support_vision = credentials.get("support_vision", False) if support_vision: features.append(ModelFeature.VISION) - context_length = credentials.get('context_length', 2048) + context_length = credentials.get("context_length", 2048) entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, features=features, - model_properties={ - ModelPropertyKey.MODE: completion_type, - ModelPropertyKey.CONTEXT_SIZE: context_length - }, - parameter_rules=rules + model_properties={ModelPropertyKey.MODE: completion_type, ModelPropertyKey.CONTEXT_SIZE: context_length}, + parameter_rules=rules, ) return entity - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter, - tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ - -> LLMResult | Generator: + def _generate( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + extra_model_kwargs: XinferenceModelExtraParameter, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: """ - generate text from LLM + generate text from LLM - see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate` + see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._generate` - extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` + extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` """ - if 'server_url' not in credentials: - raise CredentialsValidateFailedError('server_url is required in credentials') + if "server_url" not in credentials: + raise CredentialsValidateFailedError("server_url is required in credentials") - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + credentials["server_url"] = credentials["server_url"].removesuffix("/") - api_key = credentials.get('api_key') or "abc" + api_key = credentials.get("api_key") or "abc" client = OpenAI( base_url=f'{credentials["server_url"]}/v1', @@ -463,33 +471,29 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM ) xinference_client = Client( - base_url=credentials['server_url'], + base_url=credentials["server_url"], + api_key=credentials.get("api_key"), ) - xinference_model = xinference_client.get_model(credentials['model_uid']) + xinference_model = xinference_client.get_model(credentials["model_uid"]) generate_config = { - 'temperature': model_parameters.get('temperature', 1.0), - 'top_p': model_parameters.get('top_p', 0.7), - 'max_tokens': model_parameters.get('max_tokens', 512), - 'presence_penalty': model_parameters.get('presence_penalty', 0.0), - 'frequency_penalty': model_parameters.get('frequency_penalty', 0.0), + "temperature": model_parameters.get("temperature", 1.0), + "top_p": model_parameters.get("top_p", 0.7), + "max_tokens": model_parameters.get("max_tokens", 512), + "presence_penalty": model_parameters.get("presence_penalty", 0.0), + "frequency_penalty": model_parameters.get("frequency_penalty", 0.0), } if stop: - generate_config['stop'] = stop + generate_config["stop"] = stop if tools and len(tools) > 0: - generate_config['tools'] = [ - { - 'type': 'function', - 'function': helper.dump_model(tool) - } for tool in tools - ] - vision = credentials.get('support_vision', False) - if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle): + generate_config["tools"] = [{"type": "function", "function": helper.dump_model(tool)} for tool in tools] + vision = credentials.get("support_vision", False) + if isinstance(xinference_model, RESTfulChatModelHandle): resp = client.chat.completions.create( - model=credentials['model_uid'], + model=credentials["model_uid"], messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages], stream=stream, user=user, @@ -497,34 +501,34 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM ) if stream: if tools and len(tools) > 0: - raise InvokeBadRequestError('xinference tool calls does not support stream mode') - return self._handle_chat_stream_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=resp) - return self._handle_chat_generate_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=resp) + raise InvokeBadRequestError("xinference tool calls does not support stream mode") + return self._handle_chat_stream_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=resp + ) + return self._handle_chat_generate_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=resp + ) elif isinstance(xinference_model, RESTfulGenerateModelHandle): resp = client.completions.create( - model=credentials['model_uid'], + model=credentials["model_uid"], prompt=self._convert_prompt_message_to_text(prompt_messages), stream=stream, user=user, **generate_config, ) if stream: - return self._handle_completion_stream_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=resp) - return self._handle_completion_generate_response(model=model, credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, resp=resp) + return self._handle_completion_stream_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=resp + ) + return self._handle_completion_generate_response( + model=model, credentials=credentials, prompt_messages=prompt_messages, tools=tools, resp=resp + ) else: - raise NotImplementedError(f'xinference model handle type {type(xinference_model)} is not supported') + raise NotImplementedError(f"xinference model handle type {type(xinference_model)} is not supported") - def _extract_response_tool_calls(self, - response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ - -> list[AssistantPromptMessage.ToolCall]: + def _extract_response_tool_calls( + self, response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall] + ) -> list[AssistantPromptMessage.ToolCall]: """ Extract tool calls from response @@ -535,21 +539,19 @@ def _extract_response_tool_calls(self, if response_tool_calls: for response_tool_call in response_tool_calls: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_tool_call.function.name, - arguments=response_tool_call.function.arguments + name=response_tool_call.function.name, arguments=response_tool_call.function.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_tool_call.id, - type=response_tool_call.type, - function=function + id=response_tool_call.id, type=response_tool_call.type, function=function ) tool_calls.append(tool_call) return tool_calls - def _extract_response_function_call(self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall) \ - -> AssistantPromptMessage.ToolCall: + def _extract_response_function_call( + self, response_function_call: FunctionCall | ChoiceDeltaFunctionCall + ) -> AssistantPromptMessage.ToolCall: """ Extract function call from response @@ -559,23 +561,25 @@ def _extract_response_function_call(self, response_function_call: FunctionCall | tool_call = None if response_function_call: function = AssistantPromptMessage.ToolCall.ToolCallFunction( - name=response_function_call.name, - arguments=response_function_call.arguments + name=response_function_call.name, arguments=response_function_call.arguments ) tool_call = AssistantPromptMessage.ToolCall( - id=response_function_call.name, - type="function", - function=function + id=response_function_call.name, type="function", function=function ) return tool_call - def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: ChatCompletion) -> LLMResult: + def _handle_chat_generate_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: ChatCompletion, + ) -> LLMResult: """ - handle normal chat generate response + handle normal chat generate response """ if len(resp.choices) == 0: raise InvokeServerUnavailableError("Empty response") @@ -584,22 +588,22 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_m # convert tool call to assistant message tool call tool_calls = assistant_message.tool_calls - assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls if tool_calls else []) + assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls or []) function_call = assistant_message.function_call if function_call: assistant_prompt_message_tool_calls += [self._extract_response_function_call(function_call)] # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=assistant_message.content, - tool_calls=assistant_prompt_message_tool_calls + content=assistant_message.content, tool_calls=assistant_prompt_message_tool_calls ) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools) - usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens + ) response = LLMResult( model=model, @@ -611,13 +615,18 @@ def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_m return response - def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Iterator[ChatCompletionChunk]) -> Generator: + def _handle_chat_stream_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Iterator[ChatCompletionChunk], + ) -> Generator: """ - handle stream chat generate response + handle stream chat generate response """ - full_response = '' + full_response = "" for chunk in resp: if len(chunk.choices) == 0: @@ -625,7 +634,7 @@ def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_mes delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""): continue # check if there is a tool call in the response @@ -642,32 +651,31 @@ def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_mes # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_message_tool_calls + content=delta.delta.content or "", tool_calls=assistant_message_tool_calls ) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=assistant_message_tool_calls + content=full_response, tool_calls=assistant_message_tool_calls ) prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, system_fingerprint=chunk.system_fingerprint, delta=LLMResultChunkDelta( - index=0, - message=assistant_prompt_message, - finish_reason=delta.finish_reason, - usage=usage + index=0, message=assistant_prompt_message, finish_reason=delta.finish_reason, usage=usage ), ) else: @@ -683,11 +691,16 @@ def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_mes full_response += delta.delta.content - def _handle_completion_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Completion) -> LLMResult: + def _handle_completion_generate_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Completion, + ) -> LLMResult: """ - handle normal completion generate response + handle normal completion generate response """ if len(resp.choices) == 0: raise InvokeServerUnavailableError("Empty response") @@ -695,14 +708,9 @@ def _handle_completion_generate_response(self, model: str, credentials: dict, pr assistant_message = resp.choices[0].text # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=assistant_message, - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=assistant_message, tool_calls=[]) - prompt_tokens = self._get_num_tokens_by_gpt2( - self._convert_prompt_message_to_text(prompt_messages) - ) + prompt_tokens = self._get_num_tokens_by_gpt2(self._convert_prompt_message_to_text(prompt_messages)) completion_tokens = self._num_tokens_from_messages( messages=[assistant_prompt_message], tools=[], is_completion_model=True ) @@ -720,13 +728,18 @@ def _handle_completion_generate_response(self, model: str, credentials: dict, pr return response - def _handle_completion_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - resp: Iterator[Completion]) -> Generator: + def _handle_completion_stream_response( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + resp: Iterator[Completion], + ) -> Generator: """ - handle stream completion generate response + handle stream completion generate response """ - full_response = '' + full_response = "" for chunk in resp: if len(chunk.choices) == 0: @@ -735,40 +748,33 @@ def _handle_completion_stream_response(self, model: str, credentials: dict, prom delta = chunk.choices[0] # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=delta.text if delta.text else '', - tool_calls=[] - ) + assistant_prompt_message = AssistantPromptMessage(content=delta.text or "", tool_calls=[]) if delta.finish_reason is not None: # temp_assistant_prompt_message is used to calculate usage - temp_assistant_prompt_message = AssistantPromptMessage( - content=full_response, - tool_calls=[] - ) + temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[]) - prompt_tokens = self._get_num_tokens_by_gpt2( - self._convert_prompt_message_to_text(prompt_messages) - ) + prompt_tokens = self._get_num_tokens_by_gpt2(self._convert_prompt_message_to_text(prompt_messages)) completion_tokens = self._num_tokens_from_messages( messages=[temp_assistant_prompt_message], tools=[], is_completion_model=True ) - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) + usage = self._calc_response_usage( + model=model, + credentials=credentials, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) yield LLMResultChunk( model=model, prompt_messages=prompt_messages, system_fingerprint=chunk.system_fingerprint, delta=LLMResultChunkDelta( - index=0, - message=assistant_prompt_message, - finish_reason=delta.finish_reason, - usage=usage + index=0, message=assistant_prompt_message, finish_reason=delta.finish_reason, usage=usage ), ) else: - if delta.text is None or delta.text == '': + if delta.text is None or delta.text == "": continue yield LLMResultChunk( @@ -803,15 +809,9 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] ConflictError, NotFoundError, UnprocessableEntityError, - PermissionDeniedError - ], - InvokeRateLimitError: [ - RateLimitError - ], - InvokeAuthorizationError: [ - AuthenticationError + PermissionDeniedError, ], - InvokeBadRequestError: [ - ValueError - ] + InvokeRateLimitError: [RateLimitError], + InvokeAuthorizationError: [AuthenticationError], + InvokeBadRequestError: [ValueError], } diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index 4e7543fd996fd7..efaf114854b5c1 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -15,6 +15,7 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.rerank_model import RerankModel +from core.model_runtime.model_providers.xinference.xinference_helper import validate_model_uid class XinferenceRerankModel(RerankModel): @@ -22,10 +23,16 @@ class XinferenceRerankModel(RerankModel): Model class for Xinference rerank model. """ - def _invoke(self, model: str, credentials: dict, - query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, - user: Optional[str] = None) \ - -> RerankResult: + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: """ Invoke rerank model @@ -39,24 +46,15 @@ def _invoke(self, model: str, credentials: dict, :return: rerank result """ if len(docs) == 0: - return RerankResult( - model=model, - docs=[] - ) + return RerankResult(model=model, docs=[]) - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] - api_key = credentials.get('api_key') - if server_url.endswith('/'): - server_url = server_url[:-1] - auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} - - params = { - 'documents': docs, - 'query': query, - 'top_n': top_n, - 'return_documents': True - } + server_url = credentials["server_url"] + model_uid = credentials["model_uid"] + api_key = credentials.get("api_key") + server_url = server_url.removesuffix("/") + auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} + + params = {"documents": docs, "query": query, "top_n": top_n, "return_documents": True} try: handle = RESTfulRerankModelHandle(model_uid, server_url, auth_headers) response = handle.rerank(**params) @@ -69,27 +67,21 @@ def _invoke(self, model: str, credentials: dict, response = handle.rerank(**params) rerank_documents = [] - for idx, result in enumerate(response['results']): + for idx, result in enumerate(response["results"]): # format document - index = result['index'] - page_content = result['document'] if isinstance(result['document'], str) else result['document']['text'] + index = result["index"] + page_content = result["document"] if isinstance(result["document"], str) else result["document"]["text"] rerank_document = RerankDocument( index=index, text=page_content, - score=result['relevance_score'], + score=result["relevance_score"], ) # score threshold check - if score_threshold is not None: - if result['relevance_score'] >= score_threshold: - rerank_documents.append(rerank_document) - else: + if score_threshold is None or result["relevance_score"] >= score_threshold: rerank_documents.append(rerank_document) - return RerankResult( - model=model, - docs=rerank_documents - ) + return RerankResult(model=model, docs=rerank_documents) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -100,33 +92,34 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + if not validate_model_uid(credentials): raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + credentials["server_url"] = credentials["server_url"].removesuffix("/") # initialize client client = Client( - base_url=credentials['server_url'] + base_url=credentials["server_url"], + api_key=credentials.get("api_key"), ) - xinference_client = client.get_model(model_uid=credentials['model_uid']) + xinference_client = client.get_model(model_uid=credentials["model_uid"]) if not isinstance(xinference_client, RESTfulRerankModelHandle): raise InvokeBadRequestError( - 'please check model type, the model you want to invoke is not a rerank model') + "please check model type, the model you want to invoke is not a rerank model" + ) self.invoke( model=model, credentials=credentials, query="Whose kasumi", docs=[ - "Kasumi is a girl's name of Japanese origin meaning \"mist\".", + 'Kasumi is a girl\'s name of Japanese origin meaning "mist".', "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", - "and she leads a team named PopiParty." + "and she leads a team named PopiParty.", ], - score_threshold=0.8 + score_threshold=0.8, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -142,53 +135,38 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.RERANK, model_properties={}, - parameter_rules=[] + parameter_rules=[], ) return entity class RESTfulRerankModelHandleWithoutExtraParameter(RESTfulRerankModelHandle): - def rerank( - self, - documents: list[str], - query: str, - top_n: Optional[int] = None, - max_chunks_per_doc: Optional[int] = None, - return_documents: Optional[bool] = None, - **kwargs + self, + documents: list[str], + query: str, + top_n: Optional[int] = None, + max_chunks_per_doc: Optional[int] = None, + return_documents: Optional[bool] = None, + **kwargs, ): url = f"{self._base_url}/v1/rerank" request_body = { @@ -204,8 +182,6 @@ def rerank( response = requests.post(url, json=request_body, headers=self.auth_headers) if response.status_code != 200: - raise InvokeServerUnavailableError( - f"Failed to rerank documents, detail: {response.json()['detail']}" - ) + raise InvokeServerUnavailableError(f"Failed to rerank documents, detail: {response.json()['detail']}") response_data = response.json() return response_data diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py index 9ee36213176ef7..3d7aefeb6dd89a 100644 --- a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py @@ -14,6 +14,7 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from core.model_runtime.model_providers.xinference.xinference_helper import validate_model_uid class XinferenceSpeech2TextModel(Speech2TextModel): @@ -21,9 +22,7 @@ class XinferenceSpeech2TextModel(Speech2TextModel): Model class for Xinference speech to text model. """ - def _invoke(self, model: str, credentials: dict, - file: IO[bytes], user: Optional[str] = None) \ - -> str: + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Invoke speech2text model @@ -44,26 +43,27 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + if not validate_model_uid(credentials): raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + credentials["server_url"] = credentials["server_url"].removesuffix("/") # initialize client client = Client( - base_url=credentials['server_url'] + base_url=credentials["server_url"], + api_key=credentials.get("api_key"), ) - xinference_client = client.get_model(model_uid=credentials['model_uid']) + xinference_client = client.get_model(model_uid=credentials["model_uid"]) if not isinstance(xinference_client, RESTfulAudioModelHandle): raise InvokeBadRequestError( - 'please check model type, the model you want to invoke is not a audio model') + "please check model type, the model you want to invoke is not a audio model" + ) audio_file_path = self._get_demo_file_path() - with open(audio_file_path, 'rb') as audio_file: + with open(audio_file_path, "rb") as audio_file: self.invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -79,23 +79,11 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def _speech2text_invoke( @@ -113,48 +101,45 @@ def _speech2text_invoke( :param model: model name :param credentials: model credentials - :param file: The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpe g,mpga, m4a, ogg, wav, or webm. + :param file: The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg, + mpga, m4a, ogg, wav, or webm. :param language: The language of the input audio. Supplying the input language in ISO-639-1 :param prompt: An optional text to guide the model's style or continue a previous audio segment. The prompt should match the audio language. - :param response_format: The format of the transcript output, in one of these options: json, text, srt, verbose _json, or vtt. - :param temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output mor e random,while lower values like 0.2 will make it more focused and deterministic.If set to 0, the model wi ll use log probability to automatically increase the temperature until certain thresholds are hit. + :param response_format: The format of the transcript output, in one of these options: json, text, srt, + verbose_json, or vtt. + :param temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more + random,while lower values like 0.2 will make it more focused and deterministic.If set to 0, the model will use + log probability to automatically increase the temperature until certain thresholds are hit. :return: text for given audio file """ - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] - api_key = credentials.get('api_key') - if server_url.endswith('/'): - server_url = server_url[:-1] - auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + server_url = credentials["server_url"] + model_uid = credentials["model_uid"] + api_key = credentials.get("api_key") + server_url = server_url.removesuffix("/") + auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} try: handle = RESTfulAudioModelHandle(model_uid, server_url, auth_headers) response = handle.transcriptions( - audio=file, - language=language, - prompt=prompt, - response_format=response_format, - temperature=temperature + audio=file, language=language, prompt=prompt, response_format=response_format, temperature=temperature ) except RuntimeError as e: raise InvokeServerUnavailableError(str(e)) return response["text"] - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.SPEECH2TEXT, - model_properties={ }, - parameter_rules=[] + model_properties={}, + parameter_rules=[], ) return entity diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index 11f1e29cb39f81..e51e6a941c5413 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -3,6 +3,7 @@ from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult @@ -16,16 +17,22 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper +from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper, validate_model_uid class XinferenceTextEmbeddingModel(TextEmbeddingModel): """ Model class for Xinference text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -39,14 +46,14 @@ def _invoke(self, model: str, credentials: dict, :param credentials: model credentials :param texts: texts to embed :param user: unique user id + :param input_type: input type :return: embeddings result """ - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] - api_key = credentials.get('api_key') - if server_url.endswith('/'): - server_url = server_url[:-1] - auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + server_url = credentials["server_url"] + model_uid = credentials["model_uid"] + api_key = credentials.get("api_key") + server_url = server_url.removesuffix("/") + auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} try: handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers) @@ -70,13 +77,11 @@ class EmbeddingData(TypedDict): embedding: List[float] """ - usage = embeddings['usage'] - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens']) + usage = embeddings["usage"] + usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"]) result = TextEmbeddingResult( - model=model, - embeddings=[embedding['embedding'] for embedding in embeddings['data']], - usage=usage + model=model, embeddings=[embedding["embedding"] for embedding in embeddings["data"]], usage=usage ) return result @@ -105,19 +110,26 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + if not validate_model_uid(credentials): raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - server_url = credentials['server_url'] - model_uid = credentials['model_uid'] - extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid) + server_url = credentials["server_url"] + model_uid = credentials["model_uid"] + api_key = credentials.get("api_key") + extra_args = XinferenceHelper.get_xinference_extra_parameter( + server_url=server_url, + model_uid=model_uid, + api_key=api_key, + ) if extra_args.max_tokens: - credentials['max_tokens'] = extra_args.max_tokens - if server_url.endswith('/'): - server_url = server_url[:-1] + credentials["max_tokens"] = extra_args.max_tokens + server_url = server_url.removesuffix("/") - client = Client(base_url=server_url) + client = Client( + base_url=server_url, + api_key=api_key, + ) try: handle = client.get_model(model_uid=model_uid) @@ -125,32 +137,24 @@ def validate_credentials(self, model: str, credentials: dict) -> None: raise InvokeAuthorizationError(e) if not isinstance(handle, RESTfulEmbeddingModelHandle): - raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model') + raise InvokeBadRequestError( + "please check model type, the model you want to invoke is not a text embedding model" + ) - self._invoke(model=model, credentials=credentials, texts=['ping']) + self._invoke(model=model, credentials=credentials, texts=["ping"]) except InvokeAuthorizationError as e: - raise CredentialsValidateFailedError(f'Failed to validate credentials for model {model}: {e}') + raise CredentialsValidateFailedError(f"Failed to validate credentials for model {model}: {e}") except RuntimeError as e: raise CredentialsValidateFailedError(e) @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - KeyError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [KeyError], } def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: @@ -164,10 +168,7 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -178,28 +179,26 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, model_properties={ ModelPropertyKey.MAX_CHUNKS: 1, - ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512, + ModelPropertyKey.CONTEXT_SIZE: "max_tokens" in credentials and credentials["max_tokens"] or 512, }, - parameter_rules=[] + parameter_rules=[], ) return entity diff --git a/api/core/model_runtime/model_providers/xinference/tts/tts.py b/api/core/model_runtime/model_providers/xinference/tts/tts.py index a564a021b19615..ad7b64efb5d2e7 100644 --- a/api/core/model_runtime/model_providers/xinference/tts/tts.py +++ b/api/core/model_runtime/model_providers/xinference/tts/tts.py @@ -1,5 +1,5 @@ import concurrent.futures -from typing import Optional +from typing import Any, Optional from xinference_client.client.restful.restful_client import RESTfulAudioModelHandle @@ -15,95 +15,94 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.tts_model import TTSModel -from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper +from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper, validate_model_uid class XinferenceText2SpeechModel(TTSModel): - def __init__(self): # preset voices, need support custom voice self.model_voices = { - '__default': { - 'all': [ - {'name': 'Default', 'value': 'default'}, + "__default": { + "all": [ + {"name": "Default", "value": "default"}, ] }, - 'ChatTTS': { - 'all': [ - {'name': 'Alloy', 'value': 'alloy'}, - {'name': 'Echo', 'value': 'echo'}, - {'name': 'Fable', 'value': 'fable'}, - {'name': 'Onyx', 'value': 'onyx'}, - {'name': 'Nova', 'value': 'nova'}, - {'name': 'Shimmer', 'value': 'shimmer'}, + "ChatTTS": { + "all": [ + {"name": "Alloy", "value": "alloy"}, + {"name": "Echo", "value": "echo"}, + {"name": "Fable", "value": "fable"}, + {"name": "Onyx", "value": "onyx"}, + {"name": "Nova", "value": "nova"}, + {"name": "Shimmer", "value": "shimmer"}, ] }, - 'CosyVoice': { - 'zh-Hans': [ - {'name': '中文男', 'value': '中文男'}, - {'name': '中文女', 'value': '中文女'}, - {'name': '粤语女', 'value': '粤语女'}, + "CosyVoice": { + "zh-Hans": [ + {"name": "中文男", "value": "中文男"}, + {"name": "中文女", "value": "中文女"}, + {"name": "粤语女", "value": "粤语女"}, ], - 'zh-Hant': [ - {'name': '中文男', 'value': '中文男'}, - {'name': '中文女', 'value': '中文女'}, - {'name': '粤语女', 'value': '粤语女'}, + "zh-Hant": [ + {"name": "中文男", "value": "中文男"}, + {"name": "中文女", "value": "中文女"}, + {"name": "粤语女", "value": "粤语女"}, ], - 'en-US': [ - {'name': '英文男', 'value': '英文男'}, - {'name': '英文女', 'value': '英文女'}, + "en-US": [ + {"name": "英文男", "value": "英文男"}, + {"name": "英文女", "value": "英文女"}, ], - 'ja-JP': [ - {'name': '日语男', 'value': '日语男'}, + "ja-JP": [ + {"name": "日语男", "value": "日语男"}, ], - 'ko-KR': [ - {'name': '韩语女', 'value': '韩语女'}, - ] - } + "ko-KR": [ + {"name": "韩语女", "value": "韩语女"}, + ], + }, } def validate_credentials(self, model: str, credentials: dict) -> None: """ - Validate model credentials + Validate model credentials - :param model: model name - :param credentials: model credentials - :return: - """ + :param model: model name + :param credentials: model credentials + :return: + """ try: - if ("/" in credentials['model_uid'] or - "?" in credentials['model_uid'] or - "#" in credentials['model_uid']): + if not validate_model_uid(credentials): raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + credentials["server_url"] = credentials["server_url"].removesuffix("/") extra_param = XinferenceHelper.get_xinference_extra_parameter( - server_url=credentials['server_url'], - model_uid=credentials['model_uid'] + server_url=credentials["server_url"], + model_uid=credentials["model_uid"], + api_key=credentials.get("api_key"), ) - if 'text-to-audio' not in extra_param.model_ability: + if "text-to-audio" not in extra_param.model_ability: raise InvokeBadRequestError( - 'please check model type, the model you want to invoke is not a text-to-audio model') + "please check model type, the model you want to invoke is not a text-to-audio model" + ) if extra_param.model_family and extra_param.model_family in self.model_voices: - credentials['audio_model_name'] = extra_param.model_family + credentials["audio_model_name"] = extra_param.model_family else: - credentials['audio_model_name'] = '__default' + credentials["audio_model_name"] = "__default" self._tts_invoke_streaming( model=model, credentials=credentials, - content_text='Hello Dify!', + content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, - user: Optional[str] = None): + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ): """ _invoke text2speech model @@ -117,20 +116,18 @@ def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: s """ return self._tts_invoke_streaming(model, credentials, content_text, voice) - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ - used to define customizable model schema + used to define customizable model schema """ entity = AIModelEntity( model=model, - label=I18nObject( - en_US=model - ), + label=I18nObject(en_US=model), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TTS, model_properties={}, - parameter_rules=[] + parameter_rules=[], ) return entity @@ -146,37 +143,30 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] :return: Invoke error mapping """ return { - InvokeConnectionError: [ - InvokeConnectionError - ], - InvokeServerUnavailableError: [ - InvokeServerUnavailableError - ], - InvokeRateLimitError: [ - InvokeRateLimitError - ], - InvokeAuthorizationError: [ - InvokeAuthorizationError - ], - InvokeBadRequestError: [ - InvokeBadRequestError, - KeyError, - ValueError - ] + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], } def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: - audio_model_name = credentials.get('audio_model_name', '__default') + audio_model_name = credentials.get("audio_model_name", "__default") for key, voices in self.model_voices.items(): if key in audio_model_name: if language and language in voices: return voices[language] - elif 'all' in voices: - return voices['all'] + elif "all" in voices: + return voices["all"] + else: + all_voices = [] + for lang, lang_voices in voices.items(): + all_voices.extend(lang_voices) + return all_voices - return self.model_voices['__default']['all'] + return self.model_voices["__default"]["all"] - def _get_model_default_voice(self, model: str, credentials: dict) -> any: + def _get_model_default_voice(self, model: str, credentials: dict) -> Any: return "" def _get_model_word_limit(self, model: str, credentials: dict) -> int: @@ -188,8 +178,7 @@ def _get_model_audio_type(self, model: str, credentials: dict) -> str: def _get_model_workers_limit(self, model: str, credentials: dict) -> int: return 5 - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, - voice: str) -> any: + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> Any: """ _tts_invoke_streaming text2speech model @@ -199,44 +188,41 @@ def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str :param voice: model timbre :return: text translated to audio file """ - if credentials['server_url'].endswith('/'): - credentials['server_url'] = credentials['server_url'][:-1] + credentials["server_url"] = credentials["server_url"].removesuffix("/") try: - handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={}) + api_key = credentials.get("api_key") + auth_headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} + handle = RESTfulAudioModelHandle( + credentials["model_uid"], credentials["server_url"], auth_headers=auth_headers + ) - model_support_voice = [x.get("value") for x in - self.get_tts_model_voices(model=model, credentials=credentials)] + model_support_voice = [ + x.get("value") for x in self.get_tts_model_voices(model=model, credentials=credentials) + ] if not voice or voice not in model_support_voice: voice = self._get_model_default_voice(model, credentials) word_limit = self._get_model_word_limit(model, credentials) if len(content_text) > word_limit: sentences = self._split_text_into_sentences(content_text, max_length=word_limit) executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences))) - futures = [executor.submit( - handle.speech, - input=sentences[i], - voice=voice, - response_format="mp3", - speed=1.0, - stream=False - ) - for i in range(len(sentences))] + futures = [ + executor.submit( + handle.speech, input=sentences[i], voice=voice, response_format="mp3", speed=1.0, stream=True + ) + for i in range(len(sentences)) + ] - for index, future in enumerate(futures): + for future in futures: response = future.result() - for i in range(0, len(response), 1024): - yield response[i:i + 1024] + for chunk in response: + yield chunk else: response = handle.speech( - input=content_text.strip(), - voice=voice, - response_format="mp3", - speed=1.0, - stream=False + input=content_text.strip(), voice=voice, response_format="mp3", speed=1.0, stream=True ) - for i in range(0, len(response), 1024): - yield response[i:i + 1024] + for chunk in response: + yield chunk except Exception as ex: raise InvokeBadRequestError(str(ex)) diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 7db483a485ee1c..baa3ccbe8adbc0 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -18,9 +18,17 @@ class XinferenceModelExtraParameter: support_vision: bool = False model_family: Optional[str] - def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str], - support_function_call: bool, support_vision: bool, max_tokens: int, context_length: int, - model_family: Optional[str]) -> None: + def __init__( + self, + model_format: str, + model_handle_type: str, + model_ability: list[str], + support_function_call: bool, + support_vision: bool, + max_tokens: int, + context_length: int, + model_family: Optional[str], + ) -> None: self.model_format = model_format self.model_handle_type = model_handle_type self.model_ability = model_ability @@ -30,82 +38,89 @@ def __init__(self, model_format: str, model_handle_type: str, model_ability: lis self.context_length = context_length self.model_family = model_family + cache = {} cache_lock = Lock() + class XinferenceHelper: @staticmethod - def get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: + def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: XinferenceHelper._clean_cache() with cache_lock: if model_uid not in cache: cache[model_uid] = { - 'expires': time() + 300, - 'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid) + "expires": time() + 300, + "value": XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid, api_key), } - return cache[model_uid]['value'] + return cache[model_uid]["value"] @staticmethod def _clean_cache() -> None: try: with cache_lock: - expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()] + expired_keys = [model_uid for model_uid, model in cache.items() if model["expires"] < time()] for model_uid in expired_keys: del cache[model_uid] except RuntimeError as e: pass @staticmethod - def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: + def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: """ - get xinference model extra parameter like model_format and model_handle_type + get xinference model extra parameter like model_format and model_handle_type """ if not model_uid or not model_uid.strip() or not server_url or not server_url.strip(): - raise RuntimeError('model_uid is empty') + raise RuntimeError("model_uid is empty") - url = str(URL(server_url) / 'v1' / 'models' / model_uid) + url = str(URL(server_url) / "v1" / "models" / model_uid) - # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 + # this method is surrounded by a lock, and default requests may hang forever, + # so we just set a Adapter with max_retries=3 session = Session() - session.mount('http://', HTTPAdapter(max_retries=3)) - session.mount('https://', HTTPAdapter(max_retries=3)) + session.mount("http://", HTTPAdapter(max_retries=3)) + session.mount("https://", HTTPAdapter(max_retries=3)) + headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} try: - response = session.get(url, timeout=10) + response = session.get(url, headers=headers, timeout=10) except (MissingSchema, ConnectionError, Timeout) as e: - raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}') + raise RuntimeError(f"get xinference model extra parameter failed, url: {url}, error: {e}") if response.status_code != 200: - raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}') + raise RuntimeError( + f"get xinference model extra parameter failed, status code: {response.status_code}," + f" response: {response.text}" + ) response_json = response.json() - model_format = response_json.get('model_format', 'ggmlv3') - model_ability = response_json.get('model_ability', []) - model_family = response_json.get('model_family', None) + model_format = response_json.get("model_format", "ggmlv3") + model_ability = response_json.get("model_ability", []) + model_family = response_json.get("model_family", None) - if response_json.get('model_type') == 'embedding': - model_handle_type = 'embedding' - elif response_json.get('model_type') == 'audio': - model_handle_type = 'audio' - if model_family and model_family in ['ChatTTS', 'CosyVoice']: - model_ability.append('text-to-audio') + if response_json.get("model_type") == "embedding": + model_handle_type = "embedding" + elif response_json.get("model_type") == "audio": + model_handle_type = "audio" + if model_family and model_family in {"ChatTTS", "CosyVoice", "FishAudio"}: + model_ability.append("text-to-audio") else: - model_ability.append('audio-to-text') - elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']: - model_handle_type = 'chatglm' - elif 'generate' in model_ability: - model_handle_type = 'generate' - elif 'chat' in model_ability: - model_handle_type = 'chat' + model_ability.append("audio-to-text") + elif model_format == "ggmlv3" and "chatglm" in response_json["model_name"]: + model_handle_type = "chatglm" + elif "generate" in model_ability: + model_handle_type = "generate" + elif "chat" in model_ability: + model_handle_type = "chat" else: - raise NotImplementedError('xinference model handle type is not supported') + raise NotImplementedError("xinference model handle type is not supported") - support_function_call = 'tools' in model_ability - support_vision = 'vision' in model_ability - max_tokens = response_json.get('max_tokens', 512) + support_function_call = "tools" in model_ability + support_vision = "vision" in model_ability + max_tokens = response_json.get("max_tokens", 512) - context_length = response_json.get('context_length', 2048) + context_length = response_json.get("context_length", 2048) return XinferenceModelExtraParameter( model_format=model_format, @@ -115,5 +130,18 @@ def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> Xinferen support_vision=support_vision, max_tokens=max_tokens, context_length=context_length, - model_family=model_family + model_family=model_family, ) + + +def validate_model_uid(credentials: dict) -> bool: + """ + Validate the model_uid within the credentials dictionary to ensure it does not + contain forbidden characters ("/", "?", "#"). + + param credentials: model credentials + :return: True if the model_uid does not contain forbidden characters ("/", "?", "#"), else False. + """ + forbidden_characters = ["/", "?", "#"] + model_uid = credentials.get("model_uid", "") + return not any(char in forbidden_characters for char in model_uid) diff --git a/api/core/model_runtime/model_providers/yi/llm/_position.yaml b/api/core/model_runtime/model_providers/yi/llm/_position.yaml index e876893b414985..5fa098beda1c0d 100644 --- a/api/core/model_runtime/model_providers/yi/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/yi/llm/_position.yaml @@ -7,3 +7,4 @@ - yi-medium-200k - yi-spark - yi-large-turbo +- yi-lightning diff --git a/api/core/model_runtime/model_providers/yi/llm/llm.py b/api/core/model_runtime/model_providers/yi/llm/llm.py index d33f38333be9e7..0642e72ed500e1 100644 --- a/api/core/model_runtime/model_providers/yi/llm/llm.py +++ b/api/core/model_runtime/model_providers/yi/llm/llm.py @@ -4,21 +4,37 @@ import tiktoken -from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult from core.model_runtime.entities.message_entities import ( PromptMessage, PromptMessageTool, SystemPromptMessage, ) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel class YiLargeLanguageModel(OpenAILargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) # yi-vl-plus not support system prompt yet. @@ -27,7 +43,9 @@ def _invoke(self, model: str, credentials: dict, for message in prompt_messages: if not isinstance(message, SystemPromptMessage): prompt_message_except_system.append(message) - return super()._invoke(model, credentials, prompt_message_except_system, model_parameters, tools, stop, stream) + return super()._invoke( + model, credentials, prompt_message_except_system, model_parameters, tools, stop, stream + ) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -36,8 +54,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: super().validate_credentials(model, credentials) # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_string(self, model: str, text: str, - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate num tokens for text completion model with tiktoken package. @@ -55,8 +72,9 @@ def _num_tokens_from_string(self, model: str, text: str, return num_tokens # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def _num_tokens_from_messages( + self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ @@ -76,10 +94,10 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], # which need to download the image and then get the resolution for calculation, # and will increase the request delay if isinstance(value, list): - text = '' + text = "" for item in value: - if isinstance(item, dict) and item['type'] == 'text': - text += item['text'] + if isinstance(item, dict) and item["type"] == "text": + text += item["text"] value = text @@ -110,10 +128,65 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], @staticmethod def _add_custom_parameters(credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['openai_api_key']=credentials['api_key'] - if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - credentials['openai_api_base']='https://api.lingyiwanwu.com' + credentials["mode"] = "chat" + credentials["openai_api_key"] = credentials["api_key"] + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + credentials["openai_api_base"] = "https://api.lingyiwanwu.com" else: - parsed_url = urlparse(credentials['endpoint_url']) - credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}" + parsed_url = urlparse(credentials["endpoint_url"]) + credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}" + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + return AIModelEntity( + model=model, + label=I18nObject(en_US=model, zh_Hans=model), + model_type=ModelType.LLM, + features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] + if credentials.get("function_calling_type") == "tool_call" + else [], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000)), + ModelPropertyKey.MODE: LLMMode.CHAT.value, + }, + parameter_rules=[ + ParameterRule( + name="temperature", + use_template="temperature", + label=I18nObject(en_US="Temperature", zh_Hans="温度"), + type=ParameterType.FLOAT, + ), + ParameterRule( + name="max_tokens", + use_template="max_tokens", + default=512, + min=1, + max=int(credentials.get("max_tokens", 8192)), + label=I18nObject( + en_US="Max Tokens", zh_Hans="指定生成结果长度的上限。如果生成结果截断,可以调大该参数" + ), + type=ParameterType.INT, + ), + ParameterRule( + name="top_p", + use_template="top_p", + label=I18nObject( + en_US="Top P", + zh_Hans="控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。", + ), + type=ParameterType.FLOAT, + ), + ParameterRule( + name="top_k", + use_template="top_k", + label=I18nObject(en_US="Top K", zh_Hans="取样数量"), + type=ParameterType.FLOAT, + ), + ParameterRule( + name="frequency_penalty", + use_template="frequency_penalty", + label=I18nObject(en_US="Frequency Penalty", zh_Hans="重复惩罚"), + type=ParameterType.FLOAT, + ), + ], + ) diff --git a/api/core/model_runtime/model_providers/yi/llm/yi-lightning.yaml b/api/core/model_runtime/model_providers/yi/llm/yi-lightning.yaml new file mode 100644 index 00000000000000..fccf1b3a264018 --- /dev/null +++ b/api/core/model_runtime/model_providers/yi/llm/yi-lightning.yaml @@ -0,0 +1,43 @@ +model: yi-lightning +label: + zh_Hans: yi-lightning + en_US: yi-lightning +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 16384 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。 + en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is. + - name: max_tokens + use_template: max_tokens + type: int + default: 1024 + min: 1 + max: 4000 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + type: float + default: 0.9 + min: 0.01 + max: 1.00 + help: + zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。 + en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature. +pricing: + input: '0.99' + output: '0.99' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/yi/yi.py b/api/core/model_runtime/model_providers/yi/yi.py index 691c7aa3711beb..9599acb22b505a 100644 --- a/api/core/model_runtime/model_providers/yi/yi.py +++ b/api/core/model_runtime/model_providers/yi/yi.py @@ -8,7 +8,6 @@ class YiProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -21,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: # Use `yi-34b-chat-0205` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='yi-34b-chat-0205', - credentials=credentials - ) + model_instance.validate_credentials(model="yi-34b-chat-0205", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/yi/yi.yaml b/api/core/model_runtime/model_providers/yi/yi.yaml index de741afb10990c..393526c31e3168 100644 --- a/api/core/model_runtime/model_providers/yi/yi.yaml +++ b/api/core/model_runtime/model_providers/yi/yi.yaml @@ -20,6 +20,7 @@ supported_model_types: - llm configurate_methods: - predefined-model + - customizable-model provider_credential_schema: credential_form_schemas: - variable: api_key @@ -39,3 +40,57 @@ provider_credential_schema: placeholder: zh_Hans: Base URL, e.g. https://api.lingyiwanwu.com/v1 en_US: Base URL, e.g. https://api.lingyiwanwu.com/v1 +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + type: text-input + default: '4096' + placeholder: + zh_Hans: 在此输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: max_tokens + label: + zh_Hans: 最大 token 上限 + en_US: Upper bound for max tokens + default: '4096' + type: text-input + show_on: + - variable: __model_type + value: llm + - variable: function_calling_type + label: + en_US: Function calling + type: select + required: false + default: no_call + options: + - value: no_call + label: + en_US: Not Support + zh_Hans: 不支持 + - value: function_call + label: + en_US: Support + zh_Hans: 支持 + show_on: + - variable: __model_type + value: llm diff --git a/api/core/model_runtime/model_providers/zhinao/llm/llm.py b/api/core/model_runtime/model_providers/zhinao/llm/llm.py index 6930a5ed0134b0..befc3de021e1f2 100644 --- a/api/core/model_runtime/model_providers/zhinao/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhinao/llm/llm.py @@ -7,11 +7,17 @@ class ZhinaoLargeLanguageModel(OAIAPICompatLargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) @@ -21,5 +27,5 @@ def validate_credentials(self, model: str, credentials: dict) -> None: @classmethod def _add_custom_parameters(cls, credentials: dict) -> None: - credentials['mode'] = 'chat' - credentials['endpoint_url'] = 'https://api.360.cn/v1' + credentials["mode"] = "chat" + credentials["endpoint_url"] = "https://api.360.cn/v1" diff --git a/api/core/model_runtime/model_providers/zhinao/zhinao.py b/api/core/model_runtime/model_providers/zhinao/zhinao.py index 44b36c9f51edd7..2a263292f98f14 100644 --- a/api/core/model_runtime/model_providers/zhinao/zhinao.py +++ b/api/core/model_runtime/model_providers/zhinao/zhinao.py @@ -8,7 +8,6 @@ class ZhinaoProvider(ModelProvider): - def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials @@ -21,12 +20,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: # Use `360gpt-turbo` model for validate, # no matter what model you pass in, text completion model or chat model - model_instance.validate_credentials( - model='360gpt-turbo', - credentials=credentials - ) + model_instance.validate_credentials(model="360gpt-turbo", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/zhipuai/_common.py b/api/core/model_runtime/model_providers/zhipuai/_common.py index 3412d8100f8e4c..fa95232f717d78 100644 --- a/api/core/model_runtime/model_providers/zhipuai/_common.py +++ b/api/core/model_runtime/model_providers/zhipuai/_common.py @@ -17,8 +17,7 @@ def _to_credential_kwargs(self, credentials: dict) -> dict: :return: """ credentials_kwargs = { - "api_key": credentials['api_key'] if 'api_key' in credentials else - credentials.get("zhipuai_api_key"), + "api_key": credentials["api_key"] if "api_key" in credentials else credentials.get("zhipuai_api_key"), } return credentials_kwargs @@ -38,5 +37,5 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] InvokeServerUnavailableError: [], InvokeRateLimitError: [], InvokeAuthorizationError: [], - InvokeBadRequestError: [] + InvokeBadRequestError: [], } diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/chatglm_turbo.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/chatglm_turbo.yaml index 8f51f80967748f..049fae6c1676b9 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/chatglm_turbo.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/chatglm_turbo.yaml @@ -19,15 +19,15 @@ parameter_rules: help: zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. - - name: incremental + - name: do_sample label: - zh_Hans: 增量返回 - en_US: Incremental + zh_Hans: 采样策略 + en_US: Sampling strategy type: boolean help: - zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 - en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. - required: false + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true - name: return_type label: zh_Hans: 回复类型 @@ -40,3 +40,4 @@ parameter_rules: options: - text - json_string +deprecated: true diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-0520.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-0520.yaml index 8391278e4f1ea3..7c8da51d1b82ff 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-0520.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-0520.yaml @@ -23,20 +23,29 @@ parameter_rules: help: zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. - - name: incremental + - name: do_sample label: - zh_Hans: 增量返回 - en_US: Incremental + zh_Hans: 采样策略 + en_US: Sampling strategy type: boolean help: - zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 - en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. - required: false + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true - name: max_tokens use_template: max_tokens default: 1024 min: 1 - max: 8192 + max: 4095 + - name: web_search + type: boolean + label: + zh_Hans: 联网搜索 + en_US: Web Search + default: false + help: + zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 + en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. pricing: input: '0.1' output: '0.1' diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-air.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-air.yaml index 7caebd3e4b6aa8..7a7b4b0892785e 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-air.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-air.yaml @@ -23,20 +23,29 @@ parameter_rules: help: zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. - - name: incremental + - name: do_sample label: - zh_Hans: 增量返回 - en_US: Incremental + zh_Hans: 采样策略 + en_US: Sampling strategy type: boolean help: - zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 - en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. - required: false + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true - name: max_tokens use_template: max_tokens default: 1024 min: 1 - max: 8192 + max: 4095 + - name: web_search + type: boolean + label: + zh_Hans: 联网搜索 + en_US: Web Search + default: false + help: + zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 + en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. pricing: input: '0.001' output: '0.001' diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-airx.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-airx.yaml index dc123913deb8b5..09ad842801eb9d 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-airx.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-airx.yaml @@ -23,20 +23,29 @@ parameter_rules: help: zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. - - name: incremental + - name: do_sample label: - zh_Hans: 增量返回 - en_US: Incremental + zh_Hans: 采样策略 + en_US: Sampling strategy type: boolean help: - zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 - en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. - required: false + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true - name: max_tokens use_template: max_tokens default: 1024 min: 1 - max: 8192 + max: 4095 + - name: web_search + type: boolean + label: + zh_Hans: 联网搜索 + en_US: Web Search + default: false + help: + zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 + en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. pricing: input: '0.01' output: '0.01' diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml index 1b1d499ba7383c..aee82a0602a995 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml @@ -23,22 +23,31 @@ parameter_rules: help: zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. - - name: incremental + - name: do_sample label: - zh_Hans: 增量返回 - en_US: Incremental + zh_Hans: 采样策略 + en_US: Sampling strategy type: boolean help: - zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 - en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. - required: false + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true - name: max_tokens use_template: max_tokens default: 1024 min: 1 - max: 8192 + max: 4095 + - name: web_search + type: boolean + label: + zh_Hans: 联网搜索 + en_US: Web Search + default: false + help: + zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 + en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. pricing: - input: '0.0001' - output: '0.0001' + input: '0' + output: '0' unit: '0.001' currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flashx.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flashx.yaml new file mode 100644 index 00000000000000..40ff7609c7a2e2 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flashx.yaml @@ -0,0 +1,53 @@ +model: glm-4-flashx +label: + en_US: glm-4-flashx +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat +parameter_rules: + - name: temperature + use_template: temperature + default: 0.95 + min: 0.0 + max: 1.0 + help: + zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 + en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: top_p + use_template: top_p + default: 0.7 + help: + zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 + en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: do_sample + label: + zh_Hans: 采样策略 + en_US: Sampling strategy + type: boolean + help: + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 4095 + - name: web_search + type: boolean + label: + zh_Hans: 联网搜索 + en_US: Web Search + default: false + help: + zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 + en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. +pricing: + input: '0' + output: '0' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml index 5bdb4428403908..791a77ba157cf7 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml @@ -23,17 +23,31 @@ parameter_rules: help: zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. - - name: incremental + - name: do_sample label: - zh_Hans: 增量返回 - en_US: Incremental + zh_Hans: 采样策略 + en_US: Sampling strategy type: boolean help: - zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 - en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. - required: false + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true - name: max_tokens use_template: max_tokens default: 1024 min: 1 max: 8192 + - name: web_search + type: boolean + label: + zh_Hans: 联网搜索 + en_US: Web Search + default: false + help: + zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 + en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. +pricing: + input: '0.001' + output: '0.001' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml index 6b5bcc5bcf4468..13ed1e49c99a2a 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml @@ -23,17 +23,31 @@ parameter_rules: help: zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. - - name: incremental + - name: do_sample label: - zh_Hans: 增量返回 - en_US: Incremental + zh_Hans: 采样策略 + en_US: Sampling strategy type: boolean help: - zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 - en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. - required: false + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true - name: max_tokens use_template: max_tokens default: 1024 min: 1 - max: 8192 + max: 4095 + - name: web_search + type: boolean + label: + zh_Hans: 联网搜索 + en_US: Web Search + default: false + help: + zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 + en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. +pricing: + input: '0.1' + output: '0.1' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_long.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_long.yaml index 9d92e58f6cdff1..badcee22db77b1 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_long.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_long.yaml @@ -26,8 +26,31 @@ parameter_rules: help: zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: do_sample + label: + zh_Hans: 采样策略 + en_US: Sampling strategy + type: boolean + help: + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true - name: max_tokens use_template: max_tokens default: 1024 min: 1 - max: 4096 + max: 4095 + - name: web_search + type: boolean + label: + zh_Hans: 联网搜索 + en_US: Web Search + default: false + help: + zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 + en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. +pricing: + input: '0.001' + output: '0.001' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_plus.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_plus.yaml new file mode 100644 index 00000000000000..e2f785e1bc4383 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_plus.yaml @@ -0,0 +1,53 @@ +model: glm-4-plus +label: + en_US: glm-4-plus +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call +model_properties: + mode: chat +parameter_rules: + - name: temperature + use_template: temperature + default: 0.95 + min: 0.0 + max: 1.0 + help: + zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 + en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: top_p + use_template: top_p + default: 0.7 + help: + zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 + en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: do_sample + label: + zh_Hans: 采样策略 + en_US: Sampling strategy + type: boolean + help: + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 4095 + - name: web_search + type: boolean + label: + zh_Hans: 联网搜索 + en_US: Web Search + default: false + help: + zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 + en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. +pricing: + input: '0.05' + output: '0.05' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml index ddea331c8e46f3..3baa298300a8e1 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml @@ -17,21 +17,35 @@ parameter_rules: en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. - name: top_p use_template: top_p - default: 0.7 + default: 0.6 help: zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. - - name: incremental + - name: do_sample label: - zh_Hans: 增量返回 - en_US: Incremental + zh_Hans: 采样策略 + en_US: Sampling strategy type: boolean help: - zh_Hans: SSE接口调用时,用于控制每次返回内容方式是增量还是全量,不提供此参数时默认为增量返回,true 为增量返回,false 为全量返回。 - en_US: When the SSE interface is called, it is used to control whether the content is returned incrementally or in full. If this parameter is not provided, the default is incremental return. true means incremental return, false means full return. - required: false + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true - name: max_tokens use_template: max_tokens default: 1024 min: 1 - max: 8192 + max: 1024 + - name: web_search + type: boolean + label: + zh_Hans: 联网搜索 + en_US: Web Search + default: false + help: + zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 + en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. +pricing: + input: '0.05' + output: '0.05' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml new file mode 100644 index 00000000000000..91550ceee8e6fb --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml @@ -0,0 +1,51 @@ +model: glm-4v-plus +label: + en_US: glm-4v-plus +model_type: llm +model_properties: + mode: chat +features: + - vision +parameter_rules: + - name: temperature + use_template: temperature + default: 0.95 + min: 0.0 + max: 1.0 + help: + zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 + en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: top_p + use_template: top_p + default: 0.6 + help: + zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。 + en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time. + - name: do_sample + label: + zh_Hans: 采样策略 + en_US: Sampling strategy + type: boolean + help: + zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。 + en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true. + default: true + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 1024 + - name: web_search + type: boolean + label: + zh_Hans: 联网搜索 + en_US: Web Search + default: false + help: + zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。 + en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic. +pricing: + input: '0.01' + output: '0.01' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index ff971964a8603e..eddb94aba35a93 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -1,6 +1,10 @@ from collections.abc import Generator from typing import Optional, Union +from zhipuai import ZhipuAI +from zhipuai.types.chat.chat_completion import Completion +from zhipuai.types.chat.chat_completion_chunk import ChatCompletionChunk + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -16,9 +20,6 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI -from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI -from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion import Completion -from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk from core.model_runtime.utils import helper GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object. @@ -31,16 +32,21 @@ {{instructions}} -```JSON""" +```JSON""" # noqa: E501 class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): - - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, user: Optional[str] = None) \ - -> Union[LLMResult, Generator]: + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -62,9 +68,9 @@ def _invoke(self, model: str, credentials: dict, # self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user) - # def _transform_json_prompts(self, model: str, credentials: dict, - # prompt_messages: list[PromptMessage], model_parameters: dict, - # tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + # def _transform_json_prompts(self, model: str, credentials: dict, + # prompt_messages: list[PromptMessage], model_parameters: dict, + # tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, # stream: bool = True, user: str | None = None) \ # -> None: # """ @@ -94,8 +100,13 @@ def _invoke(self, model: str, credentials: dict, # content="```JSON\n" # )) - def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> int: + def get_num_tokens( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None, + ) -> int: """ Get number of tokens for given prompt messages @@ -130,16 +141,22 @@ def validate_credentials(self, model: str, credentials: dict) -> None: "temperature": 0.5, }, tools=[], - stream=False + stream=False, ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials_kwargs: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, stream: bool = True, - user: Optional[str] = None) -> Union[LLMResult, Generator]: + def _generate( + self, + model: str, + credentials_kwargs: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -153,15 +170,14 @@ def _generate(self, model: str, credentials_kwargs: dict, :return: full response or stream response chunk generator result """ extra_model_kwargs = {} - if stop: - extra_model_kwargs['stop'] = stop + # request to glm-4v-plus with stop words will always response "finish_reason":"network_error" + if stop and model != "glm-4v-plus": + extra_model_kwargs["stop"] = stop - client = ZhipuAI( - api_key=credentials_kwargs['api_key'] - ) + client = ZhipuAI(api_key=credentials_kwargs["api_key"]) if len(prompt_messages) == 0: - raise ValueError('At least one message is required') + raise ValueError("At least one message is required") if prompt_messages[0].role == PromptMessageRole.SYSTEM: if not prompt_messages[0].content: @@ -171,13 +187,13 @@ def _generate(self, model: str, credentials_kwargs: dict, new_prompt_messages: list[PromptMessage] = [] for prompt_message in prompt_messages: copy_prompt_message = prompt_message.copy() - if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]: + if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL}: if isinstance(copy_prompt_message.content, list): # check if model is 'glm-4v' - if model != 'glm-4v': + if model not in {"glm-4v", "glm-4v-plus"}: # not support list message continue - # get image and + # get image and if not isinstance(copy_prompt_message, UserPromptMessage): # not support system message continue @@ -187,13 +203,14 @@ def _generate(self, model: str, credentials_kwargs: dict, # not support image message continue - if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER and \ - copy_prompt_message.role == PromptMessageRole.USER: + if ( + new_prompt_messages + and new_prompt_messages[-1].role == PromptMessageRole.USER + and copy_prompt_message.role == PromptMessageRole.USER + ): new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content else: - if copy_prompt_message.role == PromptMessageRole.USER: - new_prompt_messages.append(copy_prompt_message) - elif copy_prompt_message.role == PromptMessageRole.TOOL: + if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.TOOL}: new_prompt_messages.append(copy_prompt_message) elif copy_prompt_message.role == PromptMessageRole.SYSTEM: new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content) @@ -207,77 +224,76 @@ def _generate(self, model: str, credentials_kwargs: dict, else: new_prompt_messages.append(copy_prompt_message) - if model == 'glm-4v': + # zhipuai moved web_search param to tools + if "web_search" in model_parameters: + enable_web_search = model_parameters.get("web_search") + model_parameters.pop("web_search") + web_search_params = {"type": "web_search", "web_search": {"enable": enable_web_search}} + if "tools" in model_parameters: + model_parameters["tools"].append(web_search_params) + else: + model_parameters["tools"] = [web_search_params] + + if model in {"glm-4v", "glm-4v-plus"}: params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters) else: - params = { - 'model': model, - 'messages': [], - **model_parameters - } + params = {"model": model, "messages": [], **model_parameters} # glm model - if not model.startswith('chatglm'): - + if not model.startswith("chatglm"): for prompt_message in new_prompt_messages: if prompt_message.role == PromptMessageRole.TOOL: - params['messages'].append({ - 'role': 'tool', - 'content': prompt_message.content, - 'tool_call_id': prompt_message.tool_call_id - }) + params["messages"].append( + { + "role": "tool", + "content": prompt_message.content, + "tool_call_id": prompt_message.tool_call_id, + } + ) elif isinstance(prompt_message, AssistantPromptMessage): if prompt_message.tool_calls: - params['messages'].append({ - 'role': 'assistant', - 'content': prompt_message.content, - 'tool_calls': [ - { - 'id': tool_call.id, - 'type': tool_call.type, - 'function': { - 'name': tool_call.function.name, - 'arguments': tool_call.function.arguments + params["messages"].append( + { + "role": "assistant", + "content": prompt_message.content, + "tool_calls": [ + { + "id": tool_call.id, + "type": tool_call.type, + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, } - } for tool_call in prompt_message.tool_calls - ] - }) + for tool_call in prompt_message.tool_calls + ], + } + ) else: - params['messages'].append({ - 'role': 'assistant', - 'content': prompt_message.content - }) + params["messages"].append({"role": "assistant", "content": prompt_message.content}) else: - params['messages'].append({ - 'role': prompt_message.role.value, - 'content': prompt_message.content - }) + params["messages"].append( + {"role": prompt_message.role.value, "content": prompt_message.content} + ) else: # chatglm model for prompt_message in new_prompt_messages: # merge system message to user message - if prompt_message.role == PromptMessageRole.SYSTEM or \ - prompt_message.role == PromptMessageRole.TOOL or \ - prompt_message.role == PromptMessageRole.USER: - if len(params['messages']) > 0 and params['messages'][-1]['role'] == 'user': - params['messages'][-1]['content'] += "\n\n" + prompt_message.content + if prompt_message.role in { + PromptMessageRole.SYSTEM, + PromptMessageRole.TOOL, + PromptMessageRole.USER, + }: + if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user": + params["messages"][-1]["content"] += "\n\n" + prompt_message.content else: - params['messages'].append({ - 'role': 'user', - 'content': prompt_message.content - }) + params["messages"].append({"role": "user", "content": prompt_message.content}) else: - params['messages'].append({ - 'role': prompt_message.role.value, - 'content': prompt_message.content - }) + params["messages"].append( + {"role": prompt_message.role.value, "content": prompt_message.content} + ) if tools and len(tools) > 0: - params['tools'] = [ - { - 'type': 'function', - 'function': helper.dump_model(tool) - } for tool in tools - ] + params["tools"] = [{"type": "function", "function": helper.dump_model(tool)} for tool in tools] if stream: response = client.chat.completions.create(stream=stream, **params, **extra_model_kwargs) @@ -286,47 +302,55 @@ def _generate(self, model: str, credentials_kwargs: dict, response = client.chat.completions.create(**params, **extra_model_kwargs) return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages) - def _construct_glm_4v_parameter(self, model: str, prompt_messages: list[PromptMessage], - model_parameters: dict): + def _construct_glm_4v_parameter(self, model: str, prompt_messages: list[PromptMessage], model_parameters: dict): messages = [ - { - 'role': message.role.value, - 'content': self._construct_glm_4v_messages(message.content) - } + {"role": message.role.value, "content": self._construct_glm_4v_messages(message.content)} for message in prompt_messages ] - params = { - 'model': model, - 'messages': messages, - **model_parameters - } + params = {"model": model, "messages": messages, **model_parameters} return params - def _construct_glm_4v_messages(self, prompt_message: Union[str | list[PromptMessageContent]]) -> list[dict]: - if isinstance(prompt_message, str): - return [{'type': 'text', 'text': prompt_message}] - - return [ - {'type': 'image_url', 'image_url': {'url': self._remove_image_header(item.data)}} - if item.type == PromptMessageContentType.IMAGE else - {'type': 'text', 'text': item.data} - - for item in prompt_message - ] - - def _remove_image_header(self, image: str) -> str: - if image.startswith('data:image'): - return image.split(',')[1] - - return image - - def _handle_generate_response(self, model: str, - credentials: dict, - tools: Optional[list[PromptMessageTool]], - response: Completion, - prompt_messages: list[PromptMessage]) -> LLMResult: + def _construct_glm_4v_messages(self, prompt_message: Union[str, list[PromptMessageContent]]) -> list[dict]: + if isinstance(prompt_message, list): + sub_messages = [] + for item in prompt_message: + if item.type == PromptMessageContentType.IMAGE: + sub_messages.append( + { + "type": "image_url", + "image_url": {"url": self._remove_base64_header(item.data)}, + } + ) + elif item.type == PromptMessageContentType.VIDEO: + sub_messages.append( + { + "type": "video_url", + "video_url": {"url": self._remove_base64_header(item.data)}, + } + ) + else: + sub_messages.append({"type": "text", "text": item.data}) + return sub_messages + else: + return [{"type": "text", "text": prompt_message}] + + def _remove_base64_header(self, file_content: str) -> str: + if file_content.startswith("data:"): + data_split = file_content.split(";base64,") + return data_split[1] + + return file_content + + def _handle_generate_response( + self, + model: str, + credentials: dict, + tools: Optional[list[PromptMessageTool]], + response: Completion, + prompt_messages: list[PromptMessage], + ) -> LLMResult: """ Handle llm response @@ -335,12 +359,12 @@ def _handle_generate_response(self, model: str, :param prompt_messages: prompt messages :return: llm response """ - text = '' + text = "" assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = [] for choice in response.choices: if choice.message.tool_calls: for tool_call in choice.message.tool_calls: - if tool_call.type == 'function': + if tool_call.type == "function": assistant_tool_calls.append( AssistantPromptMessage.ToolCall( id=tool_call.id, @@ -348,11 +372,11 @@ def _handle_generate_response(self, model: str, function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=tool_call.function.name, arguments=tool_call.function.arguments, - ) + ), ) ) - text += choice.message.content or '' + text += choice.message.content or "" prompt_usage = response.usage.prompt_tokens completion_usage = response.usage.completion_tokens @@ -364,20 +388,20 @@ def _handle_generate_response(self, model: str, result = LLMResult( model=model, prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=text, - tool_calls=assistant_tool_calls - ), + message=AssistantPromptMessage(content=text, tool_calls=assistant_tool_calls), usage=usage, ) return result - def _handle_generate_stream_response(self, model: str, - credentials: dict, - tools: Optional[list[PromptMessageTool]], - responses: Generator[ChatCompletionChunk, None, None], - prompt_messages: list[PromptMessage]) -> Generator: + def _handle_generate_stream_response( + self, + model: str, + credentials: dict, + tools: Optional[list[PromptMessageTool]], + responses: Generator[ChatCompletionChunk, None, None], + prompt_messages: list[PromptMessage], + ) -> Generator: """ Handle llm stream response @@ -386,19 +410,19 @@ def _handle_generate_stream_response(self, model: str, :param prompt_messages: prompt messages :return: llm response chunk generator result """ - full_assistant_content = '' + full_assistant_content = "" for chunk in responses: if len(chunk.choices) == 0: continue delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""): continue assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = [] for tool_call in delta.delta.tool_calls or []: - if tool_call.type == 'function': + if tool_call.type == "function": assistant_tool_calls.append( AssistantPromptMessage.ToolCall( id=tool_call.id, @@ -406,17 +430,16 @@ def _handle_generate_stream_response(self, model: str, function=AssistantPromptMessage.ToolCall.ToolCallFunction( name=tool_call.function.name, arguments=tool_call.function.arguments, - ) + ), ) ) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content=delta.delta.content if delta.delta.content else '', - tool_calls=assistant_tool_calls + content=delta.delta.content or "", tool_calls=assistant_tool_calls ) - full_assistant_content += delta.delta.content if delta.delta.content else '' + full_assistant_content += delta.delta.content or "" if delta.finish_reason is not None and chunk.usage is not None: completion_tokens = chunk.usage.completion_tokens @@ -428,23 +451,22 @@ def _handle_generate_stream_response(self, model: str, yield LLMResultChunk( model=chunk.model, prompt_messages=prompt_messages, - system_fingerprint='', + system_fingerprint="", delta=LLMResultChunkDelta( index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage - ) + usage=usage, + ), ) else: yield LLMResultChunk( model=chunk.model, prompt_messages=prompt_messages, - system_fingerprint='', + system_fingerprint="", delta=LLMResultChunkDelta( - index=delta.index, - message=assistant_prompt_message, - ) + index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason + ), ) def _convert_one_message_to_text(self, message: PromptMessage) -> str: @@ -462,27 +484,23 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str: message_text = f"{human_prompt} {content}" elif isinstance(message, AssistantPromptMessage): message_text = f"{ai_prompt} {content}" - elif isinstance(message, SystemPromptMessage): - message_text = content - elif isinstance(message, ToolPromptMessage): + elif isinstance(message, SystemPromptMessage | ToolPromptMessage): message_text = content else: raise ValueError(f"Got unknown type {message}") return message_text - def _convert_messages_to_prompt(self, messages: list[PromptMessage], - tools: Optional[list[PromptMessageTool]] = None) -> str: + def _convert_messages_to_prompt( + self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None + ) -> str: """ :param messages: List of PromptMessage to combine. :return: Combined string with necessary human_prompt and ai_prompt tags. """ messages = messages.copy() # don't mutate the original list - text = "".join( - self._convert_one_message_to_text(message) - for message in messages - ) + text = "".join(self._convert_one_message_to_text(message) for message in messages) if tools and len(tools) > 0: text += "\n\nTools:" diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index 0f9fecfc72e69c..f629b62fd5385b 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -1,12 +1,14 @@ import time from typing import Optional +from zhipuai import ZhipuAI + +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI -from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): @@ -14,9 +16,14 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): Model class for ZhipuAI text embedding model. """ - def _invoke(self, model: str, credentials: dict, - texts: list[str], user: Optional[str] = None) \ - -> TextEmbeddingResult: + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: """ Invoke text embedding model @@ -24,19 +31,18 @@ def _invoke(self, model: str, credentials: dict, :param credentials: model credentials :param texts: texts to embed :param user: unique user id + :param input_type: input type :return: embeddings result """ credentials_kwargs = self._to_credential_kwargs(credentials) - client = ZhipuAI( - api_key=credentials_kwargs['api_key'] - ) + client = ZhipuAI(api_key=credentials_kwargs["api_key"]) embeddings, embedding_used_tokens = self.embed_documents(model, client, texts) return TextEmbeddingResult( embeddings=embeddings, usage=self._calc_response_usage(model, credentials_kwargs, embedding_used_tokens), - model=model + model=model, ) def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: @@ -50,7 +56,7 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int """ if len(texts) == 0: return 0 - + total_num_tokens = 0 for text in texts: total_num_tokens += self._get_num_tokens_by_gpt2(text) @@ -68,15 +74,13 @@ def validate_credentials(self, model: str, credentials: dict) -> None: try: # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) - client = ZhipuAI( - api_key=credentials_kwargs['api_key'] - ) + client = ZhipuAI(api_key=credentials_kwargs["api_key"]) # call embedding model self.embed_documents( model=model, client=client, - texts=['ping'], + texts=["ping"], ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) @@ -100,7 +104,7 @@ def embed_documents(self, model: str, client: ZhipuAI, texts: list[str]) -> tupl embedding_used_tokens += response.usage.total_tokens return [list(map(float, e)) for e in embeddings], embedding_used_tokens - + def embed_query(self, text: str) -> list[float]: """Call out to ZhipuAI's embedding endpoint. @@ -111,8 +115,8 @@ def embed_query(self, text: str) -> list[float]: Embeddings for the text. """ return self.embed_documents([text])[0] - - def _calc_response_usage(self, model: str,credentials: dict, tokens: int) -> EmbeddingUsage: + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage @@ -122,10 +126,7 @@ def _calc_response_usage(self, model: str,credentials: dict, tokens: int) -> Emb """ # get input price info input_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=tokens + model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens ) # transform usage @@ -136,7 +137,7 @@ def _calc_response_usage(self, model: str,credentials: dict, tokens: int) -> Emb price_unit=input_price_info.unit, total_price=input_price_info.total_amount, currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at + latency=time.perf_counter() - self.started_at, ) return usage diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai.py index c517d2dba5a2d1..e75aad6eb0eb53 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai.py @@ -19,12 +19,9 @@ def validate_provider_credentials(self, credentials: dict) -> None: try: model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials( - model='glm-4', - credentials=credentials - ) + model_instance.validate_credentials(model="glm-4", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: - logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") raise ex diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py deleted file mode 100644 index 4dcd03f5511b6f..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ - -from .__version__ import __version__ -from ._client import ZhipuAI -from .core._errors import ( - APIAuthenticationError, - APIInternalError, - APIReachLimitError, - APIRequestFailedError, - APIResponseError, - APIResponseValidationError, - APIServerFlowExceedError, - APIStatusError, - APITimeoutError, - ZhipuAIError, -) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py deleted file mode 100644 index eb0ad332ca80af..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py +++ /dev/null @@ -1,2 +0,0 @@ - -__version__ = 'v2.0.1' \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py deleted file mode 100644 index 6588d1dd684900..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py +++ /dev/null @@ -1,70 +0,0 @@ -from __future__ import annotations - -import os -from collections.abc import Mapping -from typing import Union - -import httpx -from httpx import Timeout -from typing_extensions import override - -from . import api_resource -from .core import _jwt_token -from .core._base_type import NOT_GIVEN, NotGiven -from .core._errors import ZhipuAIError -from .core._http_client import ZHIPUAI_DEFAULT_MAX_RETRIES, HttpClient - - -class ZhipuAI(HttpClient): - chat: api_resource.chat - api_key: str - - def __init__( - self, - *, - api_key: str | None = None, - base_url: str | httpx.URL | None = None, - timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, - max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES, - http_client: httpx.Client | None = None, - custom_headers: Mapping[str, str] | None = None - ) -> None: - if api_key is None: - raise ZhipuAIError("No api_key provided, please provide it through parameters or environment variables") - self.api_key = api_key - - if base_url is None: - base_url = os.environ.get("ZHIPUAI_BASE_URL") - if base_url is None: - base_url = "https://open.bigmodel.cn/api/paas/v4" - from .__version__ import __version__ - super().__init__( - version=__version__, - base_url=base_url, - timeout=timeout, - custom_httpx_client=http_client, - custom_headers=custom_headers, - ) - self.chat = api_resource.chat.Chat(self) - self.images = api_resource.images.Images(self) - self.embeddings = api_resource.embeddings.Embeddings(self) - self.files = api_resource.files.Files(self) - self.fine_tuning = api_resource.fine_tuning.FineTuning(self) - - @property - @override - def _auth_headers(self) -> dict[str, str]: - api_key = self.api_key - return {"Authorization": f"{_jwt_token.generate_token(api_key)}"} - - def __del__(self) -> None: - if (not hasattr(self, "_has_custom_http_client") - or not hasattr(self, "close") - or not hasattr(self, "_client")): - # if the '__init__' method raised an error, self would not have client attr - return - - if self._has_custom_http_client: - return - - self.close() diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/__init__.py deleted file mode 100644 index 0a90e21e48bcca..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .chat import chat -from .embeddings import Embeddings -from .files import Files -from .fine_tuning import fine_tuning -from .images import Images diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py deleted file mode 100644 index dab6dac5fe979c..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py +++ /dev/null @@ -1,86 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Literal, Optional, Union - -import httpx - -from ...core._base_api import BaseAPI -from ...core._base_type import NOT_GIVEN, Headers, NotGiven -from ...core._http_client import make_user_request_input -from ...types.chat.async_chat_completion import AsyncCompletion, AsyncTaskStatus - -if TYPE_CHECKING: - from ..._client import ZhipuAI - - -class AsyncCompletions(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: - super().__init__(client) - - - def create( - self, - *, - model: str, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - temperature: Optional[float] | NotGiven = NOT_GIVEN, - top_p: Optional[float] | NotGiven = NOT_GIVEN, - max_tokens: int | NotGiven = NOT_GIVEN, - seed: int | NotGiven = NOT_GIVEN, - messages: Union[str, list[str], list[int], list[list[int]], None], - stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, - tools: Optional[object] | NotGiven = NOT_GIVEN, - tool_choice: str | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> AsyncTaskStatus: - _cast_type = AsyncTaskStatus - - if disable_strict_validation: - _cast_type = object - return self._post( - "/async/chat/completions", - body={ - "model": model, - "request_id": request_id, - "temperature": temperature, - "top_p": top_p, - "do_sample": do_sample, - "max_tokens": max_tokens, - "seed": seed, - "messages": messages, - "stop": stop, - "sensitive_word_check": sensitive_word_check, - "tools": tools, - "tool_choice": tool_choice, - }, - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), - cast_type=_cast_type, - enable_stream=False, - ) - - def retrieve_completion_result( - self, - id: str, - extra_headers: Headers | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Union[AsyncCompletion, AsyncTaskStatus]: - _cast_type = Union[AsyncCompletion,AsyncTaskStatus] - if disable_strict_validation: - _cast_type = object - return self._get( - path=f"/async-result/{id}", - cast_type=_cast_type, - options=make_user_request_input( - extra_headers=extra_headers, - timeout=timeout - ) - ) - - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/chat.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/chat.py deleted file mode 100644 index 92362fc50a7252..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/chat.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import TYPE_CHECKING - -from ...core._base_api import BaseAPI -from .async_completions import AsyncCompletions -from .completions import Completions - -if TYPE_CHECKING: - from ..._client import ZhipuAI - - -class Chat(BaseAPI): - completions: Completions - - def __init__(self, client: "ZhipuAI") -> None: - super().__init__(client) - self.completions = Completions(client) - self.asyncCompletions = AsyncCompletions(client) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py deleted file mode 100644 index 5c4ed4d1ba5922..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py +++ /dev/null @@ -1,70 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Literal, Optional, Union - -import httpx - -from ...core._base_api import BaseAPI -from ...core._base_type import NOT_GIVEN, Headers, NotGiven -from ...core._http_client import make_user_request_input -from ...core._sse_client import StreamResponse -from ...types.chat.chat_completion import Completion -from ...types.chat.chat_completion_chunk import ChatCompletionChunk - -if TYPE_CHECKING: - from ..._client import ZhipuAI - - -class Completions(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: - super().__init__(client) - - def create( - self, - *, - model: str, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, - temperature: Optional[float] | NotGiven = NOT_GIVEN, - top_p: Optional[float] | NotGiven = NOT_GIVEN, - max_tokens: int | NotGiven = NOT_GIVEN, - seed: int | NotGiven = NOT_GIVEN, - messages: Union[str, list[str], list[int], object, None], - stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, - tools: Optional[object] | NotGiven = NOT_GIVEN, - tool_choice: str | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Completion | StreamResponse[ChatCompletionChunk]: - _cast_type = Completion - _stream_cls = StreamResponse[ChatCompletionChunk] - if disable_strict_validation: - _cast_type = object - _stream_cls = StreamResponse[object] - return self._post( - "/chat/completions", - body={ - "model": model, - "request_id": request_id, - "temperature": temperature, - "top_p": top_p, - "do_sample": do_sample, - "max_tokens": max_tokens, - "seed": seed, - "messages": messages, - "stop": stop, - "sensitive_word_check": sensitive_word_check, - "stream": stream, - "tools": tools, - "tool_choice": tool_choice, - }, - options=make_user_request_input( - extra_headers=extra_headers, - ), - cast_type=_cast_type, - enable_stream=stream or False, - stream_cls=_stream_cls, - ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py deleted file mode 100644 index 35d54592fd55c0..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py +++ /dev/null @@ -1,49 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional, Union - -import httpx - -from ..core._base_api import BaseAPI -from ..core._base_type import NOT_GIVEN, Headers, NotGiven -from ..core._http_client import make_user_request_input -from ..types.embeddings import EmbeddingsResponded - -if TYPE_CHECKING: - from .._client import ZhipuAI - - -class Embeddings(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: - super().__init__(client) - - def create( - self, - *, - input: Union[str, list[str], list[int], list[list[int]]], - model: Union[str], - encoding_format: str | NotGiven = NOT_GIVEN, - user: str | NotGiven = NOT_GIVEN, - sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> EmbeddingsResponded: - _cast_type = EmbeddingsResponded - if disable_strict_validation: - _cast_type = object - return self._post( - "/embeddings", - body={ - "input": input, - "model": model, - "encoding_format": encoding_format, - "user": user, - "sensitive_word_check": sensitive_word_check, - }, - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), - cast_type=_cast_type, - enable_stream=False, - ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py deleted file mode 100644 index 5deb8d08f3405b..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py +++ /dev/null @@ -1,76 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import httpx - -from ..core._base_api import BaseAPI -from ..core._base_type import NOT_GIVEN, FileTypes, Headers, NotGiven -from ..core._files import is_file_content -from ..core._http_client import make_user_request_input -from ..types.file_object import FileObject, ListOfFileObject - -if TYPE_CHECKING: - from .._client import ZhipuAI - -__all__ = ["Files"] - - -class Files(BaseAPI): - - def __init__(self, client: ZhipuAI) -> None: - super().__init__(client) - - def create( - self, - *, - file: FileTypes, - purpose: str, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> FileObject: - if not is_file_content(file): - prefix = f"Expected file input `{file!r}`" - raise RuntimeError( - f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(file)} instead." - ) from None - files = [("file", file)] - - extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} - - return self._post( - "/files", - body={ - "purpose": purpose, - }, - files=files, - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), - cast_type=FileObject, - ) - - def list( - self, - *, - purpose: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - after: str | NotGiven = NOT_GIVEN, - order: str | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> ListOfFileObject: - return self._get( - "/files", - cast_type=ListOfFileObject, - options=make_user_request_input( - extra_headers=extra_headers, - timeout=timeout, - query={ - "purpose": purpose, - "limit": limit, - "after": after, - "order": order, - }, - ), - ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py deleted file mode 100644 index dc54a9ca4567e3..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import TYPE_CHECKING - -from ...core._base_api import BaseAPI -from .jobs import Jobs - -if TYPE_CHECKING: - from ..._client import ZhipuAI - - -class FineTuning(BaseAPI): - jobs: Jobs - - def __init__(self, client: "ZhipuAI") -> None: - super().__init__(client) - self.jobs = Jobs(client) - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py deleted file mode 100644 index b860de192a612f..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py +++ /dev/null @@ -1,108 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional - -import httpx - -from ...core._base_api import BaseAPI -from ...core._base_type import NOT_GIVEN, Headers, NotGiven -from ...core._http_client import make_user_request_input -from ...types.fine_tuning import FineTuningJob, FineTuningJobEvent, ListOfFineTuningJob, job_create_params - -if TYPE_CHECKING: - from ..._client import ZhipuAI - -__all__ = ["Jobs"] - - -class Jobs(BaseAPI): - - def __init__(self, client: ZhipuAI) -> None: - super().__init__(client) - - def create( - self, - *, - model: str, - training_file: str, - hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN, - suffix: Optional[str] | NotGiven = NOT_GIVEN, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - validation_file: Optional[str] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> FineTuningJob: - return self._post( - "/fine_tuning/jobs", - body={ - "model": model, - "training_file": training_file, - "hyperparameters": hyperparameters, - "suffix": suffix, - "validation_file": validation_file, - "request_id": request_id, - }, - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), - cast_type=FineTuningJob, - ) - - def retrieve( - self, - fine_tuning_job_id: str, - *, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> FineTuningJob: - return self._get( - f"/fine_tuning/jobs/{fine_tuning_job_id}", - options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout - ), - cast_type=FineTuningJob, - ) - - def list( - self, - *, - after: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> ListOfFineTuningJob: - return self._get( - "/fine_tuning/jobs", - cast_type=ListOfFineTuningJob, - options=make_user_request_input( - extra_headers=extra_headers, - timeout=timeout, - query={ - "after": after, - "limit": limit, - }, - ), - ) - - def list_events( - self, - fine_tuning_job_id: str, - *, - after: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> FineTuningJobEvent: - - return self._get( - f"/fine_tuning/jobs/{fine_tuning_job_id}/events", - cast_type=FineTuningJobEvent, - options=make_user_request_input( - extra_headers=extra_headers, - timeout=timeout, - query={ - "after": after, - "limit": limit, - }, - ), - ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py deleted file mode 100644 index 8eae1216d09e89..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py +++ /dev/null @@ -1,60 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional - -import httpx - -from ..core._base_api import BaseAPI -from ..core._base_type import NOT_GIVEN, Body, Headers, NotGiven -from ..core._http_client import make_user_request_input -from ..types.image import ImagesResponded - -if TYPE_CHECKING: - from .._client import ZhipuAI - - -class Images(BaseAPI): - def __init__(self, client: ZhipuAI) -> None: - super().__init__(client) - - def generations( - self, - *, - prompt: str, - model: str | NotGiven = NOT_GIVEN, - n: Optional[int] | NotGiven = NOT_GIVEN, - quality: Optional[str] | NotGiven = NOT_GIVEN, - response_format: Optional[str] | NotGiven = NOT_GIVEN, - size: Optional[str] | NotGiven = NOT_GIVEN, - style: Optional[str] | NotGiven = NOT_GIVEN, - user: str | NotGiven = NOT_GIVEN, - request_id: Optional[str] | NotGiven = NOT_GIVEN, - extra_headers: Headers | None = None, - extra_body: Body | None = None, - disable_strict_validation: Optional[bool] | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> ImagesResponded: - _cast_type = ImagesResponded - if disable_strict_validation: - _cast_type = object - return self._post( - "/images/generations", - body={ - "prompt": prompt, - "model": model, - "n": n, - "quality": quality, - "response_format": response_format, - "size": size, - "style": style, - "user": user, - "request_id": request_id, - }, - options=make_user_request_input( - extra_headers=extra_headers, - extra_body=extra_body, - timeout=timeout - ), - cast_type=_cast_type, - enable_stream=False, - ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_api.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_api.py deleted file mode 100644 index 10b46ff8e381a3..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_api.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from .._client import ZhipuAI - - -class BaseAPI: - _client: ZhipuAI - - def __init__(self, client: ZhipuAI) -> None: - self._client = client - self._delete = client.delete - self._get = client.get - self._post = client.post - self._put = client.put - self._patch = client.patch diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py deleted file mode 100644 index b7cf6bb7fd06c4..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py +++ /dev/null @@ -1,106 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from os import PathLike -from typing import IO, TYPE_CHECKING, Any, Literal, TypeVar, Union - -import pydantic -from typing_extensions import override - -Query = Mapping[str, object] -Body = object -AnyMapping = Mapping[str, object] -PrimitiveData = Union[str, int, float, bool, None] -Data = Union[PrimitiveData, list[Any], tuple[Any], "Mapping[str, Any]"] -ModelT = TypeVar("ModelT", bound=pydantic.BaseModel) -_T = TypeVar("_T") - -if TYPE_CHECKING: - NoneType: type[None] -else: - NoneType = type(None) - - -# Sentinel class used until PEP 0661 is accepted -class NotGiven(pydantic.BaseModel): - """ - A sentinel singleton class used to distinguish omitted keyword arguments - from those passed in with the value None (which may have different behavior). - - For example: - - ```py - def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ... - - get(timeout=1) # 1s timeout - get(timeout=None) # No timeout - get() # Default timeout behavior, which may not be statically known at the method definition. - ``` - """ - - def __bool__(self) -> Literal[False]: - return False - - @override - def __repr__(self) -> str: - return "NOT_GIVEN" - - -NotGivenOr = Union[_T, NotGiven] -NOT_GIVEN = NotGiven() - - -class Omit(pydantic.BaseModel): - """In certain situations you need to be able to represent a case where a default value has - to be explicitly removed and `None` is not an appropriate substitute, for example: - - ```py - # as the default `Content-Type` header is `application/json` that will be sent - client.post('/upload/files', files={'file': b'my raw file content'}) - - # you can't explicitly override the header as it has to be dynamically generated - # to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983' - client.post(..., headers={'Content-Type': 'multipart/form-data'}) - - # instead you can remove the default `application/json` header by passing Omit - client.post(..., headers={'Content-Type': Omit()}) - ``` - """ - - def __bool__(self) -> Literal[False]: - return False - - -Headers = Mapping[str, Union[str, Omit]] - -ResponseT = TypeVar( - "ResponseT", - bound="Union[str, None, BaseModel, list[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]", -) - -# for user input files -if TYPE_CHECKING: - FileContent = Union[IO[bytes], bytes, PathLike[str]] -else: - FileContent = Union[IO[bytes], bytes, PathLike] - -FileTypes = Union[ - FileContent, # file content - tuple[str, FileContent], # (filename, file) - tuple[str, FileContent, str], # (filename, file , content_type) - tuple[str, FileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) -] - -RequestFiles = Union[Mapping[str, FileTypes], Sequence[tuple[str, FileTypes]]] - -# for httpx client supported files - -HttpxFileContent = Union[bytes, IO[bytes]] -HttpxFileTypes = Union[ - FileContent, # file content - tuple[str, HttpxFileContent], # (filename, file) - tuple[str, HttpxFileContent, str], # (filename, file , content_type) - tuple[str, HttpxFileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) -] - -HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[tuple[str, HttpxFileTypes]]] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py deleted file mode 100644 index a2a438b8f3d355..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py +++ /dev/null @@ -1,90 +0,0 @@ -from __future__ import annotations - -import httpx - -__all__ = [ - "ZhipuAIError", - "APIStatusError", - "APIRequestFailedError", - "APIAuthenticationError", - "APIReachLimitError", - "APIInternalError", - "APIServerFlowExceedError", - "APIResponseError", - "APIResponseValidationError", - "APITimeoutError", -] - - -class ZhipuAIError(Exception): - def __init__(self, message: str, ) -> None: - super().__init__(message) - - -class APIStatusError(Exception): - response: httpx.Response - status_code: int - - def __init__(self, message: str, *, response: httpx.Response) -> None: - super().__init__(message) - self.response = response - self.status_code = response.status_code - - -class APIRequestFailedError(APIStatusError): - ... - - -class APIAuthenticationError(APIStatusError): - ... - - -class APIReachLimitError(APIStatusError): - ... - - -class APIInternalError(APIStatusError): - ... - - -class APIServerFlowExceedError(APIStatusError): - ... - - -class APIResponseError(Exception): - message: str - request: httpx.Request - json_data: object - - def __init__(self, message: str, request: httpx.Request, json_data: object): - self.message = message - self.request = request - self.json_data = json_data - super().__init__(message) - - -class APIResponseValidationError(APIResponseError): - status_code: int - response: httpx.Response - - def __init__( - self, - response: httpx.Response, - json_data: object | None, *, - message: str | None = None - ) -> None: - super().__init__( - message=message or "Data returned by API invalid for expected schema.", - request=response.request, - json_data=json_data - ) - self.response = response - self.status_code = response.status_code - - -class APITimeoutError(Exception): - request: httpx.Request - - def __init__(self, request: httpx.Request): - self.request = request - super().__init__("Request Timeout") diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py deleted file mode 100644 index 0796bfe11cc658..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py +++ /dev/null @@ -1,41 +0,0 @@ -from __future__ import annotations - -import io -import os -from collections.abc import Mapping, Sequence -from pathlib import Path - -from ._base_type import FileTypes, HttpxFileTypes, HttpxRequestFiles, RequestFiles - - -def is_file_content(obj: object) -> bool: - return isinstance(obj, bytes | tuple | io.IOBase | os.PathLike) - - -def _transform_file(file: FileTypes) -> HttpxFileTypes: - if is_file_content(file): - if isinstance(file, os.PathLike): - path = Path(file) - return path.name, path.read_bytes() - else: - return file - if isinstance(file, tuple): - if isinstance(file[1], os.PathLike): - return (file[0], Path(file[1]).read_bytes(), *file[2:]) - else: - return (file[0], file[1], *file[2:]) - else: - raise TypeError(f"Unexpected input file with type {type(file)},Expected FileContent type or tuple type") - - -def make_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None: - if files is None: - return None - - if isinstance(files, Mapping): - files = {key: _transform_file(file) for key, file in files.items()} - elif isinstance(files, Sequence): - files = [(key, _transform_file(file)) for key, file in files] - else: - raise TypeError(f"Unexpected input file with type {type(files)}, excepted Mapping or Sequence") - return files diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py deleted file mode 100644 index 263fe829901c83..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py +++ /dev/null @@ -1,377 +0,0 @@ -from __future__ import annotations - -import inspect -from collections.abc import Mapping -from typing import Any, Union, cast - -import httpx -import pydantic -from httpx import URL, Timeout -from tenacity import retry -from tenacity.stop import stop_after_attempt - -from . import _errors -from ._base_type import NOT_GIVEN, AnyMapping, Body, Data, Headers, NotGiven, Query, RequestFiles, ResponseT -from ._errors import APIResponseValidationError, APIStatusError, APITimeoutError -from ._files import make_httpx_files -from ._request_opt import ClientRequestParam, UserRequestInput -from ._response import HttpResponse -from ._sse_client import StreamResponse -from ._utils import flatten - -headers = { - "Accept": "application/json", - "Content-Type": "application/json; charset=UTF-8", -} - - -def _merge_map(map1: Mapping, map2: Mapping) -> Mapping: - merged = {**map1, **map2} - return {key: val for key, val in merged.items() if val is not None} - - -from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT - -ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0) -ZHIPUAI_DEFAULT_MAX_RETRIES = 3 -ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=5, max_keepalive_connections=5) - - -class HttpClient: - _client: httpx.Client - _version: str - _base_url: URL - - timeout: Union[float, Timeout, None] - _limits: httpx.Limits - _has_custom_http_client: bool - _default_stream_cls: type[StreamResponse[Any]] | None = None - - def __init__( - self, - *, - version: str, - base_url: URL, - timeout: Union[float, Timeout, None], - custom_httpx_client: httpx.Client | None = None, - custom_headers: Mapping[str, str] | None = None, - ) -> None: - if timeout is None or isinstance(timeout, NotGiven): - if custom_httpx_client and custom_httpx_client.timeout != HTTPX_DEFAULT_TIMEOUT: - timeout = custom_httpx_client.timeout - else: - timeout = ZHIPUAI_DEFAULT_TIMEOUT - self.timeout = cast(Timeout, timeout) - self._has_custom_http_client = bool(custom_httpx_client) - self._client = custom_httpx_client or httpx.Client( - base_url=base_url, - timeout=self.timeout, - limits=ZHIPUAI_DEFAULT_LIMITS, - ) - self._version = version - url = URL(url=base_url) - if not url.raw_path.endswith(b"/"): - url = url.copy_with(raw_path=url.raw_path + b"/") - self._base_url = url - self._custom_headers = custom_headers or {} - - def _prepare_url(self, url: str) -> URL: - - sub_url = URL(url) - if sub_url.is_relative_url: - request_raw_url = self._base_url.raw_path + sub_url.raw_path.lstrip(b"/") - return self._base_url.copy_with(raw_path=request_raw_url) - - return sub_url - - @property - def _default_headers(self): - return \ - { - "Accept": "application/json", - "Content-Type": "application/json; charset=UTF-8", - "ZhipuAI-SDK-Ver": self._version, - "source_type": "zhipu-sdk-python", - "x-request-sdk": "zhipu-sdk-python", - **self._auth_headers, - **self._custom_headers, - } - - @property - def _auth_headers(self): - return {} - - def _prepare_headers(self, request_param: ClientRequestParam) -> httpx.Headers: - custom_headers = request_param.headers or {} - headers_dict = _merge_map(self._default_headers, custom_headers) - - httpx_headers = httpx.Headers(headers_dict) - - return httpx_headers - - def _prepare_request( - self, - request_param: ClientRequestParam - ) -> httpx.Request: - kwargs: dict[str, Any] = {} - json_data = request_param.json_data - headers = self._prepare_headers(request_param) - url = self._prepare_url(request_param.url) - json_data = request_param.json_data - if headers.get("Content-Type") == "multipart/form-data": - headers.pop("Content-Type") - - if json_data: - kwargs["data"] = self._make_multipartform(json_data) - - return self._client.build_request( - headers=headers, - timeout=self.timeout if isinstance(request_param.timeout, NotGiven) else request_param.timeout, - method=request_param.method, - url=url, - json=json_data, - files=request_param.files, - params=request_param.params, - **kwargs, - ) - - def _object_to_formfata(self, key: str, value: Data | Mapping[object, object]) -> list[tuple[str, str]]: - items = [] - - if isinstance(value, Mapping): - for k, v in value.items(): - items.extend(self._object_to_formfata(f"{key}[{k}]", v)) - return items - if isinstance(value, list | tuple): - for v in value: - items.extend(self._object_to_formfata(key + "[]", v)) - return items - - def _primitive_value_to_str(val) -> str: - # copied from httpx - if val is True: - return "true" - elif val is False: - return "false" - elif val is None: - return "" - return str(val) - - str_data = _primitive_value_to_str(value) - - if not str_data: - return [] - return [(key, str_data)] - - def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]: - - items = flatten([self._object_to_formfata(k, v) for k, v in data.items()]) - - serialized: dict[str, object] = {} - for key, value in items: - if key in serialized: - raise ValueError(f"存在重复的键: {key};") - serialized[key] = value - return serialized - - def _parse_response( - self, - *, - cast_type: type[ResponseT], - response: httpx.Response, - enable_stream: bool, - request_param: ClientRequestParam, - stream_cls: type[StreamResponse[Any]] | None = None, - ) -> HttpResponse: - - http_response = HttpResponse( - raw_response=response, - cast_type=cast_type, - client=self, - enable_stream=enable_stream, - stream_cls=stream_cls - ) - return http_response.parse() - - def _process_response_data( - self, - *, - data: object, - cast_type: type[ResponseT], - response: httpx.Response, - ) -> ResponseT: - if data is None: - return cast(ResponseT, None) - - try: - if inspect.isclass(cast_type) and issubclass(cast_type, pydantic.BaseModel): - return cast(ResponseT, cast_type.validate(data)) - - return cast(ResponseT, pydantic.TypeAdapter(cast_type).validate_python(data)) - except pydantic.ValidationError as err: - raise APIResponseValidationError(response=response, json_data=data) from err - - def is_closed(self) -> bool: - return self._client.is_closed - - def close(self): - self._client.close() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - - @retry(stop=stop_after_attempt(ZHIPUAI_DEFAULT_MAX_RETRIES)) - def request( - self, - *, - cast_type: type[ResponseT], - params: ClientRequestParam, - enable_stream: bool = False, - stream_cls: type[StreamResponse[Any]] | None = None, - ) -> ResponseT | StreamResponse: - request = self._prepare_request(params) - - try: - response = self._client.send( - request, - stream=enable_stream, - ) - response.raise_for_status() - except httpx.TimeoutException as err: - raise APITimeoutError(request=request) from err - except httpx.HTTPStatusError as err: - err.response.read() - # raise err - raise self._make_status_error(err.response) from None - - except Exception as err: - raise err - - return self._parse_response( - cast_type=cast_type, - request_param=params, - response=response, - enable_stream=enable_stream, - stream_cls=stream_cls, - ) - - def get( - self, - path: str, - *, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - enable_stream: bool = False, - ) -> ResponseT | StreamResponse: - opts = ClientRequestParam.construct(method="get", url=path, **options) - return self.request( - cast_type=cast_type, params=opts, - enable_stream=enable_stream - ) - - def post( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - files: RequestFiles | None = None, - enable_stream: bool = False, - stream_cls: type[StreamResponse[Any]] | None = None, - ) -> ResponseT | StreamResponse: - opts = ClientRequestParam.construct(method="post", json_data=body, files=make_httpx_files(files), url=path, - **options) - - return self.request( - cast_type=cast_type, params=opts, - enable_stream=enable_stream, - stream_cls=stream_cls - ) - - def patch( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - ) -> ResponseT: - opts = ClientRequestParam.construct(method="patch", url=path, json_data=body, **options) - - return self.request( - cast_type=cast_type, params=opts, - ) - - def put( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - files: RequestFiles | None = None, - ) -> ResponseT | StreamResponse: - opts = ClientRequestParam.construct(method="put", url=path, json_data=body, files=make_httpx_files(files), - **options) - - return self.request( - cast_type=cast_type, params=opts, - ) - - def delete( - self, - path: str, - *, - body: Body | None = None, - cast_type: type[ResponseT], - options: UserRequestInput = {}, - ) -> ResponseT | StreamResponse: - opts = ClientRequestParam.construct(method="delete", url=path, json_data=body, **options) - - return self.request( - cast_type=cast_type, params=opts, - ) - - def _make_status_error(self, response) -> APIStatusError: - response_text = response.text.strip() - status_code = response.status_code - error_msg = f"Error code: {status_code}, with error text {response_text}" - - if status_code == 400: - return _errors.APIRequestFailedError(message=error_msg, response=response) - elif status_code == 401: - return _errors.APIAuthenticationError(message=error_msg, response=response) - elif status_code == 429: - return _errors.APIReachLimitError(message=error_msg, response=response) - elif status_code == 500: - return _errors.APIInternalError(message=error_msg, response=response) - elif status_code == 503: - return _errors.APIServerFlowExceedError(message=error_msg, response=response) - return APIStatusError(message=error_msg, response=response) - - -def make_user_request_input( - max_retries: int | None = None, - timeout: float | Timeout | None | NotGiven = NOT_GIVEN, - extra_headers: Headers = None, - extra_body: Body | None = None, - query: Query | None = None, -) -> UserRequestInput: - options: UserRequestInput = {} - - if extra_headers is not None: - options["headers"] = extra_headers - if max_retries is not None: - options["max_retries"] = max_retries - if not isinstance(timeout, NotGiven): - options['timeout'] = timeout - if query is not None: - options["params"] = query - if extra_body is not None: - options["extra_json"] = cast(AnyMapping, extra_body) - - return options diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py deleted file mode 100644 index b0a91d04a99447..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py +++ /dev/null @@ -1,29 +0,0 @@ -import time - -import cachetools.func -import jwt - -API_TOKEN_TTL_SECONDS = 3 * 60 - -CACHE_TTL_SECONDS = API_TOKEN_TTL_SECONDS - 30 - - -@cachetools.func.ttl_cache(maxsize=10, ttl=CACHE_TTL_SECONDS) -def generate_token(apikey: str): - try: - api_key, secret = apikey.split(".") - except Exception as e: - raise Exception("invalid api_key", e) - - payload = { - "api_key": api_key, - "exp": int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000, - "timestamp": int(round(time.time() * 1000)), - } - ret = jwt.encode( - payload, - secret, - algorithm="HS256", - headers={"alg": "HS256", "sign_type": "SIGN"}, - ) - return ret diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py deleted file mode 100644 index a3f49ba8461e03..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py +++ /dev/null @@ -1,51 +0,0 @@ -from __future__ import annotations - -from typing import Any, ClassVar, Union - -from httpx import Timeout -from pydantic import ConfigDict -from typing_extensions import TypedDict, Unpack - -from ._base_type import Body, Headers, HttpxRequestFiles, NotGiven, Query -from ._utils import remove_notgiven_indict - - -class UserRequestInput(TypedDict, total=False): - max_retries: int - timeout: float | Timeout | None - headers: Headers - params: Query | None - - -class ClientRequestParam: - method: str - url: str - max_retries: Union[int, NotGiven] = NotGiven() - timeout: Union[float, NotGiven] = NotGiven() - headers: Union[Headers, NotGiven] = NotGiven() - json_data: Union[Body, None] = None - files: Union[HttpxRequestFiles, None] = None - params: Query = {} - model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) - - def get_max_retries(self, max_retries) -> int: - if isinstance(self.max_retries, NotGiven): - return max_retries - return self.max_retries - - @classmethod - def construct( # type: ignore - cls, - _fields_set: set[str] | None = None, - **values: Unpack[UserRequestInput], - ) -> ClientRequestParam : - kwargs: dict[str, Any] = { - key: remove_notgiven_indict(value) for key, value in values.items() - } - client = cls() - client.__dict__.update(kwargs) - - return client - - model_construct = construct - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py deleted file mode 100644 index 2f831b6fc9ca73..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py +++ /dev/null @@ -1,121 +0,0 @@ -from __future__ import annotations - -import datetime -from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, get_args, get_origin - -import httpx -import pydantic -from typing_extensions import ParamSpec - -from ._base_type import NoneType -from ._sse_client import StreamResponse - -if TYPE_CHECKING: - from ._http_client import HttpClient - -P = ParamSpec("P") -R = TypeVar("R") - - -class HttpResponse(Generic[R]): - _cast_type: type[R] - _client: HttpClient - _parsed: R | None - _enable_stream: bool - _stream_cls: type[StreamResponse[Any]] - http_response: httpx.Response - - def __init__( - self, - *, - raw_response: httpx.Response, - cast_type: type[R], - client: HttpClient, - enable_stream: bool = False, - stream_cls: type[StreamResponse[Any]] | None = None, - ) -> None: - self._cast_type = cast_type - self._client = client - self._parsed = None - self._stream_cls = stream_cls - self._enable_stream = enable_stream - self.http_response = raw_response - - def parse(self) -> R: - self._parsed = self._parse() - return self._parsed - - def _parse(self) -> R: - if self._enable_stream: - self._parsed = cast( - R, - self._stream_cls( - cast_type=cast(type, get_args(self._stream_cls)[0]), - response=self.http_response, - client=self._client - ) - ) - return self._parsed - cast_type = self._cast_type - if cast_type is NoneType: - return cast(R, None) - http_response = self.http_response - if cast_type == str: - return cast(R, http_response.text) - - content_type, *_ = http_response.headers.get("content-type", "application/json").split(";") - origin = get_origin(cast_type) or cast_type - if content_type != "application/json": - if issubclass(origin, pydantic.BaseModel): - data = http_response.json() - return self._client._process_response_data( - data=data, - cast_type=cast_type, # type: ignore - response=http_response, - ) - - return http_response.text - - data = http_response.json() - - return self._client._process_response_data( - data=data, - cast_type=cast_type, # type: ignore - response=http_response, - ) - - @property - def headers(self) -> httpx.Headers: - return self.http_response.headers - - @property - def http_request(self) -> httpx.Request: - return self.http_response.request - - @property - def status_code(self) -> int: - return self.http_response.status_code - - @property - def url(self) -> httpx.URL: - return self.http_response.url - - @property - def method(self) -> str: - return self.http_request.method - - @property - def content(self) -> bytes: - return self.http_response.content - - @property - def text(self) -> str: - return self.http_response.text - - @property - def http_version(self) -> str: - return self.http_response.http_version - - @property - def elapsed(self) -> datetime.timedelta: - return self.http_response.elapsed diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py deleted file mode 100644 index 66afbfd10780cc..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py +++ /dev/null @@ -1,148 +0,0 @@ -from __future__ import annotations - -import json -from collections.abc import Iterator, Mapping -from typing import TYPE_CHECKING, Generic - -import httpx - -from ._base_type import ResponseT -from ._errors import APIResponseError - -_FIELD_SEPARATOR = ":" - -if TYPE_CHECKING: - from ._http_client import HttpClient - - -class StreamResponse(Generic[ResponseT]): - - response: httpx.Response - _cast_type: type[ResponseT] - - def __init__( - self, - *, - cast_type: type[ResponseT], - response: httpx.Response, - client: HttpClient, - ) -> None: - self.response = response - self._cast_type = cast_type - self._data_process_func = client._process_response_data - self._stream_chunks = self.__stream__() - - def __next__(self) -> ResponseT: - return self._stream_chunks.__next__() - - def __iter__(self) -> Iterator[ResponseT]: - yield from self._stream_chunks - - def __stream__(self) -> Iterator[ResponseT]: - - sse_line_parser = SSELineParser() - iterator = sse_line_parser.iter_lines(self.response.iter_lines()) - - for sse in iterator: - if sse.data.startswith("[DONE]"): - break - - if sse.event is None: - data = sse.json_data() - if isinstance(data, Mapping) and data.get("error"): - raise APIResponseError( - message="An error occurred during streaming", - request=self.response.request, - json_data=data["error"], - ) - - yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response) - for sse in iterator: - pass - - -class Event: - def __init__( - self, - event: str | None = None, - data: str | None = None, - id: str | None = None, - retry: int | None = None - ): - self._event = event - self._data = data - self._id = id - self._retry = retry - - def __repr__(self): - data_len = len(self._data) if self._data else 0 - return f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}" - - @property - def event(self): return self._event - - @property - def data(self): return self._data - - def json_data(self): return json.loads(self._data) - - @property - def id(self): return self._id - - @property - def retry(self): return self._retry - - -class SSELineParser: - _data: list[str] - _event: str | None - _retry: int | None - _id: str | None - - def __init__(self): - self._event = None - self._data = [] - self._id = None - self._retry = None - - def iter_lines(self, lines: Iterator[str]) -> Iterator[Event]: - for line in lines: - line = line.rstrip('\n') - if not line: - if self._event is None and \ - not self._data and \ - self._id is None and \ - self._retry is None: - continue - sse_event = Event( - event=self._event, - data='\n'.join(self._data), - id=self._id, - retry=self._retry - ) - self._event = None - self._data = [] - self._id = None - self._retry = None - - yield sse_event - self.decode_line(line) - - def decode_line(self, line: str): - if line.startswith(":") or not line: - return - - field, _p, value = line.partition(":") - - if value.startswith(' '): - value = value[1:] - if field == "data": - self._data.append(value) - elif field == "event": - self._event = value - elif field == "retry": - try: - self._retry = int(value) - except (TypeError, ValueError): - pass - return diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py deleted file mode 100644 index 6b610567daa099..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py +++ /dev/null @@ -1,19 +0,0 @@ -from __future__ import annotations - -from collections.abc import Iterable, Mapping -from typing import TypeVar - -from ._base_type import NotGiven - - -def remove_notgiven_indict(obj): - if obj is None or (not isinstance(obj, Mapping)): - return obj - return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)} - - -_T = TypeVar("_T") - - -def flatten(t: Iterable[Iterable[_T]]) -> list[_T]: - return [item for sublist in t for item in sublist] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py deleted file mode 100644 index f22f32d25120f0..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py +++ /dev/null @@ -1,23 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel - -from .chat_completion import CompletionChoice, CompletionUsage - -__all__ = ["AsyncTaskStatus"] - - -class AsyncTaskStatus(BaseModel): - id: Optional[str] = None - request_id: Optional[str] = None - model: Optional[str] = None - task_status: Optional[str] = None - - -class AsyncCompletion(BaseModel): - id: Optional[str] = None - request_id: Optional[str] = None - model: Optional[str] = None - task_status: str - choices: list[CompletionChoice] - usage: CompletionUsage \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py deleted file mode 100644 index b2a847c50c357d..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py +++ /dev/null @@ -1,45 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel - -__all__ = ["Completion", "CompletionUsage"] - - -class Function(BaseModel): - arguments: str - name: str - - -class CompletionMessageToolCall(BaseModel): - id: str - function: Function - type: str - - -class CompletionMessage(BaseModel): - content: Optional[str] = None - role: str - tool_calls: Optional[list[CompletionMessageToolCall]] = None - - -class CompletionUsage(BaseModel): - prompt_tokens: int - completion_tokens: int - total_tokens: int - - -class CompletionChoice(BaseModel): - index: int - finish_reason: str - message: CompletionMessage - - -class Completion(BaseModel): - model: Optional[str] = None - created: Optional[int] = None - choices: list[CompletionChoice] - request_id: Optional[str] = None - id: Optional[str] = None - usage: CompletionUsage - - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py deleted file mode 100644 index c2506997419815..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py +++ /dev/null @@ -1,55 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel - -__all__ = [ - "ChatCompletionChunk", - "Choice", - "ChoiceDelta", - "ChoiceDeltaFunctionCall", - "ChoiceDeltaToolCall", - "ChoiceDeltaToolCallFunction", -] - - -class ChoiceDeltaFunctionCall(BaseModel): - arguments: Optional[str] = None - name: Optional[str] = None - - -class ChoiceDeltaToolCallFunction(BaseModel): - arguments: Optional[str] = None - name: Optional[str] = None - - -class ChoiceDeltaToolCall(BaseModel): - index: int - id: Optional[str] = None - function: Optional[ChoiceDeltaToolCallFunction] = None - type: Optional[str] = None - - -class ChoiceDelta(BaseModel): - content: Optional[str] = None - role: Optional[str] = None - tool_calls: Optional[list[ChoiceDeltaToolCall]] = None - - -class Choice(BaseModel): - delta: ChoiceDelta - finish_reason: Optional[str] = None - index: int - - -class CompletionUsage(BaseModel): - prompt_tokens: int - completion_tokens: int - total_tokens: int - - -class ChatCompletionChunk(BaseModel): - id: Optional[str] = None - choices: list[Choice] - created: Optional[int] = None - model: Optional[str] = None - usage: Optional[CompletionUsage] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completions_create_param.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completions_create_param.py deleted file mode 100644 index 6ee4dc4794b201..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completions_create_param.py +++ /dev/null @@ -1,8 +0,0 @@ -from typing import Optional - -from typing_extensions import TypedDict - - -class Reference(TypedDict, total=False): - enable: Optional[bool] - search_query: Optional[str] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py deleted file mode 100644 index e01f2c815fb382..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py +++ /dev/null @@ -1,22 +0,0 @@ -from __future__ import annotations - -from typing import Optional - -from pydantic import BaseModel - -from .chat.chat_completion import CompletionUsage - -__all__ = ["Embedding", "EmbeddingsResponded"] - - -class Embedding(BaseModel): - object: str - index: Optional[int] = None - embedding: list[float] - - -class EmbeddingsResponded(BaseModel): - object: str - data: list[Embedding] - model: str - usage: CompletionUsage diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py deleted file mode 100644 index 917bda75767b9d..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel - -__all__ = ["FileObject"] - - -class FileObject(BaseModel): - - id: Optional[str] = None - bytes: Optional[int] = None - created_at: Optional[int] = None - filename: Optional[str] = None - object: Optional[str] = None - purpose: Optional[str] = None - status: Optional[str] = None - status_details: Optional[str] = None - - -class ListOfFileObject(BaseModel): - - object: Optional[str] = None - data: list[FileObject] - has_more: Optional[bool] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py deleted file mode 100644 index af0991892e084f..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from __future__ import annotations - -from .fine_tuning_job import FineTuningJob as FineTuningJob -from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob -from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py deleted file mode 100644 index 71c00eaff0dd18..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Optional, Union - -from pydantic import BaseModel - -__all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob" ] - - -class Error(BaseModel): - code: str - message: str - param: Optional[str] = None - - -class Hyperparameters(BaseModel): - n_epochs: Union[str, int, None] = None - - -class FineTuningJob(BaseModel): - id: Optional[str] = None - - request_id: Optional[str] = None - - created_at: Optional[int] = None - - error: Optional[Error] = None - - fine_tuned_model: Optional[str] = None - - finished_at: Optional[int] = None - - hyperparameters: Optional[Hyperparameters] = None - - model: Optional[str] = None - - object: Optional[str] = None - - result_files: list[str] - - status: str - - trained_tokens: Optional[int] = None - - training_file: str - - validation_file: Optional[str] = None - - -class ListOfFineTuningJob(BaseModel): - object: Optional[str] = None - data: list[FineTuningJob] - has_more: Optional[bool] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py deleted file mode 100644 index e26b448534246f..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Optional, Union - -from pydantic import BaseModel - -__all__ = ["FineTuningJobEvent", "Metric", "JobEvent"] - - -class Metric(BaseModel): - epoch: Optional[Union[str, int, float]] = None - current_steps: Optional[int] = None - total_steps: Optional[int] = None - elapsed_time: Optional[str] = None - remaining_time: Optional[str] = None - trained_tokens: Optional[int] = None - loss: Optional[Union[str, int, float]] = None - eval_loss: Optional[Union[str, int, float]] = None - acc: Optional[Union[str, int, float]] = None - eval_acc: Optional[Union[str, int, float]] = None - learning_rate: Optional[Union[str, int, float]] = None - - -class JobEvent(BaseModel): - object: Optional[str] = None - id: Optional[str] = None - type: Optional[str] = None - created_at: Optional[int] = None - level: Optional[str] = None - message: Optional[str] = None - data: Optional[Metric] = None - - -class FineTuningJobEvent(BaseModel): - object: Optional[str] = None - data: list[JobEvent] - has_more: Optional[bool] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py deleted file mode 100644 index e1ebc352bc97fd..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py +++ /dev/null @@ -1,15 +0,0 @@ -from __future__ import annotations - -from typing import Literal, Union - -from typing_extensions import TypedDict - -__all__ = ["Hyperparameters"] - - -class Hyperparameters(TypedDict, total=False): - batch_size: Union[Literal["auto"], int] - - learning_rate_multiplier: Union[Literal["auto"], float] - - n_epochs: Union[Literal["auto"], int] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py deleted file mode 100644 index b352ce0954ad55..00000000000000 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import annotations - -from typing import Optional - -from pydantic import BaseModel - -__all__ = ["GeneratedImage", "ImagesResponded"] - - -class GeneratedImage(BaseModel): - b64_json: Optional[str] = None - url: Optional[str] = None - revised_prompt: Optional[str] = None - - -class ImagesResponded(BaseModel): - created: int - data: list[GeneratedImage] diff --git a/api/core/model_runtime/schema_validators/common_validator.py b/api/core/model_runtime/schema_validators/common_validator.py index fe705d6943a447..029ec1a581b2e9 100644 --- a/api/core/model_runtime/schema_validators/common_validator.py +++ b/api/core/model_runtime/schema_validators/common_validator.py @@ -4,9 +4,9 @@ class CommonValidator: - def _validate_and_filter_credential_form_schemas(self, - credential_form_schemas: list[CredentialFormSchema], - credentials: dict) -> dict: + def _validate_and_filter_credential_form_schemas( + self, credential_form_schemas: list[CredentialFormSchema], credentials: dict + ) -> dict: need_validate_credential_form_schema_map = {} for credential_form_schema in credential_form_schemas: if not credential_form_schema.show_on: @@ -36,8 +36,9 @@ def _validate_and_filter_credential_form_schemas(self, return validated_credentials - def _validate_credential_form_schema(self, credential_form_schema: CredentialFormSchema, credentials: dict) \ - -> Optional[str]: + def _validate_credential_form_schema( + self, credential_form_schema: CredentialFormSchema, credentials: dict + ) -> Optional[str]: """ Validate credential form schema @@ -49,7 +50,7 @@ def _validate_credential_form_schema(self, credential_form_schema: CredentialFor if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]: # If required is True, an exception is thrown if credential_form_schema.required: - raise ValueError(f'Variable {credential_form_schema.variable} is required') + raise ValueError(f"Variable {credential_form_schema.variable} is required") else: # Get the value of default if credential_form_schema.default: @@ -65,23 +66,26 @@ def _validate_credential_form_schema(self, credential_form_schema: CredentialFor # If max_length=0, no validation is performed if credential_form_schema.max_length: if len(value) > credential_form_schema.max_length: - raise ValueError(f'Variable {credential_form_schema.variable} length should not greater than {credential_form_schema.max_length}') + raise ValueError( + f"Variable {credential_form_schema.variable} length should not" + f" greater than {credential_form_schema.max_length}" + ) # check the type of value if not isinstance(value, str): - raise ValueError(f'Variable {credential_form_schema.variable} should be string') + raise ValueError(f"Variable {credential_form_schema.variable} should be string") - if credential_form_schema.type in [FormType.SELECT, FormType.RADIO]: + if credential_form_schema.type in {FormType.SELECT, FormType.RADIO}: # If the value is in options, no validation is performed if credential_form_schema.options: if value not in [option.value for option in credential_form_schema.options]: - raise ValueError(f'Variable {credential_form_schema.variable} is not in options') + raise ValueError(f"Variable {credential_form_schema.variable} is not in options") if credential_form_schema.type == FormType.SWITCH: # If the value is not in ['true', 'false'], an exception is thrown - if value.lower() not in ['true', 'false']: - raise ValueError(f'Variable {credential_form_schema.variable} should be true or false') + if value.lower() not in {"true", "false"}: + raise ValueError(f"Variable {credential_form_schema.variable} should be true or false") - value = True if value.lower() == 'true' else False + value = True if value.lower() == "true" else False return value diff --git a/api/core/model_runtime/schema_validators/model_credential_schema_validator.py b/api/core/model_runtime/schema_validators/model_credential_schema_validator.py index c4786fad5d4c08..7d1644d13481b1 100644 --- a/api/core/model_runtime/schema_validators/model_credential_schema_validator.py +++ b/api/core/model_runtime/schema_validators/model_credential_schema_validator.py @@ -4,7 +4,6 @@ class ModelCredentialSchemaValidator(CommonValidator): - def __init__(self, model_type: ModelType, model_credential_schema: ModelCredentialSchema): self.model_type = model_type self.model_credential_schema = model_credential_schema diff --git a/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py b/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py index c945016534ed8a..6dff2428ca0c34 100644 --- a/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py +++ b/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py @@ -3,7 +3,6 @@ class ProviderCredentialSchemaValidator(CommonValidator): - def __init__(self, provider_credential_schema: ProviderCredentialSchema): self.provider_credential_schema = provider_credential_schema diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index 5078f00bfa26d0..ec1bad5698f2eb 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -18,11 +18,10 @@ from pydantic_extra_types.color import Color -def _model_dump( - model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any -) -> Any: +def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: return model.model_dump(mode=mode, **kwargs) + # Taken from Pydantic v1 as is def isoformat(o: Union[datetime.date, datetime.time]) -> str: return o.isoformat() @@ -82,11 +81,9 @@ def decimal_encoder(dec_value: Decimal) -> Union[int, float]: def generate_encoders_by_class_tuples( - type_encoder_map: dict[Any, Callable[[Any], Any]] + type_encoder_map: dict[Any, Callable[[Any], Any]], ) -> dict[Callable[[Any], Any], tuple[Any, ...]]: - encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict( - tuple - ) + encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(tuple) for type_, encoder in type_encoder_map.items(): encoders_by_class_tuples[encoder] += (type_,) return encoders_by_class_tuples @@ -149,17 +146,13 @@ def jsonable_encoder( if isinstance(obj, str | int | float | type(None)): return obj if isinstance(obj, Decimal): - return format(obj, 'f') + return format(obj, "f") if isinstance(obj, dict): encoded_dict = {} allowed_keys = set(obj.keys()) for key, value in obj.items(): if ( - ( - not sqlalchemy_safe - or (not isinstance(key, str)) - or (not key.startswith("_sa")) - ) + (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa"))) and (value is not None or not exclude_none) and key in allowed_keys ): diff --git a/api/core/model_runtime/utils/helper.py b/api/core/model_runtime/utils/helper.py index c68a554471703f..2067092d80f582 100644 --- a/api/core/model_runtime/utils/helper.py +++ b/api/core/model_runtime/utils/helper.py @@ -3,7 +3,7 @@ def dump_model(model: BaseModel) -> dict: - if hasattr(pydantic, 'model_dump'): + if hasattr(pydantic, "model_dump"): return pydantic.model_dump(model) else: return model.model_dump() diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index f96e2a1c214f5d..094ad7863603dc 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -44,32 +44,29 @@ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInpu flagged = False preset_response = "" - if self.config['inputs_config']['enabled']: - params = ModerationInputParams( - app_id=self.app_id, - inputs=inputs, - query=query - ) + if self.config["inputs_config"]["enabled"]: + params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query) result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump()) return ModerationInputsResult(**result) - return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationInputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" - if self.config['outputs_config']['enabled']: - params = ModerationOutputParams( - app_id=self.app_id, - text=text - ) + if self.config["outputs_config"]["enabled"]: + params = ModerationOutputParams(app_id=self.app_id, text=text) result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump()) return ModerationOutputsResult(**result) - return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationOutputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict: extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id")) @@ -80,9 +77,10 @@ def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, para @staticmethod def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: - extension = db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + extension = ( + db.session.query(APIBasedExtension) + .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .first() + ) return extension diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 9a369a9f87742a..60898d5547ae3b 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -8,8 +8,8 @@ class ModerationAction(Enum): - DIRECT_OUTPUT = 'direct_output' - OVERRIDED = 'overrided' + DIRECT_OUTPUT = "direct_output" + OVERRIDDEN = "overridden" class ModerationInputsResult(BaseModel): @@ -31,6 +31,7 @@ class Moderation(Extensible, ABC): """ The base class of moderation. """ + module: ExtensionModule = ExtensionModule.MODERATION def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None: @@ -75,7 +76,7 @@ def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: raise NotImplementedError @classmethod - def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_required: bool) -> None: + def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool) -> None: # inputs_config inputs_config = config.get("inputs_config") if not isinstance(inputs_config, dict): @@ -110,5 +111,5 @@ def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_r raise ValueError("outputs_config.preset_response must be less than 100 characters") -class ModerationException(Exception): +class ModerationError(Exception): pass diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py index 8157b300b1f6c7..46d3963bd07f5a 100644 --- a/api/core/moderation/input_moderation.py +++ b/api/core/moderation/input_moderation.py @@ -2,7 +2,7 @@ from typing import Optional from core.app.app_config.entities import AppConfig -from core.moderation.base import ModerationAction, ModerationException +from core.moderation.base import ModerationAction, ModerationError from core.moderation.factory import ModerationFactory from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask @@ -13,13 +13,14 @@ class InputModeration: def check( - self, app_id: str, + self, + app_id: str, tenant_id: str, app_config: AppConfig, inputs: dict, query: str, message_id: str, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> tuple[bool, dict, str]: """ Process sensitive_word_avoidance. @@ -39,10 +40,7 @@ def check( moderation_type = sensitive_word_avoidance_config.type moderation_factory = ModerationFactory( - name=moderation_type, - app_id=app_id, - tenant_id=tenant_id, - config=sensitive_word_avoidance_config.config + name=moderation_type, app_id=app_id, tenant_id=tenant_id, config=sensitive_word_avoidance_config.config ) with measure_time() as timer: @@ -55,7 +53,7 @@ def check( message_id=message_id, moderation_result=moderation_result, inputs=inputs, - timer=timer + timer=timer, ) ) @@ -63,8 +61,8 @@ def check( return False, inputs, query if moderation_result.action == ModerationAction.DIRECT_OUTPUT: - raise ModerationException(moderation_result.preset_response) - elif moderation_result.action == ModerationAction.OVERRIDED: + raise ModerationError(moderation_result.preset_response) + elif moderation_result.action == ModerationAction.OVERRIDDEN: inputs = moderation_result.inputs query = moderation_result.query diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index ca562ad987cada..4846da8f93076e 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -18,48 +18,49 @@ def validate_config(cls, tenant_id: str, config: dict) -> None: if not config.get("keywords"): raise ValueError("keywords is required") - if len(config.get("keywords")) > 1000: - raise ValueError("keywords length must be less than 1000") + if len(config.get("keywords")) > 10000: + raise ValueError("keywords length must be less than 10000") + + keywords_row_len = config["keywords"].split("\n") + if len(keywords_row_len) > 100: + raise ValueError("the number of rows for the keywords must be less than 100") def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" - if self.config['inputs_config']['enabled']: - preset_response = self.config['inputs_config']['preset_response'] + if self.config["inputs_config"]["enabled"]: + preset_response = self.config["inputs_config"]["preset_response"] if query: - inputs['query__'] = query + inputs["query__"] = query # Filter out empty values - keywords_list = [keyword for keyword in self.config['keywords'].split('\n') if keyword] + keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword] flagged = self._is_violated(inputs, keywords_list) - return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationInputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" - if self.config['outputs_config']['enabled']: + if self.config["outputs_config"]["enabled"]: # Filter out empty values - keywords_list = [keyword for keyword in self.config['keywords'].split('\n') if keyword] + keywords_list = [keyword for keyword in self.config["keywords"].split("\n") if keyword] - flagged = self._is_violated({'text': text}, keywords_list) - preset_response = self.config['outputs_config']['preset_response'] + flagged = self._is_violated({"text": text}, keywords_list) + preset_response = self.config["outputs_config"]["preset_response"] - return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationOutputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def _is_violated(self, inputs: dict, keywords_list: list) -> bool: - for value in inputs.values(): - if self._check_keywords_in_value(keywords_list, value): - return True - - return False + return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values()) - def _check_keywords_in_value(self, keywords_list, value): - for keyword in keywords_list: - if keyword.lower() in value.lower(): - return True - return False + def _check_keywords_in_value(self, keywords_list, value) -> bool: + return any(keyword.lower() in value.lower() for keyword in keywords_list) diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index fee51007ebeed7..6465de23b9a2de 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -21,37 +21,36 @@ def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInpu flagged = False preset_response = "" - if self.config['inputs_config']['enabled']: - preset_response = self.config['inputs_config']['preset_response'] + if self.config["inputs_config"]["enabled"]: + preset_response = self.config["inputs_config"]["preset_response"] if query: - inputs['query__'] = query + inputs["query__"] = query flagged = self._is_violated(inputs) - return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationInputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" - if self.config['outputs_config']['enabled']: - flagged = self._is_violated({'text': text}) - preset_response = self.config['outputs_config']['preset_response'] + if self.config["outputs_config"]["enabled"]: + flagged = self._is_violated({"text": text}) + preset_response = self.config["outputs_config"]["preset_response"] - return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response) + return ModerationOutputsResult( + flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response + ) def _is_violated(self, inputs: dict): - text = '\n'.join(str(inputs.values())) + text = "\n".join(str(inputs.values())) model_manager = ModelManager() model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, - provider="openai", - model_type=ModelType.MODERATION, - model="text-moderation-stable" + tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="text-moderation-stable" ) - openai_moderation = model_instance.invoke_moderation( - text=text - ) + openai_moderation = model_instance.invoke_moderation(text=text) return openai_moderation diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index 9a4d8db4e2f39d..83f4d2d57d128e 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -29,18 +29,18 @@ class OutputModeration(BaseModel): thread: Optional[threading.Thread] = None thread_running: bool = True - buffer: str = '' + buffer: str = "" is_final_chunk: bool = False final_output: Optional[str] = None model_config = ConfigDict(arbitrary_types_allowed=True) - def should_direct_output(self): + def should_direct_output(self) -> bool: return self.final_output is not None - def get_final_output(self): - return self.final_output + def get_final_output(self) -> str: + return self.final_output or "" - def append_new_token(self, token: str): + def append_new_token(self, token: str) -> None: self.buffer += token if not self.thread: @@ -50,11 +50,7 @@ def moderation_completion(self, completion: str, public_event: bool = False) -> self.buffer = completion self.is_final_chunk = True - result = self.moderation( - tenant_id=self.tenant_id, - app_id=self.app_id, - moderation_buffer=completion - ) + result = self.moderation(tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=completion) if not result or not result.flagged: return completion @@ -65,21 +61,19 @@ def moderation_completion(self, completion: str, public_event: bool = False) -> final_output = result.text if public_event: - self.queue_manager.publish( - QueueMessageReplaceEvent( - text=final_output - ), - PublishFrom.TASK_PIPELINE - ) + self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE) return final_output def start_thread(self) -> threading.Thread: buffer_size = dify_config.MODERATION_BUFFER_SIZE - thread = threading.Thread(target=self.worker, kwargs={ - 'flask_app': current_app._get_current_object(), - 'buffer_size': buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE - }) + thread = threading.Thread( + target=self.worker, + kwargs={ + "flask_app": current_app._get_current_object(), + "buffer_size": buffer_size if buffer_size > 0 else dify_config.MODERATION_BUFFER_SIZE, + }, + ) thread.start() @@ -104,9 +98,7 @@ def worker(self, flask_app: Flask, buffer_size: int): current_length = buffer_length result = self.moderation( - tenant_id=self.tenant_id, - app_id=self.app_id, - moderation_buffer=moderation_buffer + tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=moderation_buffer ) if not result or not result.flagged: @@ -116,16 +108,11 @@ def worker(self, flask_app: Flask, buffer_size: int): final_output = result.preset_response self.final_output = final_output else: - final_output = result.text + self.buffer[len(moderation_buffer):] + final_output = result.text + self.buffer[len(moderation_buffer) :] # trigger replace event if self.thread_running: - self.queue_manager.publish( - QueueMessageReplaceEvent( - text=final_output - ), - PublishFrom.TASK_PIPELINE - ) + self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE) if result.action == ModerationAction.DIRECT_OUTPUT: break @@ -133,15 +120,12 @@ def worker(self, flask_app: Flask, buffer_size: int): def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]: try: moderation_factory = ModerationFactory( - name=self.rule.type, - app_id=app_id, - tenant_id=tenant_id, - config=self.rule.config + name=self.rule.type, app_id=app_id, tenant_id=tenant_id, config=self.rule.config ) result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer) return result except Exception as e: - logger.error("Moderation Output error: %s", e) + logger.exception("Moderation Output error: %s", e) return None diff --git a/api/core/ops/base_trace_instance.py b/api/core/ops/base_trace_instance.py index c7af8e296339c8..f7b882fc71d48e 100644 --- a/api/core/ops/base_trace_instance.py +++ b/api/core/ops/base_trace_instance.py @@ -23,4 +23,4 @@ def trace(self, trace_info: BaseTraceInfo): Abstract method to trace activities. Subclasses must implement specific tracing logic for activities. """ - ... \ No newline at end of file + ... diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index 221e6239ab9302..ef0f9c708f14b0 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -4,14 +4,15 @@ class TracingProviderEnum(Enum): - LANGFUSE = 'langfuse' - LANGSMITH = 'langsmith' + LANGFUSE = "langfuse" + LANGSMITH = "langsmith" class BaseTracingConfig(BaseModel): """ Base model class for tracing """ + ... @@ -19,16 +20,18 @@ class LangfuseConfig(BaseTracingConfig): """ Model class for Langfuse tracing config. """ + public_key: str secret_key: str - host: str = 'https://api.langfuse.com' + host: str = "https://api.langfuse.com" @field_validator("host") + @classmethod def set_value(cls, v, info: ValidationInfo): if v is None or v == "": - v = 'https://api.langfuse.com' - if not v.startswith('https://') and not v.startswith('http://'): - raise ValueError('host must start with https:// or http://') + v = "https://api.langfuse.com" + if not v.startswith("https://") and not v.startswith("http://"): + raise ValueError("host must start with https:// or http://") return v @@ -37,15 +40,21 @@ class LangSmithConfig(BaseTracingConfig): """ Model class for Langsmith tracing config. """ + api_key: str project: str - endpoint: str = 'https://api.smith.langchain.com' + endpoint: str = "https://api.smith.langchain.com" @field_validator("endpoint") + @classmethod def set_value(cls, v, info: ValidationInfo): if v is None or v == "": - v = 'https://api.smith.langchain.com' - if not v.startswith('https://'): - raise ValueError('endpoint must start with https://') + v = "https://api.smith.langchain.com" + if not v.startswith("https://"): + raise ValueError("endpoint must start with https://") return v + + +OPS_FILE_PATH = "ops_trace/" +OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE" diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index a1443f0691233b..256595286f324e 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -15,13 +15,19 @@ class BaseTraceInfo(BaseModel): metadata: dict[str, Any] @field_validator("inputs", "outputs") + @classmethod def ensure_type(cls, v): if v is None: return None if isinstance(v, str | dict | list): return v - else: - return "" + return "" + + class Config: + json_encoders = { + datetime: lambda v: v.isoformat(), + } + class WorkflowTraceInfo(BaseTraceInfo): workflow_data: Any @@ -98,23 +104,30 @@ class GenerateNameTraceInfo(BaseTraceInfo): conversation_id: Optional[str] = None tenant_id: str + +class TaskData(BaseModel): + app_id: str + trace_info_type: str + trace_info: Any + + trace_info_info_map = { - 'WorkflowTraceInfo': WorkflowTraceInfo, - 'MessageTraceInfo': MessageTraceInfo, - 'ModerationTraceInfo': ModerationTraceInfo, - 'SuggestedQuestionTraceInfo': SuggestedQuestionTraceInfo, - 'DatasetRetrievalTraceInfo': DatasetRetrievalTraceInfo, - 'ToolTraceInfo': ToolTraceInfo, - 'GenerateNameTraceInfo': GenerateNameTraceInfo, + "WorkflowTraceInfo": WorkflowTraceInfo, + "MessageTraceInfo": MessageTraceInfo, + "ModerationTraceInfo": ModerationTraceInfo, + "SuggestedQuestionTraceInfo": SuggestedQuestionTraceInfo, + "DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo, + "ToolTraceInfo": ToolTraceInfo, + "GenerateNameTraceInfo": GenerateNameTraceInfo, } class TraceTaskName(str, Enum): - CONVERSATION_TRACE = 'conversation' - WORKFLOW_TRACE = 'workflow' - MESSAGE_TRACE = 'message' - MODERATION_TRACE = 'moderation' - SUGGESTED_QUESTION_TRACE = 'suggested_question' - DATASET_RETRIEVAL_TRACE = 'dataset_retrieval' - TOOL_TRACE = 'tool' - GENERATE_NAME_TRACE = 'generate_conversation_name' + CONVERSATION_TRACE = "conversation" + WORKFLOW_TRACE = "workflow" + MESSAGE_TRACE = "message" + MODERATION_TRACE = "moderation" + SUGGESTED_QUESTION_TRACE = "suggested_question" + DATASET_RETRIEVAL_TRACE = "dataset_retrieval" + TOOL_TRACE = "tool" + GENERATE_NAME_TRACE = "generate_conversation_name" diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py index af7661f0afc9c8..447b799f1f1187 100644 --- a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py +++ b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py @@ -101,6 +101,7 @@ class LangfuseTrace(BaseModel): ) @field_validator("input", "output") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name return validate_input_output(v, field_name) @@ -171,6 +172,7 @@ class LangfuseSpan(BaseModel): ) @field_validator("input", "output") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name return validate_input_output(v, field_name) @@ -196,6 +198,7 @@ class GenerationUsage(BaseModel): totalCost: Optional[float] = None @field_validator("input", "output") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name return validate_input_output(v, field_name) @@ -273,6 +276,7 @@ class LangfuseGeneration(BaseModel): model_config = ConfigDict(protected_namespaces=()) @field_validator("input", "output") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name return validate_input_output(v, field_name) diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 698398e0cb8c16..0cba40c51a0d19 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -65,7 +65,7 @@ def trace(self, trace_info: BaseTraceInfo): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - trace_id = trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id + trace_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id user_id = trace_info.metadata.get("user_id") if trace_info.message_id: trace_id = trace_info.message_id @@ -84,7 +84,7 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): ) self.add_trace(langfuse_trace_data=trace_data) workflow_span_data = LangfuseSpan( - id=(trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id), + id=(trace_info.workflow_app_log_id or trace_info.workflow_run_id), name=TraceTaskName.WORKFLOW_TRACE.value, input=trace_info.workflow_run_inputs, output=trace_info.workflow_run_outputs, @@ -93,7 +93,7 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): end_time=trace_info.end_time, metadata=trace_info.metadata, level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR, - status_message=trace_info.error if trace_info.error else "", + status_message=trace_info.error or "", ) self.add_span(langfuse_span_data=workflow_span_data) else: @@ -110,26 +110,35 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): self.add_trace(langfuse_trace_data=trace_data) # through workflow_run_id get all_nodes_execution - workflow_nodes_executions = ( - db.session.query( - WorkflowNodeExecution.id, - WorkflowNodeExecution.tenant_id, - WorkflowNodeExecution.app_id, - WorkflowNodeExecution.title, - WorkflowNodeExecution.node_type, - WorkflowNodeExecution.status, - WorkflowNodeExecution.inputs, - WorkflowNodeExecution.outputs, - WorkflowNodeExecution.created_at, - WorkflowNodeExecution.elapsed_time, - WorkflowNodeExecution.process_data, - WorkflowNodeExecution.execution_metadata, - ) + workflow_nodes_execution_id_records = ( + db.session.query(WorkflowNodeExecution.id) .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) .all() ) - for node_execution in workflow_nodes_executions: + for node_execution_id_record in workflow_nodes_execution_id_records: + node_execution = ( + db.session.query( + WorkflowNodeExecution.id, + WorkflowNodeExecution.tenant_id, + WorkflowNodeExecution.app_id, + WorkflowNodeExecution.title, + WorkflowNodeExecution.node_type, + WorkflowNodeExecution.status, + WorkflowNodeExecution.inputs, + WorkflowNodeExecution.outputs, + WorkflowNodeExecution.created_at, + WorkflowNodeExecution.elapsed_time, + WorkflowNodeExecution.process_data, + WorkflowNodeExecution.execution_metadata, + ) + .filter(WorkflowNodeExecution.id == node_execution_id_record.id) + .first() + ) + + if not node_execution: + continue + node_execution_id = node_execution.id tenant_id = node_execution.tenant_id app_id = node_execution.app_id @@ -143,7 +152,7 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): else: inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} - created_at = node_execution.created_at if node_execution.created_at else datetime.now() + created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) @@ -159,6 +168,16 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): "status": status, } ) + process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} + model_provider = process_data.get("model_provider", None) + model_name = process_data.get("model_name", None) + if model_provider is not None and model_name is not None: + metadata.update( + { + "model_provider": model_provider, + "model_name": model_name, + } + ) # add span if trace_info.message_id: @@ -172,10 +191,8 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): end_time=finished_at, metadata=metadata, level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), - status_message=trace_info.error if trace_info.error else "", - parent_observation_id=( - trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id - ), + status_message=trace_info.error or "", + parent_observation_id=(trace_info.workflow_app_log_id or trace_info.workflow_run_id), ) else: span_data = LangfuseSpan( @@ -188,12 +205,11 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): end_time=finished_at, metadata=metadata, level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), - status_message=trace_info.error if trace_info.error else "", + status_message=trace_info.error or "", ) self.add_span(langfuse_span_data=span_data) - process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} if process_data and process_data.get("model_mode") == "chat": total_token = metadata.get("total_tokens", 0) # add generation @@ -204,6 +220,7 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): node_generation_data = LangfuseGeneration( name="llm", trace_id=trace_id, + model=process_data.get("model_name"), parent_observation_id=node_execution_id, start_time=created_at, end_time=finished_at, @@ -211,7 +228,7 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): output=outputs, metadata=metadata, level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), - status_message=trace_info.error if trace_info.error else "", + status_message=trace_info.error or "", usage=generation_usage, ) @@ -276,7 +293,7 @@ def message_trace(self, trace_info: MessageTraceInfo, **kwargs): output=message_data.answer, metadata=metadata, level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR), - status_message=message_data.error if message_data.error else "", + status_message=message_data.error or "", usage=generation_usage, ) @@ -318,7 +335,7 @@ def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): end_time=trace_info.end_time, metadata=trace_info.metadata, level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR), - status_message=message_data.error if message_data.error else "", + status_message=message_data.error or "", usage=generation_usage, ) @@ -419,3 +436,11 @@ def api_check(self): except Exception as e: logger.debug(f"LangFuse API check failed: {str(e)}") raise ValueError(f"LangFuse API check failed: {str(e)}") + + def get_project_key(self): + try: + projects = self.langfuse_client.client.projects.get() + return projects.data[0].id + except Exception as e: + logger.debug(f"LangFuse get project key failed: {str(e)}") + raise ValueError(f"LangFuse get project key failed: {str(e)}") diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py index f3fc46d99a8692..05c932fb99b424 100644 --- a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +++ b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py @@ -35,49 +35,32 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): run_type: LangSmithRunType = Field(..., description="Type of the run") start_time: Optional[datetime | str] = Field(None, description="Start time of the run") end_time: Optional[datetime | str] = Field(None, description="End time of the run") - extra: Optional[dict[str, Any]] = Field( - None, description="Extra information of the run" - ) + extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run") error: Optional[str] = Field(None, description="Error message of the run") - serialized: Optional[dict[str, Any]] = Field( - None, description="Serialized data of the run" - ) + serialized: Optional[dict[str, Any]] = Field(None, description="Serialized data of the run") parent_run_id: Optional[str] = Field(None, description="Parent run ID") - events: Optional[list[dict[str, Any]]] = Field( - None, description="Events associated with the run" - ) + events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run") tags: Optional[list[str]] = Field(None, description="Tags associated with the run") - trace_id: Optional[str] = Field( - None, description="Trace ID associated with the run" - ) + trace_id: Optional[str] = Field(None, description="Trace ID associated with the run") dotted_order: Optional[str] = Field(None, description="Dotted order of the run") id: Optional[str] = Field(None, description="ID of the run") - session_id: Optional[str] = Field( - None, description="Session ID associated with the run" - ) - session_name: Optional[str] = Field( - None, description="Session name associated with the run" - ) - reference_example_id: Optional[str] = Field( - None, description="Reference example ID associated with the run" - ) - input_attachments: Optional[dict[str, Any]] = Field( - None, description="Input attachments of the run" - ) - output_attachments: Optional[dict[str, Any]] = Field( - None, description="Output attachments of the run" - ) + session_id: Optional[str] = Field(None, description="Session ID associated with the run") + session_name: Optional[str] = Field(None, description="Session name associated with the run") + reference_example_id: Optional[str] = Field(None, description="Reference example ID associated with the run") + input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run") + output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") @field_validator("inputs", "outputs") + @classmethod def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name values = info.data if v == {} or v is None: return v usage_metadata = { - "input_tokens": values.get('input_tokens', 0), - "output_tokens": values.get('output_tokens', 0), - "total_tokens": values.get('total_tokens', 0), + "input_tokens": values.get("input_tokens", 0), + "output_tokens": values.get("output_tokens", 0), + "total_tokens": values.get("total_tokens", 0), } file_list = values.get("file_list", []) if isinstance(v, str): @@ -133,6 +116,7 @@ def ensure_dict(cls, v, info: ValidationInfo): return v return v + @classmethod @field_validator("start_time", "end_time") def format_time(cls, v, info: ValidationInfo): if not isinstance(v, datetime): @@ -143,25 +127,15 @@ def format_time(cls, v, info: ValidationInfo): class LangSmithRunUpdateModel(BaseModel): run_id: str = Field(..., description="ID of the run") - trace_id: Optional[str] = Field( - None, description="Trace ID associated with the run" - ) + trace_id: Optional[str] = Field(None, description="Trace ID associated with the run") dotted_order: Optional[str] = Field(None, description="Dotted order of the run") parent_run_id: Optional[str] = Field(None, description="Parent run ID") end_time: Optional[datetime | str] = Field(None, description="End time of the run") error: Optional[str] = Field(None, description="Error message of the run") inputs: Optional[dict[str, Any]] = Field(None, description="Inputs of the run") outputs: Optional[dict[str, Any]] = Field(None, description="Outputs of the run") - events: Optional[list[dict[str, Any]]] = Field( - None, description="Events associated with the run" - ) + events: Optional[list[dict[str, Any]]] = Field(None, description="Events associated with the run") tags: Optional[list[str]] = Field(None, description="Tags associated with the run") - extra: Optional[dict[str, Any]] = Field( - None, description="Extra information of the run" - ) - input_attachments: Optional[dict[str, Any]] = Field( - None, description="Input attachments of the run" - ) - output_attachments: Optional[dict[str, Any]] = Field( - None, description="Output attachments of the run" - ) + extra: Optional[dict[str, Any]] = Field(None, description="Extra information of the run") + input_attachments: Optional[dict[str, Any]] = Field(None, description="Input attachments of the run") + output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index fde8a06c612dd9..ad450504057bef 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -1,9 +1,11 @@ import json import logging import os +import uuid from datetime import datetime, timedelta from langsmith import Client +from langsmith.schemas import RunBase from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import LangSmithConfig @@ -80,7 +82,7 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): langsmith_run = LangSmithRunModel( file_list=trace_info.file_list, total_tokens=trace_info.total_tokens, - id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id, + id=trace_info.workflow_app_log_id or trace_info.workflow_run_id, name=TraceTaskName.WORKFLOW_TRACE.value, inputs=trace_info.workflow_run_inputs, run_type=LangSmithRunType.tool, @@ -92,32 +94,41 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): }, error=trace_info.error, tags=["workflow"], - parent_run_id=trace_info.message_id if trace_info.message_id else None, + parent_run_id=trace_info.message_id or None, ) self.add_run(langsmith_run) # through workflow_run_id get all_nodes_execution - workflow_nodes_executions = ( - db.session.query( - WorkflowNodeExecution.id, - WorkflowNodeExecution.tenant_id, - WorkflowNodeExecution.app_id, - WorkflowNodeExecution.title, - WorkflowNodeExecution.node_type, - WorkflowNodeExecution.status, - WorkflowNodeExecution.inputs, - WorkflowNodeExecution.outputs, - WorkflowNodeExecution.created_at, - WorkflowNodeExecution.elapsed_time, - WorkflowNodeExecution.process_data, - WorkflowNodeExecution.execution_metadata, - ) + workflow_nodes_execution_id_records = ( + db.session.query(WorkflowNodeExecution.id) .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) .all() ) - for node_execution in workflow_nodes_executions: + for node_execution_id_record in workflow_nodes_execution_id_records: + node_execution = ( + db.session.query( + WorkflowNodeExecution.id, + WorkflowNodeExecution.tenant_id, + WorkflowNodeExecution.app_id, + WorkflowNodeExecution.title, + WorkflowNodeExecution.node_type, + WorkflowNodeExecution.status, + WorkflowNodeExecution.inputs, + WorkflowNodeExecution.outputs, + WorkflowNodeExecution.created_at, + WorkflowNodeExecution.elapsed_time, + WorkflowNodeExecution.process_data, + WorkflowNodeExecution.execution_metadata, + ) + .filter(WorkflowNodeExecution.id == node_execution_id_record.id) + .first() + ) + + if not node_execution: + continue + node_execution_id = node_execution.id tenant_id = node_execution.tenant_id app_id = node_execution.app_id @@ -131,7 +142,7 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): else: inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} - created_at = node_execution.created_at if node_execution.created_at else datetime.now() + created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) @@ -139,8 +150,7 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} ) node_total_tokens = execution_metadata.get("total_tokens", 0) - - metadata = json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} + metadata = execution_metadata.copy() metadata.update( { "workflow_run_id": trace_info.workflow_run_id, @@ -156,6 +166,12 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} if process_data and process_data.get("model_mode") == "chat": run_type = LangSmithRunType.llm + metadata.update( + { + "ls_provider": process_data.get("model_provider", ""), + "ls_model_name": process_data.get("model_name", ""), + } + ) elif node_type == "knowledge-retrieval": run_type = LangSmithRunType.retriever else: @@ -173,9 +189,7 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): extra={ "metadata": metadata, }, - parent_run_id=trace_info.workflow_app_log_id - if trace_info.workflow_app_log_id - else trace_info.workflow_run_id, + parent_run_id=trace_info.workflow_app_log_id or trace_info.workflow_run_id, tags=["node_execution"], ) @@ -366,3 +380,22 @@ def api_check(self): except Exception as e: logger.debug(f"LangSmith API check failed: {str(e)}") raise ValueError(f"LangSmith API check failed: {str(e)}") + + def get_project_url(self): + try: + run_data = RunBase( + id=uuid.uuid4(), + name="tool", + inputs={"input": "test"}, + outputs={"output": "test"}, + run_type=LangSmithRunType.tool, + start_time=datetime.now(), + ) + + project_url = self.langsmith_client.get_run_url( + run=run_data, project_id=self.project_id, project_name=self.project_name + ) + return project_url.split("/r/")[0] + except Exception as e: + logger.debug(f"LangSmith get run url failed: {str(e)}") + raise ValueError(f"LangSmith get run url failed: {str(e)}") diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 068b490ec887bd..79704c115f1d0a 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -6,12 +6,13 @@ import time from datetime import timedelta from typing import Any, Optional, Union -from uuid import UUID +from uuid import UUID, uuid4 from flask import current_app from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token from core.ops.entities.config_entity import ( + OPS_FILE_PATH, LangfuseConfig, LangSmithConfig, TracingProviderEnum, @@ -22,6 +23,7 @@ MessageTraceInfo, ModerationTraceInfo, SuggestedQuestionTraceInfo, + TaskData, ToolTraceInfo, TraceTaskName, WorkflowTraceInfo, @@ -30,23 +32,24 @@ from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace from core.ops.utils import get_message_data from extensions.ext_database import db +from extensions.ext_storage import storage from models.model import App, AppModelConfig, Conversation, Message, MessageAgentThought, MessageFile, TraceAppConfig from models.workflow import WorkflowAppLog, WorkflowRun from tasks.ops_trace_task import process_trace_tasks provider_config_map = { TracingProviderEnum.LANGFUSE.value: { - 'config_class': LangfuseConfig, - 'secret_keys': ['public_key', 'secret_key'], - 'other_keys': ['host'], - 'trace_instance': LangFuseDataTrace + "config_class": LangfuseConfig, + "secret_keys": ["public_key", "secret_key"], + "other_keys": ["host", "project_key"], + "trace_instance": LangFuseDataTrace, }, TracingProviderEnum.LANGSMITH.value: { - 'config_class': LangSmithConfig, - 'secret_keys': ['api_key'], - 'other_keys': ['project', 'endpoint'], - 'trace_instance': LangSmithDataTrace - } + "config_class": LangSmithConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "endpoint"], + "trace_instance": LangSmithDataTrace, + }, } @@ -64,14 +67,17 @@ def encrypt_tracing_config( :return: encrypted tracing configuration """ # Get the configuration class and the keys that require encryption - config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys'] + config_class, secret_keys, other_keys = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["secret_keys"], + provider_config_map[tracing_provider]["other_keys"], + ) new_config = {} # Encrypt necessary keys for key in secret_keys: if key in tracing_config: - if '*' in tracing_config[key]: + if "*" in tracing_config[key]: # If the key contains '*', retain the original value from the current config new_config[key] = current_trace_config.get(key, tracing_config[key]) else: @@ -94,8 +100,11 @@ def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_c :param tracing_config: tracing config :return: """ - config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys'] + config_class, secret_keys, other_keys = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["secret_keys"], + provider_config_map[tracing_provider]["other_keys"], + ) new_config = {} for key in secret_keys: if key in tracing_config: @@ -114,8 +123,11 @@ def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: :param decrypt_tracing_config: tracing config :return: """ - config_class, secret_keys, other_keys = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['secret_keys'], provider_config_map[tracing_provider]['other_keys'] + config_class, secret_keys, other_keys = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["secret_keys"], + provider_config_map[tracing_provider]["other_keys"], + ) new_config = {} for key in secret_keys: if key in decrypt_tracing_config: @@ -123,7 +135,6 @@ def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: for key in other_keys: new_config[key] = decrypt_tracing_config.get(key, "") - return config_class(**new_config).model_dump() @classmethod @@ -134,9 +145,11 @@ def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str): :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + trace_config_data: TraceAppConfig = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if not trace_config_data: return None @@ -165,21 +178,28 @@ def get_ops_trace_instance( if app_id is None: return None - app: App = db.session.query(App).filter( - App.id == app_id - ).first() + app: App = db.session.query(App).filter(App.id == app_id).first() + + if app is None: + return None + app_ops_trace_config = json.loads(app.tracing) if app.tracing else None - if app_ops_trace_config is not None: - tracing_provider = app_ops_trace_config.get('tracing_provider') - else: + if app_ops_trace_config is None: + return None + + tracing_provider = app_ops_trace_config.get("tracing_provider") + + if tracing_provider is None or tracing_provider not in provider_config_map: return None # decrypt_token decrypt_trace_config = cls.get_decrypted_tracing_config(app_id, tracing_provider) - if app_ops_trace_config.get('enabled'): - trace_instance, config_class = provider_config_map[tracing_provider]['trace_instance'], \ - provider_config_map[tracing_provider]['config_class'] + if app_ops_trace_config.get("enabled"): + trace_instance, config_class = ( + provider_config_map[tracing_provider]["trace_instance"], + provider_config_map[tracing_provider]["config_class"], + ) tracing_instance = trace_instance(config_class(**decrypt_trace_config)) return tracing_instance @@ -193,9 +213,11 @@ def get_app_config_through_message_id(cls, message_id: str): conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() if conversation_data.app_model_config_id: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation_data.app_model_config_id - ).first() + app_model_config = ( + db.session.query(AppModelConfig) + .filter(AppModelConfig.id == conversation_data.app_model_config_id) + .first() + ) elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs: app_model_config = conversation_data.override_model_configs @@ -211,7 +233,7 @@ def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: :return: """ # auth check - if tracing_provider not in provider_config_map.keys() and tracing_provider is not None: + if tracing_provider not in provider_config_map and tracing_provider is not None: raise ValueError(f"Invalid tracing provider: {tracing_provider}") app_config: App = db.session.query(App).filter(App.id == app_id).first() @@ -232,10 +254,7 @@ def get_app_tracing_config(cls, app_id: str): """ app: App = db.session.query(App).filter(App.id == app_id).first() if not app.tracing: - return { - "enabled": False, - "tracing_provider": None - } + return {"enabled": False, "tracing_provider": None} app_trace_config = json.loads(app.tracing) return app_trace_config @@ -247,11 +266,43 @@ def check_trace_config_is_effective(tracing_config: dict, tracing_provider: str) :param tracing_provider: tracing provider :return: """ - config_type, trace_instance = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['trace_instance'] + config_type, trace_instance = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["trace_instance"], + ) tracing_config = config_type(**tracing_config) return trace_instance(tracing_config).api_check() + @staticmethod + def get_trace_config_project_key(tracing_config: dict, tracing_provider: str): + """ + get trace config is project key + :param tracing_config: tracing config + :param tracing_provider: tracing provider + :return: + """ + config_type, trace_instance = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["trace_instance"], + ) + tracing_config = config_type(**tracing_config) + return trace_instance(tracing_config).get_project_key() + + @staticmethod + def get_trace_config_project_url(tracing_config: dict, tracing_provider: str): + """ + get trace config is project key + :param tracing_config: tracing config + :param tracing_provider: tracing provider + :return: + """ + config_type, trace_instance = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["trace_instance"], + ) + tracing_config = config_type(**tracing_config) + return trace_instance(tracing_config).get_project_url() + class TraceTask: def __init__( @@ -262,7 +313,7 @@ def __init__( conversation_id: Optional[str] = None, user_id: Optional[str] = None, timer: Optional[Any] = None, - **kwargs + **kwargs, ): self.trace_type = trace_type self.message_id = message_id @@ -285,9 +336,7 @@ def preprocess(self): self.workflow_run, self.conversation_id, self.user_id ), TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id), - TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace( - self.message_id, self.timer, **self.kwargs - ), + TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(self.message_id, self.timer, **self.kwargs), TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace( self.message_id, self.timer, **self.kwargs ), @@ -312,32 +361,29 @@ def workflow_trace(self, workflow_run: WorkflowRun, conversation_id, user_id): workflow_run_id = workflow_run.id workflow_run_elapsed_time = workflow_run.elapsed_time workflow_run_status = workflow_run.status - workflow_run_inputs = ( - json.loads(workflow_run.inputs) if workflow_run.inputs else {} - ) - workflow_run_outputs = ( - json.loads(workflow_run.outputs) if workflow_run.outputs else {} - ) + workflow_run_inputs = workflow_run.inputs_dict + workflow_run_outputs = workflow_run.outputs_dict workflow_run_version = workflow_run.version - error = workflow_run.error if workflow_run.error else "" + error = workflow_run.error or "" total_tokens = workflow_run.total_tokens - file_list = workflow_run_inputs.get("sys.file") if workflow_run_inputs.get("sys.file") else [] + file_list = workflow_run_inputs.get("sys.file") or [] query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" # get workflow_app_log_id - workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by( - tenant_id=tenant_id, - app_id=workflow_run.app_id, - workflow_run_id=workflow_run.id - ).first() + workflow_app_log_data = ( + db.session.query(WorkflowAppLog) + .filter_by(tenant_id=tenant_id, app_id=workflow_run.app_id, workflow_run_id=workflow_run.id) + .first() + ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None # get message_id - message_data = db.session.query(Message.id).filter_by( - conversation_id=conversation_id, - workflow_run_id=workflow_run_id - ).first() + message_data = ( + db.session.query(Message.id) + .filter_by(conversation_id=conversation_id, workflow_run_id=workflow_run_id) + .first() + ) message_id = str(message_data.id) if message_data else None metadata = { @@ -416,7 +462,7 @@ def message_trace(self, message_id): message_tokens=message_tokens, answer_tokens=message_data.answer_tokens, total_tokens=message_tokens + message_data.answer_tokens, - error=message_data.error if message_data.error else "", + error=message_data.error or "", inputs=inputs, outputs=message_data.answer, file_list=file_list, @@ -445,13 +491,13 @@ def moderation_trace(self, message_id, timer, **kwargs): # get workflow_app_log_id workflow_app_log_id = None if message_data.workflow_run_id: - workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by( - workflow_run_id=message_data.workflow_run_id - ).first() + workflow_app_log_data = ( + db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() + ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None moderation_trace_info = ModerationTraceInfo( - message_id=workflow_app_log_id if workflow_app_log_id else message_id, + message_id=workflow_app_log_id or message_id, inputs=inputs, message_data=message_data.to_dict(), flagged=moderation_result.flagged, @@ -485,13 +531,13 @@ def suggested_question_trace(self, message_id, timer, **kwargs): # get workflow_app_log_id workflow_app_log_id = None if message_data.workflow_run_id: - workflow_app_log_data = db.session.query(WorkflowAppLog).filter_by( - workflow_run_id=message_data.workflow_run_id - ).first() + workflow_app_log_data = ( + db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() + ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None suggested_question_trace_info = SuggestedQuestionTraceInfo( - message_id=workflow_app_log_id if workflow_app_log_id else message_id, + message_id=workflow_app_log_id or message_id, message_data=message_data.to_dict(), inputs=message_data.message, outputs=message_data.answer, @@ -533,7 +579,7 @@ def dataset_retrieval_trace(self, message_id, timer, **kwargs): dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( message_id=message_id, - inputs=message_data.query if message_data.query else message_data.inputs, + inputs=message_data.query or message_data.inputs, documents=[doc.model_dump() for doc in documents], start_time=timer.get("start"), end_time=timer.get("end"), @@ -544,9 +590,9 @@ def dataset_retrieval_trace(self, message_id, timer, **kwargs): return dataset_retrieval_trace_info def tool_trace(self, message_id, timer, **kwargs): - tool_name = kwargs.get('tool_name') - tool_inputs = kwargs.get('tool_inputs') - tool_outputs = kwargs.get('tool_outputs') + tool_name = kwargs.get("tool_name") + tool_inputs = kwargs.get("tool_inputs") + tool_outputs = kwargs.get("tool_outputs") message_data = get_message_data(message_id) if not message_data: return {} @@ -561,11 +607,11 @@ def tool_trace(self, message_id, timer, **kwargs): if tool_name in agent_thought.tools: created_time = agent_thought.created_at tool_meta_data = agent_thought.tool_meta.get(tool_name, {}) - tool_config = tool_meta_data.get('tool_config', {}) - time_cost = tool_meta_data.get('time_cost', 0) + tool_config = tool_meta_data.get("tool_config", {}) + time_cost = tool_meta_data.get("time_cost", 0) end_time = created_time + timedelta(seconds=time_cost) - error = tool_meta_data.get('error', "") - tool_parameters = tool_meta_data.get('tool_parameters', {}) + error = tool_meta_data.get("error", "") + tool_parameters = tool_meta_data.get("tool_parameters", {}) metadata = { "message_id": message_id, "tool_name": tool_name, @@ -659,14 +705,13 @@ def __init__(self, app_id=None, user_id=None): self.start_timer() def add_trace_task(self, trace_task: TraceTask): - global trace_manager_timer - global trace_manager_queue + global trace_manager_timer, trace_manager_queue try: if self.trace_instance: trace_task.app_id = self.app_id trace_manager_queue.put(trace_task) except Exception as e: - logging.debug(f"Error adding trace task: {e}") + logging.exception(f"Error adding trace task: {e}") finally: self.start_timer() @@ -685,14 +730,12 @@ def run(self): if tasks: self.send_to_celery(tasks) except Exception as e: - logging.debug(f"Error processing trace tasks: {e}") + logging.exception(f"Error processing trace tasks: {e}") def start_timer(self): global trace_manager_timer if trace_manager_timer is None or not trace_manager_timer.is_alive(): - trace_manager_timer = threading.Timer( - trace_manager_interval, self.run - ) + trace_manager_timer = threading.Timer(trace_manager_interval, self.run) trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}" trace_manager_timer.daemon = False trace_manager_timer.start() @@ -700,10 +743,17 @@ def start_timer(self): def send_to_celery(self, tasks: list[TraceTask]): with self.flask_app.app_context(): for task in tasks: + file_id = uuid4().hex trace_info = task.execute() - task_data = { + task_data = TaskData( + app_id=task.app_id, + trace_info_type=type(trace_info).__name__, + trace_info=trace_info.model_dump() if trace_info else None, + ) + file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json" + storage.save(file_path, task_data.model_dump_json().encode("utf-8")) + file_info = { + "file_id": file_id, "app_id": task.app_id, - "trace_info_type": type(trace_info).__name__, - "trace_info": trace_info.model_dump() if trace_info else {}, } - process_trace_tasks.delay(task_data) + process_trace_tasks.delay(file_info) diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index 3b2e04abb73288..3cd3fb57565a5b 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -6,12 +6,15 @@ def filter_none_values(data: dict): + new_data = {} for key, value in data.items(): if value is None: continue if isinstance(value, datetime): - data[key] = value.isoformat() - return {key: value for key, value in data.items() if value is not None} + new_data[key] = value.isoformat() + else: + new_data[key] = value + return new_data def get_message_data(message_id): @@ -20,19 +23,19 @@ def get_message_data(message_id): @contextmanager def measure_time(): - timing_info = {'start': datetime.now(), 'end': None} + timing_info = {"start": datetime.now(), "end": None} try: yield timing_info finally: - timing_info['end'] = datetime.now() + timing_info["end"] = datetime.now() def replace_text_with_content(data): if isinstance(data, dict): new_data = {} for key, value in data.items(): - if key == 'text': - new_data['content'] = value + if key == "text": + new_data["content"] = value else: new_data[key] = replace_text_with_content(value) return new_data diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 22420fea2cc02f..0f3f8249661bf0 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,45 +1,55 @@ -from typing import Optional, Union +from collections.abc import Sequence +from typing import Optional from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.file.file_obj import FileVar +from core.file import file_manager +from core.file.models import File from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import ( +from core.model_runtime.entities import ( AssistantPromptMessage, PromptMessage, + PromptMessageContent, PromptMessageRole, SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, ) +from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.prompt_transform import PromptTransform -from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.entities.variable_pool import VariablePool class AdvancedPromptTransform(PromptTransform): """ Advanced Prompt Transform for Workflow LLM Node. """ - def __init__(self, with_variable_tmpl: bool = False) -> None: - self.with_variable_tmpl = with_variable_tmpl - - def get_prompt(self, prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate], - inputs: dict, - query: str, - files: list[FileVar], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity, - query_prompt_template: Optional[str] = None) -> list[PromptMessage]: - inputs = {key: str(value) for key, value in inputs.items()} + def __init__( + self, + with_variable_tmpl: bool = False, + image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW, + ) -> None: + self.with_variable_tmpl = with_variable_tmpl + self.image_detail_config = image_detail_config + + def get_prompt( + self, + *, + prompt_template: Sequence[ChatModelMessage] | CompletionModelPromptTemplate, + inputs: dict[str, str], + query: str, + files: Sequence[File], + context: Optional[str], + memory_config: Optional[MemoryConfig], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> list[PromptMessage]: prompt_messages = [] - model_mode = ModelMode.value_of(model_config.mode) - if model_mode == ModelMode.COMPLETION: + if isinstance(prompt_template, CompletionModelPromptTemplate): prompt_messages = self._get_completion_model_prompt_messages( prompt_template=prompt_template, inputs=inputs, @@ -48,32 +58,33 @@ def get_prompt(self, prompt_template: Union[list[ChatModelMessage], CompletionMo context=context, memory_config=memory_config, memory=memory, - model_config=model_config + model_config=model_config, ) - elif model_mode == ModelMode.CHAT: + elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template): prompt_messages = self._get_chat_model_prompt_messages( prompt_template=prompt_template, inputs=inputs, query=query, - query_prompt_template=query_prompt_template, files=files, context=context, memory_config=memory_config, memory=memory, - model_config=model_config + model_config=model_config, ) return prompt_messages - def _get_completion_model_prompt_messages(self, - prompt_template: CompletionModelPromptTemplate, - inputs: dict, - query: Optional[str], - files: list[FileVar], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: + def _get_completion_model_prompt_messages( + self, + prompt_template: CompletionModelPromptTemplate, + inputs: dict, + query: Optional[str], + files: Sequence[File], + context: Optional[str], + memory_config: Optional[MemoryConfig], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> list[PromptMessage]: """ Get completion model prompt messages. """ @@ -81,11 +92,11 @@ def _get_completion_model_prompt_messages(self, prompt_messages = [] - if prompt_template.edition_type == 'basic' or not prompt_template.edition_type: - prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + if prompt_template.edition_type == "basic" or not prompt_template.edition_type: + parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) + prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} - prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) + prompt_inputs = self._set_context_variable(context, parser, prompt_inputs) if memory and memory_config: role_prefix = memory_config.role_prefix @@ -94,17 +105,15 @@ def _get_completion_model_prompt_messages(self, memory_config=memory_config, raw_prompt=raw_prompt, role_prefix=role_prefix, - prompt_template=prompt_template, + parser=parser, prompt_inputs=prompt_inputs, - model_config=model_config + model_config=model_config, ) if query: - prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs) + prompt_inputs = self._set_query_variable(query, parser, prompt_inputs) - prompt = prompt_template.format( - prompt_inputs - ) + prompt = parser.format(prompt_inputs) else: prompt = raw_prompt prompt_inputs = inputs @@ -112,9 +121,10 @@ def _get_completion_model_prompt_messages(self, prompt = Jinja2Formatter.format(prompt, prompt_inputs) if files: - prompt_message_contents = [TextPromptMessageContent(data=prompt)] + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=prompt)) for file in files: - prompt_message_contents.append(file.prompt_message_content) + prompt_message_contents.append(file_manager.to_prompt_message_content(file)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: @@ -122,42 +132,45 @@ def _get_completion_model_prompt_messages(self, return prompt_messages - def _get_chat_model_prompt_messages(self, - prompt_template: list[ChatModelMessage], - inputs: dict, - query: Optional[str], - files: list[FileVar], - context: Optional[str], - memory_config: Optional[MemoryConfig], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity, - query_prompt_template: Optional[str] = None) -> list[PromptMessage]: + def _get_chat_model_prompt_messages( + self, + prompt_template: list[ChatModelMessage], + inputs: dict, + query: Optional[str], + files: Sequence[File], + context: Optional[str], + memory_config: Optional[MemoryConfig], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> list[PromptMessage]: """ Get chat model prompt messages. """ - raw_prompt_list = prompt_template - prompt_messages = [] - - for prompt_item in raw_prompt_list: + for prompt_item in prompt_template: raw_prompt = prompt_item.text - if prompt_item.edition_type == 'basic' or not prompt_item.edition_type: - prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - - prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) - - prompt = prompt_template.format( - prompt_inputs - ) - elif prompt_item.edition_type == 'jinja2': + if prompt_item.edition_type == "basic" or not prompt_item.edition_type: + if self.with_variable_tmpl: + vp = VariablePool() + for k, v in inputs.items(): + if k.startswith("#"): + vp.add(k[1:-1].split("."), v) + raw_prompt = raw_prompt.replace("{{#context#}}", context or "") + prompt = vp.convert_template(raw_prompt).text + else: + parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) + prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} + prompt_inputs = self._set_context_variable( + context=context, parser=parser, prompt_inputs=prompt_inputs + ) + prompt = parser.format(prompt_inputs) + elif prompt_item.edition_type == "jinja2": prompt = raw_prompt prompt_inputs = inputs - - prompt = Jinja2Formatter.format(prompt, prompt_inputs) + prompt = Jinja2Formatter.format(template=prompt, inputs=prompt_inputs) else: - raise ValueError(f'Invalid edition type: {prompt_item.edition_type}') + raise ValueError(f"Invalid edition type: {prompt_item.edition_type}") if prompt_item.role == PromptMessageRole.USER: prompt_messages.append(UserPromptMessage(content=prompt)) @@ -166,28 +179,25 @@ def _get_chat_model_prompt_messages(self, elif prompt_item.role == PromptMessageRole.ASSISTANT: prompt_messages.append(AssistantPromptMessage(content=prompt)) - if query and query_prompt_template: - prompt_template = PromptTemplateParser( - template=query_prompt_template, - with_variable_tmpl=self.with_variable_tmpl + if query and memory_config and memory_config.query_prompt_template: + parser = PromptTemplateParser( + template=memory_config.query_prompt_template, with_variable_tmpl=self.with_variable_tmpl ) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - prompt_inputs['#sys.query#'] = query + prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} + prompt_inputs["#sys.query#"] = query - prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs) + prompt_inputs = self._set_context_variable(context, parser, prompt_inputs) - query = prompt_template.format( - prompt_inputs - ) + query = parser.format(prompt_inputs) if memory and memory_config: prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) - if files: - prompt_message_contents = [TextPromptMessageContent(data=query)] + if files and query is not None: + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=query)) for file in files: - prompt_message_contents.append(file.prompt_message_content) - + prompt_message_contents.append(file_manager.to_prompt_message_content(file)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: prompt_messages.append(UserPromptMessage(content=query)) @@ -199,19 +209,19 @@ def _get_chat_model_prompt_messages(self, # get last user message content and add files prompt_message_contents = [TextPromptMessageContent(data=last_message.content)] for file in files: - prompt_message_contents.append(file.prompt_message_content) + prompt_message_contents.append(file_manager.to_prompt_message_content(file)) last_message.content = prompt_message_contents else: - prompt_message_contents = [TextPromptMessageContent(data='')] # not for query + prompt_message_contents = [TextPromptMessageContent(data="")] # not for query for file in files: - prompt_message_contents.append(file.prompt_message_content) + prompt_message_contents.append(file_manager.to_prompt_message_content(file)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: prompt_message_contents = [TextPromptMessageContent(data=query)] for file in files: - prompt_message_contents.append(file.prompt_message_content) + prompt_message_contents.append(file_manager.to_prompt_message_content(file)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) elif query: @@ -219,39 +229,40 @@ def _get_chat_model_prompt_messages(self, return prompt_messages - def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: - if '#context#' in prompt_template.variable_keys: + def _set_context_variable(self, context: str | None, parser: PromptTemplateParser, prompt_inputs: dict) -> dict: + if "#context#" in parser.variable_keys: if context: - prompt_inputs['#context#'] = context + prompt_inputs["#context#"] = context else: - prompt_inputs['#context#'] = '' + prompt_inputs["#context#"] = "" return prompt_inputs - def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict: - if '#query#' in prompt_template.variable_keys: + def _set_query_variable(self, query: str, parser: PromptTemplateParser, prompt_inputs: dict) -> dict: + if "#query#" in parser.variable_keys: if query: - prompt_inputs['#query#'] = query + prompt_inputs["#query#"] = query else: - prompt_inputs['#query#'] = '' + prompt_inputs["#query#"] = "" return prompt_inputs - def _set_histories_variable(self, memory: TokenBufferMemory, - memory_config: MemoryConfig, - raw_prompt: str, - role_prefix: MemoryConfig.RolePrefix, - prompt_template: PromptTemplateParser, - prompt_inputs: dict, - model_config: ModelConfigWithCredentialsEntity) -> dict: - if '#histories#' in prompt_template.variable_keys: + def _set_histories_variable( + self, + memory: TokenBufferMemory, + memory_config: MemoryConfig, + raw_prompt: str, + role_prefix: MemoryConfig.RolePrefix, + parser: PromptTemplateParser, + prompt_inputs: dict, + model_config: ModelConfigWithCredentialsEntity, + ) -> dict: + if "#histories#" in parser.variable_keys: if memory: - inputs = {'#histories#': '', **prompt_inputs} - prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - tmp_human_message = UserPromptMessage( - content=prompt_template.format(prompt_inputs) - ) + inputs = {"#histories#": "", **prompt_inputs} + parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) + prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} + tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs)) rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) @@ -260,10 +271,10 @@ def _set_histories_variable(self, memory: TokenBufferMemory, memory_config=memory_config, max_token_limit=rest_tokens, human_prefix=role_prefix.user, - ai_prefix=role_prefix.assistant + ai_prefix=role_prefix.assistant, ) - prompt_inputs['#histories#'] = histories + prompt_inputs["#histories#"] = histories else: - prompt_inputs['#histories#'] = '' + prompt_inputs["#histories#"] = "" return prompt_inputs diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index af0075ea9154fc..caa1793ea8c039 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -17,12 +17,14 @@ class AgentHistoryPromptTransform(PromptTransform): """ History Prompt Transform for Agent App """ - def __init__(self, - model_config: ModelConfigWithCredentialsEntity, - prompt_messages: list[PromptMessage], - history_messages: list[PromptMessage], - memory: Optional[TokenBufferMemory] = None, - ): + + def __init__( + self, + model_config: ModelConfigWithCredentialsEntity, + prompt_messages: list[PromptMessage], + history_messages: list[PromptMessage], + memory: Optional[TokenBufferMemory] = None, + ): self.model_config = model_config self.prompt_messages = prompt_messages self.history_messages = history_messages @@ -45,9 +47,7 @@ def get_prompt(self) -> list[PromptMessage]: model_type_instance = cast(LargeLanguageModel, model_type_instance) curr_message_tokens = model_type_instance.get_num_tokens( - self.memory.model_instance.model, - self.memory.model_instance.credentials, - self.history_messages + self.memory.model_instance.model, self.memory.model_instance.credentials, self.history_messages ) if curr_message_tokens <= max_token_limit: return self.history_messages @@ -63,9 +63,7 @@ def get_prompt(self) -> list[PromptMessage]: # a message is start with UserPromptMessage if isinstance(prompt_message, UserPromptMessage): curr_message_tokens = model_type_instance.get_num_tokens( - self.memory.model_instance.model, - self.memory.model_instance.credentials, - prompt_messages + self.memory.model_instance.model, self.memory.model_instance.credentials, prompt_messages ) # if current message token is overflow, drop all the prompts in current message and break if curr_message_tokens > max_token_limit: diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py index 61df69163cba1c..c8e7b414dffe85 100644 --- a/api/core/prompt/entities/advanced_prompt_entities.py +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -9,27 +9,31 @@ class ChatModelMessage(BaseModel): """ Chat Message. """ + text: str role: PromptMessageRole - edition_type: Optional[Literal['basic', 'jinja2']] = None + edition_type: Optional[Literal["basic", "jinja2"]] = None class CompletionModelPromptTemplate(BaseModel): """ Completion Model Prompt Template. """ + text: str - edition_type: Optional[Literal['basic', 'jinja2']] = None + edition_type: Optional[Literal["basic", "jinja2"]] = None class MemoryConfig(BaseModel): """ Memory Config. """ + class RolePrefix(BaseModel): """ Role Prefix. """ + user: str assistant: str @@ -37,6 +41,7 @@ class WindowConfig(BaseModel): """ Window Config. """ + enabled: bool size: Optional[int] = None diff --git a/api/core/prompt/prompt_templates/advanced_prompt_templates.py b/api/core/prompt/prompt_templates/advanced_prompt_templates.py index da40534d99485b..0ab7f526cc4fee 100644 --- a/api/core/prompt/prompt_templates/advanced_prompt_templates.py +++ b/api/core/prompt/prompt_templates/advanced_prompt_templates.py @@ -1,83 +1,45 @@ -CONTEXT = "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{#context#}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n" +CONTEXT = "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{#context#}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n" # noqa: E501 -BAICHUAN_CONTEXT = "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n" +BAICHUAN_CONTEXT = "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n" # noqa: E501 CHAT_APP_COMPLETION_PROMPT_CONFIG = { "completion_prompt_config": { "prompt": { - "text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside XML tags.\n\n\n{{#histories#}}\n\n\n\nHuman: {{#query#}}\n\nAssistant: " + "text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside XML tags.\n\n\n{{#histories#}}\n\n\n\nHuman: {{#query#}}\n\nAssistant: " # noqa: E501 }, - "conversation_histories_role": { - "user_prefix": "Human", - "assistant_prefix": "Assistant" - } + "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, }, - "stop": ["Human:"] + "stop": ["Human:"], } -CHAT_APP_CHAT_PROMPT_CONFIG = { - "chat_prompt_config": { - "prompt": [{ - "role": "system", - "text": "{{#pre_prompt#}}" - }] - } -} +CHAT_APP_CHAT_PROMPT_CONFIG = {"chat_prompt_config": {"prompt": [{"role": "system", "text": "{{#pre_prompt#}}"}]}} -COMPLETION_APP_CHAT_PROMPT_CONFIG = { - "chat_prompt_config": { - "prompt": [{ - "role": "user", - "text": "{{#pre_prompt#}}" - }] - } -} +COMPLETION_APP_CHAT_PROMPT_CONFIG = {"chat_prompt_config": {"prompt": [{"role": "user", "text": "{{#pre_prompt#}}"}]}} COMPLETION_APP_COMPLETION_PROMPT_CONFIG = { - "completion_prompt_config": { - "prompt": { - "text": "{{#pre_prompt#}}" - } - }, - "stop": ["Human:"] + "completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}}, + "stop": ["Human:"], } BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = { "completion_prompt_config": { "prompt": { - "text": "{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}" + "text": "{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}" # noqa: E501 }, - "conversation_histories_role": { - "user_prefix": "用户", - "assistant_prefix": "助手" - } + "conversation_histories_role": {"user_prefix": "用户", "assistant_prefix": "助手"}, }, - "stop": ["用户:"] + "stop": ["用户:"], } -BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = { - "chat_prompt_config": { - "prompt": [{ - "role": "system", - "text": "{{#pre_prompt#}}" - }] - } +BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = { + "chat_prompt_config": {"prompt": [{"role": "system", "text": "{{#pre_prompt#}}"}]} } BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = { - "chat_prompt_config": { - "prompt": [{ - "role": "user", - "text": "{{#pre_prompt#}}" - }] - } + "chat_prompt_config": {"prompt": [{"role": "user", "text": "{{#pre_prompt#}}"}]} } BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = { - "completion_prompt_config": { - "prompt": { - "text": "{{#pre_prompt#}}" - } - }, - "stop": ["用户:"] + "completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}}, + "stop": ["用户:"], } diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index b86d3fa815d2ce..87acdb3c49cc01 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -9,75 +9,78 @@ class PromptTransform: - def _append_chat_histories(self, memory: TokenBufferMemory, - memory_config: MemoryConfig, - prompt_messages: list[PromptMessage], - model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]: + def _append_chat_histories( + self, + memory: TokenBufferMemory, + memory_config: MemoryConfig, + prompt_messages: list[PromptMessage], + model_config: ModelConfigWithCredentialsEntity, + ) -> list[PromptMessage]: rest_tokens = self._calculate_rest_token(prompt_messages, model_config) histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens) prompt_messages.extend(histories) return prompt_messages - def _calculate_rest_token(self, prompt_messages: list[PromptMessage], - model_config: ModelConfigWithCredentialsEntity) -> int: + def _calculate_rest_token( + self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity + ) -> int: rest_tokens = 2000 model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) - curr_message_tokens = model_instance.get_llm_num_tokens( - prompt_messages - ) + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template) + ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens rest_tokens = max(rest_tokens, 0) return rest_tokens - def _get_history_messages_from_memory(self, memory: TokenBufferMemory, - memory_config: MemoryConfig, - max_token_limit: int, - human_prefix: Optional[str] = None, - ai_prefix: Optional[str] = None) -> str: + def _get_history_messages_from_memory( + self, + memory: TokenBufferMemory, + memory_config: MemoryConfig, + max_token_limit: int, + human_prefix: Optional[str] = None, + ai_prefix: Optional[str] = None, + ) -> str: """Get memory messages.""" - kwargs = { - "max_token_limit": max_token_limit - } + kwargs = {"max_token_limit": max_token_limit} if human_prefix: - kwargs['human_prefix'] = human_prefix + kwargs["human_prefix"] = human_prefix if ai_prefix: - kwargs['ai_prefix'] = ai_prefix + kwargs["ai_prefix"] = ai_prefix if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0: - kwargs['message_limit'] = memory_config.window.size + kwargs["message_limit"] = memory_config.window.size - return memory.get_history_prompt_text( - **kwargs - ) + return memory.get_history_prompt_text(**kwargs) - def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory, - memory_config: MemoryConfig, - max_token_limit: int) -> list[PromptMessage]: + def _get_history_messages_list_from_memory( + self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int + ) -> list[PromptMessage]: """Get memory messages.""" return memory.get_history_prompt_messages( max_token_limit=max_token_limit, message_limit=memory_config.window.size - if (memory_config.window.enabled - and memory_config.window.size is not None - and memory_config.window.size > 0) - else None + if ( + memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0 + ) + else None, ) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index fd7ed0181be2f2..5a3481b96388e4 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -5,9 +5,11 @@ from core.app.app_config.entities import PromptTemplateEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.file import file_manager from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( PromptMessage, + PromptMessageContent, SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, @@ -18,15 +20,15 @@ from models.model import AppMode if TYPE_CHECKING: - from core.file.file_obj import FileVar + from core.file.models import File -class ModelMode(enum.Enum): - COMPLETION = 'completion' - CHAT = 'chat' +class ModelMode(str, enum.Enum): + COMPLETION = "completion" + CHAT = "chat" @classmethod - def value_of(cls, value: str) -> 'ModelMode': + def value_of(cls, value: str) -> "ModelMode": """ Get value of given mode. @@ -36,7 +38,7 @@ def value_of(cls, value: str) -> 'ModelMode': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") prompt_file_contents = {} @@ -47,16 +49,17 @@ class SimplePromptTransform(PromptTransform): Simple Prompt Transform for Chatbot App Basic Mode. """ - def get_prompt(self, - app_mode: AppMode, - prompt_template_entity: PromptTemplateEntity, - inputs: dict, - query: str, - files: list["FileVar"], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) -> \ - tuple[list[PromptMessage], Optional[list[str]]]: + def get_prompt( + self, + app_mode: AppMode, + prompt_template_entity: PromptTemplateEntity, + inputs: dict, + query: str, + files: list["File"], + context: Optional[str], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: inputs = {key: str(value) for key, value in inputs.items()} model_mode = ModelMode.value_of(model_config.mode) @@ -69,7 +72,7 @@ def get_prompt(self, files=files, context=context, memory=memory, - model_config=model_config + model_config=model_config, ) else: prompt_messages, stops = self._get_completion_model_prompt_messages( @@ -80,19 +83,21 @@ def get_prompt(self, files=files, context=context, memory=memory, - model_config=model_config + model_config=model_config, ) return prompt_messages, stops - def get_prompt_str_and_rules(self, app_mode: AppMode, - model_config: ModelConfigWithCredentialsEntity, - pre_prompt: str, - inputs: dict, - query: Optional[str] = None, - context: Optional[str] = None, - histories: Optional[str] = None, - ) -> tuple[str, dict]: + def get_prompt_str_and_rules( + self, + app_mode: AppMode, + model_config: ModelConfigWithCredentialsEntity, + pre_prompt: str, + inputs: dict, + query: Optional[str] = None, + context: Optional[str] = None, + histories: Optional[str] = None, + ) -> tuple[str, dict]: # get prompt template prompt_template_config = self.get_prompt_template( app_mode=app_mode, @@ -101,74 +106,75 @@ def get_prompt_str_and_rules(self, app_mode: AppMode, pre_prompt=pre_prompt, has_context=context is not None, query_in_prompt=query is not None, - with_memory_prompt=histories is not None + with_memory_prompt=histories is not None, ) - variables = {k: inputs[k] for k in prompt_template_config['custom_variable_keys'] if k in inputs} + variables = {k: inputs[k] for k in prompt_template_config["custom_variable_keys"] if k in inputs} - for v in prompt_template_config['special_variable_keys']: + for v in prompt_template_config["special_variable_keys"]: # support #context#, #query# and #histories# - if v == '#context#': - variables['#context#'] = context if context else '' - elif v == '#query#': - variables['#query#'] = query if query else '' - elif v == '#histories#': - variables['#histories#'] = histories if histories else '' - - prompt_template = prompt_template_config['prompt_template'] + if v == "#context#": + variables["#context#"] = context or "" + elif v == "#query#": + variables["#query#"] = query or "" + elif v == "#histories#": + variables["#histories#"] = histories or "" + + prompt_template = prompt_template_config["prompt_template"] prompt = prompt_template.format(variables) - return prompt, prompt_template_config['prompt_rules'] + return prompt, prompt_template_config["prompt_rules"] - def get_prompt_template(self, app_mode: AppMode, - provider: str, - model: str, - pre_prompt: str, - has_context: bool, - query_in_prompt: bool, - with_memory_prompt: bool = False) -> dict: - prompt_rules = self._get_prompt_rule( - app_mode=app_mode, - provider=provider, - model=model - ) + def get_prompt_template( + self, + app_mode: AppMode, + provider: str, + model: str, + pre_prompt: str, + has_context: bool, + query_in_prompt: bool, + with_memory_prompt: bool = False, + ) -> dict: + prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model) custom_variable_keys = [] special_variable_keys = [] - prompt = '' - for order in prompt_rules['system_prompt_orders']: - if order == 'context_prompt' and has_context: - prompt += prompt_rules['context_prompt'] - special_variable_keys.append('#context#') - elif order == 'pre_prompt' and pre_prompt: - prompt += pre_prompt + '\n' + prompt = "" + for order in prompt_rules["system_prompt_orders"]: + if order == "context_prompt" and has_context: + prompt += prompt_rules["context_prompt"] + special_variable_keys.append("#context#") + elif order == "pre_prompt" and pre_prompt: + prompt += pre_prompt + "\n" pre_prompt_template = PromptTemplateParser(template=pre_prompt) custom_variable_keys = pre_prompt_template.variable_keys - elif order == 'histories_prompt' and with_memory_prompt: - prompt += prompt_rules['histories_prompt'] - special_variable_keys.append('#histories#') + elif order == "histories_prompt" and with_memory_prompt: + prompt += prompt_rules["histories_prompt"] + special_variable_keys.append("#histories#") if query_in_prompt: - prompt += prompt_rules.get('query_prompt', '{{#query#}}') - special_variable_keys.append('#query#') + prompt += prompt_rules.get("query_prompt", "{{#query#}}") + special_variable_keys.append("#query#") return { "prompt_template": PromptTemplateParser(template=prompt), "custom_variable_keys": custom_variable_keys, "special_variable_keys": special_variable_keys, - "prompt_rules": prompt_rules + "prompt_rules": prompt_rules, } - def _get_chat_model_prompt_messages(self, app_mode: AppMode, - pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - files: list["FileVar"], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: + def _get_chat_model_prompt_messages( + self, + app_mode: AppMode, + pre_prompt: str, + inputs: dict, + query: str, + context: Optional[str], + files: list["File"], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: prompt_messages = [] # get prompt @@ -178,7 +184,7 @@ def _get_chat_model_prompt_messages(self, app_mode: AppMode, pre_prompt=pre_prompt, inputs=inputs, query=None, - context=context + context=context, ) if prompt and query: @@ -193,7 +199,7 @@ def _get_chat_model_prompt_messages(self, app_mode: AppMode, ) ), prompt_messages=prompt_messages, - model_config=model_config + model_config=model_config, ) if query: @@ -203,15 +209,17 @@ def _get_chat_model_prompt_messages(self, app_mode: AppMode, return prompt_messages, None - def _get_completion_model_prompt_messages(self, app_mode: AppMode, - pre_prompt: str, - inputs: dict, - query: str, - context: Optional[str], - files: list["FileVar"], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: + def _get_completion_model_prompt_messages( + self, + app_mode: AppMode, + pre_prompt: str, + inputs: dict, + query: str, + context: Optional[str], + files: list["File"], + memory: Optional[TokenBufferMemory], + model_config: ModelConfigWithCredentialsEntity, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: # get prompt prompt, prompt_rules = self.get_prompt_str_and_rules( app_mode=app_mode, @@ -219,13 +227,11 @@ def _get_completion_model_prompt_messages(self, app_mode: AppMode, pre_prompt=pre_prompt, inputs=inputs, query=query, - context=context + context=context, ) if memory: - tmp_human_message = UserPromptMessage( - content=prompt - ) + tmp_human_message = UserPromptMessage(content=prompt) rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) histories = self._get_history_messages_from_memory( @@ -236,8 +242,8 @@ def _get_completion_model_prompt_messages(self, app_mode: AppMode, ) ), max_token_limit=rest_tokens, - human_prefix=prompt_rules.get('human_prefix', 'Human'), - ai_prefix=prompt_rules.get('assistant_prefix', 'Assistant') + human_prefix=prompt_rules.get("human_prefix", "Human"), + ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"), ) # get prompt @@ -248,20 +254,21 @@ def _get_completion_model_prompt_messages(self, app_mode: AppMode, inputs=inputs, query=query, context=context, - histories=histories + histories=histories, ) - stops = prompt_rules.get('stops') + stops = prompt_rules.get("stops") if stops is not None and len(stops) == 0: stops = None return [self.get_last_user_message(prompt, files)], stops - def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage: + def get_last_user_message(self, prompt: str, files: list["File"]) -> UserPromptMessage: if files: - prompt_message_contents = [TextPromptMessageContent(data=prompt)] + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=prompt)) for file in files: - prompt_message_contents.append(file.prompt_message_content) + prompt_message_contents.append(file_manager.to_prompt_message_content(file)) prompt_message = UserPromptMessage(content=prompt_message_contents) else: @@ -277,22 +284,18 @@ def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict :param model: model name :return: """ - prompt_file_name = self._prompt_file_name( - app_mode=app_mode, - provider=provider, - model=model - ) + prompt_file_name = self._prompt_file_name(app_mode=app_mode, provider=provider, model=model) # Check if the prompt file is already loaded if prompt_file_name in prompt_file_contents: return prompt_file_contents[prompt_file_name] # Get the absolute path of the subdirectory - prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'prompt_templates') - json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json') + prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates") + json_file_path = os.path.join(prompt_path, f"{prompt_file_name}.json") # Open the JSON file and read its content - with open(json_file_path, encoding='utf-8') as json_file: + with open(json_file_path, encoding="utf-8") as json_file: content = json.load(json_file) # Store the content of the prompt file @@ -303,21 +306,21 @@ def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str: # baichuan is_baichuan = False - if provider == 'baichuan': + if provider == "baichuan": is_baichuan = True else: baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"] - if provider in baichuan_supported_providers and 'baichuan' in model.lower(): + if provider in baichuan_supported_providers and "baichuan" in model.lower(): is_baichuan = True if is_baichuan: if app_mode == AppMode.COMPLETION: - return 'baichuan_completion' + return "baichuan_completion" else: - return 'baichuan_chat' + return "baichuan_chat" # common if app_mode == AppMode.COMPLETION: - return 'common_completion' + return "common_completion" else: - return 'common_chat' + return "common_chat" diff --git a/api/core/prompt/utils/extract_thread_messages.py b/api/core/prompt/utils/extract_thread_messages.py new file mode 100644 index 00000000000000..f7aef76c87edc8 --- /dev/null +++ b/api/core/prompt/utils/extract_thread_messages.py @@ -0,0 +1,24 @@ +from typing import Any + +from constants import UUID_NIL + + +def extract_thread_messages(messages: list[Any]): + thread_messages = [] + next_message = None + + for message in messages: + if not message.parent_message_id: + # If the message is regenerated and does not have a parent message, it is the start of a new thread + thread_messages.append(message) + break + + if not next_message: + thread_messages.append(message) + next_message = message.parent_message_id + else: + if next_message in {message.id, UUID_NIL}: + thread_messages.append(message) + next_message = message.parent_message_id + + return thread_messages diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index befdceeda505fb..5eec5e3c99a00f 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,7 +1,8 @@ from typing import cast -from core.model_runtime.entities.message_entities import ( +from core.model_runtime.entities import ( AssistantPromptMessage, + AudioPromptMessageContent, ImagePromptMessageContent, PromptMessage, PromptMessageContentType, @@ -21,59 +22,66 @@ def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[ :return: """ prompts = [] - if model_mode == ModelMode.CHAT.value: + if model_mode == ModelMode.CHAT: tool_calls = [] for prompt_message in prompt_messages: if prompt_message.role == PromptMessageRole.USER: - role = 'user' + role = "user" elif prompt_message.role == PromptMessageRole.ASSISTANT: - role = 'assistant' + role = "assistant" if isinstance(prompt_message, AssistantPromptMessage): - tool_calls = [{ - 'id': tool_call.id, - 'type': 'function', - 'function': { - 'name': tool_call.function.name, - 'arguments': tool_call.function.arguments, + tool_calls = [ + { + "id": tool_call.id, + "type": "function", + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, } - } for tool_call in prompt_message.tool_calls] + for tool_call in prompt_message.tool_calls + ] elif prompt_message.role == PromptMessageRole.SYSTEM: - role = 'system' + role = "system" elif prompt_message.role == PromptMessageRole.TOOL: - role = 'tool' + role = "tool" else: continue - text = '' + text = "" files = [] if isinstance(prompt_message.content, list): for content in prompt_message.content: - if content.type == PromptMessageContentType.TEXT: - content = cast(TextPromptMessageContent, content) + if isinstance(content, TextPromptMessageContent): text += content.data - else: - content = cast(ImagePromptMessageContent, content) - files.append({ - "type": 'image', - "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], - "detail": content.detail.value - }) + elif isinstance(content, ImagePromptMessageContent): + files.append( + { + "type": "image", + "data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:], + "detail": content.detail.value, + } + ) + elif isinstance(content, AudioPromptMessageContent): + files.append( + { + "type": "audio", + "data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:], + "format": content.format, + } + ) else: text = prompt_message.content - prompt = { - "role": role, - "text": text, - "files": files - } - + prompt = {"role": role, "text": text, "files": files} + if tool_calls: - prompt['tool_calls'] = tool_calls + prompt["tool_calls"] = tool_calls prompts.append(prompt) else: prompt_message = prompt_messages[0] - text = '' + text = "" files = [] if isinstance(prompt_message.content, list): for content in prompt_message.content: @@ -82,21 +90,23 @@ def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[ text += content.data else: content = cast(ImagePromptMessageContent, content) - files.append({ - "type": 'image', - "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], - "detail": content.detail.value - }) + files.append( + { + "type": "image", + "data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:], + "detail": content.detail.value, + } + ) else: text = prompt_message.content params = { - "role": 'user', + "role": "user", "text": text, } if files: - params['files'] = files + params["files"] = files prompts.append(params) diff --git a/api/core/prompt/utils/prompt_template_parser.py b/api/core/prompt/utils/prompt_template_parser.py index 3e68492df2f2d7..0fd08c5d3c1a3e 100644 --- a/api/core/prompt/utils/prompt_template_parser.py +++ b/api/core/prompt/utils/prompt_template_parser.py @@ -33,13 +33,13 @@ def replacer(match): key = match.group(1) value = inputs.get(key, match.group(0)) # return original matched string if key not found - if remove_template_variables: + if remove_template_variables and isinstance(value, str): return PromptTemplateParser.remove_template_variables(value, self.with_variable_tmpl) return value prompt = re.sub(self.regex, replacer, self.template) - return re.sub(r'<\|.*?\|>', '', prompt) + return re.sub(r"<\|.*?\|>", "", prompt) @classmethod def remove_template_variables(cls, text: str, with_variable_tmpl: bool = False): - return re.sub(WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX, r'{\1}', text) + return re.sub(WITH_VARIABLE_TMPL_REGEX if with_variable_tmpl else REGEX, r"{\1}", text) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 65a4cada88f15a..3a1fe300dfd311 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -5,6 +5,7 @@ from sqlalchemy.exc import IntegrityError +from configs import dify_config from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle from core.entities.provider_entities import ( @@ -18,12 +19,9 @@ ) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType +from core.helper.position_helper import is_filtered from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import ( - CredentialFormSchema, - FormType, - ProviderEntity, -) +from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity from core.model_runtime.model_providers import model_provider_factory from extensions import ext_hosting_provider from extensions.ext_database import db @@ -45,6 +43,7 @@ class ProviderManager: """ ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers. """ + def __init__(self) -> None: self.decoding_rsa_key = None self.decoding_cipher_rsa = None @@ -91,8 +90,7 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations: # Initialize trial provider records if not exist provider_name_to_provider_records_dict = self._init_trial_provider_records( - tenant_id, - provider_name_to_provider_records_dict + tenant_id, provider_name_to_provider_records_dict ) # Get all provider model records of the workspace @@ -108,33 +106,34 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations: provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id) # Get All load balancing configs - provider_name_to_provider_load_balancing_model_configs_dict \ - = self._get_all_provider_load_balancing_configs(tenant_id) - - provider_configurations = ProviderConfigurations( - tenant_id=tenant_id + provider_name_to_provider_load_balancing_model_configs_dict = self._get_all_provider_load_balancing_configs( + tenant_id ) + provider_configurations = ProviderConfigurations(tenant_id=tenant_id) + # Construct ProviderConfiguration objects for each provider for provider_entity in provider_entities: + # handle include, exclude + if is_filtered( + include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, + exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, + data=provider_entity, + name_func=lambda x: x.provider, + ): + continue + provider_name = provider_entity.provider provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, []) provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, []) # Convert to custom configuration custom_configuration = self._to_custom_configuration( - tenant_id, - provider_entity, - provider_records, - provider_model_records + tenant_id, provider_entity, provider_records, provider_model_records ) # Convert to system configuration - system_configuration = self._to_system_configuration( - tenant_id, - provider_entity, - provider_records - ) + system_configuration = self._to_system_configuration(tenant_id, provider_entity, provider_records) # Get preferred provider type preferred_provider_type_record = provider_name_to_preferred_model_provider_records_dict.get(provider_name) @@ -164,14 +163,15 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations: provider_model_settings = provider_name_to_provider_model_settings_dict.get(provider_name) # Get provider load balancing configs - provider_load_balancing_configs \ - = provider_name_to_provider_load_balancing_model_configs_dict.get(provider_name) + provider_load_balancing_configs = provider_name_to_provider_load_balancing_model_configs_dict.get( + provider_name + ) # Convert to model settings model_settings = self._to_model_settings( provider_entity=provider_entity, provider_model_settings=provider_model_settings, - load_balancing_model_configs=provider_load_balancing_configs + load_balancing_model_configs=provider_load_balancing_configs, ) provider_configuration = ProviderConfiguration( @@ -181,7 +181,7 @@ def get_configurations(self, tenant_id: str) -> ProviderConfigurations: using_provider_type=using_provider_type, system_configuration=system_configuration, custom_configuration=custom_configuration, - model_settings=model_settings + model_settings=model_settings, ) provider_configurations[provider_name] = provider_configuration @@ -210,7 +210,7 @@ def get_provider_model_bundle(self, tenant_id: str, provider: str, model_type: M return ProviderModelBundle( configuration=provider_configuration, provider_instance=provider_instance, - model_type_instance=model_type_instance + model_type_instance=model_type_instance, ) def get_default_model(self, tenant_id: str, model_type: ModelType) -> Optional[DefaultModelEntity]: @@ -222,11 +222,14 @@ def get_default_model(self, tenant_id: str, model_type: ModelType) -> Optional[D :return: """ # Get the corresponding TenantDefaultModel record - default_model = db.session.query(TenantDefaultModel) \ + default_model = ( + db.session.query(TenantDefaultModel) .filter( - TenantDefaultModel.tenant_id == tenant_id, - TenantDefaultModel.model_type == model_type.to_origin_model_type() - ).first() + TenantDefaultModel.tenant_id == tenant_id, + TenantDefaultModel.model_type == model_type.to_origin_model_type(), + ) + .first() + ) # If it does not exist, get the first available provider model from get_configurations # and update the TenantDefaultModel record @@ -235,20 +238,18 @@ def get_default_model(self, tenant_id: str, model_type: ModelType) -> Optional[D provider_configurations = self.get_configurations(tenant_id) # get available models from provider_configurations - available_models = provider_configurations.get_models( - model_type=model_type, - only_active=True - ) + available_models = provider_configurations.get_models(model_type=model_type, only_active=True) if available_models: - available_model = next((model for model in available_models if model.model == "gpt-4"), - available_models[0]) + available_model = next( + (model for model in available_models if model.model == "gpt-4"), available_models[0] + ) default_model = TenantDefaultModel( tenant_id=tenant_id, model_type=model_type.to_origin_model_type(), provider_name=available_model.provider.provider, - model_name=available_model.model + model_name=available_model.model, ) db.session.add(default_model) db.session.commit() @@ -267,12 +268,28 @@ def get_default_model(self, tenant_id: str, model_type: ModelType) -> Optional[D label=provider_schema.label, icon_small=provider_schema.icon_small, icon_large=provider_schema.icon_large, - supported_model_types=provider_schema.supported_model_types - ) + supported_model_types=provider_schema.supported_model_types, + ), ) - def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \ - -> TenantDefaultModel: + def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]: + """ + Get names of first model and its provider + + :param tenant_id: workspace id + :param model_type: model type + :return: provider name, model name + """ + provider_configurations = self.get_configurations(tenant_id) + + # get available models from provider_configurations + all_models = provider_configurations.get_models(model_type=model_type, only_active=False) + + return all_models[0].provider.provider, all_models[0].model + + def update_default_model_record( + self, tenant_id: str, model_type: ModelType, provider: str, model: str + ) -> TenantDefaultModel: """ Update default model record. @@ -287,10 +304,7 @@ def update_default_model_record(self, tenant_id: str, model_type: ModelType, pro raise ValueError(f"Provider {provider} does not exist.") # get available models from provider_configurations - available_models = provider_configurations.get_models( - model_type=model_type, - only_active=True - ) + available_models = provider_configurations.get_models(model_type=model_type, only_active=True) # check if the model is exist in available models model_names = [model.model for model in available_models] @@ -298,11 +312,14 @@ def update_default_model_record(self, tenant_id: str, model_type: ModelType, pro raise ValueError(f"Model {model} does not exist.") # Get the list of available models from get_configurations and check if it is LLM - default_model = db.session.query(TenantDefaultModel) \ + default_model = ( + db.session.query(TenantDefaultModel) .filter( - TenantDefaultModel.tenant_id == tenant_id, - TenantDefaultModel.model_type == model_type.to_origin_model_type() - ).first() + TenantDefaultModel.tenant_id == tenant_id, + TenantDefaultModel.model_type == model_type.to_origin_model_type(), + ) + .first() + ) # create or update TenantDefaultModel record if default_model: @@ -323,18 +340,15 @@ def update_default_model_record(self, tenant_id: str, model_type: ModelType, pro return default_model - def _get_all_providers(self, tenant_id: str) -> dict[str, list[Provider]]: + @staticmethod + def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]: """ Get all provider records of the workspace. :param tenant_id: workspace id :return: """ - providers = db.session.query(Provider) \ - .filter( - Provider.tenant_id == tenant_id, - Provider.is_valid == True - ).all() + providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all() provider_name_to_provider_records_dict = defaultdict(list) for provider in providers: @@ -342,7 +356,8 @@ def _get_all_providers(self, tenant_id: str) -> dict[str, list[Provider]]: return provider_name_to_provider_records_dict - def _get_all_provider_models(self, tenant_id: str) -> dict[str, list[ProviderModel]]: + @staticmethod + def _get_all_provider_models(tenant_id: str) -> dict[str, list[ProviderModel]]: """ Get all provider model records of the workspace. @@ -350,11 +365,11 @@ def _get_all_provider_models(self, tenant_id: str) -> dict[str, list[ProviderMod :return: """ # Get all provider model records of the workspace - provider_models = db.session.query(ProviderModel) \ - .filter( - ProviderModel.tenant_id == tenant_id, - ProviderModel.is_valid == True - ).all() + provider_models = ( + db.session.query(ProviderModel) + .filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) + .all() + ) provider_name_to_provider_model_records_dict = defaultdict(list) for provider_model in provider_models: @@ -362,17 +377,19 @@ def _get_all_provider_models(self, tenant_id: str) -> dict[str, list[ProviderMod return provider_name_to_provider_model_records_dict - def _get_all_preferred_model_providers(self, tenant_id: str) -> dict[str, TenantPreferredModelProvider]: + @staticmethod + def _get_all_preferred_model_providers(tenant_id: str) -> dict[str, TenantPreferredModelProvider]: """ Get All preferred provider types of the workspace. :param tenant_id: workspace id :return: """ - preferred_provider_types = db.session.query(TenantPreferredModelProvider) \ - .filter( - TenantPreferredModelProvider.tenant_id == tenant_id - ).all() + preferred_provider_types = ( + db.session.query(TenantPreferredModelProvider) + .filter(TenantPreferredModelProvider.tenant_id == tenant_id) + .all() + ) provider_name_to_preferred_provider_type_records_dict = { preferred_provider_type.provider_name: preferred_provider_type @@ -381,26 +398,30 @@ def _get_all_preferred_model_providers(self, tenant_id: str) -> dict[str, Tenant return provider_name_to_preferred_provider_type_records_dict - def _get_all_provider_model_settings(self, tenant_id: str) -> dict[str, list[ProviderModelSetting]]: + @staticmethod + def _get_all_provider_model_settings(tenant_id: str) -> dict[str, list[ProviderModelSetting]]: """ Get All provider model settings of the workspace. :param tenant_id: workspace id :return: """ - provider_model_settings = db.session.query(ProviderModelSetting) \ - .filter( - ProviderModelSetting.tenant_id == tenant_id - ).all() + provider_model_settings = ( + db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all() + ) provider_name_to_provider_model_settings_dict = defaultdict(list) for provider_model_setting in provider_model_settings: - (provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name] - .append(provider_model_setting)) + ( + provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append( + provider_model_setting + ) + ) return provider_name_to_provider_model_settings_dict - def _get_all_provider_load_balancing_configs(self, tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]: + @staticmethod + def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]: """ Get All provider load balancing configs of the workspace. @@ -413,26 +434,30 @@ def _get_all_provider_load_balancing_configs(self, tenant_id: str) -> dict[str, model_load_balancing_enabled = FeatureService.get_features(tenant_id).model_load_balancing_enabled redis_client.setex(cache_key, 120, str(model_load_balancing_enabled)) else: - cache_result = cache_result.decode('utf-8') - model_load_balancing_enabled = cache_result == 'True' + cache_result = cache_result.decode("utf-8") + model_load_balancing_enabled = cache_result == "True" if not model_load_balancing_enabled: return {} - provider_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ - .filter( - LoadBalancingModelConfig.tenant_id == tenant_id - ).all() + provider_load_balancing_configs = ( + db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all() + ) provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list) for provider_load_balancing_config in provider_load_balancing_configs: - (provider_name_to_provider_load_balancing_model_configs_dict[provider_load_balancing_config.provider_name] - .append(provider_load_balancing_config)) + ( + provider_name_to_provider_load_balancing_model_configs_dict[ + provider_load_balancing_config.provider_name + ].append(provider_load_balancing_config) + ) return provider_name_to_provider_load_balancing_model_configs_dict - def _init_trial_provider_records(self, tenant_id: str, - provider_name_to_provider_records_dict: dict[str, list]) -> dict[str, list]: + @staticmethod + def _init_trial_provider_records( + tenant_id: str, provider_name_to_provider_records_dict: dict[str, list] + ) -> dict[str, list]: """ Initialize trial provider records if not exists. @@ -456,8 +481,9 @@ def _init_trial_provider_records(self, tenant_id: str, if provider_record.provider_type != ProviderType.SYSTEM.value: continue - provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] \ - = provider_record + provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( + provider_record + ) for quota in configuration.quotas: if quota.quota_type == ProviderQuotaType.TRIAL: @@ -471,19 +497,22 @@ def _init_trial_provider_records(self, tenant_id: str, quota_type=ProviderQuotaType.TRIAL.value, quota_limit=quota.quota_limit, quota_used=0, - is_valid=True + is_valid=True, ) db.session.add(provider_record) db.session.commit() except IntegrityError: db.session.rollback() - provider_record = db.session.query(Provider) \ + provider_record = ( + db.session.query(Provider) .filter( - Provider.tenant_id == tenant_id, - Provider.provider_name == provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == ProviderQuotaType.TRIAL.value - ).first() + Provider.tenant_id == tenant_id, + Provider.provider_name == provider_name, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == ProviderQuotaType.TRIAL.value, + ) + .first() + ) if provider_record and not provider_record.is_valid: provider_record.is_valid = True @@ -493,11 +522,13 @@ def _init_trial_provider_records(self, tenant_id: str, return provider_name_to_provider_records_dict - def _to_custom_configuration(self, - tenant_id: str, - provider_entity: ProviderEntity, - provider_records: list[Provider], - provider_model_records: list[ProviderModel]) -> CustomConfiguration: + def _to_custom_configuration( + self, + tenant_id: str, + provider_entity: ProviderEntity, + provider_records: list[Provider], + provider_model_records: list[ProviderModel], + ) -> CustomConfiguration: """ Convert to custom configuration. @@ -510,7 +541,8 @@ def _to_custom_configuration(self, # Get provider credential secret variables provider_credential_secret_variables = self._extract_secret_variables( provider_entity.provider_credential_schema.credential_form_schemas - if provider_entity.provider_credential_schema else [] + if provider_entity.provider_credential_schema + else [] ) # Get custom provider record @@ -530,7 +562,7 @@ def _to_custom_configuration(self, provider_credentials_cache = ProviderCredentialsCache( tenant_id=tenant_id, identity_id=custom_provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER + cache_type=ProviderCredentialsCacheType.PROVIDER, ) # Get cached provider credentials @@ -539,11 +571,11 @@ def _to_custom_configuration(self, if not cached_provider_credentials: try: # fix origin data - if (custom_provider_record.encrypted_config - and not custom_provider_record.encrypted_config.startswith("{")): - provider_credentials = { - "openai_api_key": custom_provider_record.encrypted_config - } + if ( + custom_provider_record.encrypted_config + and not custom_provider_record.encrypted_config.startswith("{") + ): + provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} else: provider_credentials = json.loads(custom_provider_record.encrypted_config) except JSONDecodeError: @@ -557,28 +589,23 @@ def _to_custom_configuration(self, if variable in provider_credentials: try: provider_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_credentials.get(variable), - self.decoding_rsa_key, - self.decoding_cipher_rsa + provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa ) except ValueError: pass # cache provider credentials - provider_credentials_cache.set( - credentials=provider_credentials - ) + provider_credentials_cache.set(credentials=provider_credentials) else: provider_credentials = cached_provider_credentials - custom_provider_configuration = CustomProviderConfiguration( - credentials=provider_credentials - ) + custom_provider_configuration = CustomProviderConfiguration(credentials=provider_credentials) # Get provider model credential secret variables model_credential_secret_variables = self._extract_secret_variables( provider_entity.model_credential_schema.credential_form_schemas - if provider_entity.model_credential_schema else [] + if provider_entity.model_credential_schema + else [] ) # Get custom provider model credentials @@ -588,9 +615,7 @@ def _to_custom_configuration(self, continue provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=tenant_id, - identity_id=provider_model_record.id, - cache_type=ProviderCredentialsCacheType.MODEL + tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL ) # Get cached provider model credentials @@ -612,15 +637,13 @@ def _to_custom_configuration(self, provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_model_credentials.get(variable), self.decoding_rsa_key, - self.decoding_cipher_rsa + self.decoding_cipher_rsa, ) except ValueError: pass # cache provider model credentials - provider_model_credentials_cache.set( - credentials=provider_model_credentials - ) + provider_model_credentials_cache.set(credentials=provider_model_credentials) else: provider_model_credentials = cached_provider_model_credentials @@ -628,19 +651,15 @@ def _to_custom_configuration(self, CustomModelConfiguration( model=provider_model_record.model_name, model_type=ModelType.value_of(provider_model_record.model_type), - credentials=provider_model_credentials + credentials=provider_model_credentials, ) ) - return CustomConfiguration( - provider=custom_provider_configuration, - models=custom_model_configurations - ) + return CustomConfiguration(provider=custom_provider_configuration, models=custom_model_configurations) - def _to_system_configuration(self, - tenant_id: str, - provider_entity: ProviderEntity, - provider_records: list[Provider]) -> SystemConfiguration: + def _to_system_configuration( + self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider] + ) -> SystemConfiguration: """ Convert to system configuration. @@ -652,11 +671,11 @@ def _to_system_configuration(self, # Get hosting configuration hosting_configuration = ext_hosting_provider.hosting_configuration - if provider_entity.provider not in hosting_configuration.provider_map \ - or not hosting_configuration.provider_map.get(provider_entity.provider).enabled: - return SystemConfiguration( - enabled=False - ) + if ( + provider_entity.provider not in hosting_configuration.provider_map + or not hosting_configuration.provider_map.get(provider_entity.provider).enabled + ): + return SystemConfiguration(enabled=False) provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider) @@ -666,8 +685,9 @@ def _to_system_configuration(self, if provider_record.provider_type != ProviderType.SYSTEM.value: continue - quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] \ - = provider_record + quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( + provider_record + ) quota_configurations = [] for provider_quota in provider_hosting_configuration.quotas: @@ -679,7 +699,7 @@ def _to_system_configuration(self, quota_used=0, quota_limit=0, is_valid=False, - restrict_models=provider_quota.restrict_models + restrict_models=provider_quota.restrict_models, ) else: continue @@ -691,16 +711,15 @@ def _to_system_configuration(self, quota_unit=provider_hosting_configuration.quota_unit, quota_used=provider_record.quota_used, quota_limit=provider_record.quota_limit, - is_valid=provider_record.quota_limit > provider_record.quota_used or provider_record.quota_limit == -1, - restrict_models=provider_quota.restrict_models + is_valid=provider_record.quota_limit > provider_record.quota_used + or provider_record.quota_limit == -1, + restrict_models=provider_quota.restrict_models, ) quota_configurations.append(quota_configuration) if len(quota_configurations) == 0: - return SystemConfiguration( - enabled=False - ) + return SystemConfiguration(enabled=False) current_quota_type = self._choice_current_using_quota_type(quota_configurations) @@ -712,7 +731,7 @@ def _to_system_configuration(self, provider_credentials_cache = ProviderCredentialsCache( tenant_id=tenant_id, identity_id=provider_record.id, - cache_type=ProviderCredentialsCacheType.PROVIDER + cache_type=ProviderCredentialsCacheType.PROVIDER, ) # Get cached provider credentials @@ -727,7 +746,8 @@ def _to_system_configuration(self, # Get provider credential secret variables provider_credential_secret_variables = self._extract_secret_variables( provider_entity.provider_credential_schema.credential_form_schemas - if provider_entity.provider_credential_schema else [] + if provider_entity.provider_credential_schema + else [] ) # Get decoding rsa key and cipher for decrypting credentials @@ -738,9 +758,7 @@ def _to_system_configuration(self, if variable in provider_credentials: try: provider_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_credentials.get(variable), - self.decoding_rsa_key, - self.decoding_cipher_rsa + provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa ) except ValueError: pass @@ -748,9 +766,7 @@ def _to_system_configuration(self, current_using_credentials = provider_credentials # cache provider credentials - provider_credentials_cache.set( - credentials=current_using_credentials - ) + provider_credentials_cache.set(credentials=current_using_credentials) else: current_using_credentials = cached_provider_credentials else: @@ -761,10 +777,11 @@ def _to_system_configuration(self, enabled=True, current_quota_type=current_quota_type, quota_configurations=quota_configurations, - credentials=current_using_credentials + credentials=current_using_credentials, ) - def _choice_current_using_quota_type(self, quota_configurations: list[QuotaConfiguration]) -> ProviderQuotaType: + @staticmethod + def _choice_current_using_quota_type(quota_configurations: list[QuotaConfiguration]) -> ProviderQuotaType: """ Choice current using quota type. paid quotas > provider free quotas > hosting trial quotas @@ -775,8 +792,7 @@ def _choice_current_using_quota_type(self, quota_configurations: list[QuotaConfi """ # convert to dict quota_type_to_quota_configuration_dict = { - quota_configuration.quota_type: quota_configuration - for quota_configuration in quota_configurations + quota_configuration.quota_type: quota_configuration for quota_configuration in quota_configurations } last_quota_configuration = None @@ -789,9 +805,10 @@ def _choice_current_using_quota_type(self, quota_configurations: list[QuotaConfi if last_quota_configuration: return last_quota_configuration.quota_type - raise ValueError('No quota type available') + raise ValueError("No quota type available") - def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: + @staticmethod + def _extract_secret_variables(credential_form_schemas: list[CredentialFormSchema]) -> list[str]: """ Extract secret input form variables. @@ -805,10 +822,12 @@ def _extract_secret_variables(self, credential_form_schemas: list[CredentialForm return secret_input_form_variables - def _to_model_settings(self, provider_entity: ProviderEntity, - provider_model_settings: Optional[list[ProviderModelSetting]] = None, - load_balancing_model_configs: Optional[list[LoadBalancingModelConfig]] = None) \ - -> list[ModelSettings]: + def _to_model_settings( + self, + provider_entity: ProviderEntity, + provider_model_settings: Optional[list[ProviderModelSetting]] = None, + load_balancing_model_configs: Optional[list[LoadBalancingModelConfig]] = None, + ) -> list[ModelSettings]: """ Convert to model settings. :param provider_entity: provider entity @@ -819,7 +838,8 @@ def _to_model_settings(self, provider_entity: ProviderEntity, # Get provider model credential secret variables model_credential_secret_variables = self._extract_secret_variables( provider_entity.model_credential_schema.credential_form_schemas - if provider_entity.model_credential_schema else [] + if provider_entity.model_credential_schema + else [] ) model_settings = [] @@ -830,24 +850,28 @@ def _to_model_settings(self, provider_entity: ProviderEntity, load_balancing_configs = [] if provider_model_setting.load_balancing_enabled and load_balancing_model_configs: for load_balancing_model_config in load_balancing_model_configs: - if (load_balancing_model_config.model_name == provider_model_setting.model_name - and load_balancing_model_config.model_type == provider_model_setting.model_type): + if ( + load_balancing_model_config.model_name == provider_model_setting.model_name + and load_balancing_model_config.model_type == provider_model_setting.model_type + ): if not load_balancing_model_config.enabled: continue if not load_balancing_model_config.encrypted_config: if load_balancing_model_config.name == "__inherit__": - load_balancing_configs.append(ModelLoadBalancingConfiguration( - id=load_balancing_model_config.id, - name=load_balancing_model_config.name, - credentials={} - )) + load_balancing_configs.append( + ModelLoadBalancingConfiguration( + id=load_balancing_model_config.id, + name=load_balancing_model_config.name, + credentials={}, + ) + ) continue provider_model_credentials_cache = ProviderCredentialsCache( tenant_id=load_balancing_model_config.tenant_id, identity_id=load_balancing_model_config.id, - cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL + cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, ) # Get cached provider model credentials @@ -862,7 +886,8 @@ def _to_model_settings(self, provider_entity: ProviderEntity, # Get decoding rsa key and cipher for decrypting credentials if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding( - load_balancing_model_config.tenant_id) + load_balancing_model_config.tenant_id + ) for variable in model_credential_secret_variables: if variable in provider_model_credentials: @@ -870,30 +895,30 @@ def _to_model_settings(self, provider_entity: ProviderEntity, provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_model_credentials.get(variable), self.decoding_rsa_key, - self.decoding_cipher_rsa + self.decoding_cipher_rsa, ) except ValueError: pass # cache provider model credentials - provider_model_credentials_cache.set( - credentials=provider_model_credentials - ) + provider_model_credentials_cache.set(credentials=provider_model_credentials) else: provider_model_credentials = cached_provider_model_credentials - load_balancing_configs.append(ModelLoadBalancingConfiguration( - id=load_balancing_model_config.id, - name=load_balancing_model_config.name, - credentials=provider_model_credentials - )) + load_balancing_configs.append( + ModelLoadBalancingConfiguration( + id=load_balancing_model_config.id, + name=load_balancing_model_config.name, + credentials=provider_model_credentials, + ) + ) model_settings.append( ModelSettings( model=provider_model_setting.model_name, model_type=ModelType.value_of(provider_model_setting.model_type), enabled=provider_model_setting.enabled, - load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [] + load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [], ) ) diff --git a/api/core/rag/cleaner/clean_processor.py b/api/core/rag/cleaner/clean_processor.py index eaad0e0f4c3a45..3c6ab2e4cfc56b 100644 --- a/api/core/rag/cleaner/clean_processor.py +++ b/api/core/rag/cleaner/clean_processor.py @@ -2,37 +2,35 @@ class CleanProcessor: - @classmethod def clean(cls, text: str, process_rule: dict) -> str: # default clean # remove invalid symbol - text = re.sub(r'<\|', '<', text) - text = re.sub(r'\|>', '>', text) - text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text) + text = re.sub(r"<\|", "<", text) + text = re.sub(r"\|>", ">", text) + text = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]", "", text) # Unicode U+FFFE - text = re.sub('\uFFFE', '', text) + text = re.sub("\ufffe", "", text) - rules = process_rule['rules'] if process_rule else None - if 'pre_processing_rules' in rules: + rules = process_rule["rules"] if process_rule else None + if "pre_processing_rules" in rules: pre_processing_rules = rules["pre_processing_rules"] for pre_processing_rule in pre_processing_rules: if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True: # Remove extra spaces - pattern = r'\n{3,}' - text = re.sub(pattern, '\n\n', text) - pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}' - text = re.sub(pattern, ' ', text) + pattern = r"\n{3,}" + text = re.sub(pattern, "\n\n", text) + pattern = r"[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}" + text = re.sub(pattern, " ", text) elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True: # Remove email - pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)' - text = re.sub(pattern, '', text) + pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)" + text = re.sub(pattern, "", text) # Remove URL - pattern = r'https?://[^\s]+' - text = re.sub(pattern, '', text) + pattern = r"https?://[^\s]+" + text = re.sub(pattern, "", text) return text def filter_string(self, text): - return text diff --git a/api/core/rag/cleaner/cleaner_base.py b/api/core/rag/cleaner/cleaner_base.py index 523bd904f272c7..d3bc2f765e9654 100644 --- a/api/core/rag/cleaner/cleaner_base.py +++ b/api/core/rag/cleaner/cleaner_base.py @@ -1,12 +1,11 @@ """Abstract interface for document cleaner implementations.""" + from abc import ABC, abstractmethod class BaseCleaner(ABC): - """Interface for clean chunk content. - """ + """Interface for clean chunk content.""" @abstractmethod def clean(self, content: str): raise NotImplementedError - diff --git a/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py index 6a0b8c904603da..167a919e69aa31 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py @@ -1,9 +1,9 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: """clean document content.""" from unstructured.cleaners.core import clean_extra_whitespace diff --git a/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py index 6fc3a408dacbc6..9c682d29db376d 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py @@ -1,9 +1,9 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner): - def clean(self, content) -> str: """clean document content.""" import re diff --git a/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py index ca1ae8dfd1166c..0cdbb171e1081e 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py @@ -1,12 +1,12 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: """clean document content.""" from unstructured.cleaners.core import clean_non_ascii_chars - # Returns "This text containsnon-ascii characters!" + # Returns "This text contains non-ascii characters!" return clean_non_ascii_chars(content) diff --git a/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py index 974a28fef16f3f..9f42044a2d5db8 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py @@ -1,11 +1,12 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: """Replaces unicode quote characters, such as the \x91 character in a string.""" from unstructured.cleaners.core import replace_unicode_quotes + return replace_unicode_quotes(content) diff --git a/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py index dfaf3a27874af2..32ae7217e878a5 100644 --- a/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py +++ b/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py @@ -1,9 +1,9 @@ """Abstract interface for document clean implementations.""" + from core.rag.cleaner.cleaner_base import BaseCleaner class UnstructuredTranslateTextCleaner(BaseCleaner): - def clean(self, content) -> str: """clean document content.""" from unstructured.cleaners.translate import translate_text diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index ad9ee4f7cfad60..992415657eced2 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -1,28 +1,38 @@ from typing import Optional -from core.model_manager import ModelManager +from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.models.document import Document -from core.rag.rerank.constants.rerank_mode import RerankMode from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights -from core.rag.rerank.rerank_model import RerankModelRunner -from core.rag.rerank.weight_rerank import WeightRerankRunner +from core.rag.rerank.rerank_base import BaseRerankRunner +from core.rag.rerank.rerank_factory import RerankRunnerFactory +from core.rag.rerank.rerank_type import RerankMode class DataPostProcessor: - """Interface for data post-processing document. - """ + """Interface for data post-processing document.""" - def __init__(self, tenant_id: str, reranking_mode: str, - reranking_model: Optional[dict] = None, weights: Optional[dict] = None, - reorder_enabled: bool = False): + def __init__( + self, + tenant_id: str, + reranking_mode: str, + reranking_model: Optional[dict] = None, + weights: Optional[dict] = None, + reorder_enabled: bool = False, + ): self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights) self.reorder_runner = self._get_reorder_runner(reorder_enabled) - def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]: + def invoke( + self, + query: str, + documents: list[Document], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> list[Document]: if self.rerank_runner: documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user) @@ -31,36 +41,37 @@ def invoke(self, query: str, documents: list[Document], score_threshold: Optiona return documents - def _get_rerank_runner(self, reranking_mode: str, tenant_id: str, reranking_model: Optional[dict] = None, - weights: Optional[dict] = None) -> Optional[RerankModelRunner | WeightRerankRunner]: + def _get_rerank_runner( + self, + reranking_mode: str, + tenant_id: str, + reranking_model: Optional[dict] = None, + weights: Optional[dict] = None, + ) -> Optional[BaseRerankRunner]: if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights: - return WeightRerankRunner( - tenant_id, - Weights( + runner = RerankRunnerFactory.create_rerank_runner( + runner_type=reranking_mode, + tenant_id=tenant_id, + weights=Weights( vector_setting=VectorSetting( - vector_weight=weights['vector_setting']['vector_weight'], - embedding_provider_name=weights['vector_setting']['embedding_provider_name'], - embedding_model_name=weights['vector_setting']['embedding_model_name'], + vector_weight=weights["vector_setting"]["vector_weight"], + embedding_provider_name=weights["vector_setting"]["embedding_provider_name"], + embedding_model_name=weights["vector_setting"]["embedding_model_name"], ), keyword_setting=KeywordSetting( - keyword_weight=weights['keyword_setting']['keyword_weight'], - ) - ) + keyword_weight=weights["keyword_setting"]["keyword_weight"], + ), + ), ) + return runner elif reranking_mode == RerankMode.RERANKING_MODEL.value: - if reranking_model: - try: - model_manager = ModelManager() - rerank_model_instance = model_manager.get_model_instance( - tenant_id=tenant_id, - provider=reranking_model['reranking_provider_name'], - model_type=ModelType.RERANK, - model=reranking_model['reranking_model_name'] - ) - except InvokeAuthorizationError: - return None - return RerankModelRunner(rerank_model_instance) - return None + rerank_model_instance = self._get_rerank_model_instance(tenant_id, reranking_model) + if rerank_model_instance is None: + return None + runner = RerankRunnerFactory.create_rerank_runner( + runner_type=reranking_mode, rerank_model_instance=rerank_model_instance + ) + return runner return None def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]: @@ -68,4 +79,17 @@ def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]: return ReorderRunner() return None - + def _get_rerank_model_instance(self, tenant_id: str, reranking_model: Optional[dict]) -> ModelInstance | None: + if reranking_model: + try: + model_manager = ModelManager() + rerank_model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + provider=reranking_model["reranking_provider_name"], + model_type=ModelType.RERANK, + model=reranking_model["reranking_model_name"], + ) + return rerank_model_instance + except InvokeAuthorizationError: + return None + return None diff --git a/api/core/rag/data_post_processor/reorder.py b/api/core/rag/data_post_processor/reorder.py index 71297588a4e7e5..a9a0885241e4fd 100644 --- a/api/core/rag/data_post_processor/reorder.py +++ b/api/core/rag/data_post_processor/reorder.py @@ -2,7 +2,6 @@ class ReorderRunner: - def run(self, documents: list[Document]) -> list[Document]: # Retrieve elements from odd indices (0, 2, 4, etc.) of the documents list odd_elements = documents[::2] diff --git a/api/core/rag/datasource/entity/embedding.py b/api/core/rag/datasource/entity/embedding.py deleted file mode 100644 index 126c1a372383bd..00000000000000 --- a/api/core/rag/datasource/entity/embedding.py +++ /dev/null @@ -1,21 +0,0 @@ -from abc import ABC, abstractmethod - - -class Embeddings(ABC): - """Interface for embedding models.""" - - @abstractmethod - def embed_documents(self, texts: list[str]) -> list[list[float]]: - """Embed search docs.""" - - @abstractmethod - def embed_query(self, text: str) -> list[float]: - """Embed query text.""" - - async def aembed_documents(self, texts: list[str]) -> list[list[float]]: - """Asynchronous Embed search docs.""" - raise NotImplementedError - - async def aembed_query(self, text: str) -> list[float]: - """Asynchronous Embed query text.""" - raise NotImplementedError diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index a3714c2fd3a38c..a0153c1e58a1a8 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -24,37 +24,42 @@ def __init__(self, dataset: Dataset): self._config = KeywordTableConfig() def create(self, texts: list[Document], **kwargs) -> BaseKeyword: - lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() for text in texts: - keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) - self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) + keywords = keyword_table_handler.extract_keywords( + text.page_content, self._config.max_keywords_per_chunk + ) + self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) + keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords)) self._save_dataset_keyword_table(keyword_table) return self def add_texts(self, texts: list[Document], **kwargs): - lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() - keywords_list = kwargs.get('keywords_list', None) + keywords_list = kwargs.get("keywords_list") for i in range(len(texts)): text = texts[i] if keywords_list: keywords = keywords_list[i] if not keywords: - keywords = keyword_table_handler.extract_keywords(text.page_content, - self._config.max_keywords_per_chunk) + keywords = keyword_table_handler.extract_keywords( + text.page_content, self._config.max_keywords_per_chunk + ) else: - keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk) - self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords)) - keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords)) + keywords = keyword_table_handler.extract_keywords( + text.page_content, self._config.max_keywords_per_chunk + ) + self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords)) + keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords)) self._save_dataset_keyword_table(keyword_table) @@ -63,97 +68,91 @@ def text_exists(self, id: str) -> bool: return id in set.union(*keyword_table.values()) def delete_by_ids(self, ids: list[str]) -> None: - lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): keyword_table = self._get_dataset_keyword_table() keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) self._save_dataset_keyword_table(keyword_table) - def search( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search(self, query: str, **kwargs: Any) -> list[Document]: keyword_table = self._get_dataset_keyword_table() - k = kwargs.get('top_k', 4) + k = kwargs.get("top_k", 4) sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k) documents = [] for chunk_index in sorted_chunk_indices: - segment = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == self.dataset.id, - DocumentSegment.index_node_id == chunk_index - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index) + .first() + ) if segment: - - documents.append(Document( - page_content=segment.content, - metadata={ - "doc_id": chunk_index, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - } - )) + documents.append( + Document( + page_content=segment.content, + metadata={ + "doc_id": chunk_index, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + ) return documents def delete(self) -> None: - lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) + lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) with redis_client.lock(lock_name, timeout=600): dataset_keyword_table = self.dataset.dataset_keyword_table if dataset_keyword_table: db.session.delete(dataset_keyword_table) db.session.commit() - if dataset_keyword_table.data_source_type != 'database': - file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt' + if dataset_keyword_table.data_source_type != "database": + file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt" storage.delete(file_key) def _save_dataset_keyword_table(self, keyword_table): keyword_table_dict = { - '__type__': 'keyword_table', - '__data__': { - "index_id": self.dataset.id, - "summary": None, - "table": keyword_table - } + "__type__": "keyword_table", + "__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table}, } dataset_keyword_table = self.dataset.dataset_keyword_table keyword_data_source_type = dataset_keyword_table.data_source_type - if keyword_data_source_type == 'database': + if keyword_data_source_type == "database": dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder) db.session.commit() else: - file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt' + file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt" if storage.exists(file_key): storage.delete(file_key) - storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode('utf-8')) + storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode("utf-8")) def _get_dataset_keyword_table(self) -> Optional[dict]: dataset_keyword_table = self.dataset.dataset_keyword_table if dataset_keyword_table: keyword_table_dict = dataset_keyword_table.keyword_table_dict if keyword_table_dict: - return keyword_table_dict['__data__']['table'] + return keyword_table_dict["__data__"]["table"] else: keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE dataset_keyword_table = DatasetKeywordTable( dataset_id=self.dataset.id, - keyword_table='', + keyword_table="", data_source_type=keyword_data_source_type, ) - if keyword_data_source_type == 'database': - dataset_keyword_table.keyword_table = json.dumps({ - '__type__': 'keyword_table', - '__data__': { - "index_id": self.dataset.id, - "summary": None, - "table": {} - } - }, cls=SetEncoder) + if keyword_data_source_type == "database": + dataset_keyword_table.keyword_table = json.dumps( + { + "__type__": "keyword_table", + "__data__": {"index_id": self.dataset.id, "summary": None, "table": {}}, + }, + cls=SetEncoder, + ) db.session.add(dataset_keyword_table) db.session.commit() @@ -174,9 +173,7 @@ def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]) -> keywords_to_delete = set() for keyword, node_idxs in keyword_table.items(): if node_idxs_to_delete.intersection(node_idxs): - keyword_table[keyword] = node_idxs.difference( - node_idxs_to_delete - ) + keyword_table[keyword] = node_idxs.difference(node_idxs_to_delete) if not keyword_table[keyword]: keywords_to_delete.add(keyword) @@ -202,13 +199,14 @@ def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4): reverse=True, ) - return sorted_chunk_indices[: k] + return sorted_chunk_indices[:k] def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]): - document_segment = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == dataset_id, - DocumentSegment.index_node_id == node_id - ).first() + document_segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id) + .first() + ) if document_segment: document_segment.keywords = keywords db.session.add(document_segment) @@ -224,14 +222,14 @@ def multi_create_segment_keywords(self, pre_segment_data_list: list): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() for pre_segment_data in pre_segment_data_list: - segment = pre_segment_data['segment'] - if pre_segment_data['keywords']: - segment.keywords = pre_segment_data['keywords'] - keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, - pre_segment_data['keywords']) + segment = pre_segment_data["segment"] + if pre_segment_data["keywords"]: + segment.keywords = pre_segment_data["keywords"] + keyword_table = self._add_text_to_keyword_table( + keyword_table, segment.index_node_id, pre_segment_data["keywords"] + ) else: - keywords = keyword_table_handler.extract_keywords(segment.content, - self._config.max_keywords_per_chunk) + keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk) segment.keywords = list(keywords) keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords)) self._save_dataset_keyword_table(keyword_table) diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index ad669ef5150bef..4b1ade8e3fa095 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -8,7 +8,6 @@ class JiebaKeywordTableHandler: - def __init__(self): default_tfidf.stop_words = STOPWORDS @@ -30,4 +29,4 @@ def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]: if len(sub_tokens) > 1: results.update({w for w in sub_tokens if w not in list(STOPWORDS)}) - return results \ No newline at end of file + return results diff --git a/api/core/rag/datasource/keyword/jieba/stopwords.py b/api/core/rag/datasource/keyword/jieba/stopwords.py index c616a15cf0c20f..9abe78d6ef7e8d 100644 --- a/api/core/rag/datasource/keyword/jieba/stopwords.py +++ b/api/core/rag/datasource/keyword/jieba/stopwords.py @@ -1,90 +1,1380 @@ STOPWORDS = { - "during", "when", "but", "then", "further", "isn", "mustn't", "until", "own", "i", "couldn", "y", "only", "you've", - "ours", "who", "where", "ourselves", "has", "to", "was", "didn't", "themselves", "if", "against", "through", "her", - "an", "your", "can", "those", "didn", "about", "aren't", "shan't", "be", "not", "these", "again", "so", "t", - "theirs", "weren", "won't", "won", "itself", "just", "same", "while", "why", "doesn", "aren", "him", "haven", - "for", "you'll", "that", "we", "am", "d", "by", "having", "wasn't", "than", "weren't", "out", "from", "now", - "their", "too", "hadn", "o", "needn", "most", "it", "under", "needn't", "any", "some", "few", "ll", "hers", "which", - "m", "you're", "off", "other", "had", "she", "you'd", "do", "you", "does", "s", "will", "each", "wouldn't", "hasn't", - "such", "more", "whom", "she's", "my", "yours", "yourself", "of", "on", "very", "hadn't", "with", "yourselves", - "been", "ma", "them", "mightn't", "shan", "mustn", "they", "what", "both", "that'll", "how", "is", "he", "because", - "down", "haven't", "are", "no", "it's", "our", "being", "the", "or", "above", "myself", "once", "don't", "doesn't", - "as", "nor", "here", "herself", "hasn", "mightn", "have", "its", "all", "were", "ain", "this", "at", "after", - "over", "shouldn't", "into", "before", "don", "wouldn", "re", "couldn't", "wasn", "in", "should", "there", - "himself", "isn't", "should've", "doing", "ve", "shouldn", "a", "did", "and", "his", "between", "me", "up", "below", - "人民", "末##末", "啊", "阿", "哎", "哎呀", "哎哟", "唉", "俺", "俺们", "按", "按照", "吧", "吧哒", "把", "罢了", "被", "本", - "本着", "比", "比方", "比如", "鄙人", "彼", "彼此", "边", "别", "别的", "别说", "并", "并且", "不比", "不成", "不单", "不但", - "不独", "不管", "不光", "不过", "不仅", "不拘", "不论", "不怕", "不然", "不如", "不特", "不惟", "不问", "不只", "朝", "朝着", - "趁", "趁着", "乘", "冲", "除", "除此之外", "除非", "除了", "此", "此间", "此外", "从", "从而", "打", "待", "但", "但是", "当", - "当着", "到", "得", "的", "的话", "等", "等等", "地", "第", "叮咚", "对", "对于", "多", "多少", "而", "而况", "而且", "而是", - "而外", "而言", "而已", "尔后", "反过来", "反过来说", "反之", "非但", "非徒", "否则", "嘎", "嘎登", "该", "赶", "个", "各", - "各个", "各位", "各种", "各自", "给", "根据", "跟", "故", "故此", "固然", "关于", "管", "归", "果然", "果真", "过", "哈", - "哈哈", "呵", "和", "何", "何处", "何况", "何时", "嘿", "哼", "哼唷", "呼哧", "乎", "哗", "还是", "还有", "换句话说", "换言之", - "或", "或是", "或者", "极了", "及", "及其", "及至", "即", "即便", "即或", "即令", "即若", "即使", "几", "几时", "己", "既", - "既然", "既是", "继而", "加之", "假如", "假若", "假使", "鉴于", "将", "较", "较之", "叫", "接着", "结果", "借", "紧接着", - "进而", "尽", "尽管", "经", "经过", "就", "就是", "就是说", "据", "具体地说", "具体说来", "开始", "开外", "靠", "咳", "可", - "可见", "可是", "可以", "况且", "啦", "来", "来着", "离", "例如", "哩", "连", "连同", "两者", "了", "临", "另", "另外", - "另一方面", "论", "嘛", "吗", "慢说", "漫说", "冒", "么", "每", "每当", "们", "莫若", "某", "某个", "某些", "拿", "哪", - "哪边", "哪儿", "哪个", "哪里", "哪年", "哪怕", "哪天", "哪些", "哪样", "那", "那边", "那儿", "那个", "那会儿", "那里", "那么", - "那么些", "那么样", "那时", "那些", "那样", "乃", "乃至", "呢", "能", "你", "你们", "您", "宁", "宁可", "宁肯", "宁愿", "哦", - "呕", "啪达", "旁人", "呸", "凭", "凭借", "其", "其次", "其二", "其他", "其它", "其一", "其余", "其中", "起", "起见", "岂但", - "恰恰相反", "前后", "前者", "且", "然而", "然后", "然则", "让", "人家", "任", "任何", "任凭", "如", "如此", "如果", "如何", - "如其", "如若", "如上所述", "若", "若非", "若是", "啥", "上下", "尚且", "设若", "设使", "甚而", "甚么", "甚至", "省得", "时候", - "什么", "什么样", "使得", "是", "是的", "首先", "谁", "谁知", "顺", "顺着", "似的", "虽", "虽然", "虽说", "虽则", "随", "随着", - "所", "所以", "他", "他们", "他人", "它", "它们", "她", "她们", "倘", "倘或", "倘然", "倘若", "倘使", "腾", "替", "通过", "同", - "同时", "哇", "万一", "往", "望", "为", "为何", "为了", "为什么", "为着", "喂", "嗡嗡", "我", "我们", "呜", "呜呼", "乌乎", - "无论", "无宁", "毋宁", "嘻", "吓", "相对而言", "像", "向", "向着", "嘘", "呀", "焉", "沿", "沿着", "要", "要不", "要不然", - "要不是", "要么", "要是", "也", "也罢", "也好", "一", "一般", "一旦", "一方面", "一来", "一切", "一样", "一则", "依", "依照", - "矣", "以", "以便", "以及", "以免", "以至", "以至于", "以致", "抑或", "因", "因此", "因而", "因为", "哟", "用", "由", - "由此可见", "由于", "有", "有的", "有关", "有些", "又", "于", "于是", "于是乎", "与", "与此同时", "与否", "与其", "越是", - "云云", "哉", "再说", "再者", "在", "在下", "咱", "咱们", "则", "怎", "怎么", "怎么办", "怎么样", "怎样", "咋", "照", "照着", - "者", "这", "这边", "这儿", "这个", "这会儿", "这就是说", "这里", "这么", "这么点儿", "这么些", "这么样", "这时", "这些", "这样", - "正如", "吱", "之", "之类", "之所以", "之一", "只是", "只限", "只要", "只有", "至", "至于", "诸位", "着", "着呢", "自", "自从", - "自个儿", "自各儿", "自己", "自家", "自身", "综上所述", "总的来看", "总的来说", "总的说来", "总而言之", "总之", "纵", "纵令", - "纵然", "纵使", "遵照", "作为", "兮", "呃", "呗", "咚", "咦", "喏", "啐", "喔唷", "嗬", "嗯", "嗳", "~", "!", ".", ":", - "\"", "'", "(", ")", "*", "A", "白", "社会主义", "--", "..", ">>", " [", " ]", "", "<", ">", "/", "\\", "|", "-", "_", - "+", "=", "&", "^", "%", "#", "@", "`", ";", "$", "(", ")", "——", "—", "¥", "·", "...", "‘", "’", "〉", "〈", "…", - " ", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "二", - "三", "四", "五", "六", "七", "八", "九", "零", ">", "<", "@", "#", "$", "%", "︿", "&", "*", "+", "~", "|", "[", - "]", "{", "}", "啊哈", "啊呀", "啊哟", "挨次", "挨个", "挨家挨户", "挨门挨户", "挨门逐户", "挨着", "按理", "按期", "按时", - "按说", "暗地里", "暗中", "暗自", "昂然", "八成", "白白", "半", "梆", "保管", "保险", "饱", "背地里", "背靠背", "倍感", "倍加", - "本人", "本身", "甭", "比起", "比如说", "比照", "毕竟", "必", "必定", "必将", "必须", "便", "别人", "并非", "并肩", "并没", - "并没有", "并排", "并无", "勃然", "不", "不必", "不常", "不大", "不但...而且", "不得", "不得不", "不得了", "不得已", "不迭", - "不定", "不对", "不妨", "不管怎样", "不会", "不仅...而且", "不仅仅", "不仅仅是", "不经意", "不可开交", "不可抗拒", "不力", "不了", - "不料", "不满", "不免", "不能不", "不起", "不巧", "不然的话", "不日", "不少", "不胜", "不时", "不是", "不同", "不能", "不要", - "不外", "不外乎", "不下", "不限", "不消", "不已", "不亦乐乎", "不由得", "不再", "不择手段", "不怎么", "不曾", "不知不觉", "不止", - "不止一次", "不至于", "才", "才能", "策略地", "差不多", "差一点", "常", "常常", "常言道", "常言说", "常言说得好", "长此下去", - "长话短说", "长期以来", "长线", "敞开儿", "彻夜", "陈年", "趁便", "趁机", "趁热", "趁势", "趁早", "成年", "成年累月", "成心", - "乘机", "乘胜", "乘势", "乘隙", "乘虚", "诚然", "迟早", "充分", "充其极", "充其量", "抽冷子", "臭", "初", "出", "出来", "出去", - "除此", "除此而外", "除此以外", "除开", "除去", "除却", "除外", "处处", "川流不息", "传", "传说", "传闻", "串行", "纯", "纯粹", - "此后", "此中", "次第", "匆匆", "从不", "从此", "从此以后", "从古到今", "从古至今", "从今以后", "从宽", "从来", "从轻", "从速", - "从头", "从未", "从无到有", "从小", "从新", "从严", "从优", "从早到晚", "从中", "从重", "凑巧", "粗", "存心", "达旦", "打从", - "打开天窗说亮话", "大", "大不了", "大大", "大抵", "大都", "大多", "大凡", "大概", "大家", "大举", "大略", "大面儿上", "大事", - "大体", "大体上", "大约", "大张旗鼓", "大致", "呆呆地", "带", "殆", "待到", "单", "单纯", "单单", "但愿", "弹指之间", "当场", - "当儿", "当即", "当口儿", "当然", "当庭", "当头", "当下", "当真", "当中", "倒不如", "倒不如说", "倒是", "到处", "到底", "到了儿", - "到目前为止", "到头", "到头来", "得起", "得天独厚", "的确", "等到", "叮当", "顶多", "定", "动不动", "动辄", "陡然", "都", "独", - "独自", "断然", "顿时", "多次", "多多", "多多少少", "多多益善", "多亏", "多年来", "多年前", "而后", "而论", "而又", "尔等", - "二话不说", "二话没说", "反倒", "反倒是", "反而", "反手", "反之亦然", "反之则", "方", "方才", "方能", "放量", "非常", "非得", - "分期", "分期分批", "分头", "奋勇", "愤然", "风雨无阻", "逢", "弗", "甫", "嘎嘎", "该当", "概", "赶快", "赶早不赶晚", "敢", - "敢情", "敢于", "刚", "刚才", "刚好", "刚巧", "高低", "格外", "隔日", "隔夜", "个人", "各式", "更", "更加", "更进一步", "更为", - "公然", "共", "共总", "够瞧的", "姑且", "古来", "故而", "故意", "固", "怪", "怪不得", "惯常", "光", "光是", "归根到底", - "归根结底", "过于", "毫不", "毫无", "毫无保留地", "毫无例外", "好在", "何必", "何尝", "何妨", "何苦", "何乐而不为", "何须", - "何止", "很", "很多", "很少", "轰然", "后来", "呼啦", "忽地", "忽然", "互", "互相", "哗啦", "话说", "还", "恍然", "会", "豁然", - "活", "伙同", "或多或少", "或许", "基本", "基本上", "基于", "极", "极大", "极度", "极端", "极力", "极其", "极为", "急匆匆", - "即将", "即刻", "即是说", "几度", "几番", "几乎", "几经", "既...又", "继之", "加上", "加以", "间或", "简而言之", "简言之", - "简直", "见", "将才", "将近", "将要", "交口", "较比", "较为", "接连不断", "接下来", "皆可", "截然", "截至", "藉以", "借此", - "借以", "届时", "仅", "仅仅", "谨", "进来", "进去", "近", "近几年来", "近来", "近年来", "尽管如此", "尽可能", "尽快", "尽量", - "尽然", "尽如人意", "尽心竭力", "尽心尽力", "尽早", "精光", "经常", "竟", "竟然", "究竟", "就此", "就地", "就算", "居然", "局外", - "举凡", "据称", "据此", "据实", "据说", "据我所知", "据悉", "具体来说", "决不", "决非", "绝", "绝不", "绝顶", "绝对", "绝非", - "均", "喀", "看", "看来", "看起来", "看上去", "看样子", "可好", "可能", "恐怕", "快", "快要", "来不及", "来得及", "来讲", - "来看", "拦腰", "牢牢", "老", "老大", "老老实实", "老是", "累次", "累年", "理当", "理该", "理应", "历", "立", "立地", "立刻", - "立马", "立时", "联袂", "连连", "连日", "连日来", "连声", "连袂", "临到", "另方面", "另行", "另一个", "路经", "屡", "屡次", - "屡次三番", "屡屡", "缕缕", "率尔", "率然", "略", "略加", "略微", "略为", "论说", "马上", "蛮", "满", "没", "没有", "每逢", - "每每", "每时每刻", "猛然", "猛然间", "莫", "莫不", "莫非", "莫如", "默默地", "默然", "呐", "那末", "奈", "难道", "难得", "难怪", - "难说", "内", "年复一年", "凝神", "偶而", "偶尔", "怕", "砰", "碰巧", "譬如", "偏偏", "乒", "平素", "颇", "迫于", "扑通", - "其后", "其实", "奇", "齐", "起初", "起来", "起首", "起头", "起先", "岂", "岂非", "岂止", "迄", "恰逢", "恰好", "恰恰", "恰巧", - "恰如", "恰似", "千", "千万", "千万千万", "切", "切不可", "切莫", "切切", "切勿", "窃", "亲口", "亲身", "亲手", "亲眼", "亲自", - "顷", "顷刻", "顷刻间", "顷刻之间", "请勿", "穷年累月", "取道", "去", "权时", "全都", "全力", "全年", "全然", "全身心", "然", - "人人", "仍", "仍旧", "仍然", "日复一日", "日见", "日渐", "日益", "日臻", "如常", "如此等等", "如次", "如今", "如期", "如前所述", - "如上", "如下", "汝", "三番两次", "三番五次", "三天两头", "瑟瑟", "沙沙", "上", "上来", "上去", "一个", "月", "日", "\n" + "during", + "when", + "but", + "then", + "further", + "isn", + "mustn't", + "until", + "own", + "i", + "couldn", + "y", + "only", + "you've", + "ours", + "who", + "where", + "ourselves", + "has", + "to", + "was", + "didn't", + "themselves", + "if", + "against", + "through", + "her", + "an", + "your", + "can", + "those", + "didn", + "about", + "aren't", + "shan't", + "be", + "not", + "these", + "again", + "so", + "t", + "theirs", + "weren", + "won't", + "won", + "itself", + "just", + "same", + "while", + "why", + "doesn", + "aren", + "him", + "haven", + "for", + "you'll", + "that", + "we", + "am", + "d", + "by", + "having", + "wasn't", + "than", + "weren't", + "out", + "from", + "now", + "their", + "too", + "hadn", + "o", + "needn", + "most", + "it", + "under", + "needn't", + "any", + "some", + "few", + "ll", + "hers", + "which", + "m", + "you're", + "off", + "other", + "had", + "she", + "you'd", + "do", + "you", + "does", + "s", + "will", + "each", + "wouldn't", + "hasn't", + "such", + "more", + "whom", + "she's", + "my", + "yours", + "yourself", + "of", + "on", + "very", + "hadn't", + "with", + "yourselves", + "been", + "ma", + "them", + "mightn't", + "shan", + "mustn", + "they", + "what", + "both", + "that'll", + "how", + "is", + "he", + "because", + "down", + "haven't", + "are", + "no", + "it's", + "our", + "being", + "the", + "or", + "above", + "myself", + "once", + "don't", + "doesn't", + "as", + "nor", + "here", + "herself", + "hasn", + "mightn", + "have", + "its", + "all", + "were", + "ain", + "this", + "at", + "after", + "over", + "shouldn't", + "into", + "before", + "don", + "wouldn", + "re", + "couldn't", + "wasn", + "in", + "should", + "there", + "himself", + "isn't", + "should've", + "doing", + "ve", + "shouldn", + "a", + "did", + "and", + "his", + "between", + "me", + "up", + "below", + "人民", + "末##末", + "啊", + "阿", + "哎", + "哎呀", + "哎哟", + "唉", + "俺", + "俺们", + "按", + "按照", + "吧", + "吧哒", + "把", + "罢了", + "被", + "本", + "本着", + "比", + "比方", + "比如", + "鄙人", + "彼", + "彼此", + "边", + "别", + "别的", + "别说", + "并", + "并且", + "不比", + "不成", + "不单", + "不但", + "不独", + "不管", + "不光", + "不过", + "不仅", + "不拘", + "不论", + "不怕", + "不然", + "不如", + "不特", + "不惟", + "不问", + "不只", + "朝", + "朝着", + "趁", + "趁着", + "乘", + "冲", + "除", + "除此之外", + "除非", + "除了", + "此", + "此间", + "此外", + "从", + "从而", + "打", + "待", + "但", + "但是", + "当", + "当着", + "到", + "得", + "的", + "的话", + "等", + "等等", + "地", + "第", + "叮咚", + "对", + "对于", + "多", + "多少", + "而", + "而况", + "而且", + "而是", + "而外", + "而言", + "而已", + "尔后", + "反过来", + "反过来说", + "反之", + "非但", + "非徒", + "否则", + "嘎", + "嘎登", + "该", + "赶", + "个", + "各", + "各个", + "各位", + "各种", + "各自", + "给", + "根据", + "跟", + "故", + "故此", + "固然", + "关于", + "管", + "归", + "果然", + "果真", + "过", + "哈", + "哈哈", + "呵", + "和", + "何", + "何处", + "何况", + "何时", + "嘿", + "哼", + "哼唷", + "呼哧", + "乎", + "哗", + "还是", + "还有", + "换句话说", + "换言之", + "或", + "或是", + "或者", + "极了", + "及", + "及其", + "及至", + "即", + "即便", + "即或", + "即令", + "即若", + "即使", + "几", + "几时", + "己", + "既", + "既然", + "既是", + "继而", + "加之", + "假如", + "假若", + "假使", + "鉴于", + "将", + "较", + "较之", + "叫", + "接着", + "结果", + "借", + "紧接着", + "进而", + "尽", + "尽管", + "经", + "经过", + "就", + "就是", + "就是说", + "据", + "具体地说", + "具体说来", + "开始", + "开外", + "靠", + "咳", + "可", + "可见", + "可是", + "可以", + "况且", + "啦", + "来", + "来着", + "离", + "例如", + "哩", + "连", + "连同", + "两者", + "了", + "临", + "另", + "另外", + "另一方面", + "论", + "嘛", + "吗", + "慢说", + "漫说", + "冒", + "么", + "每", + "每当", + "们", + "莫若", + "某", + "某个", + "某些", + "拿", + "哪", + "哪边", + "哪儿", + "哪个", + "哪里", + "哪年", + "哪怕", + "哪天", + "哪些", + "哪样", + "那", + "那边", + "那儿", + "那个", + "那会儿", + "那里", + "那么", + "那么些", + "那么样", + "那时", + "那些", + "那样", + "乃", + "乃至", + "呢", + "能", + "你", + "你们", + "您", + "宁", + "宁可", + "宁肯", + "宁愿", + "哦", + "呕", + "啪达", + "旁人", + "呸", + "凭", + "凭借", + "其", + "其次", + "其二", + "其他", + "其它", + "其一", + "其余", + "其中", + "起", + "起见", + "岂但", + "恰恰相反", + "前后", + "前者", + "且", + "然而", + "然后", + "然则", + "让", + "人家", + "任", + "任何", + "任凭", + "如", + "如此", + "如果", + "如何", + "如其", + "如若", + "如上所述", + "若", + "若非", + "若是", + "啥", + "上下", + "尚且", + "设若", + "设使", + "甚而", + "甚么", + "甚至", + "省得", + "时候", + "什么", + "什么样", + "使得", + "是", + "是的", + "首先", + "谁", + "谁知", + "顺", + "顺着", + "似的", + "虽", + "虽然", + "虽说", + "虽则", + "随", + "随着", + "所", + "所以", + "他", + "他们", + "他人", + "它", + "它们", + "她", + "她们", + "倘", + "倘或", + "倘然", + "倘若", + "倘使", + "腾", + "替", + "通过", + "同", + "同时", + "哇", + "万一", + "往", + "望", + "为", + "为何", + "为了", + "为什么", + "为着", + "喂", + "嗡嗡", + "我", + "我们", + "呜", + "呜呼", + "乌乎", + "无论", + "无宁", + "毋宁", + "嘻", + "吓", + "相对而言", + "像", + "向", + "向着", + "嘘", + "呀", + "焉", + "沿", + "沿着", + "要", + "要不", + "要不然", + "要不是", + "要么", + "要是", + "也", + "也罢", + "也好", + "一", + "一般", + "一旦", + "一方面", + "一来", + "一切", + "一样", + "一则", + "依", + "依照", + "矣", + "以", + "以便", + "以及", + "以免", + "以至", + "以至于", + "以致", + "抑或", + "因", + "因此", + "因而", + "因为", + "哟", + "用", + "由", + "由此可见", + "由于", + "有", + "有的", + "有关", + "有些", + "又", + "于", + "于是", + "于是乎", + "与", + "与此同时", + "与否", + "与其", + "越是", + "云云", + "哉", + "再说", + "再者", + "在", + "在下", + "咱", + "咱们", + "则", + "怎", + "怎么", + "怎么办", + "怎么样", + "怎样", + "咋", + "照", + "照着", + "者", + "这", + "这边", + "这儿", + "这个", + "这会儿", + "这就是说", + "这里", + "这么", + "这么点儿", + "这么些", + "这么样", + "这时", + "这些", + "这样", + "正如", + "吱", + "之", + "之类", + "之所以", + "之一", + "只是", + "只限", + "只要", + "只有", + "至", + "至于", + "诸位", + "着", + "着呢", + "自", + "自从", + "自个儿", + "自各儿", + "自己", + "自家", + "自身", + "综上所述", + "总的来看", + "总的来说", + "总的说来", + "总而言之", + "总之", + "纵", + "纵令", + "纵然", + "纵使", + "遵照", + "作为", + "兮", + "呃", + "呗", + "咚", + "咦", + "喏", + "啐", + "喔唷", + "嗬", + "嗯", + "嗳", + "~", + "!", + ".", + ":", + '"', + "'", + "(", + ")", + "*", + "A", + "白", + "社会主义", + "--", + "..", + ">>", + " [", + " ]", + "", + "<", + ">", + "/", + "\\", + "|", + "-", + "_", + "+", + "=", + "&", + "^", + "%", + "#", + "@", + "`", + ";", + "$", + "(", + ")", + "——", + "—", + "¥", + "·", + "...", + "‘", + "’", + "〉", + "〈", + "…", + " ", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "二", + "三", + "四", + "五", + "六", + "七", + "八", + "九", + "零", + ">", + "<", + "@", + "#", + "$", + "%", + "︿", + "&", + "*", + "+", + "~", + "|", + "[", + "]", + "{", + "}", + "啊哈", + "啊呀", + "啊哟", + "挨次", + "挨个", + "挨家挨户", + "挨门挨户", + "挨门逐户", + "挨着", + "按理", + "按期", + "按时", + "按说", + "暗地里", + "暗中", + "暗自", + "昂然", + "八成", + "白白", + "半", + "梆", + "保管", + "保险", + "饱", + "背地里", + "背靠背", + "倍感", + "倍加", + "本人", + "本身", + "甭", + "比起", + "比如说", + "比照", + "毕竟", + "必", + "必定", + "必将", + "必须", + "便", + "别人", + "并非", + "并肩", + "并没", + "并没有", + "并排", + "并无", + "勃然", + "不", + "不必", + "不常", + "不大", + "不但...而且", + "不得", + "不得不", + "不得了", + "不得已", + "不迭", + "不定", + "不对", + "不妨", + "不管怎样", + "不会", + "不仅...而且", + "不仅仅", + "不仅仅是", + "不经意", + "不可开交", + "不可抗拒", + "不力", + "不了", + "不料", + "不满", + "不免", + "不能不", + "不起", + "不巧", + "不然的话", + "不日", + "不少", + "不胜", + "不时", + "不是", + "不同", + "不能", + "不要", + "不外", + "不外乎", + "不下", + "不限", + "不消", + "不已", + "不亦乐乎", + "不由得", + "不再", + "不择手段", + "不怎么", + "不曾", + "不知不觉", + "不止", + "不止一次", + "不至于", + "才", + "才能", + "策略地", + "差不多", + "差一点", + "常", + "常常", + "常言道", + "常言说", + "常言说得好", + "长此下去", + "长话短说", + "长期以来", + "长线", + "敞开儿", + "彻夜", + "陈年", + "趁便", + "趁机", + "趁热", + "趁势", + "趁早", + "成年", + "成年累月", + "成心", + "乘机", + "乘胜", + "乘势", + "乘隙", + "乘虚", + "诚然", + "迟早", + "充分", + "充其极", + "充其量", + "抽冷子", + "臭", + "初", + "出", + "出来", + "出去", + "除此", + "除此而外", + "除此以外", + "除开", + "除去", + "除却", + "除外", + "处处", + "川流不息", + "传", + "传说", + "传闻", + "串行", + "纯", + "纯粹", + "此后", + "此中", + "次第", + "匆匆", + "从不", + "从此", + "从此以后", + "从古到今", + "从古至今", + "从今以后", + "从宽", + "从来", + "从轻", + "从速", + "从头", + "从未", + "从无到有", + "从小", + "从新", + "从严", + "从优", + "从早到晚", + "从中", + "从重", + "凑巧", + "粗", + "存心", + "达旦", + "打从", + "打开天窗说亮话", + "大", + "大不了", + "大大", + "大抵", + "大都", + "大多", + "大凡", + "大概", + "大家", + "大举", + "大略", + "大面儿上", + "大事", + "大体", + "大体上", + "大约", + "大张旗鼓", + "大致", + "呆呆地", + "带", + "殆", + "待到", + "单", + "单纯", + "单单", + "但愿", + "弹指之间", + "当场", + "当儿", + "当即", + "当口儿", + "当然", + "当庭", + "当头", + "当下", + "当真", + "当中", + "倒不如", + "倒不如说", + "倒是", + "到处", + "到底", + "到了儿", + "到目前为止", + "到头", + "到头来", + "得起", + "得天独厚", + "的确", + "等到", + "叮当", + "顶多", + "定", + "动不动", + "动辄", + "陡然", + "都", + "独", + "独自", + "断然", + "顿时", + "多次", + "多多", + "多多少少", + "多多益善", + "多亏", + "多年来", + "多年前", + "而后", + "而论", + "而又", + "尔等", + "二话不说", + "二话没说", + "反倒", + "反倒是", + "反而", + "反手", + "反之亦然", + "反之则", + "方", + "方才", + "方能", + "放量", + "非常", + "非得", + "分期", + "分期分批", + "分头", + "奋勇", + "愤然", + "风雨无阻", + "逢", + "弗", + "甫", + "嘎嘎", + "该当", + "概", + "赶快", + "赶早不赶晚", + "敢", + "敢情", + "敢于", + "刚", + "刚才", + "刚好", + "刚巧", + "高低", + "格外", + "隔日", + "隔夜", + "个人", + "各式", + "更", + "更加", + "更进一步", + "更为", + "公然", + "共", + "共总", + "够瞧的", + "姑且", + "古来", + "故而", + "故意", + "固", + "怪", + "怪不得", + "惯常", + "光", + "光是", + "归根到底", + "归根结底", + "过于", + "毫不", + "毫无", + "毫无保留地", + "毫无例外", + "好在", + "何必", + "何尝", + "何妨", + "何苦", + "何乐而不为", + "何须", + "何止", + "很", + "很多", + "很少", + "轰然", + "后来", + "呼啦", + "忽地", + "忽然", + "互", + "互相", + "哗啦", + "话说", + "还", + "恍然", + "会", + "豁然", + "活", + "伙同", + "或多或少", + "或许", + "基本", + "基本上", + "基于", + "极", + "极大", + "极度", + "极端", + "极力", + "极其", + "极为", + "急匆匆", + "即将", + "即刻", + "即是说", + "几度", + "几番", + "几乎", + "几经", + "既...又", + "继之", + "加上", + "加以", + "间或", + "简而言之", + "简言之", + "简直", + "见", + "将才", + "将近", + "将要", + "交口", + "较比", + "较为", + "接连不断", + "接下来", + "皆可", + "截然", + "截至", + "藉以", + "借此", + "借以", + "届时", + "仅", + "仅仅", + "谨", + "进来", + "进去", + "近", + "近几年来", + "近来", + "近年来", + "尽管如此", + "尽可能", + "尽快", + "尽量", + "尽然", + "尽如人意", + "尽心竭力", + "尽心尽力", + "尽早", + "精光", + "经常", + "竟", + "竟然", + "究竟", + "就此", + "就地", + "就算", + "居然", + "局外", + "举凡", + "据称", + "据此", + "据实", + "据说", + "据我所知", + "据悉", + "具体来说", + "决不", + "决非", + "绝", + "绝不", + "绝顶", + "绝对", + "绝非", + "均", + "喀", + "看", + "看来", + "看起来", + "看上去", + "看样子", + "可好", + "可能", + "恐怕", + "快", + "快要", + "来不及", + "来得及", + "来讲", + "来看", + "拦腰", + "牢牢", + "老", + "老大", + "老老实实", + "老是", + "累次", + "累年", + "理当", + "理该", + "理应", + "历", + "立", + "立地", + "立刻", + "立马", + "立时", + "联袂", + "连连", + "连日", + "连日来", + "连声", + "连袂", + "临到", + "另方面", + "另行", + "另一个", + "路经", + "屡", + "屡次", + "屡次三番", + "屡屡", + "缕缕", + "率尔", + "率然", + "略", + "略加", + "略微", + "略为", + "论说", + "马上", + "蛮", + "满", + "没", + "没有", + "每逢", + "每每", + "每时每刻", + "猛然", + "猛然间", + "莫", + "莫不", + "莫非", + "莫如", + "默默地", + "默然", + "呐", + "那末", + "奈", + "难道", + "难得", + "难怪", + "难说", + "内", + "年复一年", + "凝神", + "偶而", + "偶尔", + "怕", + "砰", + "碰巧", + "譬如", + "偏偏", + "乒", + "平素", + "颇", + "迫于", + "扑通", + "其后", + "其实", + "奇", + "齐", + "起初", + "起来", + "起首", + "起头", + "起先", + "岂", + "岂非", + "岂止", + "迄", + "恰逢", + "恰好", + "恰恰", + "恰巧", + "恰如", + "恰似", + "千", + "千万", + "千万千万", + "切", + "切不可", + "切莫", + "切切", + "切勿", + "窃", + "亲口", + "亲身", + "亲手", + "亲眼", + "亲自", + "顷", + "顷刻", + "顷刻间", + "顷刻之间", + "请勿", + "穷年累月", + "取道", + "去", + "权时", + "全都", + "全力", + "全年", + "全然", + "全身心", + "然", + "人人", + "仍", + "仍旧", + "仍然", + "日复一日", + "日见", + "日渐", + "日益", + "日臻", + "如常", + "如此等等", + "如次", + "如今", + "如期", + "如前所述", + "如上", + "如下", + "汝", + "三番两次", + "三番五次", + "三天两头", + "瑟瑟", + "沙沙", + "上", + "上来", + "上去", + "一个", + "月", + "日", + "\n", } diff --git a/api/core/rag/datasource/keyword/keyword_base.py b/api/core/rag/datasource/keyword/keyword_base.py index b77c6562b25ce3..be00687abd5025 100644 --- a/api/core/rag/datasource/keyword/keyword_base.py +++ b/api/core/rag/datasource/keyword/keyword_base.py @@ -8,7 +8,6 @@ class BaseKeyword(ABC): - def __init__(self, dataset: Dataset): self.dataset = dataset @@ -28,18 +27,17 @@ def text_exists(self, id: str) -> bool: def delete_by_ids(self, ids: list[str]) -> None: raise NotImplementedError + @abstractmethod def delete(self) -> None: raise NotImplementedError - def search( - self, query: str, - **kwargs: Any - ) -> list[Document]: + @abstractmethod + def search(self, query: str, **kwargs: Any) -> list[Document]: raise NotImplementedError def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: - for text in texts[:]: - doc_id = text.metadata['doc_id'] + for text in texts.copy(): + doc_id = text.metadata["doc_id"] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: texts.remove(text) @@ -47,4 +45,4 @@ def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: return texts def _get_uuids(self, texts: list[Document]) -> list[str]: - return [text.metadata['doc_id'] for text in texts] + return [text.metadata["doc_id"] for text in texts] diff --git a/api/core/rag/datasource/keyword/keyword_factory.py b/api/core/rag/datasource/keyword/keyword_factory.py index 6ac610f82b45ba..f1a6ade91f9bd1 100644 --- a/api/core/rag/datasource/keyword/keyword_factory.py +++ b/api/core/rag/datasource/keyword/keyword_factory.py @@ -1,8 +1,8 @@ from typing import Any from configs import dify_config -from core.rag.datasource.keyword.jieba.jieba import Jieba from core.rag.datasource.keyword.keyword_base import BaseKeyword +from core.rag.datasource.keyword.keyword_type import KeyWordType from core.rag.models.document import Document from models.dataset import Dataset @@ -13,18 +13,19 @@ def __init__(self, dataset: Dataset): self._keyword_processor = self._init_keyword() def _init_keyword(self) -> BaseKeyword: - config = dify_config - keyword_type = config.KEYWORD_STORE + keyword_type = dify_config.KEYWORD_STORE + keyword_factory = self.get_keyword_factory(keyword_type) + return keyword_factory(self._dataset) - if not keyword_type: - raise ValueError("Keyword store must be specified.") + @staticmethod + def get_keyword_factory(keyword_type: str) -> type[BaseKeyword]: + match keyword_type: + case KeyWordType.JIEBA: + from core.rag.datasource.keyword.jieba.jieba import Jieba - if keyword_type == "jieba": - return Jieba( - dataset=self._dataset - ) - else: - raise ValueError(f"Keyword store {keyword_type} is not supported.") + return Jieba + case _: + raise ValueError(f"Keyword store {keyword_type} is not supported.") def create(self, texts: list[Document], **kwargs): self._keyword_processor.create(texts, **kwargs) @@ -41,10 +42,7 @@ def delete_by_ids(self, ids: list[str]) -> None: def delete(self) -> None: self._keyword_processor.delete() - def search( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search(self, query: str, **kwargs: Any) -> list[Document]: return self._keyword_processor.search(query, **kwargs) def __getattr__(self, name): diff --git a/api/core/rag/datasource/keyword/keyword_type.py b/api/core/rag/datasource/keyword/keyword_type.py new file mode 100644 index 00000000000000..d6deba3fb09fdf --- /dev/null +++ b/api/core/rag/datasource/keyword/keyword_type.py @@ -0,0 +1,5 @@ +from enum import Enum + + +class KeyWordType(str, Enum): + JIEBA = "jieba" diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 3932e90042c59c..57af05861c1ad0 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -6,79 +6,95 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector -from core.rag.rerank.constants.rerank_mode import RerankMode -from core.rag.retrieval.retrival_methods import RetrievalMethod +from core.rag.rerank.rerank_type import RerankMode +from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from models.dataset import Dataset +from services.external_knowledge_service import ExternalDatasetService default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } class RetrievalService: - @classmethod - def retrieve(cls, retrival_method: str, dataset_id: str, query: str, - top_k: int, score_threshold: Optional[float] = .0, - reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = 'reranking_model', - weights: Optional[dict] = None): - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + def retrieve( + cls, + retrieval_method: str, + dataset_id: str, + query: str, + top_k: int, + score_threshold: Optional[float] = 0.0, + reranking_model: Optional[dict] = None, + reranking_mode: Optional[str] = "reranking_model", + weights: Optional[dict] = None, + ): + if not query: + return [] + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + return [] + if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: return [] all_documents = [] threads = [] exceptions = [] # retrieval_model source with keyword - if retrival_method == 'keyword_search': - keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'top_k': top_k, - 'all_documents': all_documents, - 'exceptions': exceptions, - }) + if retrieval_method == "keyword_search": + keyword_thread = threading.Thread( + target=RetrievalService.keyword_search, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset_id, + "query": query, + "top_k": top_k, + "all_documents": all_documents, + "exceptions": exceptions, + }, + ) threads.append(keyword_thread) keyword_thread.start() # retrieval_model source with semantic - if RetrievalMethod.is_support_semantic_search(retrival_method): - embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'top_k': top_k, - 'score_threshold': score_threshold, - 'reranking_model': reranking_model, - 'all_documents': all_documents, - 'retrival_method': retrival_method, - 'exceptions': exceptions, - }) + if RetrievalMethod.is_support_semantic_search(retrieval_method): + embedding_thread = threading.Thread( + target=RetrievalService.embedding_search, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset_id, + "query": query, + "top_k": top_k, + "score_threshold": score_threshold, + "reranking_model": reranking_model, + "all_documents": all_documents, + "retrieval_method": retrieval_method, + "exceptions": exceptions, + }, + ) threads.append(embedding_thread) embedding_thread.start() # retrieval source with full text - if RetrievalMethod.is_support_fulltext_search(retrival_method): - full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'retrival_method': retrival_method, - 'score_threshold': score_threshold, - 'top_k': top_k, - 'reranking_model': reranking_model, - 'all_documents': all_documents, - 'exceptions': exceptions, - }) + if RetrievalMethod.is_support_fulltext_search(retrieval_method): + full_text_index_thread = threading.Thread( + target=RetrievalService.full_text_index_search, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset_id, + "query": query, + "retrieval_method": retrieval_method, + "score_threshold": score_threshold, + "top_k": top_k, + "reranking_model": reranking_model, + "all_documents": all_documents, + "exceptions": exceptions, + }, + ) threads.append(full_text_index_thread) full_text_index_thread.start() @@ -86,110 +102,127 @@ def retrieve(cls, retrival_method: str, dataset_id: str, query: str, thread.join() if exceptions: - exception_message = ';\n'.join(exceptions) + exception_message = ";\n".join(exceptions) raise Exception(exception_message) - if retrival_method == RetrievalMethod.HYBRID_SEARCH.value: - data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode, - reranking_model, weights, False) + if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: + data_post_processor = DataPostProcessor( + str(dataset.tenant_id), reranking_mode, reranking_model, weights, False + ) all_documents = data_post_processor.invoke( - query=query, - documents=all_documents, - score_threshold=score_threshold, - top_n=top_k + query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k ) return all_documents @classmethod - def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str, - top_k: int, all_documents: list, exceptions: list): + def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None): + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + return [] + all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( + dataset.tenant_id, dataset_id, query, external_retrieval_model + ) + return all_documents + + @classmethod + def keyword_search( + cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list + ): with flask_app.app_context(): try: - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() - keyword = Keyword( - dataset=dataset - ) + keyword = Keyword(dataset=dataset) - documents = keyword.search( - cls.escape_query_for_search(query), - top_k=top_k - ) + documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k) all_documents.extend(documents) except Exception as e: exceptions.append(str(e)) @classmethod - def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str, - top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], - all_documents: list, retrival_method: str, exceptions: list): + def embedding_search( + cls, + flask_app: Flask, + dataset_id: str, + query: str, + top_k: int, + score_threshold: Optional[float], + reranking_model: Optional[dict], + all_documents: list, + retrieval_method: str, + exceptions: list, + ): with flask_app.app_context(): try: - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() - vector = Vector( - dataset=dataset - ) + vector = Vector(dataset=dataset) documents = vector.search_by_vector( cls.escape_query_for_search(query), - search_type='similarity_score_threshold', + search_type="similarity_score_threshold", top_k=top_k, score_threshold=score_threshold, - filter={ - 'group_id': [dataset.id] - } + filter={"group_id": [dataset.id]}, ) if documents: - if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH.value: - data_post_processor = DataPostProcessor(str(dataset.tenant_id), - RerankMode.RERANKING_MODEL.value, - reranking_model, None, False) - all_documents.extend(data_post_processor.invoke( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=len(documents) - )) + if ( + reranking_model + and reranking_model.get("reranking_model_name") + and reranking_model.get("reranking_provider_name") + and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value + ): + data_post_processor = DataPostProcessor( + str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False + ) + all_documents.extend( + data_post_processor.invoke( + query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents) + ) + ) else: all_documents.extend(documents) except Exception as e: exceptions.append(str(e)) @classmethod - def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, - top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], - all_documents: list, retrival_method: str, exceptions: list): + def full_text_index_search( + cls, + flask_app: Flask, + dataset_id: str, + query: str, + top_k: int, + score_threshold: Optional[float], + reranking_model: Optional[dict], + all_documents: list, + retrieval_method: str, + exceptions: list, + ): with flask_app.app_context(): try: - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() vector_processor = Vector( dataset=dataset, ) - documents = vector_processor.search_by_full_text( - cls.escape_query_for_search(query), - top_k=top_k - ) + documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k) if documents: - if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH.value: - data_post_processor = DataPostProcessor(str(dataset.tenant_id), - RerankMode.RERANKING_MODEL.value, - reranking_model, None, False) - all_documents.extend(data_post_processor.invoke( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=len(documents) - )) + if ( + reranking_model + and reranking_model.get("reranking_model_name") + and reranking_model.get("reranking_provider_name") + and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value + ): + data_post_processor = DataPostProcessor( + str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False + ) + all_documents.extend( + data_post_processor.invoke( + query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents) + ) + ) else: all_documents.extend(documents) except Exception as e: @@ -197,4 +230,4 @@ def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, @staticmethod def escape_query_for_search(query: str) -> str: - return query.replace('"', '\\"') \ No newline at end of file + return query.replace('"', '\\"') diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index b78e2a59b1eb6f..c77cb873760626 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -9,10 +9,10 @@ ) from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -29,6 +29,7 @@ class AnalyticdbConfig(BaseModel): namespace_password: str = (None,) metrics: str = ("cosine",) read_timeout: int = 60000 + def to_analyticdb_client_params(self): return { "access_key_id": self.access_key_id, @@ -37,32 +38,19 @@ def to_analyticdb_client_params(self): "read_timeout": self.read_timeout, } -class AnalyticdbVector(BaseVector): - _instance = None - _init = False - - def __new__(cls, *args, **kwargs): - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance +class AnalyticdbVector(BaseVector): def __init__(self, collection_name: str, config: AnalyticdbConfig): - # collection_name must be updated every time self._collection_name = collection_name.lower() - if AnalyticdbVector._init: - return try: from alibabacloud_gpdb20160503.client import Client from alibabacloud_tea_openapi import models as open_api_models except: raise ImportError(_import_err_msg) self.config = config - self._client_config = open_api_models.Config( - user_agent="dify", **config.to_analyticdb_client_params() - ) + self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params()) self._client = Client(self._client_config) self._initialize() - AnalyticdbVector._init = True def _initialize(self) -> None: cache_key = f"vector_indexing_{self.config.instance_id}" @@ -77,6 +65,7 @@ def _initialize(self) -> None: def _initialize_vector_database(self) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + request = gpdb_20160503_models.InitVectorDatabaseRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -88,6 +77,7 @@ def _initialize_vector_database(self) -> None: def _create_namespace_if_not_exists(self) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from Tea.exceptions import TeaException + try: request = gpdb_20160503_models.DescribeNamespaceRequest( dbinstance_id=self.config.instance_id, @@ -109,13 +99,12 @@ def _create_namespace_if_not_exists(self) -> None: ) self._client.create_namespace(request) else: - raise ValueError( - f"failed to create namespace {self.config.namespace}: {e}" - ) + raise ValueError(f"failed to create namespace {self.config.namespace}: {e}") def _create_collection_if_not_exists(self, embedding_dimension: int): from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from Tea.exceptions import TeaException + cache_key = f"vector_indexing_{self._collection_name}" lock_name = f"{cache_key}_lock" with redis_client.lock(lock_name, timeout=20): @@ -149,9 +138,7 @@ def _create_collection_if_not_exists(self, embedding_dimension: int): ) self._client.create_collection(request) else: - raise ValueError( - f"failed to create collection {self._collection_name}: {e}" - ) + raise ValueError(f"failed to create collection {self._collection_name}: {e}") redis_client.set(collection_exist_cache_key, 1, ex=3600) def get_type(self) -> str: @@ -162,10 +149,9 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) self._create_collection_if_not_exists(dimension) self.add_texts(texts, embeddings) - def add_texts( - self, documents: list[Document], embeddings: list[list[float]], **kwargs - ): + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = [] for doc, embedding in zip(documents, embeddings, strict=True): metadata = { @@ -191,6 +177,7 @@ def add_texts( def text_exists(self, id: str) -> bool: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -202,13 +189,14 @@ def text_exists(self, id: str) -> bool: vector=None, content=None, top_k=1, - filter=f"ref_doc_id='{id}'" + filter=f"ref_doc_id='{id}'", ) response = self._client.query_collection_data(request) return len(response.body.matches.match) > 0 def delete_by_ids(self, ids: list[str]) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + ids_str = ",".join(f"'{id}'" for id in ids) ids_str = f"({ids_str})" request = gpdb_20160503_models.DeleteCollectionDataRequest( @@ -224,6 +212,7 @@ def delete_by_ids(self, ids: list[str]) -> None: def delete_by_metadata_field(self, key: str, value: str) -> None: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + request = gpdb_20160503_models.DeleteCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -235,15 +224,10 @@ def delete_by_metadata_field(self, key: str, value: str) -> None: ) self._client.delete_collection_data(request) - def search_by_vector( - self, query_vector: list[float], **kwargs: Any - ) -> list[Document]: + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - score_threshold = ( - kwargs.get("score_threshold", 0.0) - if kwargs.get("score_threshold", 0.0) - else 0.0 - ) + + score_threshold = kwargs.get("score_threshold") or 0.0 request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -261,20 +245,20 @@ def search_by_vector( documents = [] for match in response.body.matches.match: if match.score > score_threshold: + metadata = json.loads(match.metadata.get("metadata_")) + metadata["score"] = match.score doc = Document( page_content=match.metadata.get("page_content"), - metadata=json.loads(match.metadata.get("metadata_")), + metadata=metadata, ) documents.append(doc) + documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) return documents def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models - score_threshold = ( - kwargs.get("score_threshold", 0.0) - if kwargs.get("score_threshold", 0.0) - else 0.0 - ) + + score_threshold = float(kwargs.get("score_threshold") or 0.0) request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, region_id=self.config.region_id, @@ -293,17 +277,20 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: for match in response.body.matches.match: if match.score > score_threshold: metadata = json.loads(match.metadata.get("metadata_")) + metadata["score"] = match.score doc = Document( page_content=match.metadata.get("page_content"), vector=match.metadata.get("vector"), metadata=metadata, ) documents.append(doc) + documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) return documents def delete(self) -> None: try: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + request = gpdb_20160503_models.DeleteCollectionRequest( collection=self._collection_name, dbinstance_id=self.config.instance_id, @@ -315,19 +302,16 @@ def delete(self) -> None: except Exception as e: raise e + class AnalyticdbVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings): if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict["vector_store"][ - "class_prefix" - ] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name) - ) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)) # handle optional params if dify_config.ANALYTICDB_KEY_ID is None: diff --git a/api/core/rag/datasource/vdb/baidu/__init__.py b/api/core/rag/datasource/vdb/baidu/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py new file mode 100644 index 00000000000000..eb78e8aa698b9b --- /dev/null +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -0,0 +1,287 @@ +import json +import time +import uuid +from typing import Any + +import numpy as np +from pydantic import BaseModel, model_validator +from pymochow import MochowClient +from pymochow.auth.bce_credentials import BceCredentials +from pymochow.configuration import Configuration +from pymochow.exception import ServerError +from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState +from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex +from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row + +from configs import dify_config +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document +from extensions.ext_redis import redis_client +from models.dataset import Dataset + + +class BaiduConfig(BaseModel): + endpoint: str + connection_timeout_in_mills: int = 30 * 1000 + account: str + api_key: str + database: str + index_type: str = "HNSW" + metric_type: str = "L2" + shard: int = 1 + replicas: int = 3 + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + if not values["endpoint"]: + raise ValueError("config BAIDU_VECTOR_DB_ENDPOINT is required") + if not values["account"]: + raise ValueError("config BAIDU_VECTOR_DB_ACCOUNT is required") + if not values["api_key"]: + raise ValueError("config BAIDU_VECTOR_DB_API_KEY is required") + if not values["database"]: + raise ValueError("config BAIDU_VECTOR_DB_DATABASE is required") + return values + + +class BaiduVector(BaseVector): + field_id: str = "id" + field_vector: str = "vector" + field_text: str = "text" + field_metadata: str = "metadata" + field_app_id: str = "app_id" + field_annotation_id: str = "annotation_id" + index_vector: str = "vector_idx" + + def __init__(self, collection_name: str, config: BaiduConfig): + super().__init__(collection_name) + self._client_config = config + self._client = self._init_client(config) + self._db = self._init_database() + + def get_type(self) -> str: + return VectorType.BAIDU + + def to_index_struct(self) -> dict: + return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + self._create_table(len(embeddings[0])) + self.add_texts(texts, embeddings) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + texts = [doc.page_content for doc in documents] + metadatas = [doc.metadata for doc in documents] + total_count = len(documents) + batch_size = 1000 + + # upsert texts and embeddings batch by batch + table = self._db.table(self._collection_name) + for start in range(0, total_count, batch_size): + end = min(start + batch_size, total_count) + rows = [] + for i in range(start, end, 1): + row = Row( + id=metadatas[i].get("doc_id", str(uuid.uuid4())), + vector=embeddings[i], + text=texts[i], + metadata=json.dumps(metadatas[i]), + app_id=metadatas[i].get("app_id", ""), + annotation_id=metadatas[i].get("annotation_id", ""), + ) + rows.append(row) + table.upsert(rows=rows) + + # rebuild vector index after upsert finished + table.rebuild_index(self.index_vector) + while True: + time.sleep(1) + index = table.describe_index(self.index_vector) + if index.state == IndexState.NORMAL: + break + + def text_exists(self, id: str) -> bool: + res = self._db.table(self._collection_name).query(primary_key={self.field_id: id}) + if res and res.code == 0: + return True + return False + + def delete_by_ids(self, ids: list[str]) -> None: + quoted_ids = [f"'{id}'" for id in ids] + self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})") + + def delete_by_metadata_field(self, key: str, value: str) -> None: + self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'") + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector] + anns = AnnSearch( + vector_field=self.field_vector, + vector_floats=query_vector, + params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)), + ) + res = self._db.table(self._collection_name).search( + anns=anns, + projections=[self.field_id, self.field_text, self.field_metadata], + retrieve_vector=True, + ) + score_threshold = float(kwargs.get("score_threshold") or 0.0) + return self._get_search_res(res, score_threshold) + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + # baidu vector database doesn't support bm25 search on current version + return [] + + def _get_search_res(self, res, score_threshold): + docs = [] + for row in res.rows: + row_data = row.get("row", {}) + meta = row_data.get(self.field_metadata) + if meta is not None: + meta = json.loads(meta) + score = row.get("score", 0.0) + if score > score_threshold: + meta["score"] = score + doc = Document(page_content=row_data.get(self.field_text), metadata=meta) + docs.append(doc) + + return docs + + def delete(self) -> None: + try: + self._db.drop_table(table_name=self._collection_name) + except ServerError as e: + if e.code == ServerErrCode.TABLE_NOT_EXIST: + pass + else: + raise + + def _init_client(self, config) -> MochowClient: + config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint) + client = MochowClient(config) + return client + + def _init_database(self): + exists = False + for db in self._client.list_databases(): + if db.database_name == self._client_config.database: + exists = True + break + # Create database if not existed + if exists: + return self._client.database(self._client_config.database) + else: + try: + self._client.create_database(database_name=self._client_config.database) + except ServerError as e: + if e.code == ServerErrCode.DB_ALREADY_EXIST: + pass + else: + raise + return + + def _table_existed(self) -> bool: + tables = self._db.list_table() + return any(table.table_name == self._collection_name for table in tables) + + def _create_table(self, dimension: int) -> None: + # Try to grab distributed lock and create table + lock_name = "vector_indexing_lock_{}".format(self._collection_name) + with redis_client.lock(lock_name, timeout=60): + table_exist_cache_key = "vector_indexing_{}".format(self._collection_name) + if redis_client.get(table_exist_cache_key): + return + + if self._table_existed(): + return + + self.delete() + + # check IndexType and MetricType + index_type = None + for k, v in IndexType.__members__.items(): + if k == self._client_config.index_type: + index_type = v + if index_type is None: + raise ValueError("unsupported index_type") + metric_type = None + for k, v in MetricType.__members__.items(): + if k == self._client_config.metric_type: + metric_type = v + if metric_type is None: + raise ValueError("unsupported metric_type") + + # Construct field schema + fields = [] + fields.append( + Field( + self.field_id, + FieldType.STRING, + primary_key=True, + partition_key=True, + auto_increment=False, + not_null=True, + ) + ) + fields.append(Field(self.field_metadata, FieldType.STRING, not_null=True)) + fields.append(Field(self.field_app_id, FieldType.STRING)) + fields.append(Field(self.field_annotation_id, FieldType.STRING)) + fields.append(Field(self.field_text, FieldType.TEXT, not_null=True)) + fields.append(Field(self.field_vector, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension)) + + # Construct vector index params + indexes = [] + indexes.append( + VectorIndex( + index_name="vector_idx", + index_type=index_type, + field="vector", + metric_type=metric_type, + params=HNSWParams(m=16, efconstruction=200), + ) + ) + + # Create table + self._db.create_table( + table_name=self._collection_name, + replication=self._client_config.replicas, + partition=Partition(partition_num=self._client_config.shard), + schema=Schema(fields=fields, indexes=indexes), + description="Table for Dify", + ) + + # Wait for table created + while True: + time.sleep(1) + table = self._db.describe_table(self._collection_name) + if table.state == TableState.NORMAL: + break + redis_client.set(table_exist_cache_key, 1, ex=3600) + + +class BaiduVectorFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaiduVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix.lower() + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.BAIDU, collection_name)) + + return BaiduVector( + collection_name=collection_name, + config=BaiduConfig( + endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT, + connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS, + account=dify_config.BAIDU_VECTOR_DB_ACCOUNT, + api_key=dify_config.BAIDU_VECTOR_DB_API_KEY, + database=dify_config.BAIDU_VECTOR_DB_DATABASE, + shard=dify_config.BAIDU_VECTOR_DB_SHARD, + replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS, + ), + ) diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index 3629887b448aeb..a9e1486edd25f1 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -6,10 +6,10 @@ from pydantic import BaseModel from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -27,21 +27,20 @@ def to_chroma_params(self): settings = Settings( # auth chroma_client_auth_provider=self.auth_provider, - chroma_client_auth_credentials=self.auth_credentials + chroma_client_auth_credentials=self.auth_credentials, ) return { - 'host': self.host, - 'port': self.port, - 'ssl': False, - 'tenant': self.tenant, - 'database': self.database, - 'settings': settings, + "host": self.host, + "port": self.port, + "ssl": False, + "tenant": self.tenant, + "database": self.database, + "settings": settings, } class ChromaVector(BaseVector): - def __init__(self, collection_name: str, config: ChromaConfig): super().__init__(collection_name) self._client_config = config @@ -58,9 +57,9 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) self.add_texts(texts, embeddings, **kwargs) def create_collection(self, collection_name: str): - lock_name = 'vector_indexing_lock_{}'.format(collection_name) + lock_name = "vector_indexing_lock_{}".format(collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return self._client.get_or_create_collection(collection_name) @@ -76,7 +75,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** def delete_by_metadata_field(self, key: str, value: str): collection = self._client.get_or_create_collection(self._collection_name) - collection.delete(where={key: {'$eq': value}}) + collection.delete(where={key: {"$eq": value}}) def delete(self): self._client.delete_collection(self._collection_name) @@ -93,26 +92,26 @@ def text_exists(self, id: str) -> bool: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: collection = self._client.get_or_create_collection(self._collection_name) results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + score_threshold = float(kwargs.get("score_threshold") or 0.0) - ids: list[str] = results['ids'][0] - documents: list[str] = results['documents'][0] - metadatas: dict[str, Any] = results['metadatas'][0] - distances: list[float] = results['distances'][0] + ids: list[str] = results["ids"][0] + documents: list[str] = results["documents"][0] + metadatas: dict[str, Any] = results["metadatas"][0] + distances: list[float] = results["distances"][0] docs = [] for index in range(len(ids)): distance = distances[index] metadata = metadatas[index] if distance >= score_threshold: - metadata['score'] = distance + metadata["score"] = distance doc = Document( page_content=documents[index], metadata=metadata, ) docs.append(doc) - # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) + # Sort the documents by score in descending order + docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -123,15 +122,12 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: class ChromaVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - index_struct_dict = { - "type": VectorType.CHROMA, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": VectorType.CHROMA, "vector_store": {"class_prefix": collection_name}} dataset.index_struct = json.dumps(index_struct_dict) return ChromaVector( diff --git a/api/core/rag/datasource/vdb/couchbase/__init__.py b/api/core/rag/datasource/vdb/couchbase/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py new file mode 100644 index 00000000000000..98da5e3d5e91a4 --- /dev/null +++ b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py @@ -0,0 +1,378 @@ +import json +import logging +import time +import uuid +from datetime import timedelta +from typing import Any + +from couchbase import search +from couchbase.auth import PasswordAuthenticator +from couchbase.cluster import Cluster +from couchbase.management.search import SearchIndex + +# needed for options -- cluster, timeout, SQL++ (N1QL) query, etc. +from couchbase.options import ClusterOptions, SearchOptions +from couchbase.vector_search import VectorQuery, VectorSearch +from flask import current_app +from pydantic import BaseModel, model_validator + +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document +from extensions.ext_redis import redis_client +from models.dataset import Dataset + +logger = logging.getLogger(__name__) + + +class CouchbaseConfig(BaseModel): + connection_string: str + user: str + password: str + bucket_name: str + scope_name: str + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + if not values.get("connection_string"): + raise ValueError("config COUCHBASE_CONNECTION_STRING is required") + if not values.get("user"): + raise ValueError("config COUCHBASE_USER is required") + if not values.get("password"): + raise ValueError("config COUCHBASE_PASSWORD is required") + if not values.get("bucket_name"): + raise ValueError("config COUCHBASE_PASSWORD is required") + if not values.get("scope_name"): + raise ValueError("config COUCHBASE_SCOPE_NAME is required") + return values + + +class CouchbaseVector(BaseVector): + def __init__(self, collection_name: str, config: CouchbaseConfig): + super().__init__(collection_name) + self._client_config = config + + """Connect to couchbase""" + + auth = PasswordAuthenticator(config.user, config.password) + options = ClusterOptions(auth) + self._cluster = Cluster(config.connection_string, options) + self._bucket = self._cluster.bucket(config.bucket_name) + self._scope = self._bucket.scope(config.scope_name) + self._bucket_name = config.bucket_name + self._scope_name = config.scope_name + + # Wait until the cluster is ready for use. + self._cluster.wait_until_ready(timedelta(seconds=5)) + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + index_id = str(uuid.uuid4()).replace("-", "") + self._create_collection(uuid=index_id, vector_length=len(embeddings[0])) + self.add_texts(texts, embeddings) + + def _create_collection(self, vector_length: int, uuid: str): + lock_name = "vector_indexing_lock_{}".format(self._collection_name) + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) + if redis_client.get(collection_exist_cache_key): + return + if self._collection_exists(self._collection_name): + return + manager = self._bucket.collections() + manager.create_collection(self._client_config.scope_name, self._collection_name) + + index_manager = self._scope.search_indexes() + + index_definition = json.loads(""" +{ + "type": "fulltext-index", + "name": "Embeddings._default.Vector_Search", + "uuid": "26d4db528e78b716", + "sourceType": "gocbcore", + "sourceName": "Embeddings", + "sourceUUID": "2242e4a25b4decd6650c9c7b3afa1dbf", + "planParams": { + "maxPartitionsPerPIndex": 1024, + "indexPartitions": 1 + }, + "params": { + "doc_config": { + "docid_prefix_delim": "", + "docid_regexp": "", + "mode": "scope.collection.type_field", + "type_field": "type" + }, + "mapping": { + "analysis": { }, + "default_analyzer": "standard", + "default_datetime_parser": "dateTimeOptional", + "default_field": "_all", + "default_mapping": { + "dynamic": true, + "enabled": true + }, + "default_type": "_default", + "docvalues_dynamic": false, + "index_dynamic": true, + "store_dynamic": true, + "type_field": "_type", + "types": { + "collection_name": { + "dynamic": true, + "enabled": true, + "properties": { + "embedding": { + "dynamic": false, + "enabled": true, + "fields": [ + { + "dims": 1536, + "index": true, + "name": "embedding", + "similarity": "dot_product", + "type": "vector", + "vector_index_optimized_for": "recall" + } + ] + }, + "metadata": { + "dynamic": true, + "enabled": true + }, + "text": { + "dynamic": false, + "enabled": true, + "fields": [ + { + "index": true, + "name": "text", + "store": true, + "type": "text" + } + ] + } + } + } + } + }, + "store": { + "indexType": "scorch", + "segmentVersion": 16 + } + }, + "sourceParams": { } + } +""") + index_definition["name"] = self._collection_name + "_search" + index_definition["uuid"] = uuid + index_definition["params"]["mapping"]["types"]["collection_name"]["properties"]["embedding"]["fields"][0][ + "dims" + ] = vector_length + index_definition["params"]["mapping"]["types"][self._scope_name + "." + self._collection_name] = ( + index_definition["params"]["mapping"]["types"].pop("collection_name") + ) + time.sleep(2) + index_manager.upsert_index( + SearchIndex( + index_definition["name"], + params=index_definition["params"], + source_name=self._bucket_name, + ), + ) + time.sleep(1) + + redis_client.set(collection_exist_cache_key, 1, ex=3600) + + def _collection_exists(self, name: str): + scope_collection_map: dict[str, Any] = {} + + # Get a list of all scopes in the bucket + for scope in self._bucket.collections().get_all_scopes(): + scope_collection_map[scope.name] = [] + + # Get a list of all the collections in the scope + for collection in scope.collections: + scope_collection_map[scope.name].append(collection.name) + + # Check if the collection exists in the scope + return self._collection_name in scope_collection_map[self._scope_name] + + def get_type(self) -> str: + return VectorType.COUCHBASE + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + uuids = self._get_uuids(documents) + texts = [d.page_content for d in documents] + metadatas = [d.metadata for d in documents] + + doc_ids = [] + + documents_to_insert = [ + {"text": text, "embedding": vector, "metadata": metadata} + for id, text, vector, metadata in zip(uuids, texts, embeddings, metadatas) + ] + for doc, id in zip(documents_to_insert, uuids): + result = self._scope.collection(self._collection_name).upsert(id, doc) + + doc_ids.extend(uuids) + + return doc_ids + + def text_exists(self, id: str) -> bool: + # Use a parameterized query for safety and correctness + query = f""" + SELECT COUNT(1) AS count FROM + `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name} + WHERE META().id = $doc_id + """ + # Pass the id as a parameter to the query + result = self._cluster.query(query, named_parameters={"doc_id": id}).execute() + for row in result: + return row["count"] > 0 + return False # Return False if no rows are returned + + def delete_by_ids(self, ids: list[str]) -> None: + query = f""" + DELETE FROM `{self._bucket_name}`.{self._client_config.scope_name}.{self._collection_name} + WHERE META().id IN $doc_ids; + """ + try: + self._cluster.query(query, named_parameters={"doc_ids": ids}).execute() + except Exception as e: + logger.exception(e) + + def delete_by_document_id(self, document_id: str): + query = f""" + DELETE FROM + `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name} + WHERE META().id = $doc_id; + """ + self._cluster.query(query, named_parameters={"doc_id": document_id}).execute() + + # def get_ids_by_metadata_field(self, key: str, value: str): + # query = f""" + # SELECT id FROM + # `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name} + # WHERE `metadata.{key}` = $value; + # """ + # result = self._cluster.query(query, named_parameters={'value':value}) + # return [row['id'] for row in result.rows()] + + def delete_by_metadata_field(self, key: str, value: str) -> None: + query = f""" + DELETE FROM `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name} + WHERE metadata.{key} = $value; + """ + self._cluster.query(query, named_parameters={"value": value}).execute() + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + top_k = kwargs.get("top_k", 5) + score_threshold = kwargs.get("score_threshold") or 0.0 + + search_req = search.SearchRequest.create( + VectorSearch.from_vector_query( + VectorQuery( + "embedding", + query_vector, + top_k, + ) + ) + ) + try: + search_iter = self._scope.search( + self._collection_name + "_search", + search_req, + SearchOptions(limit=top_k, collections=[self._collection_name], fields=["*"]), + ) + + docs = [] + # Parse the results + for row in search_iter.rows(): + text = row.fields.pop("text") + metadata = self._format_metadata(row.fields) + score = row.score + metadata["score"] = score + doc = Document(page_content=text, metadata=metadata) + if score >= score_threshold: + docs.append(doc) + except Exception as e: + raise ValueError(f"Search failed with error: {e}") + + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + top_k = kwargs.get("top_k", 2) + try: + CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) + search_iter = self._scope.search( + self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"]) + ) + + docs = [] + for row in search_iter.rows(): + text = row.fields.pop("text") + metadata = self._format_metadata(row.fields) + score = row.score + metadata["score"] = score + doc = Document(page_content=text, metadata=metadata) + docs.append(doc) + + except Exception as e: + raise ValueError(f"Search failed with error: {e}") + + return docs + + def delete(self): + manager = self._bucket.collections() + scopes = manager.get_all_scopes() + + for scope in scopes: + for collection in scope.collections: + if collection.name == self._collection_name: + manager.drop_collection("_default", self._collection_name) + + def _format_metadata(self, row_fields: dict[str, Any]) -> dict[str, Any]: + """Helper method to format the metadata from the Couchbase Search API. + Args: + row_fields (Dict[str, Any]): The fields to format. + + Returns: + Dict[str, Any]: The formatted metadata. + """ + metadata = {} + for key, value in row_fields.items(): + # Couchbase Search returns the metadata key with a prefix + # `metadata.` We remove it to get the original metadata key + if key.startswith("metadata"): + new_key = key.split("metadata" + ".")[-1] + metadata[new_key] = value + else: + metadata[key] = value + + return metadata + + +class CouchbaseVectorFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> CouchbaseVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.COUCHBASE, collection_name)) + + config = current_app.config + return CouchbaseVector( + collection_name=collection_name, + config=CouchbaseConfig( + connection_string=config.get("COUCHBASE_CONNECTION_STRING"), + user=config.get("COUCHBASE_USER"), + password=config.get("COUCHBASE_PASSWORD"), + bucket_name=config.get("COUCHBASE_BUCKET_NAME"), + scope_name=config.get("COUCHBASE_SCOPE_NAME"), + ), + ) diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 01ba6fb3248786..c62042af8071d1 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -1,34 +1,42 @@ import json -from typing import Any +import logging +import math +from typing import Any, Optional +from urllib.parse import urlparse import requests from elasticsearch import Elasticsearch from flask import current_app from pydantic import BaseModel, model_validator -from core.rag.datasource.entity.embedding import Embeddings +from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document +from extensions.ext_redis import redis_client from models.dataset import Dataset +logger = logging.getLogger(__name__) + class ElasticSearchConfig(BaseModel): host: str - port: str + port: int username: str password: str - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: - if not values['host']: + if not values["host"]: raise ValueError("config HOST is required") - if not values['port']: + if not values["port"]: raise ValueError("config PORT is required") - if not values['username']: + if not values["username"]: raise ValueError("config USERNAME is required") - if not values['password']: + if not values["password"]: raise ValueError("config PASSWORD is required") return values @@ -37,12 +45,19 @@ class ElasticSearchVector(BaseVector): def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: list): super().__init__(index_name.lower()) self._client = self._init_client(config) + self._version = self._get_version() + self._check_version() self._attributes = attributes def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: try: + parsed_url = urlparse(config.host) + if parsed_url.scheme in {"http", "https"}: + hosts = f"{config.host}:{config.port}" + else: + hosts = f"http://{config.host}:{config.port}" client = Elasticsearch( - hosts=f'{config.host}:{config.port}', + hosts=hosts, basic_auth=(config.username, config.password), request_timeout=100000, retry_on_timeout=True, @@ -53,62 +68,43 @@ def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: return client + def _get_version(self) -> str: + info = self._client.info() + return info["version"]["number"] + + def _check_version(self): + if self._version < "8.0.0": + raise ValueError("Elasticsearch vector database version must be greater than 8.0.0") + def get_type(self) -> str: - return 'elasticsearch' + return VectorType.ELASTICSEARCH def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents) - texts = [d.page_content for d in documents] - metadatas = [d.metadata for d in documents] - - if not self._client.indices.exists(index=self._collection_name): - dim = len(embeddings[0]) - mapping = { - "properties": { - "text": { - "type": "text" - }, - "vector": { - "type": "dense_vector", - "index": True, - "dims": dim, - "similarity": "l2_norm" - }, - } - } - self._client.indices.create(index=self._collection_name, mappings=mapping) - - added_ids = [] - for i, text in enumerate(texts): - self._client.index(index=self._collection_name, - id=uuids[i], - document={ - "text": text, - "vector": embeddings[i] if embeddings[i] else None, - "metadata": metadatas[i] if metadatas[i] else {}, - }) - added_ids.append(uuids[i]) - + for i in range(len(documents)): + self._client.index( + index=self._collection_name, + id=uuids[i], + document={ + Field.CONTENT_KEY.value: documents[i].page_content, + Field.VECTOR.value: embeddings[i] or None, + Field.METADATA_KEY.value: documents[i].metadata or {}, + }, + ) self._client.indices.refresh(index=self._collection_name) return uuids def text_exists(self, id: str) -> bool: - return self._client.exists(index=self._collection_name, id=id).__bool__() + return bool(self._client.exists(index=self._collection_name, id=id)) def delete_by_ids(self, ids: list[str]) -> None: for id in ids: self._client.delete(index=self._collection_name, id=id) def delete_by_metadata_field(self, key: str, value: str) -> None: - query_str = { - 'query': { - 'match': { - f'metadata.{key}': f'{value}' - } - } - } + query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}} results = self._client.search(index=self._collection_name, body=query_str) - ids = [hit['_id'] for hit in results['hits']['hits']] + ids = [hit["_id"] for hit in results["hits"]["hits"]] if ids: self.delete_by_ids(ids) @@ -116,76 +112,105 @@ def delete(self) -> None: self._client.indices.delete(index=self._collection_name) def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: - query_str = { - "query": { - "script_score": { - "query": { - "match_all": {} - }, - "script": { - "source": "cosineSimilarity(params.query_vector, 'vector') + 1.0", - "params": { - "query_vector": query_vector - } - } - } - } - } + top_k = kwargs.get("top_k", 4) + num_candidates = math.ceil(top_k * 1.5) + knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates} - results = self._client.search(index=self._collection_name, body=query_str) + results = self._client.search(index=self._collection_name, knn=knn, size=top_k) docs_and_scores = [] - for hit in results['hits']['hits']: + for hit in results["hits"]["hits"]: docs_and_scores.append( - (Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']), hit['_score'])) + ( + Document( + page_content=hit["_source"][Field.CONTENT_KEY.value], + vector=hit["_source"][Field.VECTOR.value], + metadata=hit["_source"][Field.METADATA_KEY.value], + ), + hit["_score"], + ) + ) docs = [] for doc, score in docs_and_scores: - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + score_threshold = float(kwargs.get("score_threshold") or 0.0) if score > score_threshold: - doc.metadata['score'] = score + doc.metadata["score"] = score docs.append(doc) - # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) - return docs + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - query_str = { - "match": { - "text": query - } - } - results = self._client.search(index=self._collection_name, query=query_str) + query_str = {"match": {Field.CONTENT_KEY.value: query}} + results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4)) docs = [] - for hit in results['hits']['hits']: - docs.append(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata'])) + for hit in results["hits"]["hits"]: + docs.append( + Document( + page_content=hit["_source"][Field.CONTENT_KEY.value], + vector=hit["_source"][Field.VECTOR.value], + metadata=hit["_source"][Field.METADATA_KEY.value], + ) + ) return docs def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - return self.add_texts(texts, embeddings, **kwargs) + metadatas = [d.metadata for d in texts] + self.create_collection(embeddings, metadatas) + self.add_texts(texts, embeddings, **kwargs) + + def create_collection( + self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + ): + lock_name = f"vector_indexing_lock_{self._collection_name}" + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" + if redis_client.get(collection_exist_cache_key): + logger.info(f"Collection {self._collection_name} already exists.") + return + + if not self._client.indices.exists(index=self._collection_name): + dim = len(embeddings[0]) + mappings = { + "properties": { + Field.CONTENT_KEY.value: {"type": "text"}, + Field.VECTOR.value: { # Make sure the dimension is correct here + "type": "dense_vector", + "dims": dim, + "similarity": "cosine", + }, + Field.METADATA_KEY.value: { + "type": "object", + "properties": { + "doc_id": {"type": "keyword"} # Map doc_id to keyword type + }, + }, + } + } + self._client.indices.create(index=self._collection_name, mappings=mappings) + + redis_client.set(collection_exist_cache_key, 1, ex=3600) class ElasticSearchVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> ElasticSearchVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ELASTICSEARCH, collection_name)) config = current_app.config return ElasticSearchVector( index_name=collection_name, config=ElasticSearchConfig( - host=config.get('ELASTICSEARCH_HOST'), - port=config.get('ELASTICSEARCH_PORT'), - username=config.get('ELASTICSEARCH_USERNAME'), - password=config.get('ELASTICSEARCH_PASSWORD'), + host=config.get("ELASTICSEARCH_HOST"), + port=config.get("ELASTICSEARCH_PORT"), + username=config.get("ELASTICSEARCH_USERNAME"), + password=config.get("ELASTICSEARCH_PASSWORD"), ), - attributes=[] + attributes=[], ) diff --git a/api/core/rag/datasource/vdb/lindorm/__init__.py b/api/core/rag/datasource/vdb/lindorm/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py new file mode 100644 index 00000000000000..30d7f09ec20ccf --- /dev/null +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -0,0 +1,498 @@ +import copy +import json +import logging +from collections.abc import Iterable +from typing import Any, Optional + +from opensearchpy import OpenSearch +from opensearchpy.helpers import bulk +from pydantic import BaseModel, model_validator +from tenacity import retry, stop_after_attempt, wait_fixed + +from configs import dify_config +from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document +from extensions.ext_redis import redis_client +from models.dataset import Dataset + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logging.getLogger("lindorm").setLevel(logging.WARN) + + +class LindormVectorStoreConfig(BaseModel): + hosts: str + username: Optional[str] = None + password: Optional[str] = None + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + if not values["hosts"]: + raise ValueError("config URL is required") + if not values["username"]: + raise ValueError("config USERNAME is required") + if not values["password"]: + raise ValueError("config PASSWORD is required") + return values + + def to_opensearch_params(self) -> dict[str, Any]: + params = { + "hosts": self.hosts, + } + if self.username and self.password: + params["http_auth"] = (self.username, self.password) + return params + + +class LindormVectorStore(BaseVector): + def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs): + super().__init__(collection_name.lower()) + self._client_config = config + self._client = OpenSearch(**config.to_opensearch_params()) + self.kwargs = kwargs + + def get_type(self) -> str: + return VectorType.LINDORM + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + self.create_collection(len(embeddings[0]), **kwargs) + self.add_texts(texts, embeddings) + + def refresh(self): + self._client.indices.refresh(index=self._collection_name) + + def __filter_existed_ids( + self, + texts: list[str], + metadatas: list[dict], + ids: list[str], + bulk_size: int = 1024, + ) -> tuple[Iterable[str], Optional[list[dict]], Optional[list[str]]]: + @retry(stop=stop_after_attempt(3), wait=wait_fixed(60)) + def __fetch_existing_ids(batch_ids: list[str]) -> set[str]: + try: + existing_docs = self._client.mget(index=self._collection_name, body={"ids": batch_ids}, _source=False) + return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]} + except Exception as e: + logger.exception(f"Error fetching batch {batch_ids}: {e}") + return set() + + @retry(stop=stop_after_attempt(3), wait=wait_fixed(60)) + def __fetch_existing_routing_ids(batch_ids: list[str], route_ids: list[str]) -> set[str]: + try: + existing_docs = self._client.mget( + body={ + "docs": [ + {"_index": self._collection_name, "_id": id, "routing": routing} + for id, routing in zip(batch_ids, route_ids) + ] + }, + _source=False, + ) + return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]} + except Exception as e: + logger.exception(f"Error fetching batch {batch_ids}: {e}") + return set() + + if ids is None: + return texts, metadatas, ids + + if len(texts) != len(ids): + raise RuntimeError(f"texts {len(texts)} != {ids}") + + filtered_texts = [] + filtered_metadatas = [] + filtered_ids = [] + + def batch(iterable, n): + length = len(iterable) + for idx in range(0, length, n): + yield iterable[idx : min(idx + n, length)] + + for ids_batch, texts_batch, metadatas_batch in zip( + batch(ids, bulk_size), + batch(texts, bulk_size), + batch(metadatas, bulk_size) if metadatas is not None else batch([None] * len(ids), bulk_size), + ): + existing_ids_set = __fetch_existing_ids(ids_batch) + for text, metadata, doc_id in zip(texts_batch, metadatas_batch, ids_batch): + if doc_id not in existing_ids_set: + filtered_texts.append(text) + filtered_ids.append(doc_id) + if metadatas is not None: + filtered_metadatas.append(metadata) + + return filtered_texts, metadatas if metadatas is None else filtered_metadatas, filtered_ids + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + actions = [] + uuids = self._get_uuids(documents) + for i in range(len(documents)): + action = { + "_op_type": "index", + "_index": self._collection_name.lower(), + "_id": uuids[i], + "_source": { + Field.CONTENT_KEY.value: documents[i].page_content, + Field.VECTOR.value: embeddings[i], # Make sure you pass an array here + Field.METADATA_KEY.value: documents[i].metadata, + }, + } + actions.append(action) + bulk(self._client, actions) + self.refresh() + + def get_ids_by_metadata_field(self, key: str, value: str): + query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}} + response = self._client.search(index=self._collection_name, body=query) + if response["hits"]["hits"]: + return [hit["_id"] for hit in response["hits"]["hits"]] + else: + return None + + def delete_by_metadata_field(self, key: str, value: str): + query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}} + results = self._client.search(index=self._collection_name, body=query_str) + ids = [hit["_id"] for hit in results["hits"]["hits"]] + if ids: + self.delete_by_ids(ids) + + def delete_by_ids(self, ids: list[str]) -> None: + for id in ids: + if self._client.exists(index=self._collection_name, id=id): + self._client.delete(index=self._collection_name, id=id) + else: + logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.") + + def delete(self) -> None: + try: + if self._client.indices.exists(index=self._collection_name): + self._client.indices.delete(index=self._collection_name, params={"timeout": 60}) + logger.info("Delete index success") + else: + logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.") + except Exception as e: + logger.exception(f"Error occurred while deleting the index: {e}") + raise e + + def text_exists(self, id: str) -> bool: + try: + self._client.get(index=self._collection_name, id=id) + return True + except: + return False + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + # Make sure query_vector is a list + if not isinstance(query_vector, list): + raise ValueError("query_vector should be a list of floats") + + # Check whether query_vector is a floating-point number list + if not all(isinstance(x, float) for x in query_vector): + raise ValueError("All elements in query_vector should be floats") + + top_k = kwargs.get("top_k", 10) + query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs) + try: + response = self._client.search(index=self._collection_name, body=query) + except Exception as e: + logger.exception(f"Error executing search: {e}") + raise + + docs_and_scores = [] + for hit in response["hits"]["hits"]: + docs_and_scores.append( + ( + Document( + page_content=hit["_source"][Field.CONTENT_KEY.value], + vector=hit["_source"][Field.VECTOR.value], + metadata=hit["_source"][Field.METADATA_KEY.value], + ), + hit["_score"], + ) + ) + docs = [] + for doc, score in docs_and_scores: + score_threshold = kwargs.get("score_threshold", 0.0) or 0.0 + if score > score_threshold: + doc.metadata["score"] = score + docs.append(doc) + + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + must = kwargs.get("must") + must_not = kwargs.get("must_not") + should = kwargs.get("should") + minimum_should_match = kwargs.get("minimum_should_match", 0) + top_k = kwargs.get("top_k", 10) + filters = kwargs.get("filter") + routing = kwargs.get("routing") + full_text_query = default_text_search_query( + query_text=query, + k=top_k, + text_field=Field.CONTENT_KEY.value, + must=must, + must_not=must_not, + should=should, + minimum_should_match=minimum_should_match, + filters=filters, + routing=routing, + ) + response = self._client.search(index=self._collection_name, body=full_text_query) + docs = [] + for hit in response["hits"]["hits"]: + docs.append( + Document( + page_content=hit["_source"][Field.CONTENT_KEY.value], + vector=hit["_source"][Field.VECTOR.value], + metadata=hit["_source"][Field.METADATA_KEY.value], + ) + ) + + return docs + + def create_collection(self, dimension: int, **kwargs): + lock_name = f"vector_indexing_lock_{self._collection_name}" + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" + if redis_client.get(collection_exist_cache_key): + logger.info(f"Collection {self._collection_name} already exists.") + return + if self._client.indices.exists(index=self._collection_name): + logger.info("{self._collection_name.lower()} already exists.") + return + if len(self.kwargs) == 0 and len(kwargs) != 0: + self.kwargs = copy.deepcopy(kwargs) + vector_field = kwargs.pop("vector_field", Field.VECTOR.value) + shards = kwargs.pop("shards", 2) + + engine = kwargs.pop("engine", "lvector") + method_name = kwargs.pop("method_name", "hnsw") + data_type = kwargs.pop("data_type", "float") + space_type = kwargs.pop("space_type", "cosinesimil") + + hnsw_m = kwargs.pop("hnsw_m", 24) + hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500) + ivfpq_m = kwargs.pop("ivfpq_m", dimension) + nlist = kwargs.pop("nlist", 1000) + centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", True if nlist >= 5000 else False) + centroids_hnsw_m = kwargs.pop("centroids_hnsw_m", 24) + centroids_hnsw_ef_construct = kwargs.pop("centroids_hnsw_ef_construct", 500) + centroids_hnsw_ef_search = kwargs.pop("centroids_hnsw_ef_search", 100) + mapping = default_text_mapping( + dimension, + method_name, + shards=shards, + engine=engine, + data_type=data_type, + space_type=space_type, + vector_field=vector_field, + hnsw_m=hnsw_m, + hnsw_ef_construction=hnsw_ef_construction, + nlist=nlist, + ivfpq_m=ivfpq_m, + centroids_use_hnsw=centroids_use_hnsw, + centroids_hnsw_m=centroids_hnsw_m, + centroids_hnsw_ef_construct=centroids_hnsw_ef_construct, + centroids_hnsw_ef_search=centroids_hnsw_ef_search, + **kwargs, + ) + self._client.indices.create(index=self._collection_name.lower(), body=mapping) + redis_client.set(collection_exist_cache_key, 1, ex=3600) + # logger.info(f"create index success: {self._collection_name}") + + +def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict: + routing_field = kwargs.get("routing_field") + excludes_from_source = kwargs.get("excludes_from_source") + analyzer = kwargs.get("analyzer", "ik_max_word") + text_field = kwargs.get("text_field", Field.CONTENT_KEY.value) + engine = kwargs["engine"] + shard = kwargs["shards"] + space_type = kwargs["space_type"] + data_type = kwargs["data_type"] + vector_field = kwargs.get("vector_field", Field.VECTOR.value) + + if method_name == "ivfpq": + ivfpq_m = kwargs["ivfpq_m"] + nlist = kwargs["nlist"] + centroids_use_hnsw = True if nlist > 10000 else False + centroids_hnsw_m = 24 + centroids_hnsw_ef_construct = 500 + centroids_hnsw_ef_search = 100 + parameters = { + "m": ivfpq_m, + "nlist": nlist, + "centroids_use_hnsw": centroids_use_hnsw, + "centroids_hnsw_m": centroids_hnsw_m, + "centroids_hnsw_ef_construct": centroids_hnsw_ef_construct, + "centroids_hnsw_ef_search": centroids_hnsw_ef_search, + } + elif method_name == "hnsw": + neighbor = kwargs["hnsw_m"] + ef_construction = kwargs["hnsw_ef_construction"] + parameters = {"m": neighbor, "ef_construction": ef_construction} + elif method_name == "flat": + parameters = {} + else: + raise RuntimeError(f"unexpected method_name: {method_name}") + + mapping = { + "settings": {"index": {"number_of_shards": shard, "knn": True}}, + "mappings": { + "properties": { + vector_field: { + "type": "knn_vector", + "dimension": dimension, + "data_type": data_type, + "method": { + "engine": engine, + "name": method_name, + "space_type": space_type, + "parameters": parameters, + }, + }, + text_field: {"type": "text", "analyzer": analyzer}, + } + }, + } + + if excludes_from_source: + mapping["mappings"]["_source"] = {"excludes": excludes_from_source} # e.g. {"excludes": ["vector_field"]} + + if method_name == "ivfpq" and routing_field is not None: + mapping["settings"]["index"]["knn_routing"] = True + mapping["settings"]["index"]["knn.offline.construction"] = True + + if method_name == "flat" and routing_field is not None: + mapping["settings"]["index"]["knn_routing"] = True + + return mapping + + +def default_text_search_query( + query_text: str, + k: int = 4, + text_field: str = Field.CONTENT_KEY.value, + must: Optional[list[dict]] = None, + must_not: Optional[list[dict]] = None, + should: Optional[list[dict]] = None, + minimum_should_match: int = 0, + filters: Optional[list[dict]] = None, + routing: Optional[str] = None, + **kwargs, +) -> dict: + if routing is not None: + routing_field = kwargs.get("routing_field", "routing_field") + query_clause = { + "bool": { + "must": [{"match": {text_field: query_text}}, {"term": {f"metadata.{routing_field}.keyword": routing}}] + } + } + else: + query_clause = {"match": {text_field: query_text}} + # build the simplest search_query when only query_text is specified + if not must and not must_not and not should and not filters: + search_query = {"size": k, "query": query_clause} + return search_query + + # build complex search_query when either of must/must_not/should/filter is specified + if must: + if not isinstance(must, list): + raise RuntimeError(f"unexpected [must] clause with {type(filters)}") + if query_clause not in must: + must.append(query_clause) + else: + must = [query_clause] + + boolean_query = {"must": must} + + if must_not: + if not isinstance(must_not, list): + raise RuntimeError(f"unexpected [must_not] clause with {type(filters)}") + boolean_query["must_not"] = must_not + + if should: + if not isinstance(should, list): + raise RuntimeError(f"unexpected [should] clause with {type(filters)}") + boolean_query["should"] = should + if minimum_should_match != 0: + boolean_query["minimum_should_match"] = minimum_should_match + + if filters: + if not isinstance(filters, list): + raise RuntimeError(f"unexpected [filter] clause with {type(filters)}") + boolean_query["filter"] = filters + + search_query = {"size": k, "query": {"bool": boolean_query}} + return search_query + + +def default_vector_search_query( + query_vector: list[float], + k: int = 4, + min_score: str = "0.0", + ef_search: Optional[str] = None, # only for hnsw + nprobe: Optional[str] = None, # "2000" + reorder_factor: Optional[str] = None, # "20" + client_refactor: Optional[str] = None, # "true" + vector_field: str = Field.VECTOR.value, + filters: Optional[list[dict]] = None, + filter_type: Optional[str] = None, + **kwargs, +) -> dict: + if filters is not None: + filter_type = "post_filter" if filter_type is None else filter_type + if not isinstance(filter, list): + raise RuntimeError(f"unexpected filter with {type(filters)}") + final_ext = {"lvector": {}} + if min_score != "0.0": + final_ext["lvector"]["min_score"] = min_score + if ef_search: + final_ext["lvector"]["ef_search"] = ef_search + if nprobe: + final_ext["lvector"]["nprobe"] = nprobe + if reorder_factor: + final_ext["lvector"]["reorder_factor"] = reorder_factor + if client_refactor: + final_ext["lvector"]["client_refactor"] = client_refactor + + search_query = { + "size": k, + "_source": True, # force return '_source' + "query": {"knn": {vector_field: {"vector": query_vector, "k": k}}}, + } + + if filters is not None: + # when using filter, transform filter from List[Dict] to Dict as valid format + filters = {"bool": {"must": filters}} if len(filters) > 1 else filters[0] + search_query["query"]["knn"][vector_field]["filter"] = filters # filter should be Dict + if filter_type: + final_ext["lvector"]["filter_type"] = filter_type + + if final_ext != {"lvector": {}}: + search_query["ext"] = final_ext + return search_query + + +class LindormVectorStoreFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.LINDORM, collection_name)) + lindorm_config = LindormVectorStoreConfig( + hosts=dify_config.LINDORM_URL, + username=dify_config.LINDORM_USERNAME, + password=dify_config.LINDORM_PASSWORD, + ) + return LindormVectorStore(collection_name, lindorm_config) diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index cfc533ed33a6d0..5a263d6e78c3bd 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -1,18 +1,17 @@ import json import logging from typing import Any, Optional -from uuid import uuid4 from pydantic import BaseModel, model_validator -from pymilvus import MilvusClient, MilvusException, connections +from pymilvus import MilvusClient, MilvusException from pymilvus.milvus_client import IndexParams from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -21,55 +20,47 @@ class MilvusConfig(BaseModel): - host: str - port: int + uri: str + token: Optional[str] = None user: str password: str - secure: bool = False batch_size: int = 100 database: str = "default" - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: - if not values.get('host'): - raise ValueError("config MILVUS_HOST is required") - if not values.get('port'): - raise ValueError("config MILVUS_PORT is required") - if not values.get('user'): + if not values.get("uri"): + raise ValueError("config MILVUS_URI is required") + if not values.get("user"): raise ValueError("config MILVUS_USER is required") - if not values.get('password'): + if not values.get("password"): raise ValueError("config MILVUS_PASSWORD is required") return values def to_milvus_params(self): return { - 'host': self.host, - 'port': self.port, - 'user': self.user, - 'password': self.password, - 'secure': self.secure, - 'db_name': self.database, + "uri": self.uri, + "token": self.token, + "user": self.user, + "password": self.password, + "db_name": self.database, } class MilvusVector(BaseVector): - def __init__(self, collection_name: str, config: MilvusConfig): super().__init__(collection_name) self._client_config = config self._client = self._init_client(config) - self._consistency_level = 'Session' + self._consistency_level = "Session" self._fields = [] def get_type(self) -> str: return VectorType.MILVUS def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): - index_params = { - 'metric_type': 'IP', - 'index_type': "HNSW", - 'params': {"M": 8, "efConstruction": 64} - } + index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}} metadatas = [d.metadata for d in texts] self.create_collection(embeddings, metadatas, index_params) self.add_texts(texts, embeddings) @@ -80,7 +71,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** insert_dict = { Field.CONTENT_KEY.value: documents[i].page_content, Field.VECTOR.value: embeddings[i], - Field.METADATA_KEY.value: documents[i].metadata + Field.METADATA_KEY.value: documents[i].metadata, } insert_dict_list.append(insert_dict) # Total insert count @@ -89,111 +80,70 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** pks: list[str] = [] for i in range(0, total_count, 1000): - batch_insert_list = insert_dict_list[i:i + 1000] + batch_insert_list = insert_dict_list[i : i + 1000] # Insert into the collection. try: ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list) pks.extend(ids) except MilvusException as e: - logger.error( - "Failed to insert batch starting at entity: %s/%s", i, total_count - ) + logger.exception("Failed to insert batch starting at entity: %s/%s", i, total_count) raise e return pks def get_ids_by_metadata_field(self, key: str, value: str): - result = self._client.query(collection_name=self._collection_name, - filter=f'metadata["{key}"] == "{value}"', - output_fields=["id"]) + result = self._client.query( + collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"] + ) if result: return [item["id"] for item in result] else: return None def delete_by_metadata_field(self, key: str, value: str): - alias = uuid4().hex - if self._client_config.secure: - uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) - else: - uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) - connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password, - db_name=self._client_config.database) - - from pymilvus import utility - if utility.has_collection(self._collection_name, using=alias): - + if self._client.has_collection(self._collection_name): ids = self.get_ids_by_metadata_field(key, value) if ids: self._client.delete(collection_name=self._collection_name, pks=ids) def delete_by_ids(self, ids: list[str]) -> None: - alias = uuid4().hex - if self._client_config.secure: - uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) - else: - uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) - connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password, - db_name=self._client_config.database) - - from pymilvus import utility - if utility.has_collection(self._collection_name, using=alias): - - result = self._client.query(collection_name=self._collection_name, - filter=f'metadata["doc_id"] in {ids}', - output_fields=["id"]) + if self._client.has_collection(self._collection_name): + result = self._client.query( + collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"] + ) if result: ids = [item["id"] for item in result] self._client.delete(collection_name=self._collection_name, pks=ids) def delete(self) -> None: - alias = uuid4().hex - if self._client_config.secure: - uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) - else: - uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) - connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password, - db_name=self._client_config.database) - - from pymilvus import utility - if utility.has_collection(self._collection_name, using=alias): - utility.drop_collection(self._collection_name, None, using=alias) + if self._client.has_collection(self._collection_name): + self._client.drop_collection(self._collection_name, None) def text_exists(self, id: str) -> bool: - alias = uuid4().hex - if self._client_config.secure: - uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) - else: - uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) - connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password, - db_name=self._client_config.database) - - from pymilvus import utility - if not utility.has_collection(self._collection_name, using=alias): + if not self._client.has_collection(self._collection_name): return False - result = self._client.query(collection_name=self._collection_name, - filter=f'metadata["doc_id"] == "{id}"', - output_fields=["id"]) + result = self._client.query( + collection_name=self._collection_name, filter=f'metadata["doc_id"] == "{id}"', output_fields=["id"] + ) return len(result) > 0 def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: - # Set search parameters. - results = self._client.search(collection_name=self._collection_name, - data=[query_vector], - limit=kwargs.get('top_k', 4), - output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], - ) + results = self._client.search( + collection_name=self._collection_name, + data=[query_vector], + limit=kwargs.get("top_k", 4), + output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], + ) # Organize results. docs = [] for result in results[0]: - metadata = result['entity'].get(Field.METADATA_KEY.value) - metadata['score'] = result['distance'] - score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 - if result['distance'] > score_threshold: - doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value), - metadata=metadata) + metadata = result["entity"].get(Field.METADATA_KEY.value) + metadata["score"] = result["distance"] + score_threshold = float(kwargs.get("score_threshold") or 0.0) + if result["distance"] > score_threshold: + doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata) docs.append(doc) return docs @@ -202,23 +152,15 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return [] def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None ): - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return # Grab the existing collection if it exists - from pymilvus import utility - alias = uuid4().hex - if self._client_config.secure: - uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) - else: - uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) - connections.connect(alias=alias, uri=uri, user=self._client_config.user, - password=self._client_config.password, db_name=self._client_config.database) - if not utility.has_collection(self._collection_name, using=alias): + if not self._client.has_collection(self._collection_name): from pymilvus import CollectionSchema, DataType, FieldSchema from pymilvus.orm.types import infer_dtype_bydata @@ -229,19 +171,11 @@ def create_collection( fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) # Create the text field - fields.append( - FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535) - ) + fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)) # Create the primary key field - fields.append( - FieldSchema( - Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True - ) - ) + fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True)) # Create the vector field, supports binary or float vectors - fields.append( - FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim) - ) + fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)) # Create the schema for the collection schema = CollectionSchema(fields) @@ -257,39 +191,36 @@ def create_collection( # Create the collection collection_name = self._collection_name - self._client.create_collection(collection_name=collection_name, - schema=schema, index_params=index_params_obj, - consistency_level=self._consistency_level) + self._client.create_collection( + collection_name=collection_name, + schema=schema, + index_params=index_params_obj, + consistency_level=self._consistency_level, + ) redis_client.set(collection_exist_cache_key, 1, ex=3600) def _init_client(self, config) -> MilvusClient: - if config.secure: - uri = "https://" + str(config.host) + ":" + str(config.port) - else: - uri = "http://" + str(config.host) + ":" + str(config.port) - client = MilvusClient(uri=uri, user=config.user, password=config.password, db_name=config.database) + client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database) return client class MilvusVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.MILVUS, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MILVUS, collection_name)) return MilvusVector( collection_name=collection_name, config=MilvusConfig( - host=dify_config.MILVUS_HOST, - port=dify_config.MILVUS_PORT, + uri=dify_config.MILVUS_URI, + token=dify_config.MILVUS_TOKEN, user=dify_config.MILVUS_USER, password=dify_config.MILVUS_PASSWORD, - secure=dify_config.MILVUS_SECURE, database=dify_config.MILVUS_DATABASE, - ) + ), ) diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index 4ae1a3395b0749..2610b60a7799da 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -8,10 +8,10 @@ from pydantic import BaseModel from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from models.dataset import Dataset @@ -31,12 +31,11 @@ class SortOrder(Enum): class MyScaleVector(BaseVector): - def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "Cosine"): super().__init__(collection_name) self._config = config self._metric = metric - self._vec_order = SortOrder.ASC if metric.upper() in ["COSINE", "L2"] else SortOrder.DESC + self._vec_order = SortOrder.ASC if metric.upper() in {"COSINE", "L2"} else SortOrder.DESC self._client = get_client( host=config.host, port=config.port, @@ -80,7 +79,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** doc_id, self.escape_str(doc.page_content), embeddings[i], - json.dumps(doc.metadata) if doc.metadata else {} + json.dumps(doc.metadata) if doc.metadata else {}, ) values.append(str(row)) ids.append(doc_id) @@ -93,7 +92,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** @staticmethod def escape_str(value: Any) -> str: - return "".join(" " if c in ("\\", "'") else c for c in str(value)) + return "".join(" " if c in {"\\", "'"} else c for c in str(value)) def text_exists(self, id: str) -> bool: results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'") @@ -101,7 +100,8 @@ def text_exists(self, id: str) -> bool: def delete_by_ids(self, ids: list[str]) -> None: self._client.command( - f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}") + f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}" + ) def get_ids_by_metadata_field(self, key: str, value: str): rows = self._client.query( @@ -121,10 +121,13 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return self._search(f"TextSearch('enable_nlq=false')(text, '{query}')", SortOrder.DESC, **kwargs) def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]: - top_k = kwargs.get("top_k", 5) - score_threshold = kwargs.get("score_threshold", 0.0) - where_str = f"WHERE dist < {1 - score_threshold}" if \ - self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else "" + top_k = kwargs.get("top_k", 4) + score_threshold = float(kwargs.get("score_threshold") or 0.0) + where_str = ( + f"WHERE dist < {1 - score_threshold}" + if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 + else "" + ) sql = f""" SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name} {where_str} ORDER BY dist {order.value} LIMIT {top_k} @@ -133,13 +136,13 @@ def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]: return [ Document( page_content=r["text"], - vector=r['vector'], + vector=r["vector"], metadata=r["metadata"], ) for r in self._client.query(sql).named_results() ] except Exception as e: - logging.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") + logging.exception(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") return [] def delete(self) -> None: @@ -149,13 +152,12 @@ def delete(self) -> None: class MyScaleVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MyScaleVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.MYSCALE, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MYSCALE, collection_name)) return MyScaleVector( collection_name=collection_name, diff --git a/api/core/rag/datasource/vdb/oceanbase/__init__.py b/api/core/rag/datasource/vdb/oceanbase/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py new file mode 100644 index 00000000000000..8dd26a073b27c2 --- /dev/null +++ b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py @@ -0,0 +1,209 @@ +import json +import logging +import math +from typing import Any + +from pydantic import BaseModel, model_validator +from pyobvector import VECTOR, ObVecClient +from sqlalchemy import JSON, Column, String, func +from sqlalchemy.dialects.mysql import LONGTEXT + +from configs import dify_config +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document +from extensions.ext_redis import redis_client +from models.dataset import Dataset + +logger = logging.getLogger(__name__) + +DEFAULT_OCEANBASE_HNSW_BUILD_PARAM = {"M": 16, "efConstruction": 256} +DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM = {"efSearch": 64} +OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE = "HNSW" +DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE = "l2" + + +class OceanBaseVectorConfig(BaseModel): + host: str + port: int + user: str + password: str + database: str + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + if not values["host"]: + raise ValueError("config OCEANBASE_VECTOR_HOST is required") + if not values["port"]: + raise ValueError("config OCEANBASE_VECTOR_PORT is required") + if not values["user"]: + raise ValueError("config OCEANBASE_VECTOR_USER is required") + if not values["database"]: + raise ValueError("config OCEANBASE_VECTOR_DATABASE is required") + return values + + +class OceanBaseVector(BaseVector): + def __init__(self, collection_name: str, config: OceanBaseVectorConfig): + super().__init__(collection_name) + self._config = config + self._hnsw_ef_search = -1 + self._client = ObVecClient( + uri=f"{self._config.host}:{self._config.port}", + user=self._config.user, + password=self._config.password, + db_name=self._config.database, + ) + + def get_type(self) -> str: + return VectorType.OCEANBASE + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + self._vec_dim = len(embeddings[0]) + self._create_collection() + self.add_texts(texts, embeddings) + + def _create_collection(self) -> None: + lock_name = "vector_indexing_lock_" + self._collection_name + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = "vector_indexing_" + self._collection_name + if redis_client.get(collection_exist_cache_key): + return + + if self._client.check_table_exists(self._collection_name): + return + + self.delete() + + cols = [ + Column("id", String(36), primary_key=True, autoincrement=False), + Column("vector", VECTOR(self._vec_dim)), + Column("text", LONGTEXT), + Column("metadata", JSON), + ] + vidx_params = self._client.prepare_index_params() + vidx_params.add_index( + field_name="vector", + index_type=OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE, + index_name="vector_index", + metric_type=DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE, + params=DEFAULT_OCEANBASE_HNSW_BUILD_PARAM, + ) + + self._client.create_table_with_index_params( + table_name=self._collection_name, + columns=cols, + vidxs=vidx_params, + ) + vals = [] + params = self._client.perform_raw_text_sql("SHOW PARAMETERS LIKE '%ob_vector_memory_limit_percentage%'") + for row in params: + val = int(row[6]) + vals.append(val) + if len(vals) == 0: + print("ob_vector_memory_limit_percentage not found in parameters.") + exit(1) + if any(val == 0 for val in vals): + try: + self._client.perform_raw_text_sql("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30") + except Exception as e: + raise Exception( + "Failed to set ob_vector_memory_limit_percentage. " + + "Maybe the database user has insufficient privilege.", + e, + ) + redis_client.set(collection_exist_cache_key, 1, ex=3600) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + ids = self._get_uuids(documents) + for id, doc, emb in zip(ids, documents, embeddings): + self._client.insert( + table_name=self._collection_name, + data={ + "id": id, + "vector": emb, + "text": doc.page_content, + "metadata": doc.metadata, + }, + ) + + def text_exists(self, id: str) -> bool: + cur = self._client.get(table_name=self._collection_name, id=id) + return cur.rowcount != 0 + + def delete_by_ids(self, ids: list[str]) -> None: + self._client.delete(table_name=self._collection_name, ids=ids) + + def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]: + cur = self._client.get( + table_name=self._collection_name, + where_clause=f"metadata->>'$.{key}' = '{value}'", + output_column_name=["id"], + ) + return [row[0] for row in cur] + + def delete_by_metadata_field(self, key: str, value: str) -> None: + ids = self.get_ids_by_metadata_field(key, value) + self.delete_by_ids(ids) + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + return [] + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + ef_search = kwargs.get("ef_search", self._hnsw_ef_search) + if ef_search != self._hnsw_ef_search: + self._client.set_ob_hnsw_ef_search(ef_search) + self._hnsw_ef_search = ef_search + topk = kwargs.get("top_k", 10) + cur = self._client.ann_search( + table_name=self._collection_name, + vec_column_name="vector", + vec_data=query_vector, + topk=topk, + distance_func=func.l2_distance, + output_column_names=["text", "metadata"], + with_dist=True, + ) + docs = [] + for text, metadata, distance in cur: + metadata = json.loads(metadata) + metadata["score"] = 1 - distance / math.sqrt(2) + docs.append( + Document( + page_content=text, + metadata=metadata, + ) + ) + return docs + + def delete(self) -> None: + self._client.drop_table_if_exist(self._collection_name) + + +class OceanBaseVectorFactory(AbstractVectorFactory): + def init_vector( + self, + dataset: Dataset, + attributes: list, + embeddings: Embeddings, + ) -> BaseVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix.lower() + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OCEANBASE, collection_name)) + return OceanBaseVector( + collection_name, + OceanBaseVectorConfig( + host=dify_config.OCEANBASE_VECTOR_HOST, + port=dify_config.OCEANBASE_VECTOR_PORT, + user=dify_config.OCEANBASE_VECTOR_USER, + password=(dify_config.OCEANBASE_VECTOR_PASSWORD or ""), + database=dify_config.OCEANBASE_VECTOR_DATABASE, + ), + ) diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index c95d202173b84d..49eb00f14009ef 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -9,11 +9,11 @@ from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -28,11 +28,12 @@ class OpenSearchConfig(BaseModel): password: Optional[str] = None secure: bool = False - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: - if not values.get('host'): + if not values.get("host"): raise ValueError("config OPENSEARCH_HOST is required") - if not values.get('port'): + if not values.get("port"): raise ValueError("config OPENSEARCH_PORT is required") return values @@ -44,19 +45,18 @@ def create_ssl_context(self) -> ssl.SSLContext: def to_opensearch_params(self) -> dict[str, Any]: params = { - 'hosts': [{'host': self.host, 'port': self.port}], - 'use_ssl': self.secure, - 'verify_certs': self.secure, + "hosts": [{"host": self.host, "port": self.port}], + "use_ssl": self.secure, + "verify_certs": self.secure, } if self.user and self.password: - params['http_auth'] = (self.user, self.password) + params["http_auth"] = (self.user, self.password) if self.secure: - params['ssl_context'] = self.create_ssl_context() + params["ssl_context"] = self.create_ssl_context() return params class OpenSearchVector(BaseVector): - def __init__(self, collection_name: str, config: OpenSearchConfig): super().__init__(collection_name) self._client_config = config @@ -81,7 +81,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** Field.CONTENT_KEY.value: documents[i].page_content, Field.VECTOR.value: embeddings[i], # Make sure you pass an array here Field.METADATA_KEY.value: documents[i].metadata, - } + }, } actions.append(action) @@ -90,8 +90,8 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** def get_ids_by_metadata_field(self, key: str, value: str): query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}} response = self._client.search(index=self._collection_name.lower(), body=query) - if response['hits']['hits']: - return [hit['_id'] for hit in response['hits']['hits']] + if response["hits"]["hits"]: + return [hit["_id"] for hit in response["hits"]["hits"]] else: return None @@ -110,7 +110,7 @@ def delete_by_ids(self, ids: list[str]) -> None: actual_ids = [] for doc_id in ids: - es_ids = self.get_ids_by_metadata_field('doc_id', doc_id) + es_ids = self.get_ids_by_metadata_field("doc_id", doc_id) if es_ids: actual_ids.extend(es_ids) else: @@ -122,14 +122,14 @@ def delete_by_ids(self, ids: list[str]) -> None: helpers.bulk(self._client, actions) except BulkIndexError as e: for error in e.errors: - delete_error = error.get('delete', {}) - status = delete_error.get('status') - doc_id = delete_error.get('_id') + delete_error = error.get("delete", {}) + status = delete_error.get("status") + doc_id = delete_error.get("_id") if status == 404: logger.warning(f"Document not found for deletion: {doc_id}") else: - logger.error(f"Error deleting document: {error}") + logger.exception(f"Error deleting document: {error}") def delete(self) -> None: self._client.indices.delete(index=self._collection_name.lower()) @@ -151,35 +151,28 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc raise ValueError("All elements in query_vector should be floats") query = { - "size": kwargs.get('top_k', 4), - "query": { - "knn": { - Field.VECTOR.value: { - Field.VECTOR.value: query_vector, - "k": kwargs.get('top_k', 4) - } - } - } + "size": kwargs.get("top_k", 4), + "query": {"knn": {Field.VECTOR.value: {Field.VECTOR.value: query_vector, "k": kwargs.get("top_k", 4)}}}, } try: response = self._client.search(index=self._collection_name.lower(), body=query) except Exception as e: - logger.error(f"Error executing search: {e}") + logger.exception(f"Error executing search: {e}") raise docs = [] - for hit in response['hits']['hits']: - metadata = hit['_source'].get(Field.METADATA_KEY.value, {}) + for hit in response["hits"]["hits"]: + metadata = hit["_source"].get(Field.METADATA_KEY.value, {}) # Make sure metadata is a dictionary if metadata is None: metadata = {} - metadata['score'] = hit['_score'] - score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 - if hit['_score'] > score_threshold: - doc = Document(page_content=hit['_source'].get(Field.CONTENT_KEY.value), metadata=metadata) + metadata["score"] = hit["_score"] + score_threshold = float(kwargs.get("score_threshold") or 0.0) + if hit["_score"] > score_threshold: + doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata) docs.append(doc) return docs @@ -190,32 +183,28 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: response = self._client.search(index=self._collection_name.lower(), body=full_text_query) docs = [] - for hit in response['hits']['hits']: - metadata = hit['_source'].get(Field.METADATA_KEY.value) - vector = hit['_source'].get(Field.VECTOR.value) - page_content = hit['_source'].get(Field.CONTENT_KEY.value) + for hit in response["hits"]["hits"]: + metadata = hit["_source"].get(Field.METADATA_KEY.value) + vector = hit["_source"].get(Field.VECTOR.value) + page_content = hit["_source"].get(Field.CONTENT_KEY.value) doc = Document(page_content=page_content, vector=vector, metadata=metadata) docs.append(doc) return docs def create_collection( - self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None + self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None ): - lock_name = f'vector_indexing_lock_{self._collection_name.lower()}' + lock_name = f"vector_indexing_lock_{self._collection_name.lower()}" with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = f'vector_indexing_{self._collection_name.lower()}' + collection_exist_cache_key = f"vector_indexing_{self._collection_name.lower()}" if redis_client.get(collection_exist_cache_key): logger.info(f"Collection {self._collection_name.lower()} already exists.") return if not self._client.indices.exists(index=self._collection_name.lower()): index_body = { - "settings": { - "index": { - "knn": True - } - }, + "settings": {"index": {"knn": True}}, "mappings": { "properties": { Field.CONTENT_KEY.value: {"type": "text"}, @@ -226,20 +215,17 @@ def create_collection( "name": "hnsw", "space_type": "l2", "engine": "faiss", - "parameters": { - "ef_construction": 64, - "m": 8 - } - } + "parameters": {"ef_construction": 64, "m": 8}, + }, }, Field.METADATA_KEY.value: { "type": "object", "properties": { "doc_id": {"type": "keyword"} # Map doc_id to keyword type - } - } + }, + }, } - } + }, } self._client.indices.create(index=self._collection_name.lower(), body=index_body) @@ -248,17 +234,14 @@ def create_collection( class OpenSearchVectorFactory(AbstractVectorFactory): - def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OpenSearchVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name)) - + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name)) open_search_config = OpenSearchConfig( host=dify_config.OPENSEARCH_HOST, @@ -268,7 +251,4 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings secure=dify_config.OPENSEARCH_SECURE, ) - return OpenSearchVector( - collection_name=collection_name, - config=open_search_config - ) + return OpenSearchVector(collection_name=collection_name, config=open_search_config) diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index aa2c6171c33367..4ced5d61e5748c 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -13,10 +13,10 @@ from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -31,7 +31,8 @@ class OracleVectorConfig(BaseModel): password: str database: str - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values["host"]: raise ValueError("config ORACLE_HOST is required") @@ -103,9 +104,16 @@ def output_type_handler(self, cursor, metadata): arraysize=cursor.arraysize, outconverter=self.numpy_converter_out, ) - def _create_connection_pool(self, config: OracleVectorConfig): - return oracledb.create_pool(user=config.user, password=config.password, dsn="{}:{}/{}".format(config.host, config.port, config.database), min=1, max=50, increment=1) + def _create_connection_pool(self, config: OracleVectorConfig): + return oracledb.create_pool( + user=config.user, + password=config.password, + dsn="{}:{}/{}".format(config.host, config.port, config.database), + min=1, + max=50, + increment=1, + ) @contextmanager def _get_cursor(self): @@ -136,13 +144,15 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** doc_id, doc.page_content, json.dumps(doc.metadata), - #array.array("f", embeddings[i]), + # array.array("f", embeddings[i]), numpy.array(embeddings[i]), ) ) - #print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)") + # print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)") with self._get_cursor() as cur: - cur.executemany(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values) + cur.executemany( + f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values + ) return pks def text_exists(self, id: str) -> bool: @@ -157,13 +167,6 @@ def get_by_ids(self, ids: list[str]) -> list[Document]: for record in cur: docs.append(Document(page_content=record[1], metadata=record[0])) return docs - #def get_ids_by_metadata_field(self, key: str, value: str): - # with self._get_cursor() as cur: - # cur.execute(f"SELECT id FROM {self.table_name} d WHERE d.meta.{key}='{value}'" ) - # idss = [] - # for record in cur: - # idss.append(record[0]) - # return idss def delete_by_ids(self, ids: list[str]) -> None: with self._get_cursor() as cur: @@ -181,13 +184,15 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc :param top_k: The number of nearest neighbors to return, default is 5. :return: List of Documents that are nearest to the query vector. """ - top_k = kwargs.get("top_k", 5) + top_k = kwargs.get("top_k", 4) with self._get_cursor() as cur: cur.execute( - f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name} ORDER BY distance fetch first {top_k} rows only" ,[numpy.array(query_vector)] + f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}" + f" ORDER BY distance fetch first {top_k} rows only", + [numpy.array(query_vector)], ) docs = [] - score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + score_threshold = float(kwargs.get("score_threshold") or 0.0) for record in cur: metadata, text, distance = record score = 1 - distance @@ -199,10 +204,10 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 5) # just not implement fetch by score_threshold now, may be later - score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + score_threshold = float(kwargs.get("score_threshold") or 0.0) if len(query) > 0: # Check which language the query is in - zh_pattern = re.compile('[\u4e00-\u9fa5]+') + zh_pattern = re.compile("[\u4e00-\u9fa5]+") match = zh_pattern.search(query) entities = [] # match: query condition maybe is a chinese sentence, so using Jieba split,else using nltk split @@ -210,7 +215,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: words = pseg.cut(query) current_entity = "" for word, pos in words: - if pos == 'nr' or pos == 'Ng' or pos == 'eng' or pos == 'nz' or pos == 'n' or pos == 'ORG' or pos == 'v': # nr: 人名, ns: 地名, nt: 机构名 + if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名 current_entity += word else: if current_entity: @@ -220,22 +225,23 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: entities.append(current_entity) else: try: - nltk.data.find('tokenizers/punkt') - nltk.data.find('corpora/stopwords') + nltk.data.find("tokenizers/punkt") + nltk.data.find("corpora/stopwords") except LookupError: - nltk.download('punkt') - nltk.download('stopwords') + nltk.download("punkt") + nltk.download("stopwords") print("run download") - e_str = re.sub(r'[^\w ]', '', query) + e_str = re.sub(r"[^\w ]", "", query) all_tokens = nltk.word_tokenize(e_str) - stop_words = stopwords.words('english') + stop_words = stopwords.words("english") for token in all_tokens: if token not in stop_words: entities.append(token) with self._get_cursor() as cur: cur.execute( - f"select meta, text, embedding FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only", - [" ACCUM ".join(entities)] + f"select meta, text, embedding FROM {self.table_name}" + f" WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only", + [" ACCUM ".join(entities)], ) docs = [] for record in cur: @@ -273,8 +279,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.ORACLE, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ORACLE, collection_name)) return OracleVector( collection_name=collection_name, diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index a48224070fdc21..7cbbdcc81f6039 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -12,11 +12,11 @@ from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -31,27 +31,29 @@ class PgvectoRSConfig(BaseModel): password: str database: str - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: - if not values['host']: + if not values["host"]: raise ValueError("config PGVECTO_RS_HOST is required") - if not values['port']: + if not values["port"]: raise ValueError("config PGVECTO_RS_PORT is required") - if not values['user']: + if not values["user"]: raise ValueError("config PGVECTO_RS_USER is required") - if not values['password']: + if not values["password"]: raise ValueError("config PGVECTO_RS_PASSWORD is required") - if not values['database']: + if not values["database"]: raise ValueError("config PGVECTO_RS_DATABASE is required") return values class PGVectoRS(BaseVector): - def __init__(self, collection_name: str, config: PgvectoRSConfig, dim: int): super().__init__(collection_name) self._client_config = config - self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" + self._url = ( + f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" + ) self._client = create_engine(self._url) with Session(self._client) as session: session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors")) @@ -80,9 +82,9 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) self.add_texts(texts, embeddings) def create_collection(self, dimension: int): - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return index_name = f"{self._collection_name}_embedding_index" @@ -133,9 +135,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** def get_ids_by_metadata_field(self, key: str, value: str): result = None with Session(self._client) as session: - select_statement = sql_text( - f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; " - ) + select_statement = sql_text(f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; ") result = session.execute(select_statement).fetchall() if result: return [item[0] for item in result] @@ -143,12 +143,11 @@ def get_ids_by_metadata_field(self, key: str, value: str): return None def delete_by_metadata_field(self, key: str, value: str): - ids = self.get_ids_by_metadata_field(key, value) if ids: with Session(self._client) as session: select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)") - session.execute(select_statement, {'ids': ids}) + session.execute(select_statement, {"ids": ids}) session.commit() def delete_by_ids(self, ids: list[str]) -> None: @@ -156,13 +155,13 @@ def delete_by_ids(self, ids: list[str]) -> None: select_statement = sql_text( f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = ANY (:doc_ids); " ) - result = session.execute(select_statement, {'doc_ids': ids}).fetchall() + result = session.execute(select_statement, {"doc_ids": ids}).fetchall() if result: ids = [item[0] for item in result] if ids: with Session(self._client) as session: select_statement = sql_text(f"DELETE FROM {self._collection_name} WHERE id = ANY(:ids)") - session.execute(select_statement, {'ids': ids}) + session.execute(select_statement, {"ids": ids}) session.commit() def delete(self) -> None: @@ -187,7 +186,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc query_vector, ).label("distance"), ) - .limit(kwargs.get('top_k', 2)) + .limit(kwargs.get("top_k", 4)) .order_by("distance") ) res = session.execute(stmt) @@ -198,40 +197,26 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc for record, dis in results: metadata = record.meta score = 1 - dis - metadata['score'] = score - score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 + metadata["score"] = score + score_threshold = float(kwargs.get("score_threshold") or 0.0) if score > score_threshold: - doc = Document(page_content=record.text, - metadata=metadata) + doc = Document(page_content=record.text, metadata=metadata) docs.append(doc) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - # with Session(self._client) as session: - # select_statement = sql_text( - # f"SELECT text, meta FROM {self._collection_name} WHERE to_tsvector(text) @@ '{query}'::tsquery" - # ) - # results = session.execute(select_statement).fetchall() - # if results: - # docs = [] - # for result in results: - # doc = Document(page_content=result[0], - # metadata=result[1]) - # docs.append(doc) - # return docs return [] class PGVectoRSFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVectoRS: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.PGVECTO_RS, collection_name)) dim = len(embeddings.embed_query("pgvecto_rs")) return PGVectoRS( @@ -243,5 +228,5 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings password=dify_config.PGVECTO_RS_PASSWORD, database=dify_config.PGVECTO_RS_DATABASE, ), - dim=dim + dim=dim, ) diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 33ca5bc028b052..40a9cdd136b404 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -8,10 +8,10 @@ from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -23,8 +23,11 @@ class PGVectorConfig(BaseModel): user: str password: str database: str + min_connection: int + max_connection: int - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: if not values["host"]: raise ValueError("config PGVECTOR_HOST is required") @@ -36,6 +39,12 @@ def validate_config(cls, values: dict) -> dict: raise ValueError("config PGVECTOR_PASSWORD is required") if not values["database"]: raise ValueError("config PGVECTOR_DATABASE is required") + if not values["min_connection"]: + raise ValueError("config PGVECTOR_MIN_CONNECTION is required") + if not values["max_connection"]: + raise ValueError("config PGVECTOR_MAX_CONNECTION is required") + if values["min_connection"] > values["max_connection"]: + raise ValueError("config PGVECTOR_MIN_CONNECTION should less than PGVECTOR_MAX_CONNECTION") return values @@ -60,8 +69,8 @@ def get_type(self) -> str: def _create_connection_pool(self, config: PGVectorConfig): return psycopg2.pool.SimpleConnectionPool( - 1, - 5, + config.min_connection, + config.max_connection, host=config.host, port=config.port, user=config.user, @@ -134,15 +143,16 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc :param top_k: The number of nearest neighbors to return, default is 5. :return: List of Documents that are nearest to the query vector. """ - top_k = kwargs.get("top_k", 5) + top_k = kwargs.get("top_k", 4) with self._get_cursor() as cur: cur.execute( - f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name} ORDER BY distance LIMIT {top_k}", + f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}" + f" ORDER BY distance LIMIT {top_k}", (json.dumps(query_vector),), ) docs = [] - score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 + score_threshold = float(kwargs.get("score_threshold") or 0.0) for record in cur: metadata, text, distance = record score = 1 - distance @@ -152,8 +162,27 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - # do not support bm25 search - return [] + top_k = kwargs.get("top_k", 5) + + with self._get_cursor() as cur: + cur.execute( + f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score + FROM {self.table_name} + WHERE to_tsvector(text) @@ plainto_tsquery(%s) + ORDER BY score DESC + LIMIT {top_k}""", + # f"'{query}'" is required in order to account for whitespace in query + (f"'{query}'", f"'{query}'"), + ) + + docs = [] + + for record in cur: + metadata, text, score = record + metadata["score"] = score + docs.append(Document(page_content=text, metadata=metadata)) + + return docs def delete(self) -> None: with self._get_cursor() as cur: @@ -182,8 +211,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name)) return PGVector( collection_name=collection_name, @@ -193,5 +221,7 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings user=dify_config.PGVECTOR_USER, password=dify_config.PGVECTOR_PASSWORD, database=dify_config.PGVECTOR_DATABASE, + min_connection=dify_config.PGVECTOR_MIN_CONNECTION, + max_connection=dify_config.PGVECTOR_MAX_CONNECTION, ), ) diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 297bff928e8ae8..3811458e02957c 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -20,11 +20,11 @@ from qdrant_client.local.qdrant_local import QdrantLocal from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -48,28 +48,25 @@ class QdrantConfig(BaseModel): prefer_grpc: bool = False def to_qdrant_params(self): - if self.endpoint and self.endpoint.startswith('path:'): - path = self.endpoint.replace('path:', '') + if self.endpoint and self.endpoint.startswith("path:"): + path = self.endpoint.replace("path:", "") if not os.path.isabs(path): path = os.path.join(self.root_path, path) - return { - 'path': path - } + return {"path": path} else: return { - 'url': self.endpoint, - 'api_key': self.api_key, - 'timeout': self.timeout, - 'verify': self.endpoint.startswith('https'), - 'grpc_port': self.grpc_port, - 'prefer_grpc': self.prefer_grpc + "url": self.endpoint, + "api_key": self.api_key, + "timeout": self.timeout, + "verify": self.endpoint.startswith("https"), + "grpc_port": self.grpc_port, + "prefer_grpc": self.prefer_grpc, } class QdrantVector(BaseVector): - - def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = 'Cosine'): + def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = "Cosine"): super().__init__(collection_name) self._client_config = config self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params()) @@ -80,10 +77,7 @@ def get_type(self) -> str: return VectorType.QDRANT def to_index_struct(self) -> dict: - return { - "type": self.get_type(), - "vector_store": {"class_prefix": self._collection_name} - } + return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): if texts: @@ -97,9 +91,9 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) self.add_texts(texts, embeddings, **kwargs) def create_collection(self, collection_name: str, vector_size: int): - lock_name = 'vector_indexing_lock_{}'.format(collection_name) + lock_name = "vector_indexing_lock_{}".format(collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return collection_name = collection_name or uuid.uuid4().hex @@ -110,12 +104,19 @@ def create_collection(self, collection_name: str, vector_size: int): all_collection_name.append(collection.name) if collection_name not in all_collection_name: from qdrant_client.http import models as rest + vectors_config = rest.VectorParams( size=vector_size, distance=rest.Distance[self._distance_func], ) - hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, - max_indexing_threads=0, on_disk=False) + hnsw_config = HnswConfigDiff( + m=0, + payload_m=16, + ef_construct=100, + full_scan_threshold=10000, + max_indexing_threads=0, + on_disk=False, + ) self._client.recreate_collection( collection_name=collection_name, vectors_config=vectors_config, @@ -124,21 +125,24 @@ def create_collection(self, collection_name: str, vector_size: int): ) # create group_id payload index - self._client.create_payload_index(collection_name, Field.GROUP_KEY.value, - field_schema=PayloadSchemaType.KEYWORD) + self._client.create_payload_index( + collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD + ) # create doc_id payload index - self._client.create_payload_index(collection_name, Field.DOC_ID.value, - field_schema=PayloadSchemaType.KEYWORD) + self._client.create_payload_index( + collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD + ) # create full text index text_index_params = TextIndexParams( type=TextIndexType.TEXT, tokenizer=TokenizerType.MULTILINGUAL, min_token_len=2, max_token_len=20, - lowercase=True + lowercase=True, + ) + self._client.create_payload_index( + collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params ) - self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value, - field_schema=text_index_params) redis_client.set(collection_exist_cache_key, 1, ex=3600) def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): @@ -147,26 +151,23 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** metadatas = [d.metadata for d in documents] added_ids = [] - for batch_ids, points in self._generate_rest_batches( - texts, embeddings, metadatas, uuids, 64, self._group_id - ): - self._client.upsert( - collection_name=self._collection_name, points=points - ) + for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id): + self._client.upsert(collection_name=self._collection_name, points=points) added_ids.extend(batch_ids) return added_ids def _generate_rest_batches( - self, - texts: Iterable[str], - embeddings: list[list[float]], - metadatas: Optional[list[dict]] = None, - ids: Optional[Sequence[str]] = None, - batch_size: int = 64, - group_id: Optional[str] = None, + self, + texts: Iterable[str], + embeddings: list[list[float]], + metadatas: Optional[list[dict]] = None, + ids: Optional[Sequence[str]] = None, + batch_size: int = 64, + group_id: Optional[str] = None, ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]: from qdrant_client.http import models as rest + texts_iterator = iter(texts) embeddings_iterator = iter(embeddings) metadatas_iterator = iter(metadatas or []) @@ -203,13 +204,13 @@ def _generate_rest_batches( @classmethod def _build_payloads( - cls, - texts: Iterable[str], - metadatas: Optional[list[dict]], - content_payload_key: str, - metadata_payload_key: str, - group_id: str, - group_payload_key: str + cls, + texts: Iterable[str], + metadatas: Optional[list[dict]], + content_payload_key: str, + metadata_payload_key: str, + group_id: str, + group_payload_key: str, ) -> list[dict]: payloads = [] for i, text in enumerate(texts): @@ -219,18 +220,11 @@ def _build_payloads( "calling .from_texts or .add_texts on Qdrant instance." ) metadata = metadatas[i] if metadatas is not None else None - payloads.append( - { - content_payload_key: text, - metadata_payload_key: metadata, - group_payload_key: group_id - } - ) + payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id}) return payloads def delete_by_metadata_field(self, key: str, value: str): - from qdrant_client.http import models from qdrant_client.http.exceptions import UnexpectedResponse @@ -248,9 +242,7 @@ def delete_by_metadata_field(self, key: str, value: str): self._client.delete( collection_name=self._collection_name, - points_selector=FilterSelector( - filter=filter - ), + points_selector=FilterSelector(filter=filter), ) except UnexpectedResponse as e: # Collection does not exist, so return @@ -275,9 +267,7 @@ def delete(self): ) self._client.delete( collection_name=self._collection_name, - points_selector=FilterSelector( - filter=filter - ), + points_selector=FilterSelector(filter=filter), ) except UnexpectedResponse as e: # Collection does not exist, so return @@ -288,7 +278,6 @@ def delete(self): raise e def delete_by_ids(self, ids: list[str]) -> None: - from qdrant_client.http import models from qdrant_client.http.exceptions import UnexpectedResponse @@ -304,9 +293,7 @@ def delete_by_ids(self, ids: list[str]) -> None: ) self._client.delete( collection_name=self._collection_name, - points_selector=FilterSelector( - filter=filter - ), + points_selector=FilterSelector(filter=filter), ) except UnexpectedResponse as e: # Collection does not exist, so return @@ -324,15 +311,13 @@ def text_exists(self, id: str) -> bool: all_collection_name.append(collection.name) if self._collection_name not in all_collection_name: return False - response = self._client.retrieve( - collection_name=self._collection_name, - ids=[id] - ) + response = self._client.retrieve(collection_name=self._collection_name, ids=[id]) return len(response) > 0 def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from qdrant_client.http import models + filter = models.Filter( must=[ models.FieldCondition( @@ -348,22 +333,22 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc limit=kwargs.get("top_k", 4), with_payload=True, with_vectors=True, - score_threshold=kwargs.get("score_threshold", .0) + score_threshold=float(kwargs.get("score_threshold") or 0.0), ) docs = [] for result in results: metadata = result.payload.get(Field.METADATA_KEY.value) or {} # duplicate check score threshold - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + score_threshold = float(kwargs.get("score_threshold") or 0.0) if result.score > score_threshold: - metadata['score'] = result.score + metadata["score"] = result.score doc = Document( page_content=result.payload.get(Field.CONTENT_KEY.value), metadata=metadata, ) docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -372,6 +357,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: List of documents most similar to the query text and distance for each. """ from qdrant_client.http import models + scroll_filter = models.Filter( must=[ models.FieldCondition( @@ -381,24 +367,21 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: models.FieldCondition( key="page_content", match=models.MatchText(text=query), - ) + ), ] ) response = self._client.scroll( collection_name=self._collection_name, scroll_filter=scroll_filter, - limit=kwargs.get('top_k', 2), + limit=kwargs.get("top_k", 2), with_payload=True, - with_vectors=True - + with_vectors=True, ) results = response[0] documents = [] for result in results: if result: - document = self._document_from_scored_point( - result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value - ) + document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value) documents.append(document) return documents @@ -410,10 +393,10 @@ def _reload_if_needed(self): @classmethod def _document_from_scored_point( - cls, - scored_point: Any, - content_payload_key: str, - metadata_payload_key: str, + cls, + scored_point: Any, + content_payload_key: str, + metadata_payload_key: str, ) -> Document: return Document( page_content=scored_point.payload.get(content_payload_key), @@ -425,35 +408,35 @@ def _document_from_scored_point( class QdrantVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector: if dataset.collection_binding_id: - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ - one_or_none() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter(DatasetCollectionBinding.id == dataset.collection_binding_id) + .one_or_none() + ) if dataset_collection_binding: collection_name = dataset_collection_binding.collection_name else: - raise ValueError('Dataset Collection Bindings is not exist!') + raise ValueError("Dataset Collection Bindings is not exist!") else: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) if not dataset.index_struct_dict: - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.QDRANT, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.QDRANT, collection_name)) - config = current_app.config return QdrantVector( collection_name=collection_name, group_id=dataset.id, config=QdrantConfig( endpoint=dify_config.QDRANT_URL, api_key=dify_config.QDRANT_API_KEY, - root_path=config.root_path, + root_path=current_app.config.root_path, timeout=dify_config.QDRANT_CLIENT_TIMEOUT, grpc_port=dify_config.QDRANT_GRPC_PORT, - prefer_grpc=dify_config.QDRANT_GRPC_ENABLED - ) + prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, + ), ) diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 63ad0682d729fe..f373dcfeabef92 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -8,9 +8,9 @@ from sqlalchemy.dialects.postgresql import JSON, TEXT from sqlalchemy.orm import Session -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from models.dataset import Dataset try: @@ -33,28 +33,30 @@ class RelytConfig(BaseModel): password: str database: str - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: - if not values['host']: + if not values["host"]: raise ValueError("config RELYT_HOST is required") - if not values['port']: + if not values["port"]: raise ValueError("config RELYT_PORT is required") - if not values['user']: + if not values["user"]: raise ValueError("config RELYT_USER is required") - if not values['password']: + if not values["password"]: raise ValueError("config RELYT_PASSWORD is required") - if not values['database']: + if not values["database"]: raise ValueError("config RELYT_DATABASE is required") return values class RelytVector(BaseVector): - def __init__(self, collection_name: str, config: RelytConfig, group_id: str): super().__init__(collection_name) self.embedding_dimension = 1536 self._client_config = config - self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" + self._url = ( + f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" + ) self.client = create_engine(self._url) self._fields = [] self._group_id = group_id @@ -70,9 +72,9 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) self.add_texts(texts, embeddings) def create_collection(self, dimension: int): - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return index_name = f"{self._collection_name}_embedding_index" @@ -110,7 +112,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** ids = [str(uuid.uuid1()) for _ in documents] metadatas = [d.metadata for d in documents] for metadata in metadatas: - metadata['group_id'] = self._group_id + metadata["group_id"] = self._group_id texts = [d.page_content for d in documents] # Define the table schema @@ -125,29 +127,26 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** ) chunks_table_data = [] - with self.client.connect() as conn: - with conn.begin(): - for document, metadata, chunk_id, embedding in zip( - texts, metadatas, ids, embeddings - ): - chunks_table_data.append( - { - "id": chunk_id, - "embedding": embedding, - "document": document, - "metadata": metadata, - } - ) - - # Execute the batch insert when the batch size is reached - if len(chunks_table_data) == 500: - conn.execute(insert(chunks_table).values(chunks_table_data)) - # Clear the chunks_table_data list for the next batch - chunks_table_data.clear() - - # Insert any remaining records that didn't make up a full batch - if chunks_table_data: + with self.client.connect() as conn, conn.begin(): + for document, metadata, chunk_id, embedding in zip(texts, metadatas, ids, embeddings): + chunks_table_data.append( + { + "id": chunk_id, + "embedding": embedding, + "document": document, + "metadata": metadata, + } + ) + + # Execute the batch insert when the batch size is reached + if len(chunks_table_data) == 500: conn.execute(insert(chunks_table).values(chunks_table_data)) + # Clear the chunks_table_data list for the next batch + chunks_table_data.clear() + + # Insert any remaining records that didn't make up a full batch + if chunks_table_data: + conn.execute(insert(chunks_table).values(chunks_table_data)) return ids @@ -163,7 +162,7 @@ def get_ids_by_metadata_field(self, key: str, value: str): else: return None - def delete_by_uuids(self, ids: list[str] = None): + def delete_by_uuids(self, ids: Optional[list[str]] = None): """Delete by vector IDs. Args: @@ -186,25 +185,22 @@ def delete_by_uuids(self, ids: list[str] = None): ) try: - with self.client.connect() as conn: - with conn.begin(): - delete_condition = chunks_table.c.id.in_(ids) - conn.execute(chunks_table.delete().where(delete_condition)) - return True + with self.client.connect() as conn, conn.begin(): + delete_condition = chunks_table.c.id.in_(ids) + conn.execute(chunks_table.delete().where(delete_condition)) + return True except Exception as e: print("Delete operation failed:", str(e)) return False def delete_by_metadata_field(self, key: str, value: str): - ids = self.get_ids_by_metadata_field(key, value) if ids: self.delete_by_uuids(ids) def delete_by_ids(self, ids: list[str]) -> None: - with Session(self.client) as session: - ids_str = ','.join(f"'{doc_id}'" for doc_id in ids) + ids_str = ",".join(f"'{doc_id}'" for doc_id in ids) select_statement = sql_text( f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """ ) @@ -228,38 +224,34 @@ def text_exists(self, id: str) -> bool: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: results = self.similarity_search_with_score_by_vector( - k=int(kwargs.get('top_k')), - embedding=query_vector, - filter=kwargs.get('filter') + k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=kwargs.get("filter") ) # Organize results. docs = [] for document, score in results: - score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 + score_threshold = float(kwargs.get("score_threshold") or 0.0) if 1 - score > score_threshold: docs.append(document) return docs def similarity_search_with_score_by_vector( - self, - embedding: list[float], - k: int = 4, - filter: Optional[dict] = None, + self, + embedding: list[float], + k: int = 4, + filter: Optional[dict] = None, ) -> list[tuple[Document, float]]: # Add the filter if provided try: from sqlalchemy.engine import Row except ImportError: - raise ImportError( - "Could not import Row from sqlalchemy.engine. " - "Please 'pip install sqlalchemy>=1.4'." - ) + raise ImportError("Could not import Row from sqlalchemy.engine. Please 'pip install sqlalchemy>=1.4'.") filter_condition = "" if filter is not None: conditions = [ - f"metadata->>{key!r} in ({', '.join(map(repr, value))})" if len(value) > 1 + f"metadata->>{key!r} in ({', '.join(map(repr, value))})" + if len(value) > 1 else f"metadata->>{key!r} = {value[0]!r}" for key, value in filter.items() ] @@ -305,13 +297,12 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: class RelytVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.RELYT, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.RELYT, collection_name)) return RelytVector( collection_name=collection_name, @@ -322,5 +313,5 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings password=dify_config.RELYT_PASSWORD, database=dify_config.RELYT_DATABASE, ), - group_id=dataset.id + group_id=dataset.id, ) diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 3325a1028ece52..f971a9c5eb1696 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -8,10 +8,10 @@ from tcvectordb.model.document import Filter from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -25,16 +25,11 @@ class TencentConfig(BaseModel): database: Optional[str] index_type: str = "HNSW" metric_type: str = "L2" - shard: int = 1, - replicas: int = 2, + shard: int = (1,) + replicas: int = (2,) def to_tencent_params(self): - return { - 'url': self.url, - 'username': self.username, - 'key': self.api_key, - 'timeout': self.timeout - } + return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout} class TencentVector(BaseVector): @@ -61,25 +56,19 @@ def _init_database(self): return self._client.create_database(database_name=self._client_config.database) def get_type(self) -> str: - return 'tencent' + return VectorType.TENCENT def to_index_struct(self) -> dict: - return { - "type": self.get_type(), - "vector_store": {"class_prefix": self._collection_name} - } + return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} def _has_collection(self) -> bool: collections = self._db.list_collections() - for collection in collections: - if collection.collection_name == self._collection_name: - return True - return False + return any(collection.collection_name == self._collection_name for collection in collections) def _create_collection(self, dimension: int) -> None: - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return @@ -101,9 +90,7 @@ def _create_collection(self, dimension: int) -> None: raise ValueError("unsupported metric_type") params = vdb_index.HNSWParams(m=16, efconstruction=200) index = vdb_index.Index( - vdb_index.FilterIndex( - self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY - ), + vdb_index.FilterIndex(self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY), vdb_index.VectorIndex( self.field_vector, dimension, @@ -111,12 +98,8 @@ def _create_collection(self, dimension: int) -> None: metric_type, params, ), - vdb_index.FilterIndex( - self.field_text, enum.FieldType.String, enum.IndexType.FILTER - ), - vdb_index.FilterIndex( - self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER - ), + vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER), + vdb_index.FilterIndex(self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER), ) self._db.create_collection( @@ -163,15 +146,14 @@ def delete_by_metadata_field(self, key: str, value: str) -> None: self._db.collection(self._collection_name).delete(filter=Filter(Filter.In(key, [value]))) def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: - - res = self._db.collection(self._collection_name).search(vectors=[query_vector], - params=document.HNSWSearchParams( - ef=kwargs.get("ef", 10)), - retrieve_vector=False, - limit=kwargs.get('top_k', 4), - timeout=self._client_config.timeout, - ) - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + res = self._db.collection(self._collection_name).search( + vectors=[query_vector], + params=document.HNSWSearchParams(ef=kwargs.get("ef", 10)), + retrieve_vector=False, + limit=kwargs.get("top_k", 4), + timeout=self._client_config.timeout, + ) + score_threshold = float(kwargs.get("score_threshold") or 0.0) return self._get_search_res(res, score_threshold) def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -200,15 +182,13 @@ def delete(self) -> None: class TencentVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TencentVector: - if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.TENCENT, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TENCENT, collection_name)) return TencentVector( collection_name=collection_name, @@ -220,5 +200,5 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings database=dify_config.TENCENT_VECTOR_DB_DATABASE, shard=dify_config.TENCENT_VECTOR_DB_SHARD, replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS, - ) + ), ) diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/__init__.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py new file mode 100644 index 00000000000000..1e62b3c58905c5 --- /dev/null +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py @@ -0,0 +1,17 @@ +from typing import Optional + +from pydantic import BaseModel + + +class ClusterEntity(BaseModel): + """ + Model Config Entity. + """ + + name: str + cluster_id: str + displayName: str + region: str + spendingLimit: Optional[int] = 1000 + version: str + createdBy: str diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py new file mode 100644 index 00000000000000..a38f84636e9135 --- /dev/null +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -0,0 +1,526 @@ +import json +import os +import uuid +from collections.abc import Generator, Iterable, Sequence +from itertools import islice +from typing import TYPE_CHECKING, Any, Optional, Union, cast + +import qdrant_client +import requests +from flask import current_app +from pydantic import BaseModel +from qdrant_client.http import models as rest +from qdrant_client.http.models import ( + FilterSelector, + HnswConfigDiff, + PayloadSchemaType, + TextIndexParams, + TextIndexType, + TokenizerType, +) +from qdrant_client.local.qdrant_local import QdrantLocal +from requests.auth import HTTPDigestAuth + +from configs import dify_config +from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset, TidbAuthBinding + +if TYPE_CHECKING: + from qdrant_client import grpc # noqa + from qdrant_client.conversions import common_types + from qdrant_client.http import models as rest + + DictFilter = dict[str, Union[str, int, bool, dict, list]] + MetadataFilter = Union[DictFilter, common_types.Filter] + + +class TidbOnQdrantConfig(BaseModel): + endpoint: str + api_key: Optional[str] = None + timeout: float = 20 + root_path: Optional[str] = None + grpc_port: int = 6334 + prefer_grpc: bool = False + + def to_qdrant_params(self): + if self.endpoint and self.endpoint.startswith("path:"): + path = self.endpoint.replace("path:", "") + if not os.path.isabs(path): + path = os.path.join(self.root_path, path) + + return {"path": path} + else: + return { + "url": self.endpoint, + "api_key": self.api_key, + "timeout": self.timeout, + "verify": False, + "grpc_port": self.grpc_port, + "prefer_grpc": self.prefer_grpc, + } + + +class TidbConfig(BaseModel): + api_url: str + public_key: str + private_key: str + + +class TidbOnQdrantVector(BaseVector): + def __init__(self, collection_name: str, group_id: str, config: TidbOnQdrantConfig, distance_func: str = "Cosine"): + super().__init__(collection_name) + self._client_config = config + self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params()) + self._distance_func = distance_func.upper() + self._group_id = group_id + + def get_type(self) -> str: + return VectorType.TIDB_ON_QDRANT + + def to_index_struct(self) -> dict: + return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + if texts: + # get embedding vector size + vector_size = len(embeddings[0]) + # get collection name + collection_name = self._collection_name + # create collection + self.create_collection(collection_name, vector_size) + + self.add_texts(texts, embeddings, **kwargs) + + def create_collection(self, collection_name: str, vector_size: int): + lock_name = "vector_indexing_lock_{}".format(collection_name) + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) + if redis_client.get(collection_exist_cache_key): + return + collection_name = collection_name or uuid.uuid4().hex + all_collection_name = [] + collections_response = self._client.get_collections() + collection_list = collections_response.collections + for collection in collection_list: + all_collection_name.append(collection.name) + if collection_name not in all_collection_name: + from qdrant_client.http import models as rest + + vectors_config = rest.VectorParams( + size=vector_size, + distance=rest.Distance[self._distance_func], + ) + hnsw_config = HnswConfigDiff( + m=0, + payload_m=16, + ef_construct=100, + full_scan_threshold=10000, + max_indexing_threads=0, + on_disk=False, + ) + self._client.recreate_collection( + collection_name=collection_name, + vectors_config=vectors_config, + hnsw_config=hnsw_config, + timeout=int(self._client_config.timeout), + ) + + # create group_id payload index + self._client.create_payload_index( + collection_name, Field.GROUP_KEY.value, field_schema=PayloadSchemaType.KEYWORD + ) + # create doc_id payload index + self._client.create_payload_index( + collection_name, Field.DOC_ID.value, field_schema=PayloadSchemaType.KEYWORD + ) + # create full text index + text_index_params = TextIndexParams( + type=TextIndexType.TEXT, + tokenizer=TokenizerType.MULTILINGUAL, + min_token_len=2, + max_token_len=20, + lowercase=True, + ) + self._client.create_payload_index( + collection_name, Field.CONTENT_KEY.value, field_schema=text_index_params + ) + redis_client.set(collection_exist_cache_key, 1, ex=3600) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + uuids = self._get_uuids(documents) + texts = [d.page_content for d in documents] + metadatas = [d.metadata for d in documents] + + added_ids = [] + for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id): + self._client.upsert(collection_name=self._collection_name, points=points) + added_ids.extend(batch_ids) + + return added_ids + + def _generate_rest_batches( + self, + texts: Iterable[str], + embeddings: list[list[float]], + metadatas: Optional[list[dict]] = None, + ids: Optional[Sequence[str]] = None, + batch_size: int = 64, + group_id: Optional[str] = None, + ) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]: + from qdrant_client.http import models as rest + + texts_iterator = iter(texts) + embeddings_iterator = iter(embeddings) + metadatas_iterator = iter(metadatas or []) + ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)]) + while batch_texts := list(islice(texts_iterator, batch_size)): + # Take the corresponding metadata and id for each text in a batch + batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None + batch_ids = list(islice(ids_iterator, batch_size)) + + # Generate the embeddings for all the texts in a batch + batch_embeddings = list(islice(embeddings_iterator, batch_size)) + + points = [ + rest.PointStruct( + id=point_id, + vector=vector, + payload=payload, + ) + for point_id, vector, payload in zip( + batch_ids, + batch_embeddings, + self._build_payloads( + batch_texts, + batch_metadatas, + Field.CONTENT_KEY.value, + Field.METADATA_KEY.value, + group_id, + Field.GROUP_KEY.value, + ), + ) + ] + + yield batch_ids, points + + @classmethod + def _build_payloads( + cls, + texts: Iterable[str], + metadatas: Optional[list[dict]], + content_payload_key: str, + metadata_payload_key: str, + group_id: str, + group_payload_key: str, + ) -> list[dict]: + payloads = [] + for i, text in enumerate(texts): + if text is None: + raise ValueError( + "At least one of the texts is None. Please remove it before " + "calling .from_texts or .add_texts on Qdrant instance." + ) + metadata = metadatas[i] if metadatas is not None else None + payloads.append({content_payload_key: text, metadata_payload_key: metadata, group_payload_key: group_id}) + + return payloads + + def delete_by_metadata_field(self, key: str, value: str): + from qdrant_client.http import models + from qdrant_client.http.exceptions import UnexpectedResponse + + try: + filter = models.Filter( + must=[ + models.FieldCondition( + key=f"metadata.{key}", + match=models.MatchValue(value=value), + ), + ], + ) + + self._reload_if_needed() + + self._client.delete( + collection_name=self._collection_name, + points_selector=FilterSelector(filter=filter), + ) + except UnexpectedResponse as e: + # Collection does not exist, so return + if e.status_code == 404: + return + # Some other error occurred, so re-raise the exception + else: + raise e + + def delete(self): + from qdrant_client.http.exceptions import UnexpectedResponse + + try: + self._client.delete_collection(collection_name=self._collection_name) + except UnexpectedResponse as e: + # Collection does not exist, so return + if e.status_code == 404: + return + # Some other error occurred, so re-raise the exception + else: + raise e + + def delete_by_ids(self, ids: list[str]) -> None: + from qdrant_client.http import models + from qdrant_client.http.exceptions import UnexpectedResponse + + for node_id in ids: + try: + filter = models.Filter( + must=[ + models.FieldCondition( + key="metadata.doc_id", + match=models.MatchValue(value=node_id), + ), + ], + ) + self._client.delete( + collection_name=self._collection_name, + points_selector=FilterSelector(filter=filter), + ) + except UnexpectedResponse as e: + # Collection does not exist, so return + if e.status_code == 404: + return + # Some other error occurred, so re-raise the exception + else: + raise e + + def text_exists(self, id: str) -> bool: + all_collection_name = [] + collections_response = self._client.get_collections() + collection_list = collections_response.collections + for collection in collection_list: + all_collection_name.append(collection.name) + if self._collection_name not in all_collection_name: + return False + response = self._client.retrieve(collection_name=self._collection_name, ids=[id]) + + return len(response) > 0 + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + from qdrant_client.http import models + + filter = models.Filter( + must=[ + models.FieldCondition( + key="group_id", + match=models.MatchValue(value=self._group_id), + ), + ], + ) + results = self._client.search( + collection_name=self._collection_name, + query_vector=query_vector, + query_filter=filter, + limit=kwargs.get("top_k", 4), + with_payload=True, + with_vectors=True, + score_threshold=kwargs.get("score_threshold", 0.0), + ) + docs = [] + for result in results: + metadata = result.payload.get(Field.METADATA_KEY.value) or {} + # duplicate check score threshold + score_threshold = kwargs.get("score_threshold") or 0.0 + if result.score > score_threshold: + metadata["score"] = result.score + doc = Document( + page_content=result.payload.get(Field.CONTENT_KEY.value), + metadata=metadata, + ) + docs.append(doc) + # Sort the documents by score in descending order + docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + """Return docs most similar by bm25. + Returns: + List of documents most similar to the query text and distance for each. + """ + from qdrant_client.http import models + + scroll_filter = models.Filter( + must=[ + models.FieldCondition( + key="page_content", + match=models.MatchText(text=query), + ) + ] + ) + response = self._client.scroll( + collection_name=self._collection_name, + scroll_filter=scroll_filter, + limit=kwargs.get("top_k", 2), + with_payload=True, + with_vectors=True, + ) + results = response[0] + documents = [] + for result in results: + if result: + document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value) + document.metadata["vector"] = result.vector + documents.append(document) + + return documents + + def _reload_if_needed(self): + if isinstance(self._client, QdrantLocal): + self._client = cast(QdrantLocal, self._client) + self._client._load() + + @classmethod + def _document_from_scored_point( + cls, + scored_point: Any, + content_payload_key: str, + metadata_payload_key: str, + ) -> Document: + return Document( + page_content=scored_point.payload.get(content_payload_key), + metadata=scored_point.payload.get(metadata_payload_key) or {}, + ) + + +class TidbOnQdrantVectorFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector: + tidb_auth_binding = ( + db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none() + ) + if not tidb_auth_binding: + idle_tidb_auth_binding = ( + db.session.query(TidbAuthBinding) + .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") + .limit(1) + .one_or_none() + ) + if idle_tidb_auth_binding: + idle_tidb_auth_binding.active = True + idle_tidb_auth_binding.tenant_id = dataset.tenant_id + db.session.commit() + TIDB_ON_QDRANT_API_KEY = f"{idle_tidb_auth_binding.account}:{idle_tidb_auth_binding.password}" + else: + with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): + tidb_auth_binding = ( + db.session.query(TidbAuthBinding) + .filter(TidbAuthBinding.tenant_id == dataset.tenant_id) + .one_or_none() + ) + if tidb_auth_binding: + TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" + + else: + new_cluster = TidbService.create_tidb_serverless_cluster( + dify_config.TIDB_PROJECT_ID, + dify_config.TIDB_API_URL, + dify_config.TIDB_IAM_API_URL, + dify_config.TIDB_PUBLIC_KEY, + dify_config.TIDB_PRIVATE_KEY, + dify_config.TIDB_REGION, + ) + new_tidb_auth_binding = TidbAuthBinding( + cluster_id=new_cluster["cluster_id"], + cluster_name=new_cluster["cluster_name"], + account=new_cluster["account"], + password=new_cluster["password"], + tenant_id=dataset.tenant_id, + active=True, + status="ACTIVE", + ) + db.session.add(new_tidb_auth_binding) + db.session.commit() + TIDB_ON_QDRANT_API_KEY = f"{new_tidb_auth_binding.account}:{new_tidb_auth_binding.password}" + + else: + TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" + + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_ON_QDRANT, collection_name)) + + config = current_app.config + + return TidbOnQdrantVector( + collection_name=collection_name, + group_id=dataset.id, + config=TidbOnQdrantConfig( + endpoint=dify_config.TIDB_ON_QDRANT_URL, + api_key=TIDB_ON_QDRANT_API_KEY, + root_path=config.root_path, + timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT, + grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT, + prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED, + ), + ) + + def create_tidb_serverless_cluster(self, tidb_config: TidbConfig, display_name: str, region: str): + """ + Creates a new TiDB Serverless cluster. + :param tidb_config: The configuration for the TiDB Cloud API. + :param display_name: The user-friendly display name of the cluster (required). + :param region: The region where the cluster will be created (required). + + :return: The response from the API. + """ + region_object = { + "name": region, + } + + labels = { + "tidb.cloud/project": "1372813089454548012", + } + cluster_data = {"displayName": display_name, "region": region_object, "labels": labels} + + response = requests.post( + f"{tidb_config.api_url}/clusters", + json=cluster_data, + auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key), + ) + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + + def change_tidb_serverless_root_password(self, tidb_config: TidbConfig, cluster_id: str, new_password: str): + """ + Changes the root password of a specific TiDB Serverless cluster. + + :param tidb_config: The configuration for the TiDB Cloud API. + :param cluster_id: The ID of the cluster for which the password is to be changed (required). + :param new_password: The new password for the root user (required). + :return: The response from the API. + """ + + body = {"password": new_password} + + response = requests.put( + f"{tidb_config.api_url}/clusters/{cluster_id}/password", + json=body, + auth=HTTPDigestAuth(tidb_config.public_key, tidb_config.private_key), + ) + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py new file mode 100644 index 00000000000000..a6f3ad7fef0c45 --- /dev/null +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -0,0 +1,251 @@ +import time +import uuid + +import requests +from requests.auth import HTTPDigestAuth + +from configs import dify_config +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import TidbAuthBinding + + +class TidbService: + @staticmethod + def create_tidb_serverless_cluster( + project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str + ): + """ + Creates a new TiDB Serverless cluster. + :param project_id: The project ID of the TiDB Cloud project (required). + :param api_url: The URL of the TiDB Cloud API (required). + :param iam_url: The URL of the TiDB Cloud IAM API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + :param display_name: The user-friendly display name of the cluster (required). + :param region: The region where the cluster will be created (required). + + :return: The response from the API. + """ + + region_object = { + "name": region, + } + + labels = { + "tidb.cloud/project": project_id, + } + + spending_limit = { + "monthly": dify_config.TIDB_SPEND_LIMIT, + } + password = str(uuid.uuid4()).replace("-", "")[:16] + display_name = str(uuid.uuid4()).replace("-", "")[:16] + cluster_data = { + "displayName": display_name, + "region": region_object, + "labels": labels, + "spendingLimit": spending_limit, + "rootPassword": password, + } + + response = requests.post(f"{api_url}/clusters", json=cluster_data, auth=HTTPDigestAuth(public_key, private_key)) + + if response.status_code == 200: + response_data = response.json() + cluster_id = response_data["clusterId"] + retry_count = 0 + max_retries = 30 + while retry_count < max_retries: + cluster_response = TidbService.get_tidb_serverless_cluster(api_url, public_key, private_key, cluster_id) + if cluster_response["state"] == "ACTIVE": + user_prefix = cluster_response["userPrefix"] + return { + "cluster_id": cluster_id, + "cluster_name": display_name, + "account": f"{user_prefix}.root", + "password": password, + } + time.sleep(30) # wait 30 seconds before retrying + retry_count += 1 + else: + response.raise_for_status() + + @staticmethod + def delete_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str): + """ + Deletes a specific TiDB Serverless cluster. + + :param api_url: The URL of the TiDB Cloud API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + :param cluster_id: The ID of the cluster to be deleted (required). + :return: The response from the API. + """ + + response = requests.delete(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key)) + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + + @staticmethod + def get_tidb_serverless_cluster(api_url: str, public_key: str, private_key: str, cluster_id: str): + """ + Deletes a specific TiDB Serverless cluster. + + :param api_url: The URL of the TiDB Cloud API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + :param cluster_id: The ID of the cluster to be deleted (required). + :return: The response from the API. + """ + + response = requests.get(f"{api_url}/clusters/{cluster_id}", auth=HTTPDigestAuth(public_key, private_key)) + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + + @staticmethod + def change_tidb_serverless_root_password( + api_url: str, public_key: str, private_key: str, cluster_id: str, account: str, new_password: str + ): + """ + Changes the root password of a specific TiDB Serverless cluster. + + :param api_url: The URL of the TiDB Cloud API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + :param cluster_id: The ID of the cluster for which the password is to be changed (required).+ + :param account: The account for which the password is to be changed (required). + :param new_password: The new password for the root user (required). + :return: The response from the API. + """ + + body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []} + + response = requests.patch( + f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}", + json=body, + auth=HTTPDigestAuth(public_key, private_key), + ) + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + + @staticmethod + def batch_update_tidb_serverless_cluster_status( + tidb_serverless_list: list[TidbAuthBinding], + project_id: str, + api_url: str, + iam_url: str, + public_key: str, + private_key: str, + ) -> list[dict]: + """ + Update the status of a new TiDB Serverless cluster. + :param project_id: The project ID of the TiDB Cloud project (required). + :param api_url: The URL of the TiDB Cloud API (required). + :param iam_url: The URL of the TiDB Cloud IAM API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + :param display_name: The user-friendly display name of the cluster (required). + :param region: The region where the cluster will be created (required). + + :return: The response from the API. + """ + clusters = [] + tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list} + cluster_ids = [item.cluster_id for item in tidb_serverless_list] + params = {"clusterIds": cluster_ids, "view": "FULL"} + response = requests.get( + f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key) + ) + + if response.status_code == 200: + response_data = response.json() + cluster_infos = [] + for item in response_data["clusters"]: + state = item["state"] + userPrefix = item["userPrefix"] + if state == "ACTIVE" and len(userPrefix) > 0: + cluster_info = tidb_serverless_list_map[item["clusterId"]] + cluster_info.status = "ACTIVE" + cluster_info.account = f"{userPrefix}.root" + db.session.add(cluster_info) + db.session.commit() + else: + response.raise_for_status() + + @staticmethod + def batch_create_tidb_serverless_cluster( + batch_size: int, project_id: str, api_url: str, iam_url: str, public_key: str, private_key: str, region: str + ) -> list[dict]: + """ + Creates a new TiDB Serverless cluster. + :param project_id: The project ID of the TiDB Cloud project (required). + :param api_url: The URL of the TiDB Cloud API (required). + :param iam_url: The URL of the TiDB Cloud IAM API (required). + :param public_key: The public key for the API (required). + :param private_key: The private key for the API (required). + :param display_name: The user-friendly display name of the cluster (required). + :param region: The region where the cluster will be created (required). + + :return: The response from the API. + """ + clusters = [] + for _ in range(batch_size): + region_object = { + "name": region, + } + + labels = { + "tidb.cloud/project": project_id, + } + + spending_limit = { + "monthly": dify_config.TIDB_SPEND_LIMIT, + } + password = str(uuid.uuid4()).replace("-", "")[:16] + display_name = str(uuid.uuid4()).replace("-", "") + cluster_data = { + "cluster": { + "displayName": display_name, + "region": region_object, + "labels": labels, + "spendingLimit": spending_limit, + "rootPassword": password, + } + } + cache_key = f"tidb_serverless_cluster_password:{display_name}" + redis_client.setex(cache_key, 3600, password) + clusters.append(cluster_data) + + request_body = {"requests": clusters} + response = requests.post( + f"{api_url}/clusters:batchCreate", json=request_body, auth=HTTPDigestAuth(public_key, private_key) + ) + + if response.status_code == 200: + response_data = response.json() + cluster_infos = [] + for item in response_data["clusters"]: + cache_key = f"tidb_serverless_cluster_password:{item['displayName']}" + password = redis_client.get(cache_key) + if not password: + continue + cluster_info = { + "cluster_id": item["clusterId"], + "cluster_name": item["displayName"], + "account": "root", + "password": password.decode("utf-8"), + } + cluster_infos.append(cluster_info) + return cluster_infos + else: + response.raise_for_status() diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index d3685c099162fd..1147e35ce8fa55 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -9,10 +9,10 @@ from sqlalchemy.orm import Session, declarative_base from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -28,47 +28,57 @@ class TiDBVectorConfig(BaseModel): database: str program_name: str - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: - if not values['host']: + if not values["host"]: raise ValueError("config TIDB_VECTOR_HOST is required") - if not values['port']: + if not values["port"]: raise ValueError("config TIDB_VECTOR_PORT is required") - if not values['user']: + if not values["user"]: raise ValueError("config TIDB_VECTOR_USER is required") - if not values['password']: + if not values["password"]: raise ValueError("config TIDB_VECTOR_PASSWORD is required") - if not values['database']: + if not values["database"]: raise ValueError("config TIDB_VECTOR_DATABASE is required") - if not values['program_name']: + if not values["program_name"]: raise ValueError("config APPLICATION_NAME is required") return values class TiDBVector(BaseVector): - def get_type(self) -> str: return VectorType.TIDB_VECTOR def _table(self, dim: int) -> Table: from tidb_vector.sqlalchemy import VectorType + return Table( self._collection_name, self._orm_base.metadata, - Column('id', String(36), primary_key=True, nullable=False), - Column("vector", VectorType(dim), nullable=False, comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})"), + Column("id", String(36), primary_key=True, nullable=False), + Column( + "vector", + VectorType(dim), + nullable=False, + comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})", + ), Column("text", TEXT, nullable=False), Column("meta", JSON, nullable=False), Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")), - Column("update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")), - extend_existing=True + Column( + "update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP") + ), + extend_existing=True, ) - def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = 'cosine'): + def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = "cosine"): super().__init__(collection_name) self._client_config = config - self._url = (f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?" - f"ssl_verify_cert=true&ssl_verify_identity=true&program_name={config.program_name}") + self._url = ( + f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?" + f"ssl_verify_cert=true&ssl_verify_identity=true&program_name={config.program_name}" + ) self._distance_func = distance_func.lower() self._engine = create_engine(self._url) self._orm_base = declarative_base() @@ -83,9 +93,9 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) def _create_collection(self, dimension: int): logger.info("_create_collection, collection_name " + self._collection_name) - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return with Session(self._engine) as session: @@ -114,31 +124,28 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** texts = [d.page_content for d in documents] chunks_table_data = [] - with self._engine.connect() as conn: - with conn.begin(): - for id, text, meta, embedding in zip( - ids, texts, metas, embeddings - ): - chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta}) - - # Execute the batch insert when the batch size is reached - if len(chunks_table_data) == 500: - conn.execute(insert(table).values(chunks_table_data)) - # Clear the chunks_table_data list for the next batch - chunks_table_data.clear() - - # Insert any remaining records that didn't make up a full batch - if chunks_table_data: + with self._engine.connect() as conn, conn.begin(): + for id, text, meta, embedding in zip(ids, texts, metas, embeddings): + chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta}) + + # Execute the batch insert when the batch size is reached + if len(chunks_table_data) == 500: conn.execute(insert(table).values(chunks_table_data)) + # Clear the chunks_table_data list for the next batch + chunks_table_data.clear() + + # Insert any remaining records that didn't make up a full batch + if chunks_table_data: + conn.execute(insert(table).values(chunks_table_data)) return ids def text_exists(self, id: str) -> bool: - result = self.get_ids_by_metadata_field('doc_id', id) + result = self.get_ids_by_metadata_field("doc_id", id) return bool(result) def delete_by_ids(self, ids: list[str]) -> None: with Session(self._engine) as session: - ids_str = ','.join(f"'{doc_id}'" for doc_id in ids) + ids_str = ",".join(f"'{doc_id}'" for doc_id in ids) select_statement = sql_text( f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.doc_id' in ({ids_str}); """ ) @@ -152,11 +159,10 @@ def _delete_by_ids(self, ids: list[str]) -> bool: raise ValueError("No ids provided to delete.") table = self._table(self._dimension) try: - with self._engine.connect() as conn: - with conn.begin(): - delete_condition = table.c.id.in_(ids) - conn.execute(table.delete().where(delete_condition)) - return True + with self._engine.connect() as conn, conn.begin(): + delete_condition = table.c.id.in_(ids) + conn.execute(table.delete().where(delete_condition)) + return True except Exception as e: print("Delete operation failed:", str(e)) return False @@ -178,22 +184,24 @@ def delete_by_metadata_field(self, key: str, value: str) -> None: self._delete_by_ids(ids) def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: - top_k = kwargs.get("top_k", 5) - score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 - filter = kwargs.get('filter') + top_k = kwargs.get("top_k", 4) + score_threshold = float(kwargs.get("score_threshold") or 0.0) + filter = kwargs.get("filter") distance = 1 - score_threshold query_vector_str = ", ".join(format(x) for x in query_vector) query_vector_str = "[" + query_vector_str + "]" - logger.debug(f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}") + logger.debug( + f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}" + ) docs = [] - if self._distance_func == 'l2': - tidb_func = 'Vec_l2_distance' - elif self._distance_func == 'cosine': - tidb_func = 'Vec_Cosine_distance' + if self._distance_func == "l2": + tidb_func = "Vec_l2_distance" + elif self._distance_func == "cosine": + tidb_func = "Vec_Cosine_distance" else: - tidb_func = 'Vec_Cosine_distance' + tidb_func = "Vec_Cosine_distance" with Session(self._engine) as session: select_statement = sql_text( @@ -208,7 +216,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc results = [(row[0], row[1], row[2]) for row in res] for meta, text, distance in results: metadata = json.loads(meta) - metadata['score'] = 1 - distance + metadata["score"] = 1 - distance docs.append(Document(page_content=text, metadata=metadata)) return docs @@ -224,15 +232,13 @@ def delete(self) -> None: class TiDBVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector: - if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix.lower() else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name)) return TiDBVector( collection_name=collection_name, diff --git a/api/core/rag/datasource/vdb/upstash/__init__.py b/api/core/rag/datasource/vdb/upstash/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/rag/datasource/vdb/upstash/upstash_vector.py b/api/core/rag/datasource/vdb/upstash/upstash_vector.py new file mode 100644 index 00000000000000..df1b550b40bd71 --- /dev/null +++ b/api/core/rag/datasource/vdb/upstash/upstash_vector.py @@ -0,0 +1,129 @@ +import json +from typing import Any +from uuid import uuid4 + +from pydantic import BaseModel, model_validator +from upstash_vector import Index, Vector + +from configs import dify_config +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document +from models.dataset import Dataset + + +class UpstashVectorConfig(BaseModel): + url: str + token: str + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + if not values["url"]: + raise ValueError("Upstash URL is required") + if not values["token"]: + raise ValueError("Upstash Token is required") + return values + + +class UpstashVector(BaseVector): + def __init__(self, collection_name: str, config: UpstashVectorConfig): + super().__init__(collection_name) + self._table_name = collection_name + self.index = Index(url=config.url, token=config.token) + + def _get_index_dimension(self) -> int: + index_info = self.index.info() + if index_info and index_info.dimension: + return index_info.dimension + else: + return 1536 + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + self.add_texts(texts, embeddings) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + vectors = [ + Vector( + id=str(uuid4()), + vector=embedding, + metadata=doc.metadata, + data=doc.page_content, + ) + for doc, embedding in zip(documents, embeddings) + ] + self.index.upsert(vectors=vectors) + + def text_exists(self, id: str) -> bool: + response = self.get_ids_by_metadata_field("doc_id", id) + return len(response) > 0 + + def delete_by_ids(self, ids: list[str]) -> None: + item_ids = [] + for doc_id in ids: + ids = self.get_ids_by_metadata_field("doc_id", doc_id) + if id: + item_ids += ids + self._delete_by_ids(ids=item_ids) + + def _delete_by_ids(self, ids: list[str]) -> None: + if ids: + self.index.delete(ids=ids) + + def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]: + query_result = self.index.query( + vector=[1.001 * i for i in range(self._get_index_dimension())], + include_metadata=True, + top_k=1000, + filter=f"{key} = '{value}'", + ) + return [result.id for result in query_result] + + def delete_by_metadata_field(self, key: str, value: str) -> None: + ids = self.get_ids_by_metadata_field(key, value) + if ids: + self._delete_by_ids(ids) + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + top_k = kwargs.get("top_k", 4) + result = self.index.query(vector=query_vector, top_k=top_k, include_metadata=True, include_data=True) + docs = [] + score_threshold = float(kwargs.get("score_threshold") or 0.0) + for record in result: + metadata = record.metadata + text = record.data + score = record.score + metadata["score"] = score + if score > score_threshold: + docs.append(Document(page_content=text, metadata=metadata)) + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + return [] + + def delete(self) -> None: + self.index.reset() + + def get_type(self) -> str: + return VectorType.UPSTASH + + +class UpstashVectorFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> UpstashVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix.lower() + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.UPSTASH, collection_name)) + + return UpstashVector( + collection_name=collection_name, + config=UpstashVectorConfig( + url=dify_config.UPSTASH_VECTOR_URL, + token=dify_config.UPSTASH_VECTOR_TOKEN, + ), + ) diff --git a/api/core/rag/datasource/vdb/vector_base.py b/api/core/rag/datasource/vdb/vector_base.py index 3f70e8b60814c1..22e191340d3a47 100644 --- a/api/core/rag/datasource/vdb/vector_base.py +++ b/api/core/rag/datasource/vdb/vector_base.py @@ -7,7 +7,6 @@ class BaseVector(ABC): - def __init__(self, collection_name: str): self._collection_name = collection_name @@ -39,26 +38,20 @@ def delete_by_metadata_field(self, key: str, value: str) -> None: raise NotImplementedError @abstractmethod - def search_by_vector( - self, - query_vector: list[float], - **kwargs: Any - ) -> list[Document]: + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: raise NotImplementedError @abstractmethod - def search_by_full_text( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: raise NotImplementedError + @abstractmethod def delete(self) -> None: raise NotImplementedError def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: - for text in texts[:]: - doc_id = text.metadata['doc_id'] + for text in texts.copy(): + doc_id = text.metadata["doc_id"] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: texts.remove(text) @@ -66,7 +59,7 @@ def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: return texts def _get_uuids(self, texts: list[Document]) -> list[str]: - return [text.metadata['doc_id'] for text in texts] + return [text.metadata["doc_id"] for text in texts] @property def collection_name(self): diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 3e9ca8e1fe7f4a..6d2e04fc020ab5 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -1,16 +1,17 @@ from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Optional from configs import dify_config -from core.embedding.cached_embedding import CacheEmbedding from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.cached_embedding import CacheEmbedding +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document +from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.dataset import Dataset +from models.dataset import Dataset, Whitelist class AbstractVectorFactory(ABC): @@ -20,17 +21,14 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings @staticmethod def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> dict: - index_struct_dict = { - "type": vector_type, - "vector_store": {"class_prefix": collection_name} - } + index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}} return index_struct_dict class Vector: - def __init__(self, dataset: Dataset, attributes: list = None): + def __init__(self, dataset: Dataset, attributes: Optional[list] = None): if attributes is None: - attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash'] + attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] self._dataset = dataset self._embeddings = self._get_embeddings() self._attributes = attributes @@ -38,8 +36,18 @@ def __init__(self, dataset: Dataset, attributes: list = None): def _init_vector(self) -> BaseVector: vector_type = dify_config.VECTOR_STORE + if self._dataset.index_struct_dict: - vector_type = self._dataset.index_struct_dict['type'] + vector_type = self._dataset.index_struct_dict["type"] + else: + if dify_config.VECTOR_STORE_WHITELIST_ENABLE: + whitelist = ( + db.session.query(Whitelist) + .filter(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db") + .one_or_none() + ) + if whitelist: + vector_type = VectorType.TIDB_ON_QDRANT if not vector_type: raise ValueError("Vector store must be specified.") @@ -52,67 +60,102 @@ def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]: match vector_type: case VectorType.CHROMA: from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory + return ChromaVectorFactory case VectorType.MILVUS: from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory + return MilvusVectorFactory case VectorType.MYSCALE: from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory + return MyScaleVectorFactory case VectorType.PGVECTOR: from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory + return PGVectorFactory case VectorType.PGVECTO_RS: from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory + return PGVectoRSFactory case VectorType.QDRANT: from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory + return QdrantVectorFactory case VectorType.RELYT: from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory + return RelytVectorFactory case VectorType.ELASTICSEARCH: from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory + return ElasticSearchVectorFactory case VectorType.TIDB_VECTOR: from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory + return TiDBVectorFactory case VectorType.WEAVIATE: from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory + return WeaviateVectorFactory case VectorType.TENCENT: from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory + return TencentVectorFactory case VectorType.ORACLE: from core.rag.datasource.vdb.oracle.oraclevector import OracleVectorFactory + return OracleVectorFactory case VectorType.OPENSEARCH: from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory + return OpenSearchVectorFactory case VectorType.ANALYTICDB: from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory + return AnalyticdbVectorFactory + case VectorType.COUCHBASE: + from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseVectorFactory + + return CouchbaseVectorFactory + case VectorType.BAIDU: + from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory + + return BaiduVectorFactory + case VectorType.VIKINGDB: + from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBVectorFactory + + return VikingDBVectorFactory + case VectorType.UPSTASH: + from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory + + return UpstashVectorFactory + case VectorType.TIDB_ON_QDRANT: + from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory + + return TidbOnQdrantVectorFactory + case VectorType.LINDORM: + from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory + + return LindormVectorStoreFactory + case VectorType.OCEANBASE: + from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory + + return OceanBaseVectorFactory case _: raise ValueError(f"Vector store {vector_type} is not supported.") - def create(self, texts: list = None, **kwargs): + def create(self, texts: Optional[list] = None, **kwargs): if texts: embeddings = self._embeddings.embed_documents([document.page_content for document in texts]) - self._vector_processor.create( - texts=texts, - embeddings=embeddings, - **kwargs - ) + self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs) def add_texts(self, documents: list[Document], **kwargs): - if kwargs.get('duplicate_check', False): + if kwargs.get("duplicate_check", False): documents = self._filter_duplicate_texts(documents) + embeddings = self._embeddings.embed_documents([document.page_content for document in documents]) - self._vector_processor.create( - texts=documents, - embeddings=embeddings, - **kwargs - ) + self._vector_processor.create(texts=documents, embeddings=embeddings, **kwargs) def text_exists(self, id: str) -> bool: return self._vector_processor.text_exists(id) @@ -123,24 +166,18 @@ def delete_by_ids(self, ids: list[str]) -> None: def delete_by_metadata_field(self, key: str, value: str) -> None: self._vector_processor.delete_by_metadata_field(key, value) - def search_by_vector( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]: query_vector = self._embeddings.embed_query(query) return self._vector_processor.search_by_vector(query_vector, **kwargs) - def search_by_full_text( - self, query: str, - **kwargs: Any - ) -> list[Document]: + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return self._vector_processor.search_by_full_text(query, **kwargs) def delete(self) -> None: self._vector_processor.delete() # delete collection redis cache if self._vector_processor.collection_name: - collection_exist_cache_key = 'vector_indexing_{}'.format(self._vector_processor.collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._vector_processor.collection_name) redis_client.delete(collection_exist_cache_key) def _get_embeddings(self) -> Embeddings: @@ -150,14 +187,13 @@ def _get_embeddings(self) -> Embeddings: tenant_id=self._dataset.tenant_id, provider=self._dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=self._dataset.embedding_model - + model=self._dataset.embedding_model, ) return CacheEmbedding(embedding_model) def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: - for text in texts[:]: - doc_id = text.metadata['doc_id'] + for text in texts.copy(): + doc_id = text.metadata["doc_id"] exists_duplicate_node = self.text_exists(doc_id) if exists_duplicate_node: texts.remove(text) diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index 317ca6abc8c89d..8e53e3ae8450d6 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -2,17 +2,24 @@ class VectorType(str, Enum): - ANALYTICDB = 'analyticdb' - CHROMA = 'chroma' - MILVUS = 'milvus' - MYSCALE = 'myscale' - PGVECTOR = 'pgvector' - PGVECTO_RS = 'pgvecto-rs' - QDRANT = 'qdrant' - RELYT = 'relyt' - TIDB_VECTOR = 'tidb_vector' - WEAVIATE = 'weaviate' - OPENSEARCH = 'opensearch' - TENCENT = 'tencent' - ORACLE = 'oracle' - ELASTICSEARCH = 'elasticsearch' + ANALYTICDB = "analyticdb" + CHROMA = "chroma" + MILVUS = "milvus" + MYSCALE = "myscale" + PGVECTOR = "pgvector" + PGVECTO_RS = "pgvecto-rs" + QDRANT = "qdrant" + RELYT = "relyt" + TIDB_VECTOR = "tidb_vector" + WEAVIATE = "weaviate" + OPENSEARCH = "opensearch" + TENCENT = "tencent" + ORACLE = "oracle" + ELASTICSEARCH = "elasticsearch" + LINDORM = "lindorm" + COUCHBASE = "couchbase" + BAIDU = "baidu" + VIKINGDB = "vikingdb" + UPSTASH = "upstash" + TIDB_ON_QDRANT = "tidb_on_qdrant" + OCEANBASE = "oceanbase" diff --git a/api/core/rag/datasource/vdb/vikingdb/__init__.py b/api/core/rag/datasource/vdb/vikingdb/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py new file mode 100644 index 00000000000000..4f927f28995613 --- /dev/null +++ b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py @@ -0,0 +1,239 @@ +import json +from typing import Any + +from pydantic import BaseModel +from volcengine.viking_db import ( + Data, + DistanceType, + Field, + FieldType, + IndexType, + QuantType, + VectorIndexParams, + VikingDBService, +) + +from configs import dify_config +from core.rag.datasource.vdb.field import Field as vdb_Field +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document +from extensions.ext_redis import redis_client +from models.dataset import Dataset + + +class VikingDBConfig(BaseModel): + access_key: str + secret_key: str + host: str + region: str + scheme: str + connection_timeout: int + socket_timeout: int + index_type: str = IndexType.HNSW + distance: str = DistanceType.L2 + quant: str = QuantType.Float + + +class VikingDBVector(BaseVector): + def __init__(self, collection_name: str, group_id: str, config: VikingDBConfig): + super().__init__(collection_name) + self._group_id = group_id + self._client_config = config + self._index_name = f"{self._collection_name}_idx" + self._client = VikingDBService( + host=config.host, + region=config.region, + scheme=config.scheme, + connection_timeout=config.connection_timeout, + socket_timeout=config.socket_timeout, + ak=config.access_key, + sk=config.secret_key, + ) + + def _has_collection(self) -> bool: + try: + self._client.get_collection(self._collection_name) + except Exception: + return False + return True + + def _has_index(self) -> bool: + try: + self._client.get_index(self._collection_name, self._index_name) + except Exception: + return False + return True + + def _create_collection(self, dimension: int): + lock_name = f"vector_indexing_lock_{self._collection_name}" + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" + if redis_client.get(collection_exist_cache_key): + return + + if not self._has_collection(): + fields = [ + Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True), + Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String), + Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String), + Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text), + Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=dimension), + ] + + self._client.create_collection( + collection_name=self._collection_name, + fields=fields, + description="Collection For Dify", + ) + + if not self._has_index(): + vector_index = VectorIndexParams( + distance=self._client_config.distance, + index_type=self._client_config.index_type, + quant=self._client_config.quant, + ) + + self._client.create_index( + collection_name=self._collection_name, + index_name=self._index_name, + vector_index=vector_index, + partition_by=vdb_Field.GROUP_KEY.value, + description="Index For Dify", + ) + redis_client.set(collection_exist_cache_key, 1, ex=3600) + + def get_type(self) -> str: + return VectorType.VIKINGDB + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + dimension = len(embeddings[0]) + self._create_collection(dimension) + self.add_texts(texts, embeddings, **kwargs) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + page_contents = [doc.page_content for doc in documents] + metadatas = [doc.metadata for doc in documents] + docs = [] + + for i, page_content in enumerate(page_contents): + metadata = {} + if metadatas is not None: + for key, val in metadatas[i].items(): + metadata[key] = val + doc = Data( + { + vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], + vdb_Field.VECTOR.value: embeddings[i] if embeddings else None, + vdb_Field.CONTENT_KEY.value: page_content, + vdb_Field.METADATA_KEY.value: json.dumps(metadata), + vdb_Field.GROUP_KEY.value: self._group_id, + } + ) + docs.append(doc) + + self._client.get_collection(self._collection_name).upsert_data(docs) + + def text_exists(self, id: str) -> bool: + docs = self._client.get_collection(self._collection_name).fetch_data(id) + not_exists_str = "data does not exist" + if docs is not None and not_exists_str not in docs.fields.get("message", ""): + return True + return False + + def delete_by_ids(self, ids: list[str]) -> None: + self._client.get_collection(self._collection_name).delete_data(ids) + + def get_ids_by_metadata_field(self, key: str, value: str): + # Note: Metadata field value is an dict, but vikingdb field + # not support json type + results = self._client.get_index(self._collection_name, self._index_name).search( + filter={"op": "must", "field": vdb_Field.GROUP_KEY.value, "conds": [self._group_id]}, + # max value is 5000 + limit=5000, + ) + + if not results: + return [] + + ids = [] + for result in results: + metadata = result.fields.get(vdb_Field.METADATA_KEY.value) + if metadata is not None: + metadata = json.loads(metadata) + if metadata.get(key) == value: + ids.append(result.id) + return ids + + def delete_by_metadata_field(self, key: str, value: str) -> None: + ids = self.get_ids_by_metadata_field(key, value) + self.delete_by_ids(ids) + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + results = self._client.get_index(self._collection_name, self._index_name).search_by_vector( + query_vector, limit=kwargs.get("top_k", 4) + ) + score_threshold = float(kwargs.get("score_threshold") or 0.0) + return self._get_search_res(results, score_threshold) + + def _get_search_res(self, results, score_threshold): + if len(results) == 0: + return [] + + docs = [] + for result in results: + metadata = result.fields.get(vdb_Field.METADATA_KEY.value) + if metadata is not None: + metadata = json.loads(metadata) + if result.score > score_threshold: + metadata["score"] = result.score + doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata) + docs.append(doc) + docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + return [] + + def delete(self) -> None: + if self._has_index(): + self._client.drop_index(self._collection_name, self._index_name) + if self._has_collection(): + self._client.drop_collection(self._collection_name) + + +class VikingDBVectorFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> VikingDBVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix.lower() + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.VIKINGDB, collection_name)) + + if dify_config.VIKINGDB_ACCESS_KEY is None: + raise ValueError("VIKINGDB_ACCESS_KEY should not be None") + if dify_config.VIKINGDB_SECRET_KEY is None: + raise ValueError("VIKINGDB_SECRET_KEY should not be None") + if dify_config.VIKINGDB_HOST is None: + raise ValueError("VIKINGDB_HOST should not be None") + if dify_config.VIKINGDB_REGION is None: + raise ValueError("VIKINGDB_REGION should not be None") + if dify_config.VIKINGDB_SCHEME is None: + raise ValueError("VIKINGDB_SCHEME should not be None") + return VikingDBVector( + collection_name=collection_name, + group_id=dataset.id, + config=VikingDBConfig( + access_key=dify_config.VIKINGDB_ACCESS_KEY, + secret_key=dify_config.VIKINGDB_SECRET_KEY, + host=dify_config.VIKINGDB_HOST, + region=dify_config.VIKINGDB_REGION, + scheme=dify_config.VIKINGDB_SCHEME, + connection_timeout=dify_config.VIKINGDB_CONNECTION_TIMEOUT, + socket_timeout=dify_config.VIKINGDB_SOCKET_TIMEOUT, + ), + ) diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 205fe850c35838..649cfbfea8253c 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -7,11 +7,11 @@ from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -22,15 +22,15 @@ class WeaviateConfig(BaseModel): api_key: Optional[str] = None batch_size: int = 100 - @model_validator(mode='before') + @model_validator(mode="before") + @classmethod def validate_config(cls, values: dict) -> dict: - if not values['endpoint']: + if not values["endpoint"]: raise ValueError("config WEAVIATE_ENDPOINT is required") return values class WeaviateVector(BaseVector): - def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list): super().__init__(collection_name) self._client = self._init_client(config) @@ -43,10 +43,7 @@ def _init_client(self, config: WeaviateConfig) -> weaviate.Client: try: client = weaviate.Client( - url=config.endpoint, - auth_client_secret=auth_config, - timeout_config=(5, 60), - startup_period=None + url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None ) except requests.exceptions.ConnectionError: raise ConnectionError("Vector database connection error") @@ -68,10 +65,10 @@ def get_type(self) -> str: def get_collection_name(self, dataset: Dataset) -> str: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] - if not class_prefix.endswith('_Node'): + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + if not class_prefix.endswith("_Node"): # original class_prefix - class_prefix += '_Node' + class_prefix += "_Node" return class_prefix @@ -79,10 +76,7 @@ def get_collection_name(self, dataset: Dataset) -> str: return Dataset.gen_collection_name_by_id(dataset_id) def to_index_struct(self) -> dict: - return { - "type": self.get_type(), - "vector_store": {"class_prefix": self._collection_name} - } + return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): # create collection @@ -91,9 +85,9 @@ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) self.add_texts(texts, embeddings) def _create_collection(self): - lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) + lock_name = "vector_indexing_lock_{}".format(self._collection_name) with redis_client.lock(lock_name, timeout=20): - collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) + collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(collection_exist_cache_key): return schema = self._default_schema(self._collection_name) @@ -129,17 +123,9 @@ def delete_by_metadata_field(self, key: str, value: str): # check whether the index already exists schema = self._default_schema(self._collection_name) if self._client.schema.contains(schema): - where_filter = { - "operator": "Equal", - "path": [key], - "valueText": value - } - - self._client.batch.delete_objects( - class_name=self._collection_name, - where=where_filter, - output='minimal' - ) + where_filter = {"operator": "Equal", "path": [key], "valueText": value} + + self._client.batch.delete_objects(class_name=self._collection_name, where=where_filter, output="minimal") def delete(self): # check whether the index already exists @@ -154,11 +140,19 @@ def text_exists(self, id: str) -> bool: # check whether the index already exists if not self._client.schema.contains(schema): return False - result = self._client.query.get(collection_name).with_additional(["id"]).with_where({ - "path": ["doc_id"], - "operator": "Equal", - "valueText": id, - }).with_limit(1).do() + result = ( + self._client.query.get(collection_name) + .with_additional(["id"]) + .with_where( + { + "path": ["doc_id"], + "operator": "Equal", + "valueText": id, + } + ) + .with_limit(1) + .do() + ) if "errors" in result: raise ValueError(f"Error during query: {result['errors']}") @@ -211,13 +205,13 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc docs = [] for doc, score in docs_and_scores: - score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0 + score_threshold = float(kwargs.get("score_threshold") or 0.0) # check score threshold if score > score_threshold: - doc.metadata['score'] = score + doc.metadata["score"] = score docs.append(doc) # Sort the documents by score in descending order - docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) + docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -240,15 +234,15 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: if kwargs.get("where_filter"): query_obj = query_obj.with_where(kwargs.get("where_filter")) query_obj = query_obj.with_additional(["vector"]) - properties = ['text'] - result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do() + properties = ["text"] + result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do() if "errors" in result: raise ValueError(f"Error during query: {result['errors']}") docs = [] for res in result["data"]["Get"][collection_name]: text = res.pop(Field.TEXT_KEY.value) - additional = res.pop('_additional') - docs.append(Document(page_content=text, vector=additional['vector'], metadata=res)) + additional = res.pop("_additional") + docs.append(Document(page_content=text, vector=additional["vector"], metadata=res)) return docs def _default_schema(self, index_name: str) -> dict: @@ -271,20 +265,19 @@ def _json_serializable(self, value: Any) -> Any: class WeaviateVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> WeaviateVector: if dataset.index_struct_dict: - class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix'] + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] collection_name = class_prefix else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - dataset.index_struct = json.dumps( - self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) return WeaviateVector( collection_name=collection_name, config=WeaviateConfig( endpoint=dify_config.WEAVIATE_ENDPOINT, api_key=dify_config.WEAVIATE_API_KEY, - batch_size=dify_config.WEAVIATE_BATCH_SIZE + batch_size=dify_config.WEAVIATE_BATCH_SIZE, ), - attributes=attributes + attributes=attributes, ) diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 96a15be7426c67..319a2612c7ecb8 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -12,10 +12,10 @@ class DatasetDocumentStore: def __init__( - self, - dataset: Dataset, - user_id: str, - document_id: Optional[str] = None, + self, + dataset: Dataset, + user_id: str, + document_id: Optional[str] = None, ): self._dataset = dataset self._user_id = user_id @@ -41,9 +41,9 @@ def user_id(self) -> Any: @property def docs(self) -> dict[str, Document]: - document_segments = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == self._dataset.id - ).all() + document_segments = ( + db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == self._dataset.id).all() + ) output = {} for document_segment in document_segments: @@ -55,48 +55,45 @@ def docs(self) -> dict[str, Document]: "doc_hash": document_segment.index_node_hash, "document_id": document_segment.document_id, "dataset_id": document_segment.dataset_id, - } + }, ) return output - def add_documents( - self, docs: Sequence[Document], allow_update: bool = True - ) -> None: - max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document_id == self._document_id - ).scalar() + def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> None: + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .filter(DocumentSegment.document_id == self._document_id) + .scalar() + ) if max_position is None: max_position = 0 embedding_model = None - if self._dataset.indexing_technique == 'high_quality': + if self._dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=self._dataset.tenant_id, provider=self._dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=self._dataset.embedding_model + model=self._dataset.embedding_model, ) for doc in docs: if not isinstance(doc, Document): raise ValueError("doc must be a Document") - segment_document = self.get_document_segment(doc_id=doc.metadata['doc_id']) + segment_document = self.get_document_segment(doc_id=doc.metadata["doc_id"]) # NOTE: doc could already exist in the store, but we overwrite it if not allow_update and segment_document: raise ValueError( - f"doc_id {doc.metadata['doc_id']} already exists. " - "Set allow_update to True to overwrite." + f"doc_id {doc.metadata['doc_id']} already exists. Set allow_update to True to overwrite." ) # calc embedding use tokens if embedding_model: - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[doc.page_content] - ) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[doc.page_content]) else: tokens = 0 @@ -107,8 +104,8 @@ def add_documents( tenant_id=self._dataset.tenant_id, dataset_id=self._dataset.id, document_id=self._document_id, - index_node_id=doc.metadata['doc_id'], - index_node_hash=doc.metadata['doc_hash'], + index_node_id=doc.metadata["doc_id"], + index_node_hash=doc.metadata["doc_hash"], position=max_position, content=doc.page_content, word_count=len(doc.page_content), @@ -116,15 +113,15 @@ def add_documents( enabled=False, created_by=self._user_id, ) - if doc.metadata.get('answer'): - segment_document.answer = doc.metadata.pop('answer', '') + if doc.metadata.get("answer"): + segment_document.answer = doc.metadata.pop("answer", "") db.session.add(segment_document) else: segment_document.content = doc.page_content - if doc.metadata.get('answer'): - segment_document.answer = doc.metadata.pop('answer', '') - segment_document.index_node_hash = doc.metadata['doc_hash'] + if doc.metadata.get("answer"): + segment_document.answer = doc.metadata.pop("answer", "") + segment_document.index_node_hash = doc.metadata["doc_hash"] segment_document.word_count = len(doc.page_content) segment_document.tokens = tokens @@ -135,9 +132,7 @@ def document_exists(self, doc_id: str) -> bool: result = self.get_document_segment(doc_id) return result is not None - def get_document( - self, doc_id: str, raise_error: bool = True - ) -> Optional[Document]: + def get_document(self, doc_id: str, raise_error: bool = True) -> Optional[Document]: document_segment = self.get_document_segment(doc_id) if document_segment is None: @@ -153,7 +148,7 @@ def get_document( "doc_hash": document_segment.index_node_hash, "document_id": document_segment.document_id, "dataset_id": document_segment.dataset_id, - } + }, ) def delete_document(self, doc_id: str, raise_error: bool = True) -> None: @@ -188,9 +183,10 @@ def get_document_hash(self, doc_id: str) -> Optional[str]: return document_segment.index_node_hash def get_document_segment(self, doc_id: str) -> DocumentSegment: - document_segment = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == self._dataset.id, - DocumentSegment.index_node_id == doc_id - ).first() + document_segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id) + .first() + ) return document_segment diff --git a/api/core/rag/embedding/__init__.py b/api/core/rag/embedding/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py new file mode 100644 index 00000000000000..3ac65b88bb70de --- /dev/null +++ b/api/core/rag/embedding/cached_embedding.py @@ -0,0 +1,132 @@ +import base64 +import logging +from typing import Optional, cast + +import numpy as np +from sqlalchemy.exc import IntegrityError + +from configs import dify_config +from core.entities.embedding_type import EmbeddingInputType +from core.model_manager import ModelInstance +from core.model_runtime.entities.model_entities import ModelPropertyKey +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.rag.embedding.embedding_base import Embeddings +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from libs import helper +from models.dataset import Embedding + +logger = logging.getLogger(__name__) + + +class CacheEmbedding(Embeddings): + def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> None: + self._model_instance = model_instance + self._user = user + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + """Embed search docs in batches of 10.""" + # use doc embedding cache or store if not exists + text_embeddings = [None for _ in range(len(texts))] + embedding_queue_indices = [] + for i, text in enumerate(texts): + hash = helper.generate_text_hash(text) + embedding = ( + db.session.query(Embedding) + .filter_by( + model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider + ) + .first() + ) + if embedding: + text_embeddings[i] = embedding.get_embedding() + else: + embedding_queue_indices.append(i) + if embedding_queue_indices: + embedding_queue_texts = [texts[i] for i in embedding_queue_indices] + embedding_queue_embeddings = [] + try: + model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) + model_schema = model_type_instance.get_model_schema( + self._model_instance.model, self._model_instance.credentials + ) + max_chunks = ( + model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties + else 1 + ) + for i in range(0, len(embedding_queue_texts), max_chunks): + batch_texts = embedding_queue_texts[i : i + max_chunks] + + embedding_result = self._model_instance.invoke_text_embedding( + texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT + ) + + for vector in embedding_result.embeddings: + try: + normalized_embedding = (vector / np.linalg.norm(vector)).tolist() + embedding_queue_embeddings.append(normalized_embedding) + except IntegrityError: + db.session.rollback() + except Exception as e: + logging.exception("Failed transform embedding: %s", e) + cache_embeddings = [] + try: + for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): + text_embeddings[i] = embedding + hash = helper.generate_text_hash(texts[i]) + if hash not in cache_embeddings: + embedding_cache = Embedding( + model_name=self._model_instance.model, + hash=hash, + provider_name=self._model_instance.provider, + ) + embedding_cache.set_embedding(embedding) + db.session.add(embedding_cache) + cache_embeddings.append(hash) + db.session.commit() + except IntegrityError: + db.session.rollback() + except Exception as ex: + db.session.rollback() + logger.exception("Failed to embed documents: %s", ex) + raise ex + + return text_embeddings + + def embed_query(self, text: str) -> list[float]: + """Embed query text.""" + # use doc embedding cache or store if not exists + hash = helper.generate_text_hash(text) + embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}" + embedding = redis_client.get(embedding_cache_key) + if embedding: + redis_client.expire(embedding_cache_key, 600) + return list(np.frombuffer(base64.b64decode(embedding), dtype="float")) + try: + embedding_result = self._model_instance.invoke_text_embedding( + texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY + ) + + embedding_results = embedding_result.embeddings[0] + embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() + except Exception as ex: + if dify_config.DEBUG: + logging.exception(f"Failed to embed query text: {ex}") + raise ex + + try: + # encode embedding to base64 + embedding_vector = np.array(embedding_results) + vector_bytes = embedding_vector.tobytes() + # Transform to Base64 + encoded_vector = base64.b64encode(vector_bytes) + # Transform to string + encoded_str = encoded_vector.decode("utf-8") + redis_client.setex(embedding_cache_key, 600, encoded_str) + except Exception as ex: + if dify_config.DEBUG: + logging.exception("Failed to add embedding to redis %s", ex) + raise ex + + return embedding_results diff --git a/api/core/rag/embedding/embedding_base.py b/api/core/rag/embedding/embedding_base.py new file mode 100644 index 00000000000000..9f232ab91089fe --- /dev/null +++ b/api/core/rag/embedding/embedding_base.py @@ -0,0 +1,23 @@ +from abc import ABC, abstractmethod + + +class Embeddings(ABC): + """Interface for embedding models.""" + + @abstractmethod + def embed_documents(self, texts: list[str]) -> list[list[float]]: + """Embed search docs.""" + raise NotImplementedError + + @abstractmethod + def embed_query(self, text: str) -> list[float]: + """Embed query text.""" + raise NotImplementedError + + async def aembed_documents(self, texts: list[str]) -> list[list[float]]: + """Asynchronous Embed search docs.""" + raise NotImplementedError + + async def aembed_query(self, text: str) -> list[float]: + """Asynchronous Embed query text.""" + raise NotImplementedError diff --git a/api/core/rag/entities/context_entities.py b/api/core/rag/entities/context_entities.py new file mode 100644 index 00000000000000..cd18ad081ff4fd --- /dev/null +++ b/api/core/rag/entities/context_entities.py @@ -0,0 +1,12 @@ +from typing import Optional + +from pydantic import BaseModel + + +class DocumentContext(BaseModel): + """ + Model class for document context. + """ + + content: str + score: Optional[float] = None diff --git a/api/core/rag/extractor/blob/blob.py b/api/core/rag/extractor/blob/blob.py new file mode 100644 index 00000000000000..e46ab8b7fd0ac2 --- /dev/null +++ b/api/core/rag/extractor/blob/blob.py @@ -0,0 +1,163 @@ +"""Schema for Blobs and Blob Loaders. + +The goal is to facilitate decoupling of content loading from content parsing code. + +In addition, content loading code should provide a lazy loading interface by default. +""" + +from __future__ import annotations + +import contextlib +import mimetypes +from abc import ABC, abstractmethod +from collections.abc import Generator, Iterable, Mapping +from io import BufferedReader, BytesIO +from pathlib import Path, PurePath +from typing import Any, Optional, Union + +from pydantic import BaseModel, ConfigDict, model_validator + +PathLike = Union[str, PurePath] + + +class Blob(BaseModel): + """A blob is used to represent raw data by either reference or value. + + Provides an interface to materialize the blob in different representations, and + help to decouple the development of data loaders from the downstream parsing of + the raw data. + + Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob + """ + + data: Union[bytes, str, None] = None # Raw data + mimetype: Optional[str] = None # Not to be confused with a file extension + encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string + # Location where the original content was found + # Represent location on the local file system + # Useful for situations where downstream code assumes it must work with file paths + # rather than in-memory content. + path: Optional[PathLike] = None + model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) + + @property + def source(self) -> Optional[str]: + """The source location of the blob as string if known otherwise none.""" + return str(self.path) if self.path else None + + @model_validator(mode="before") + @classmethod + def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]: + """Verify that either data or path is provided.""" + if "data" not in values and "path" not in values: + raise ValueError("Either data or path must be provided") + return values + + def as_string(self) -> str: + """Read data as a string.""" + if self.data is None and self.path: + return Path(str(self.path)).read_text(encoding=self.encoding) + elif isinstance(self.data, bytes): + return self.data.decode(self.encoding) + elif isinstance(self.data, str): + return self.data + else: + raise ValueError(f"Unable to get string for blob {self}") + + def as_bytes(self) -> bytes: + """Read data as bytes.""" + if isinstance(self.data, bytes): + return self.data + elif isinstance(self.data, str): + return self.data.encode(self.encoding) + elif self.data is None and self.path: + return Path(str(self.path)).read_bytes() + else: + raise ValueError(f"Unable to get bytes for blob {self}") + + @contextlib.contextmanager + def as_bytes_io(self) -> Generator[Union[BytesIO, BufferedReader], None, None]: + """Read data as a byte stream.""" + if isinstance(self.data, bytes): + yield BytesIO(self.data) + elif self.data is None and self.path: + with open(str(self.path), "rb") as f: + yield f + else: + raise NotImplementedError(f"Unable to convert blob {self}") + + @classmethod + def from_path( + cls, + path: PathLike, + *, + encoding: str = "utf-8", + mime_type: Optional[str] = None, + guess_type: bool = True, + ) -> Blob: + """Load the blob from a path like object. + + Args: + path: path like object to file to be read + encoding: Encoding to use if decoding the bytes into a string + mime_type: if provided, will be set as the mime-type of the data + guess_type: If True, the mimetype will be guessed from the file extension, + if a mime-type was not provided + + Returns: + Blob instance + """ + if mime_type is None and guess_type: + _mimetype = mimetypes.guess_type(path)[0] if guess_type else None + else: + _mimetype = mime_type + # We do not load the data immediately, instead we treat the blob as a + # reference to the underlying data. + return cls(data=None, mimetype=_mimetype, encoding=encoding, path=path) + + @classmethod + def from_data( + cls, + data: Union[str, bytes], + *, + encoding: str = "utf-8", + mime_type: Optional[str] = None, + path: Optional[str] = None, + ) -> Blob: + """Initialize the blob from in-memory data. + + Args: + data: the in-memory data associated with the blob + encoding: Encoding to use if decoding the bytes into a string + mime_type: if provided, will be set as the mime-type of the data + path: if provided, will be set as the source from which the data came + + Returns: + Blob instance + """ + return cls(data=data, mimetype=mime_type, encoding=encoding, path=path) + + def __repr__(self) -> str: + """Define the blob representation.""" + str_repr = f"Blob {id(self)}" + if self.source: + str_repr += f" {self.source}" + return str_repr + + +class BlobLoader(ABC): + """Abstract interface for blob loaders implementation. + + Implementer should be able to load raw content from a datasource system according + to some criteria and return the raw content lazily as a stream of blobs. + """ + + @abstractmethod + def yield_blobs( + self, + ) -> Iterable[Blob]: + """A lazy loader for raw data represented by Blob object. + + Returns: + A generator over blobs + """ diff --git a/api/core/rag/extractor/blod/blod.py b/api/core/rag/extractor/blod/blod.py deleted file mode 100644 index abfdafcfa251a4..00000000000000 --- a/api/core/rag/extractor/blod/blod.py +++ /dev/null @@ -1,164 +0,0 @@ -"""Schema for Blobs and Blob Loaders. - -The goal is to facilitate decoupling of content loading from content parsing code. - -In addition, content loading code should provide a lazy loading interface by default. -""" -from __future__ import annotations - -import contextlib -import mimetypes -from abc import ABC, abstractmethod -from collections.abc import Generator, Iterable, Mapping -from io import BufferedReader, BytesIO -from pathlib import PurePath -from typing import Any, Optional, Union - -from pydantic import BaseModel, ConfigDict, model_validator - -PathLike = Union[str, PurePath] - - -class Blob(BaseModel): - """A blob is used to represent raw data by either reference or value. - - Provides an interface to materialize the blob in different representations, and - help to decouple the development of data loaders from the downstream parsing of - the raw data. - - Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob - """ - - data: Union[bytes, str, None] = None # Raw data - mimetype: Optional[str] = None # Not to be confused with a file extension - encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string - # Location where the original content was found - # Represent location on the local file system - # Useful for situations where downstream code assumes it must work with file paths - # rather than in-memory content. - path: Optional[PathLike] = None - model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) - - @property - def source(self) -> Optional[str]: - """The source location of the blob as string if known otherwise none.""" - return str(self.path) if self.path else None - - @model_validator(mode="before") - @classmethod - def check_blob_is_valid(cls, values: Mapping[str, Any]) -> Mapping[str, Any]: - """Verify that either data or path is provided.""" - if "data" not in values and "path" not in values: - raise ValueError("Either data or path must be provided") - return values - - def as_string(self) -> str: - """Read data as a string.""" - if self.data is None and self.path: - with open(str(self.path), encoding=self.encoding) as f: - return f.read() - elif isinstance(self.data, bytes): - return self.data.decode(self.encoding) - elif isinstance(self.data, str): - return self.data - else: - raise ValueError(f"Unable to get string for blob {self}") - - def as_bytes(self) -> bytes: - """Read data as bytes.""" - if isinstance(self.data, bytes): - return self.data - elif isinstance(self.data, str): - return self.data.encode(self.encoding) - elif self.data is None and self.path: - with open(str(self.path), "rb") as f: - return f.read() - else: - raise ValueError(f"Unable to get bytes for blob {self}") - - @contextlib.contextmanager - def as_bytes_io(self) -> Generator[Union[BytesIO, BufferedReader], None, None]: - """Read data as a byte stream.""" - if isinstance(self.data, bytes): - yield BytesIO(self.data) - elif self.data is None and self.path: - with open(str(self.path), "rb") as f: - yield f - else: - raise NotImplementedError(f"Unable to convert blob {self}") - - @classmethod - def from_path( - cls, - path: PathLike, - *, - encoding: str = "utf-8", - mime_type: Optional[str] = None, - guess_type: bool = True, - ) -> Blob: - """Load the blob from a path like object. - - Args: - path: path like object to file to be read - encoding: Encoding to use if decoding the bytes into a string - mime_type: if provided, will be set as the mime-type of the data - guess_type: If True, the mimetype will be guessed from the file extension, - if a mime-type was not provided - - Returns: - Blob instance - """ - if mime_type is None and guess_type: - _mimetype = mimetypes.guess_type(path)[0] if guess_type else None - else: - _mimetype = mime_type - # We do not load the data immediately, instead we treat the blob as a - # reference to the underlying data. - return cls(data=None, mimetype=_mimetype, encoding=encoding, path=path) - - @classmethod - def from_data( - cls, - data: Union[str, bytes], - *, - encoding: str = "utf-8", - mime_type: Optional[str] = None, - path: Optional[str] = None, - ) -> Blob: - """Initialize the blob from in-memory data. - - Args: - data: the in-memory data associated with the blob - encoding: Encoding to use if decoding the bytes into a string - mime_type: if provided, will be set as the mime-type of the data - path: if provided, will be set as the source from which the data came - - Returns: - Blob instance - """ - return cls(data=data, mimetype=mime_type, encoding=encoding, path=path) - - def __repr__(self) -> str: - """Define the blob representation.""" - str_repr = f"Blob {id(self)}" - if self.source: - str_repr += f" {self.source}" - return str_repr - - -class BlobLoader(ABC): - """Abstract interface for blob loaders implementation. - - Implementer should be able to load raw content from a datasource system according - to some criteria and return the raw content lazily as a stream of blobs. - """ - - @abstractmethod - def yield_blobs( - self, - ) -> Iterable[Blob]: - """A lazy loader for raw data represented by Blob object. - - Returns: - A generator over blobs - """ diff --git a/api/core/rag/extractor/csv_extractor.py b/api/core/rag/extractor/csv_extractor.py index 0470569f393020..5b674039024189 100644 --- a/api/core/rag/extractor/csv_extractor.py +++ b/api/core/rag/extractor/csv_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + import csv from typing import Optional @@ -18,12 +19,12 @@ class CSVExtractor(BaseExtractor): """ def __init__( - self, - file_path: str, - encoding: Optional[str] = None, - autodetect_encoding: bool = False, - source_column: Optional[str] = None, - csv_args: Optional[dict] = None, + self, + file_path: str, + encoding: Optional[str] = None, + autodetect_encoding: bool = False, + source_column: Optional[str] = None, + csv_args: Optional[dict] = None, ): """Initialize with file path.""" self._file_path = file_path @@ -57,7 +58,7 @@ def _read_from_file(self, csvfile) -> list[Document]: docs = [] try: # load csv file into pandas dataframe - df = pd.read_csv(csvfile, on_bad_lines='skip', **self.csv_args) + df = pd.read_csv(csvfile, on_bad_lines="skip", **self.csv_args) # check source column exists if self.source_column and self.source_column not in df.columns: @@ -67,7 +68,7 @@ def _read_from_file(self, csvfile) -> list[Document]: for i, row in df.iterrows(): content = ";".join(f"{col.strip()}: {str(row[col]).strip()}" for col in df.columns) - source = row[self.source_column] if self.source_column else '' + source = row[self.source_column] if self.source_column else "" metadata = {"source": source, "row": i} doc = Document(page_content=content, metadata=metadata) docs.append(doc) diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index 7479b1d97b8fdc..3692b5d19dfb65 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -10,6 +10,7 @@ class NotionInfo(BaseModel): """ Notion import info. """ + notion_workspace_id: str notion_obj_id: str notion_page_type: str @@ -25,6 +26,7 @@ class WebsiteInfo(BaseModel): """ website import info. """ + provider: str job_id: str url: str @@ -43,6 +45,7 @@ class ExtractSetting(BaseModel): """ Model class for provider response. """ + datasource_type: str upload_file: Optional[UploadFile] = None notion_info: Optional[NotionInfo] = None diff --git a/api/core/rag/extractor/excel_extractor.py b/api/core/rag/extractor/excel_extractor.py index f0c302a6197a64..fc331657195454 100644 --- a/api/core/rag/extractor/excel_extractor.py +++ b/api/core/rag/extractor/excel_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + import os from typing import Optional @@ -17,59 +18,60 @@ class ExcelExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - encoding: Optional[str] = None, - autodetect_encoding: bool = False - ): + def __init__(self, file_path: str, encoding: Optional[str] = None, autodetect_encoding: bool = False): """Initialize with file path.""" self._file_path = file_path self._encoding = encoding self._autodetect_encoding = autodetect_encoding def extract(self) -> list[Document]: - """ Load from Excel file in xls or xlsx format using Pandas and openpyxl.""" + """Load from Excel file in xls or xlsx format using Pandas and openpyxl.""" documents = [] file_extension = os.path.splitext(self._file_path)[-1].lower() - if file_extension == '.xlsx': + if file_extension == ".xlsx": wb = load_workbook(self._file_path, data_only=True) for sheet_name in wb.sheetnames: sheet = wb[sheet_name] data = sheet.values - cols = next(data) + try: + cols = next(data) + except StopIteration: + continue df = pd.DataFrame(data, columns=cols) - df.dropna(how='all', inplace=True) + df.dropna(how="all", inplace=True) for index, row in df.iterrows(): page_content = [] for col_index, (k, v) in enumerate(row.items()): if pd.notna(v): - cell = sheet.cell(row=index + 2, - column=col_index + 1) # +2 to account for header and 1-based index + cell = sheet.cell( + row=index + 2, column=col_index + 1 + ) # +2 to account for header and 1-based index if cell.hyperlink: value = f"[{v}]({cell.hyperlink.target})" page_content.append(f'"{k}":"{value}"') else: page_content.append(f'"{k}":"{v}"') - documents.append(Document(page_content=';'.join(page_content), - metadata={'source': self._file_path})) + documents.append( + Document(page_content=";".join(page_content), metadata={"source": self._file_path}) + ) - elif file_extension == '.xls': - excel_file = pd.ExcelFile(self._file_path, engine='xlrd') + elif file_extension == ".xls": + excel_file = pd.ExcelFile(self._file_path, engine="xlrd") for sheet_name in excel_file.sheet_names: df = excel_file.parse(sheet_name=sheet_name) - df.dropna(how='all', inplace=True) + df.dropna(how="all", inplace=True) for _, row in df.iterrows(): page_content = [] for k, v in row.items(): if pd.notna(v): page_content.append(f'"{k}":"{v}"') - documents.append(Document(page_content=';'.join(page_content), - metadata={'source': self._file_path})) + documents.append( + Document(page_content=";".join(page_content), metadata={"source": self._file_path}) + ) else: raise ValueError(f"Unsupported file extension: {file_extension}") diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index f7a08135f57518..a0b1aa4cefbd1f 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -1,7 +1,7 @@ import re import tempfile from pathlib import Path -from typing import Union +from typing import Optional, Union from urllib.parse import unquote from configs import dify_config @@ -12,6 +12,7 @@ from core.rag.extractor.excel_extractor import ExcelExtractor from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor from core.rag.extractor.html_extractor import HtmlExtractor +from core.rag.extractor.jina_reader_extractor import JinaReaderWebExtractor from core.rag.extractor.markdown_extractor import MarkdownExtractor from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.extractor.pdf_extractor import PdfExtractor @@ -29,61 +30,62 @@ from extensions.ext_storage import storage from models.model import UploadFile -SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain', 'application/json'] -USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" +SUPPORT_URL_CONTENT_TYPES = ["application/pdf", "text/plain", "application/json"] +USER_AGENT = ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124" + " Safari/537.36" +) class ExtractProcessor: @classmethod - def load_from_upload_file(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) \ - -> Union[list[Document], str]: + def load_from_upload_file( + cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False + ) -> Union[list[Document], str]: extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=upload_file, - document_model='text_model' + datasource_type="upload_file", upload_file=upload_file, document_model="text_model" ) if return_text: - delimiter = '\n' + delimiter = "\n" return delimiter.join([document.page_content for document in cls.extract(extract_setting, is_automatic)]) else: return cls.extract(extract_setting, is_automatic) @classmethod def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]: - response = ssrf_proxy.get(url, headers={ - "User-Agent": USER_AGENT - }) + response = ssrf_proxy.get(url, headers={"User-Agent": USER_AGENT}) with tempfile.TemporaryDirectory() as temp_dir: suffix = Path(url).suffix - if not suffix and suffix != '.': + if not suffix and suffix != ".": # get content-type - if response.headers.get('Content-Type'): - suffix = '.' + response.headers.get('Content-Type').split('/')[-1] + if response.headers.get("Content-Type"): + suffix = "." + response.headers.get("Content-Type").split("/")[-1] else: - content_disposition = response.headers.get('Content-Disposition') + content_disposition = response.headers.get("Content-Disposition") filename_match = re.search(r'filename="([^"]+)"', content_disposition) if filename_match: filename = unquote(filename_match.group(1)) - suffix = '.' + re.search(r'\.(\w+)$', filename).group(1) + suffix = "." + re.search(r"\.(\w+)$", filename).group(1) file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" - with open(file_path, 'wb') as file: - file.write(response.content) - extract_setting = ExtractSetting( - datasource_type="upload_file", - document_model='text_model' - ) + Path(file_path).write_bytes(response.content) + extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model") if return_text: - delimiter = '\n' - return delimiter.join([document.page_content for document in cls.extract( - extract_setting=extract_setting, file_path=file_path)]) + delimiter = "\n" + return delimiter.join( + [ + document.page_content + for document in cls.extract(extract_setting=extract_setting, file_path=file_path) + ] + ) else: return cls.extract(extract_setting=extract_setting, file_path=file_path) @classmethod - def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False, - file_path: str = None) -> list[Document]: + def extract( + cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: Optional[str] = None + ) -> list[Document]: if extract_setting.datasource_type == DatasourceType.FILE.value: with tempfile.TemporaryDirectory() as temp_dir: if not file_path: @@ -96,50 +98,58 @@ def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False, etl_type = dify_config.ETL_TYPE unstructured_api_url = dify_config.UNSTRUCTURED_API_URL unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY - if etl_type == 'Unstructured': - if file_extension == '.xlsx' or file_extension == '.xls': + if etl_type == "Unstructured": + if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) - elif file_extension == '.pdf': + elif file_extension == ".pdf": extractor = PdfExtractor(file_path) - elif file_extension in ['.md', '.markdown']: - extractor = UnstructuredMarkdownExtractor(file_path, unstructured_api_url) if is_automatic \ + elif file_extension in {".md", ".markdown"}: + extractor = ( + UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key) + if is_automatic else MarkdownExtractor(file_path, autodetect_encoding=True) - elif file_extension in ['.htm', '.html']: + ) + elif file_extension in {".htm", ".html"}: extractor = HtmlExtractor(file_path) - elif file_extension in ['.docx']: + elif file_extension == ".docx": extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) - elif file_extension == '.csv': + elif file_extension == ".csv": extractor = CSVExtractor(file_path, autodetect_encoding=True) - elif file_extension == '.msg': - extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url) - elif file_extension == '.eml': - extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url) - elif file_extension == '.ppt': + elif file_extension == ".msg": + extractor = UnstructuredMsgExtractor(file_path, unstructured_api_url, unstructured_api_key) + elif file_extension == ".eml": + extractor = UnstructuredEmailExtractor(file_path, unstructured_api_url, unstructured_api_key) + elif file_extension == ".ppt": extractor = UnstructuredPPTExtractor(file_path, unstructured_api_url, unstructured_api_key) - elif file_extension == '.pptx': - extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url) - elif file_extension == '.xml': - extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url) - elif file_extension == 'epub': - extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url) + # You must first specify the API key + # because unstructured_api_key is necessary to parse .ppt documents + elif file_extension == ".pptx": + extractor = UnstructuredPPTXExtractor(file_path, unstructured_api_url, unstructured_api_key) + elif file_extension == ".xml": + extractor = UnstructuredXmlExtractor(file_path, unstructured_api_url, unstructured_api_key) + elif file_extension == ".epub": + extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url, unstructured_api_key) else: # txt - extractor = UnstructuredTextExtractor(file_path, unstructured_api_url) if is_automatic \ + extractor = ( + UnstructuredTextExtractor(file_path, unstructured_api_url) + if is_automatic else TextExtractor(file_path, autodetect_encoding=True) + ) else: - if file_extension == '.xlsx' or file_extension == '.xls': + if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) - elif file_extension == '.pdf': + elif file_extension == ".pdf": extractor = PdfExtractor(file_path) - elif file_extension in ['.md', '.markdown']: + elif file_extension in {".md", ".markdown"}: extractor = MarkdownExtractor(file_path, autodetect_encoding=True) - elif file_extension in ['.htm', '.html']: + elif file_extension in {".htm", ".html"}: extractor = HtmlExtractor(file_path) - elif file_extension in ['.docx']: + elif file_extension == ".docx": extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) - elif file_extension == '.csv': + elif file_extension == ".csv": extractor = CSVExtractor(file_path, autodetect_encoding=True) - elif file_extension == 'epub': + elif file_extension == ".epub": extractor = UnstructuredEpubExtractor(file_path) else: # txt @@ -155,13 +165,22 @@ def extract(cls, extract_setting: ExtractSetting, is_automatic: bool = False, ) return extractor.extract() elif extract_setting.datasource_type == DatasourceType.WEBSITE.value: - if extract_setting.website_info.provider == 'firecrawl': + if extract_setting.website_info.provider == "firecrawl": extractor = FirecrawlWebExtractor( url=extract_setting.website_info.url, job_id=extract_setting.website_info.job_id, tenant_id=extract_setting.website_info.tenant_id, mode=extract_setting.website_info.mode, - only_main_content=extract_setting.website_info.only_main_content + only_main_content=extract_setting.website_info.only_main_content, + ) + return extractor.extract() + elif extract_setting.website_info.provider == "jinareader": + extractor = JinaReaderWebExtractor( + url=extract_setting.website_info.url, + job_id=extract_setting.website_info.job_id, + tenant_id=extract_setting.website_info.tenant_id, + mode=extract_setting.website_info.mode, + only_main_content=extract_setting.website_info.only_main_content, ) return extractor.extract() else: diff --git a/api/core/rag/extractor/extractor_base.py b/api/core/rag/extractor/extractor_base.py index c490e59332d237..582eca94df71e1 100644 --- a/api/core/rag/extractor/extractor_base.py +++ b/api/core/rag/extractor/extractor_base.py @@ -1,12 +1,11 @@ """Abstract interface for document loader implementations.""" + from abc import ABC, abstractmethod class BaseExtractor(ABC): - """Interface for extract files. - """ + """Interface for extract files.""" @abstractmethod def extract(self): raise NotImplementedError - diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 2b85ad9739881e..17c2087a0ab575 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -9,108 +9,98 @@ class FirecrawlApp: def __init__(self, api_key=None, base_url=None): self.api_key = api_key - self.base_url = base_url or 'https://api.firecrawl.dev' - if self.api_key is None and self.base_url == 'https://api.firecrawl.dev': - raise ValueError('No API key provided') + self.base_url = base_url or "https://api.firecrawl.dev" + if self.api_key is None and self.base_url == "https://api.firecrawl.dev": + raise ValueError("No API key provided") def scrape_url(self, url, params=None) -> dict: - headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } - json_data = {'url': url} + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} + json_data = {"url": url} if params: json_data.update(params) - response = requests.post( - f'{self.base_url}/v0/scrape', - headers=headers, - json=json_data - ) + response = requests.post(f"{self.base_url}/v0/scrape", headers=headers, json=json_data) if response.status_code == 200: response = response.json() - if response['success'] == True: - data = response['data'] + if response["success"] == True: + data = response["data"] return { - 'title': data.get('metadata').get('title'), - 'description': data.get('metadata').get('description'), - 'source_url': data.get('metadata').get('sourceURL'), - 'markdown': data.get('markdown') + "title": data.get("metadata").get("title"), + "description": data.get("metadata").get("description"), + "source_url": data.get("metadata").get("sourceURL"), + "markdown": data.get("markdown"), } else: raise Exception(f'Failed to scrape URL. Error: {response["error"]}') - elif response.status_code in [402, 409, 500]: - error_message = response.json().get('error', 'Unknown error occurred') - raise Exception(f'Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}') + elif response.status_code in {402, 409, 500}: + error_message = response.json().get("error", "Unknown error occurred") + raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}") else: - raise Exception(f'Failed to scrape URL. Status code: {response.status_code}') + raise Exception(f"Failed to scrape URL. Status code: {response.status_code}") def crawl_url(self, url, params=None) -> str: headers = self._prepare_headers() - json_data = {'url': url} + json_data = {"url": url} if params: json_data.update(params) - response = self._post_request(f'{self.base_url}/v0/crawl', json_data, headers) + response = self._post_request(f"{self.base_url}/v0/crawl", json_data, headers) if response.status_code == 200: - job_id = response.json().get('jobId') + job_id = response.json().get("jobId") return job_id else: - self._handle_error(response, 'start crawl job') + self._handle_error(response, "start crawl job") def check_crawl_status(self, job_id) -> dict: headers = self._prepare_headers() - response = self._get_request(f'{self.base_url}/v0/crawl/status/{job_id}', headers) + response = self._get_request(f"{self.base_url}/v0/crawl/status/{job_id}", headers) if response.status_code == 200: crawl_status_response = response.json() - if crawl_status_response.get('status') == 'completed': - total = crawl_status_response.get('total', 0) + if crawl_status_response.get("status") == "completed": + total = crawl_status_response.get("total", 0) if total == 0: - raise Exception('Failed to check crawl status. Error: No page found') - data = crawl_status_response.get('data', []) + raise Exception("Failed to check crawl status. Error: No page found") + data = crawl_status_response.get("data", []) url_data_list = [] for item in data: - if isinstance(item, dict) and 'metadata' in item and 'markdown' in item: + if isinstance(item, dict) and "metadata" in item and "markdown" in item: url_data = { - 'title': item.get('metadata').get('title'), - 'description': item.get('metadata').get('description'), - 'source_url': item.get('metadata').get('sourceURL'), - 'markdown': item.get('markdown') + "title": item.get("metadata").get("title"), + "description": item.get("metadata").get("description"), + "source_url": item.get("metadata").get("sourceURL"), + "markdown": item.get("markdown"), } url_data_list.append(url_data) if url_data_list: - file_key = 'website_files/' + job_id + '.txt' + file_key = "website_files/" + job_id + ".txt" if storage.exists(file_key): storage.delete(file_key) - storage.save(file_key, json.dumps(url_data_list).encode('utf-8')) + storage.save(file_key, json.dumps(url_data_list).encode("utf-8")) return { - 'status': 'completed', - 'total': crawl_status_response.get('total'), - 'current': crawl_status_response.get('current'), - 'data': url_data_list + "status": "completed", + "total": crawl_status_response.get("total"), + "current": crawl_status_response.get("current"), + "data": url_data_list, } else: return { - 'status': crawl_status_response.get('status'), - 'total': crawl_status_response.get('total'), - 'current': crawl_status_response.get('current'), - 'data': [] + "status": crawl_status_response.get("status"), + "total": crawl_status_response.get("total"), + "current": crawl_status_response.get("current"), + "data": [], } else: - self._handle_error(response, 'check crawl status') + self._handle_error(response, "check crawl status") def _prepare_headers(self): - return { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } + return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5): for attempt in range(retries): response = requests.post(url, headers=headers, json=data) if response.status_code == 502: - time.sleep(backoff_factor * (2 ** attempt)) + time.sleep(backoff_factor * (2**attempt)) else: return response return response @@ -119,13 +109,11 @@ def _get_request(self, url, headers, retries=3, backoff_factor=0.5): for attempt in range(retries): response = requests.get(url, headers=headers) if response.status_code == 502: - time.sleep(backoff_factor * (2 ** attempt)) + time.sleep(backoff_factor * (2**attempt)) else: return response return response def _handle_error(self, response, action): - error_message = response.json().get('error', 'Unknown error occurred') - raise Exception(f'Failed to {action}. Status code: {response.status_code}. Error: {error_message}') - - + error_message = response.json().get("error", "Unknown error occurred") + raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") diff --git a/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py b/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py index 8e2f107e5eb795..b33ce167c21c82 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py @@ -5,7 +5,7 @@ class FirecrawlWebExtractor(BaseExtractor): """ - Crawl and scrape websites and return content in clean llm-ready markdown. + Crawl and scrape websites and return content in clean llm-ready markdown. Args: @@ -15,14 +15,7 @@ class FirecrawlWebExtractor(BaseExtractor): mode: The mode of operation. Defaults to 'scrape'. Options are 'crawl', 'scrape' and 'crawl_return_urls'. """ - def __init__( - self, - url: str, - job_id: str, - tenant_id: str, - mode: str = 'crawl', - only_main_content: bool = False - ): + def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = False): """Initialize with url, api_key, base_url and mode.""" self._url = url self.job_id = job_id @@ -33,28 +26,31 @@ def __init__( def extract(self) -> list[Document]: """Extract content from the URL.""" documents = [] - if self.mode == 'crawl': - crawl_data = WebsiteService.get_crawl_url_data(self.job_id, 'firecrawl', self._url, self.tenant_id) + if self.mode == "crawl": + crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "firecrawl", self._url, self.tenant_id) if crawl_data is None: return [] - document = Document(page_content=crawl_data.get('markdown', ''), - metadata={ - 'source_url': crawl_data.get('source_url'), - 'description': crawl_data.get('description'), - 'title': crawl_data.get('title') - } - ) + document = Document( + page_content=crawl_data.get("markdown", ""), + metadata={ + "source_url": crawl_data.get("source_url"), + "description": crawl_data.get("description"), + "title": crawl_data.get("title"), + }, + ) documents.append(document) - elif self.mode == 'scrape': - scrape_data = WebsiteService.get_scrape_url_data('firecrawl', self._url, self.tenant_id, - self.only_main_content) + elif self.mode == "scrape": + scrape_data = WebsiteService.get_scrape_url_data( + "firecrawl", self._url, self.tenant_id, self.only_main_content + ) - document = Document(page_content=scrape_data.get('markdown', ''), - metadata={ - 'source_url': scrape_data.get('source_url'), - 'description': scrape_data.get('description'), - 'title': scrape_data.get('title') - } - ) + document = Document( + page_content=scrape_data.get("markdown", ""), + metadata={ + "source_url": scrape_data.get("source_url"), + "description": scrape_data.get("description"), + "title": scrape_data.get("title"), + }, + ) documents.append(document) return documents diff --git a/api/core/rag/extractor/helpers.py b/api/core/rag/extractor/helpers.py index 0c17a47b329cab..69ca9d5d636888 100644 --- a/api/core/rag/extractor/helpers.py +++ b/api/core/rag/extractor/helpers.py @@ -1,6 +1,7 @@ """Document loader helpers.""" import concurrent.futures +from pathlib import Path from typing import NamedTuple, Optional, cast @@ -28,8 +29,7 @@ def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding import chardet def read_and_detect(file_path: str) -> list[dict]: - with open(file_path, "rb") as f: - rawdata = f.read() + rawdata = Path(file_path).read_bytes() return cast(list[dict], chardet.detect_all(rawdata)) with concurrent.futures.ThreadPoolExecutor() as executor: @@ -37,9 +37,7 @@ def read_and_detect(file_path: str) -> list[dict]: try: encodings = future.result(timeout=timeout) except concurrent.futures.TimeoutError: - raise TimeoutError( - f"Timeout reached while detecting encoding for {file_path}" - ) + raise TimeoutError(f"Timeout reached while detecting encoding for {file_path}") if all(encoding["encoding"] is None for encoding in encodings): raise RuntimeError(f"Could not detect encoding for {file_path}") diff --git a/api/core/rag/extractor/html_extractor.py b/api/core/rag/extractor/html_extractor.py index ceb53062559a4d..560c2d1d84b04e 100644 --- a/api/core/rag/extractor/html_extractor.py +++ b/api/core/rag/extractor/html_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + from bs4 import BeautifulSoup from core.rag.extractor.extractor_base import BaseExtractor @@ -6,7 +7,6 @@ class HtmlExtractor(BaseExtractor): - """ Load html files. @@ -15,10 +15,7 @@ class HtmlExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str - ): + def __init__(self, file_path: str): """Initialize with file path.""" self._file_path = file_path @@ -27,8 +24,8 @@ def extract(self) -> list[Document]: def _load_as_text(self) -> str: with open(self._file_path, "rb") as fp: - soup = BeautifulSoup(fp, 'html.parser') + soup = BeautifulSoup(fp, "html.parser") text = soup.get_text() - text = text.strip() if text else '' + text = text.strip() if text else "" - return text \ No newline at end of file + return text diff --git a/api/core/rag/extractor/jina_reader_extractor.py b/api/core/rag/extractor/jina_reader_extractor.py new file mode 100644 index 00000000000000..5b780af126b309 --- /dev/null +++ b/api/core/rag/extractor/jina_reader_extractor.py @@ -0,0 +1,35 @@ +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document +from services.website_service import WebsiteService + + +class JinaReaderWebExtractor(BaseExtractor): + """ + Crawl and scrape websites and return content in clean llm-ready markdown. + """ + + def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = False): + """Initialize with url, api_key, base_url and mode.""" + self._url = url + self.job_id = job_id + self.tenant_id = tenant_id + self.mode = mode + self.only_main_content = only_main_content + + def extract(self) -> list[Document]: + """Extract content from the URL.""" + documents = [] + if self.mode == "crawl": + crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "jinareader", self._url, self.tenant_id) + if crawl_data is None: + return [] + document = Document( + page_content=crawl_data.get("content", ""), + metadata={ + "source_url": crawl_data.get("url"), + "description": crawl_data.get("description"), + "title": crawl_data.get("title"), + }, + ) + documents.append(document) + return documents diff --git a/api/core/rag/extractor/markdown_extractor.py b/api/core/rag/extractor/markdown_extractor.py index b24cf2e1707376..849852ac23819a 100644 --- a/api/core/rag/extractor/markdown_extractor.py +++ b/api/core/rag/extractor/markdown_extractor.py @@ -1,5 +1,7 @@ """Abstract interface for document loader implementations.""" + import re +from pathlib import Path from typing import Optional, cast from core.rag.extractor.extractor_base import BaseExtractor @@ -16,12 +18,12 @@ class MarkdownExtractor(BaseExtractor): """ def __init__( - self, - file_path: str, - remove_hyperlinks: bool = False, - remove_images: bool = False, - encoding: Optional[str] = None, - autodetect_encoding: bool = True, + self, + file_path: str, + remove_hyperlinks: bool = False, + remove_images: bool = False, + encoding: Optional[str] = None, + autodetect_encoding: bool = True, ): """Initialize with file path.""" self._file_path = file_path @@ -78,13 +80,10 @@ def markdown_to_tups(self, markdown_text: str) -> list[tuple[Optional[str], str] if current_header is not None: # pass linting, assert keys are defined markdown_tups = [ - (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) - for key, value in markdown_tups + (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) for key, value in markdown_tups ] else: - markdown_tups = [ - (key, re.sub("\n", "", value)) for key, value in markdown_tups - ] + markdown_tups = [(key, re.sub("\n", "", value)) for key, value in markdown_tups] return markdown_tups @@ -104,15 +103,13 @@ def parse_tups(self, filepath: str) -> list[tuple[Optional[str], str]]: """Parse file into tuples.""" content = "" try: - with open(filepath, encoding=self._encoding) as f: - content = f.read() + content = Path(filepath).read_text(encoding=self._encoding) except UnicodeDecodeError as e: if self._autodetect_encoding: detected_encodings = detect_file_encodings(filepath) for encoding in detected_encodings: try: - with open(filepath, encoding=encoding.encoding) as f: - content = f.read() + content = Path(filepath).read_text(encoding=encoding.encoding) break except UnicodeDecodeError: continue diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 9535455909dd9c..87a4ce08bf3f89 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -21,22 +21,21 @@ RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}" # if user want split by headings, use the corresponding splitter HEADING_SPLITTER = { - 'heading_1': '# ', - 'heading_2': '## ', - 'heading_3': '### ', + "heading_1": "# ", + "heading_2": "## ", + "heading_3": "### ", } -class NotionExtractor(BaseExtractor): +class NotionExtractor(BaseExtractor): def __init__( - self, - notion_workspace_id: str, - notion_obj_id: str, - notion_page_type: str, - tenant_id: str, - document_model: Optional[DocumentModel] = None, - notion_access_token: Optional[str] = None, - + self, + notion_workspace_id: str, + notion_obj_id: str, + notion_page_type: str, + tenant_id: str, + document_model: Optional[DocumentModel] = None, + notion_access_token: Optional[str] = None, ): self._notion_access_token = None self._document_model = document_model @@ -46,46 +45,38 @@ def __init__( if notion_access_token: self._notion_access_token = notion_access_token else: - self._notion_access_token = self._get_access_token(tenant_id, - self._notion_workspace_id) + self._notion_access_token = self._get_access_token(tenant_id, self._notion_workspace_id) if not self._notion_access_token: integration_token = dify_config.NOTION_INTEGRATION_TOKEN if integration_token is None: raise ValueError( - "Must specify `integration_token` or set environment " - "variable `NOTION_INTEGRATION_TOKEN`." + "Must specify `integration_token` or set environment variable `NOTION_INTEGRATION_TOKEN`." ) self._notion_access_token = integration_token def extract(self) -> list[Document]: - self.update_last_edited_time( - self._document_model - ) + self.update_last_edited_time(self._document_model) text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type) return text_docs - def _load_data_as_documents( - self, notion_obj_id: str, notion_page_type: str - ) -> list[Document]: + def _load_data_as_documents(self, notion_obj_id: str, notion_page_type: str) -> list[Document]: docs = [] - if notion_page_type == 'database': + if notion_page_type == "database": # get all the pages in the database page_text_documents = self._get_notion_database_data(notion_obj_id) docs.extend(page_text_documents) - elif notion_page_type == 'page': + elif notion_page_type == "page": page_text_list = self._get_notion_block_data(notion_obj_id) - docs.append(Document(page_content='\n'.join(page_text_list))) + docs.append(Document(page_content="\n".join(page_text_list))) else: raise ValueError("notion page type not supported") return docs - def _get_notion_database_data( - self, database_id: str, query_dict: dict[str, Any] = {} - ) -> list[Document]: + def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] = {}) -> list[Document]: """Get all the pages from a Notion database.""" res = requests.post( DATABASE_URL_TMPL.format(database_id=database_id), @@ -100,50 +91,50 @@ def _get_notion_database_data( data = res.json() database_content = [] - if 'results' not in data or data["results"] is None: + if "results" not in data or data["results"] is None: return [] for result in data["results"]: - properties = result['properties'] + properties = result["properties"] data = {} for property_name, property_value in properties.items(): - type = property_value['type'] - if type == 'multi_select': + type = property_value["type"] + if type == "multi_select": value = [] multi_select_list = property_value[type] for multi_select in multi_select_list: - value.append(multi_select['name']) - elif type == 'rich_text' or type == 'title': + value.append(multi_select["name"]) + elif type in {"rich_text", "title"}: if len(property_value[type]) > 0: - value = property_value[type][0]['plain_text'] + value = property_value[type][0]["plain_text"] else: - value = '' - elif type == 'select' or type == 'status': + value = "" + elif type in {"select", "status"}: if property_value[type]: - value = property_value[type]['name'] + value = property_value[type]["name"] else: - value = '' + value = "" else: value = property_value[type] data[property_name] = value row_dict = {k: v for k, v in data.items() if v} - row_content = '' + row_content = "" for key, value in row_dict.items(): if isinstance(value, dict): value_dict = {k: v for k, v in value.items() if v} - value_content = ''.join(f'{k}:{v} ' for k, v in value_dict.items()) - row_content = row_content + f'{key}:{value_content}\n' + value_content = "".join(f"{k}:{v} " for k, v in value_dict.items()) + row_content = row_content + f"{key}:{value_content}\n" else: - row_content = row_content + f'{key}:{value}\n' + row_content = row_content + f"{key}:{value}\n" database_content.append(row_content) - return [Document(page_content='\n'.join(database_content))] + return [Document(page_content="\n".join(database_content))] def _get_notion_block_data(self, page_id: str) -> list[str]: result_lines_arr = [] start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=page_id) while True: - query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor} + query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor} res = requests.request( "GET", block_url, @@ -152,14 +143,14 @@ def _get_notion_block_data(self, page_id: str) -> list[str]: "Content-Type": "application/json", "Notion-Version": "2022-06-28", }, - params=query_dict + params=query_dict, ) data = res.json() for result in data["results"]: result_type = result["type"] result_obj = result[result_type] cur_result_text_arr = [] - if result_type == 'table': + if result_type == "table": result_block_id = result["id"] text = self._read_table_rows(result_block_id) text += "\n\n" @@ -175,17 +166,15 @@ def _get_notion_block_data(self, page_id: str) -> list[str]: result_block_id = result["id"] has_children = result["has_children"] block_type = result["type"] - if has_children and block_type != 'child_page': - children_text = self._read_block( - result_block_id, num_tabs=1 - ) + if has_children and block_type != "child_page": + children_text = self._read_block(result_block_id, num_tabs=1) cur_result_text_arr.append(children_text) cur_result_text = "\n".join(cur_result_text_arr) if result_type in HEADING_SPLITTER: result_lines_arr.append(f"{HEADING_SPLITTER[result_type]}{cur_result_text}") else: - result_lines_arr.append(cur_result_text + '\n\n') + result_lines_arr.append(cur_result_text + "\n\n") if data["next_cursor"] is None: break @@ -199,7 +188,7 @@ def _read_block(self, block_id: str, num_tabs: int = 0) -> str: start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id) while True: - query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor} + query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor} res = requests.request( "GET", @@ -209,16 +198,16 @@ def _read_block(self, block_id: str, num_tabs: int = 0) -> str: "Content-Type": "application/json", "Notion-Version": "2022-06-28", }, - params=query_dict + params=query_dict, ) data = res.json() - if 'results' not in data or data["results"] is None: + if "results" not in data or data["results"] is None: break for result in data["results"]: result_type = result["type"] result_obj = result[result_type] cur_result_text_arr = [] - if result_type == 'table': + if result_type == "table": result_block_id = result["id"] text = self._read_table_rows(result_block_id) result_lines_arr.append(text) @@ -233,17 +222,15 @@ def _read_block(self, block_id: str, num_tabs: int = 0) -> str: result_block_id = result["id"] has_children = result["has_children"] block_type = result["type"] - if has_children and block_type != 'child_page': - children_text = self._read_block( - result_block_id, num_tabs=num_tabs + 1 - ) + if has_children and block_type != "child_page": + children_text = self._read_block(result_block_id, num_tabs=num_tabs + 1) cur_result_text_arr.append(children_text) cur_result_text = "\n".join(cur_result_text_arr) if result_type in HEADING_SPLITTER: - result_lines_arr.append(f'{HEADING_SPLITTER[result_type]}{cur_result_text}') + result_lines_arr.append(f"{HEADING_SPLITTER[result_type]}{cur_result_text}") else: - result_lines_arr.append(cur_result_text + '\n\n') + result_lines_arr.append(cur_result_text + "\n\n") if data["next_cursor"] is None: break @@ -260,7 +247,7 @@ def _read_table_rows(self, block_id: str) -> str: start_cursor = None block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id) while not done: - query_dict: dict[str, Any] = {} if not start_cursor else {'start_cursor': start_cursor} + query_dict: dict[str, Any] = {} if not start_cursor else {"start_cursor": start_cursor} res = requests.request( "GET", @@ -270,31 +257,36 @@ def _read_table_rows(self, block_id: str) -> str: "Content-Type": "application/json", "Notion-Version": "2022-06-28", }, - params=query_dict + params=query_dict, ) data = res.json() # get table headers text table_header_cell_texts = [] - tabel_header_cells = data["results"][0]['table_row']['cells'] - for tabel_header_cell in tabel_header_cells: - if tabel_header_cell: - for table_header_cell_text in tabel_header_cell: + table_header_cells = data["results"][0]["table_row"]["cells"] + for table_header_cell in table_header_cells: + if table_header_cell: + for table_header_cell_text in table_header_cell: text = table_header_cell_text["text"]["content"] table_header_cell_texts.append(text) - # get table columns text and format + else: + table_header_cell_texts.append("") + # Initialize Markdown table with headers + markdown_table = "| " + " | ".join(table_header_cell_texts) + " |\n" + markdown_table += "| " + " | ".join(["---"] * len(table_header_cell_texts)) + " |\n" + + # Process data to format each row in Markdown table format results = data["results"] for i in range(len(results) - 1): column_texts = [] - tabel_column_cells = data["results"][i + 1]['table_row']['cells'] - for j in range(len(tabel_column_cells)): - if tabel_column_cells[j]: - for table_column_cell_text in tabel_column_cells[j]: + table_column_cells = data["results"][i + 1]["table_row"]["cells"] + for j in range(len(table_column_cells)): + if table_column_cells[j]: + for table_column_cell_text in table_column_cells[j]: column_text = table_column_cell_text["text"]["content"] - column_texts.append(f'{table_header_cell_texts[j]}:{column_text}') - - cur_result_text = "\n".join(column_texts) - result_lines_arr.append(cur_result_text) - + column_texts.append(column_text) + # Add row to Markdown table + markdown_table += "| " + " | ".join(column_texts) + " |\n" + result_lines_arr.append(markdown_table) if data["next_cursor"] is None: done = True break @@ -310,10 +302,8 @@ def update_last_edited_time(self, document_model: DocumentModel): last_edited_time = self.get_notion_last_edited_time() data_source_info = document_model.data_source_info_dict - data_source_info['last_edited_time'] = last_edited_time - update_params = { - DocumentModel.data_source_info: json.dumps(data_source_info) - } + data_source_info["last_edited_time"] = last_edited_time + update_params = {DocumentModel.data_source_info: json.dumps(data_source_info)} DocumentModel.query.filter_by(id=document_model.id).update(update_params) db.session.commit() @@ -321,7 +311,7 @@ def update_last_edited_time(self, document_model: DocumentModel): def get_notion_last_edited_time(self) -> str: obj_id = self._notion_obj_id page_type = self._notion_page_type - if page_type == 'database': + if page_type == "database": retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=obj_id) else: retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id) @@ -336,7 +326,7 @@ def get_notion_last_edited_time(self) -> str: "Content-Type": "application/json", "Notion-Version": "2022-06-28", }, - json=query_dict + json=query_dict, ) data = res.json() @@ -347,14 +337,16 @@ def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"' + DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"', ) ).first() if not data_source_binding: - raise Exception(f'No notion data source binding found for tenant {tenant_id} ' - f'and notion workspace {notion_workspace_id}') + raise Exception( + f"No notion data source binding found for tenant {tenant_id} " + f"and notion workspace {notion_workspace_id}" + ) return data_source_binding.access_token diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index cbb2655390be2b..57cb9610ba267e 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -1,8 +1,9 @@ """Abstract interface for document loader implementations.""" + from collections.abc import Iterator from typing import Optional -from core.rag.extractor.blod.blod import Blob +from core.rag.extractor.blob.blob import Blob from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from extensions.ext_storage import storage @@ -16,21 +17,17 @@ class PdfExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - file_cache_key: Optional[str] = None - ): + def __init__(self, file_path: str, file_cache_key: Optional[str] = None): """Initialize with file path.""" self._file_path = file_path self._file_cache_key = file_cache_key def extract(self) -> list[Document]: - plaintext_file_key = '' + plaintext_file_key = "" plaintext_file_exists = False if self._file_cache_key: try: - text = storage.load(self._file_cache_key).decode('utf-8') + text = storage.load(self._file_cache_key).decode("utf-8") plaintext_file_exists = True return [Document(page_content=text)] except FileNotFoundError: @@ -43,12 +40,12 @@ def extract(self) -> list[Document]: # save plaintext file for caching if not plaintext_file_exists and plaintext_file_key: - storage.save(plaintext_file_key, text.encode('utf-8')) + storage.save(plaintext_file_key, text.encode("utf-8")) return documents def load( - self, + self, ) -> Iterator[Document]: """Lazy load given path as pages.""" blob = Blob.from_path(self._file_path) diff --git a/api/core/rag/extractor/text_extractor.py b/api/core/rag/extractor/text_extractor.py index ac5d0920cffb48..b2b51d71d73a16 100644 --- a/api/core/rag/extractor/text_extractor.py +++ b/api/core/rag/extractor/text_extractor.py @@ -1,4 +1,6 @@ """Abstract interface for document loader implementations.""" + +from pathlib import Path from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor @@ -14,12 +16,7 @@ class TextExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - encoding: Optional[str] = None, - autodetect_encoding: bool = False - ): + def __init__(self, file_path: str, encoding: Optional[str] = None, autodetect_encoding: bool = False): """Initialize with file path.""" self._file_path = file_path self._encoding = encoding @@ -29,15 +26,13 @@ def extract(self) -> list[Document]: """Load from file path.""" text = "" try: - with open(self._file_path, encoding=self._encoding) as f: - text = f.read() + text = Path(self._file_path).read_text(encoding=self._encoding) except UnicodeDecodeError as e: if self._autodetect_encoding: detected_encodings = detect_file_encodings(self._file_path) for encoding in detected_encodings: try: - with open(self._file_path, encoding=encoding.encoding) as f: - text = f.read() + text = Path(self._file_path).read_text(encoding=encoding.encoding) break except UnicodeDecodeError: continue diff --git a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py index 0323b14a4a34fd..a525c9e9e3c443 100644 --- a/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_doc_extractor.py @@ -8,13 +8,12 @@ class UnstructuredWordExtractor(BaseExtractor): - """Loader that uses unstructured to load word documents. - """ + """Loader that uses unstructured to load word documents.""" def __init__( - self, - file_path: str, - api_url: str, + self, + file_path: str, + api_url: str, ): """Initialize with file path.""" self._file_path = file_path @@ -24,9 +23,7 @@ def extract(self) -> list[Document]: from unstructured.__version__ import __version__ as __unstructured_version__ from unstructured.file_utils.filetype import FileType, detect_filetype - unstructured_version = tuple( - int(x) for x in __unstructured_version__.split(".") - ) + unstructured_version = tuple(int(x) for x in __unstructured_version__.split(".")) # check the file extension try: import magic # noqa: F401 @@ -53,6 +50,7 @@ def extract(self) -> list[Document]: elements = partition_docx(filename=self._file_path) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py index 2e704f187d05d6..bd669bbad36873 100644 --- a/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_eml_extractor.py @@ -10,23 +10,26 @@ class UnstructuredEmailExtractor(BaseExtractor): - """Load msg files. + """Load eml files. Args: file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str, - ): + def __init__(self, file_path: str, api_url: str, api_key: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url + self._api_key = api_key def extract(self) -> list[Document]: - from unstructured.partition.email import partition_email - elements = partition_email(filename=self._file_path) + if self._api_url: + from unstructured.partition.api import partition_via_api + + elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key) + else: + from unstructured.partition.email import partition_email + + elements = partition_email(filename=self._file_path) # noinspection PyBroadException try: @@ -34,15 +37,16 @@ def extract(self) -> list[Document]: element_text = element.text.strip() padding_needed = 4 - len(element_text) % 4 - element_text += '=' * padding_needed + element_text += "=" * padding_needed element_decode = base64.b64decode(element_text) - soup = BeautifulSoup(element_decode.decode('utf-8'), 'html.parser') + soup = BeautifulSoup(element_decode.decode("utf-8"), "html.parser") element.text = soup.get_text() except Exception: pass from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py index 44cf958ea2b636..35220b558afab9 100644 --- a/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_epub_extractor.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -17,17 +18,26 @@ class UnstructuredEpubExtractor(BaseExtractor): def __init__( self, file_path: str, - api_url: str = None, + api_url: Optional[str] = None, + api_key: Optional[str] = None, ): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url + self._api_key = api_key def extract(self) -> list[Document]: - from unstructured.partition.epub import partition_epub + if self._api_url: + from unstructured.partition.api import partition_via_api + + elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key) + else: + from unstructured.partition.epub import partition_epub + + elements = partition_epub(filename=self._file_path, xml_keep_tags=True) - elements = partition_epub(filename=self._file_path, xml_keep_tags=True) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py index 144b4e0c1d7a91..4173d4d122dc24 100644 --- a/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_markdown_extractor.py @@ -24,20 +24,23 @@ class UnstructuredMarkdownExtractor(BaseExtractor): if the specified encoding fails. """ - def __init__( - self, - file_path: str, - api_url: str, - ): + def __init__(self, file_path: str, api_url: str, api_key: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url + self._api_key = api_key def extract(self) -> list[Document]: - from unstructured.partition.md import partition_md + if self._api_url: + from unstructured.partition.api import partition_via_api - elements = partition_md(filename=self._file_path) + elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key) + else: + from unstructured.partition.md import partition_md + + elements = partition_md(filename=self._file_path) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py index ad09b79eb00a07..57affb8d36b2bb 100644 --- a/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_msg_extractor.py @@ -14,20 +14,23 @@ class UnstructuredMsgExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str - ): + def __init__(self, file_path: str, api_url: str, api_key: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url + self._api_key = api_key def extract(self) -> list[Document]: - from unstructured.partition.msg import partition_msg + if self._api_url: + from unstructured.partition.api import partition_via_api - elements = partition_msg(filename=self._file_path) + elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key) + else: + from unstructured.partition.msg import partition_msg + + elements = partition_msg(filename=self._file_path) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_pdf_extractor.py b/api/core/rag/extractor/unstructured/unstructured_pdf_extractor.py new file mode 100644 index 00000000000000..dd8a979e709989 --- /dev/null +++ b/api/core/rag/extractor/unstructured/unstructured_pdf_extractor.py @@ -0,0 +1,47 @@ +import logging + +from core.rag.extractor.extractor_base import BaseExtractor +from core.rag.models.document import Document + +logger = logging.getLogger(__name__) + + +class UnstructuredPDFExtractor(BaseExtractor): + """Load pdf files. + + + Args: + file_path: Path to the file to load. + + api_url: Unstructured API URL + + api_key: Unstructured API Key + """ + + def __init__(self, file_path: str, api_url: str, api_key: str): + """Initialize with file path.""" + self._file_path = file_path + self._api_url = api_url + self._api_key = api_key + + def extract(self) -> list[Document]: + if self._api_url: + from unstructured.partition.api import partition_via_api + + elements = partition_via_api( + filename=self._file_path, api_url=self._api_url, api_key=self._api_key, strategy="auto" + ) + else: + from unstructured.partition.pdf import partition_pdf + + elements = partition_pdf(filename=self._file_path, strategy="auto") + + from unstructured.chunking.title import chunk_by_title + + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) + documents = [] + for chunk in chunks: + text = chunk.text.strip() + documents.append(Document(page_content=text)) + + return documents diff --git a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py index d354b593ed7af6..0fdcd58b2e569b 100644 --- a/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_ppt_extractor.py @@ -7,28 +7,26 @@ class UnstructuredPPTExtractor(BaseExtractor): - """Load msg files. + """Load ppt files. Args: file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str, - api_key: str - ): + def __init__(self, file_path: str, api_url: str, api_key: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url self._api_key = api_key def extract(self) -> list[Document]: - from unstructured.partition.api import partition_via_api + if self._api_url: + from unstructured.partition.api import partition_via_api - elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key) + elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key) + else: + raise NotImplementedError("Unstructured API Url is not configured") text_by_page = {} for element in elements: page = element.metadata.page_number diff --git a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py index 6fcbb5feb991d0..ab41290fbc4537 100644 --- a/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_pptx_extractor.py @@ -7,26 +7,28 @@ class UnstructuredPPTXExtractor(BaseExtractor): - """Load msg files. + """Load pptx files. Args: file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str - ): + def __init__(self, file_path: str, api_url: str, api_key: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url + self._api_key = api_key def extract(self) -> list[Document]: - from unstructured.partition.pptx import partition_pptx + if self._api_url: + from unstructured.partition.api import partition_via_api - elements = partition_pptx(filename=self._file_path) + elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key) + else: + from unstructured.partition.pptx import partition_pptx + + elements = partition_pptx(filename=self._file_path) text_by_page = {} for element in elements: page = element.metadata.page_number diff --git a/api/core/rag/extractor/unstructured/unstructured_text_extractor.py b/api/core/rag/extractor/unstructured/unstructured_text_extractor.py index f4a4adbc1600fd..22dfdd20752cbf 100644 --- a/api/core/rag/extractor/unstructured/unstructured_text_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_text_extractor.py @@ -14,11 +14,7 @@ class UnstructuredTextExtractor(BaseExtractor): file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str - ): + def __init__(self, file_path: str, api_url: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url @@ -28,6 +24,7 @@ def extract(self) -> list[Document]: elements = partition_text(filename=self._file_path) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py index 6aef8e0f7e2718..ef46ab0e70da6e 100644 --- a/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py +++ b/api/core/rag/extractor/unstructured/unstructured_xml_extractor.py @@ -7,27 +7,31 @@ class UnstructuredXmlExtractor(BaseExtractor): - """Load msg files. + """Load xml files. Args: file_path: Path to the file to load. """ - def __init__( - self, - file_path: str, - api_url: str - ): + def __init__(self, file_path: str, api_url: str, api_key: str): """Initialize with file path.""" self._file_path = file_path self._api_url = api_url + self._api_key = api_key def extract(self) -> list[Document]: - from unstructured.partition.xml import partition_xml + if self._api_url: + from unstructured.partition.api import partition_via_api + + elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key) + else: + from unstructured.partition.xml import partition_xml + + elements = partition_xml(filename=self._file_path, xml_keep_tags=True) - elements = partition_xml(filename=self._file_path, xml_keep_tags=True) from unstructured.chunking.title import chunk_by_title + chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) documents = [] for chunk in chunks: diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index c3f0b75cfba5f1..b59e7f94fd5013 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + import datetime import logging import mimetypes @@ -6,25 +7,27 @@ import re import tempfile import uuid -import xml.etree.ElementTree as ET from urllib.parse import urlparse +from xml.etree import ElementTree import requests from docx import Document as DocxDocument from configs import dify_config +from core.helper import ssrf_proxy from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_storage import storage +from models.enums import CreatedByRole from models.model import UploadFile logger = logging.getLogger(__name__) + class WordExtractor(BaseExtractor): """Load docx files. - Args: file_path: Path to the file to load. """ @@ -43,14 +46,13 @@ def __init__(self, file_path: str, tenant_id: str, user_id: str): r = requests.get(self.file_path) if r.status_code != 200: - raise ValueError( - f"Check the url of your file; returned status code {r.status_code}" - ) + raise ValueError(f"Check the url of your file; returned status code {r.status_code}") self.web_path = self.file_path - self.temp_file = tempfile.NamedTemporaryFile() - self.temp_file.write(r.content) - self.file_path = self.temp_file.name + # TODO: use a better way to handle the file + with tempfile.NamedTemporaryFile(delete=False) as self.temp_file: + self.temp_file.write(r.content) + self.file_path = self.temp_file.name elif not os.path.isfile(self.file_path): raise ValueError(f"File path {self.file_path} is not a valid file or url") @@ -60,11 +62,13 @@ def __del__(self) -> None: def extract(self) -> list[Document]: """Load given path as single page.""" - content = self.parse_docx(self.file_path, 'storage') - return [Document( - page_content=content, - metadata={"source": self.file_path}, - )] + content = self.parse_docx(self.file_path, "storage") + return [ + Document( + page_content=content, + metadata={"source": self.file_path}, + ) + ] @staticmethod def _is_valid_url(url: str) -> bool: @@ -82,20 +86,20 @@ def _extract_images_from_docx(self, doc, image_folder): image_count += 1 if rel.is_external: url = rel.reltype - response = requests.get(url, stream=True) + response = ssrf_proxy.get(url, stream=True) if response.status_code == 200: - image_ext = mimetypes.guess_extension(response.headers['Content-Type']) + image_ext = mimetypes.guess_extension(response.headers["Content-Type"]) file_uuid = str(uuid.uuid4()) - file_key = 'image_files/' + self.tenant_id + '/' + file_uuid + '.' + image_ext + file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext mime_type, _ = mimetypes.guess_type(file_key) storage.save(file_key, response.content) else: continue else: - image_ext = rel.target_ref.split('.')[-1] + image_ext = rel.target_ref.split(".")[-1] # user uuid as file name file_uuid = str(uuid.uuid4()) - file_key = 'image_files/' + self.tenant_id + '/' + file_uuid + '.' + image_ext + file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext mime_type, _ = mimetypes.guess_type(file_key) storage.save(file_key, rel.target_part.blob) @@ -106,18 +110,21 @@ def _extract_images_from_docx(self, doc, image_folder): key=file_key, name=file_key, size=0, - extension=image_ext, - mime_type=mime_type, + extension=str(image_ext), + mime_type=mime_type or "", created_by=self.user_id, + created_by_role=CreatedByRole.ACCOUNT, created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), used=True, used_by=self.user_id, - used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), ) db.session.add(upload_file) db.session.commit() - image_map[rel.target_part] = f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/image-preview)" + image_map[rel.target_part] = ( + f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/file-preview)" + ) return image_map @@ -148,7 +155,7 @@ def _parse_row(self, row, image_map, total_cols): if col_index >= total_cols: break cell_content = self._parse_cell(cell, image_map).strip() - cell_colspan = cell.grid_span if cell.grid_span else 1 + cell_colspan = cell.grid_span or 1 for i in range(cell_colspan): if col_index + i < total_cols: row_cells[col_index + i] = cell_content if i == 0 else "" @@ -167,9 +174,11 @@ def _parse_cell(self, cell, image_map): def _parse_cell_paragraph(self, paragraph, image_map): paragraph_content = [] for run in paragraph.runs: - if run.element.xpath('.//a:blip'): - for blip in run.element.xpath('.//a:blip'): + if run.element.xpath(".//a:blip"): + for blip in run.element.xpath(".//a:blip"): image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") + if not image_id: + continue image_part = paragraph.part.rels[image_id].target_part if image_part in image_map: @@ -182,16 +191,16 @@ def _parse_cell_paragraph(self, paragraph, image_map): def _parse_paragraph(self, paragraph, image_map): paragraph_content = [] for run in paragraph.runs: - if run.element.xpath('.//a:blip'): - for blip in run.element.xpath('.//a:blip'): - embed_id = blip.get('{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed') + if run.element.xpath(".//a:blip"): + for blip in run.element.xpath(".//a:blip"): + embed_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") if embed_id: rel_target = run.part.rels[embed_id].target_ref if rel_target in image_map: paragraph_content.append(image_map[rel_target]) if run.text.strip(): paragraph_content.append(run.text.strip()) - return ' '.join(paragraph_content) if paragraph_content else '' + return " ".join(paragraph_content) if paragraph_content else "" def parse_docx(self, docx_path, image_folder): doc = DocxDocument(docx_path) @@ -202,60 +211,59 @@ def parse_docx(self, docx_path, image_folder): image_map = self._extract_images_from_docx(doc, image_folder) hyperlinks_url = None - url_pattern = re.compile(r'http://[^\s+]+//|https://[^\s+]+') + url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+") for para in doc.paragraphs: for run in para.runs: if run.text and hyperlinks_url: - result = f' [{run.text}]({hyperlinks_url}) ' + result = f" [{run.text}]({hyperlinks_url}) " run.text = result hyperlinks_url = None - if 'HYPERLINK' in run.element.xml: + if "HYPERLINK" in run.element.xml: try: - xml = ET.XML(run.element.xml) + xml = ElementTree.XML(run.element.xml) x_child = [c for c in xml.iter() if c is not None] for x in x_child: if x_child is None: continue - if x.tag.endswith('instrText'): + if x.tag.endswith("instrText"): for i in url_pattern.findall(x.text): hyperlinks_url = str(i) except Exception as e: - logger.error(e) - - - + logger.exception(e) def parse_paragraph(paragraph): paragraph_content = [] for run in paragraph.runs: - if hasattr(run.element, 'tag') and isinstance(element.tag, str) and run.element.tag.endswith('r'): + if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"): drawing_elements = run.element.findall( - './/{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing') + ".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing" + ) for drawing in drawing_elements: blip_elements = drawing.findall( - './/{http://schemas.openxmlformats.org/drawingml/2006/main}blip') + ".//{http://schemas.openxmlformats.org/drawingml/2006/main}blip" + ) for blip in blip_elements: embed_id = blip.get( - '{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed') + "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed" + ) if embed_id: image_part = doc.part.related_parts.get(embed_id) if image_part in image_map: paragraph_content.append(image_map[image_part]) if run.text.strip(): paragraph_content.append(run.text.strip()) - return ''.join(paragraph_content) if paragraph_content else '' + return "".join(paragraph_content) if paragraph_content else "" paragraphs = doc.paragraphs.copy() tables = doc.tables.copy() for element in doc.element.body: - if hasattr(element, 'tag'): - if isinstance(element.tag, str) and element.tag.endswith('p'): # paragraph + if hasattr(element, "tag"): + if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph para = paragraphs.pop(0) parsed_paragraph = parse_paragraph(para) if parsed_paragraph: content.append(parsed_paragraph) - elif isinstance(element.tag, str) and element.tag.endswith('tbl'): # table + elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table table = tables.pop(0) - content.append(self._table_to_markdown(table,image_map)) - return '\n'.join(content) - + content.append(self._table_to_markdown(table, image_map)) + return "\n".join(content) diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 33e78ce8c5ccb0..be857bd12215fd 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -1,4 +1,5 @@ """Abstract interface for document loader implementations.""" + from abc import ABC, abstractmethod from typing import Optional @@ -15,8 +16,7 @@ class BaseIndexProcessor(ABC): - """Interface for extract files. - """ + """Interface for extract files.""" @abstractmethod def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: @@ -34,18 +34,24 @@ def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: raise NotImplementedError @abstractmethod - def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, - score_threshold: float, reranking_model: dict) -> list[Document]: + def retrieve( + self, + retrieval_method: str, + query: str, + dataset: Dataset, + top_k: int, + score_threshold: float, + reranking_model: dict, + ) -> list[Document]: raise NotImplementedError - def _get_splitter(self, processing_rule: dict, - embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: + def _get_splitter(self, processing_rule: dict, embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: """ Get the NodeParser object according to the processing rule. """ - if processing_rule['mode'] == "custom": + if processing_rule["mode"] == "custom": # The user-defined segmentation rule - rules = processing_rule['rules'] + rules = processing_rule["rules"] segmentation = rules["segmentation"] max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: @@ -53,22 +59,22 @@ def _get_splitter(self, processing_rule: dict, separator = segmentation["separator"] if separator: - separator = separator.replace('\\n', '\n') + separator = separator.replace("\\n", "\n") character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( chunk_size=segmentation["max_tokens"], - chunk_overlap=segmentation.get('chunk_overlap', 0), + chunk_overlap=segmentation.get("chunk_overlap", 0) or 0, fixed_separator=separator, separators=["\n\n", "。", ". ", " ", ""], - embedding_model_instance=embedding_model_instance + embedding_model_instance=embedding_model_instance, ) else: # Automatic segmentation character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( - chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], - chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'], + chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"], + chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"], separators=["\n\n", "。", ". ", " ", ""], - embedding_model_instance=embedding_model_instance + embedding_model_instance=embedding_model_instance, ) return character_splitter diff --git a/api/core/rag/index_processor/index_processor_factory.py b/api/core/rag/index_processor/index_processor_factory.py index df43a6491074b8..9b855ece2c3512 100644 --- a/api/core/rag/index_processor/index_processor_factory.py +++ b/api/core/rag/index_processor/index_processor_factory.py @@ -7,8 +7,7 @@ class IndexProcessorFactory: - """IndexProcessorInit. - """ + """IndexProcessorInit.""" def __init__(self, index_type: str): self._index_type = index_type @@ -22,7 +21,6 @@ def init_index_processor(self) -> BaseIndexProcessor: if self._index_type == IndexType.PARAGRAPH_INDEX.value: return ParagraphIndexProcessor() elif self._index_type == IndexType.QA_INDEX.value: - return QAIndexProcessor() else: raise ValueError(f"Index type {self._index_type} is not supported.") diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 5fbc319fd633d2..ed5712220f072e 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -1,4 +1,5 @@ """Paragraph index processor.""" + import uuid from typing import Optional @@ -15,34 +16,33 @@ class ParagraphIndexProcessor(BaseIndexProcessor): - def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: - - text_docs = ExtractProcessor.extract(extract_setting=extract_setting, - is_automatic=kwargs.get('process_rule_mode') == "automatic") + text_docs = ExtractProcessor.extract( + extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" + ) return text_docs def transform(self, documents: list[Document], **kwargs) -> list[Document]: # Split the text documents into nodes. - splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), - embedding_model_instance=kwargs.get('embedding_model_instance')) + splitter = self._get_splitter( + processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance") + ) all_documents = [] for document in documents: # document clean - document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule')) + document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule")) document.page_content = document_text # parse document to nodes document_nodes = splitter.split_documents([document]) split_documents = [] for document_node in document_nodes: - if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata['doc_id'] = doc_id - document_node.metadata['doc_hash'] = hash - # delete Spliter character + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash + # delete Splitter character page_content = document_node.page_content if page_content.startswith(".") or page_content.startswith("。"): page_content = page_content[1:].strip() @@ -55,7 +55,7 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: return all_documents def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) if with_keywords: @@ -63,7 +63,7 @@ def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool keyword.create(documents) def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": vector = Vector(dataset) if node_ids: vector.delete_by_ids(node_ids) @@ -76,17 +76,29 @@ def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: else: keyword.delete() - def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, - score_threshold: float, reranking_model: dict) -> list[Document]: + def retrieve( + self, + retrieval_method: str, + query: str, + dataset: Dataset, + top_k: int, + score_threshold: float, + reranking_model: dict, + ) -> list[Document]: # Set search parameters. - results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query, - top_k=top_k, score_threshold=score_threshold, - reranking_model=reranking_model) + results = RetrievalService.retrieve( + retrieval_method=retrieval_method, + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + ) # Organize results. docs = [] for result in results: metadata = result.metadata - metadata['score'] = result.score + metadata["score"] = result.score if result.score > score_threshold: doc = Document(page_content=result.page_content, metadata=metadata) docs.append(doc) diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 139bfe15f328d6..1dbc473281daf6 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -1,4 +1,5 @@ """Paragraph index processor.""" + import logging import re import threading @@ -23,34 +24,34 @@ class QAIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: - - text_docs = ExtractProcessor.extract(extract_setting=extract_setting, - is_automatic=kwargs.get('process_rule_mode') == "automatic") + text_docs = ExtractProcessor.extract( + extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" + ) return text_docs def transform(self, documents: list[Document], **kwargs) -> list[Document]: - splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), - embedding_model_instance=kwargs.get('embedding_model_instance')) + splitter = self._get_splitter( + processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance") + ) # Split the text documents into nodes. all_documents = [] all_qa_documents = [] for document in documents: # document clean - document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule')) + document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule")) document.page_content = document_text # parse document to nodes document_nodes = splitter.split_documents([document]) split_documents = [] for document_node in document_nodes: - if document_node.page_content.strip(): doc_id = str(uuid.uuid4()) hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata['doc_id'] = doc_id - document_node.metadata['doc_hash'] = hash - # delete Spliter character + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash + # delete Splitter character page_content = document_node.page_content if page_content.startswith(".") or page_content.startswith("。"): page_content = page_content[1:] @@ -61,14 +62,18 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: all_documents.extend(split_documents) for i in range(0, len(all_documents), 10): threads = [] - sub_documents = all_documents[i:i + 10] + sub_documents = all_documents[i : i + 10] for doc in sub_documents: - document_format_thread = threading.Thread(target=self._format_qa_document, kwargs={ - 'flask_app': current_app._get_current_object(), - 'tenant_id': kwargs.get('tenant_id'), - 'document_node': doc, - 'all_qa_documents': all_qa_documents, - 'document_language': kwargs.get('doc_language', 'English')}) + document_format_thread = threading.Thread( + target=self._format_qa_document, + kwargs={ + "flask_app": current_app._get_current_object(), + "tenant_id": kwargs.get("tenant_id"), + "document_node": doc, + "all_qa_documents": all_qa_documents, + "document_language": kwargs.get("doc_language", "English"), + }, + ) threads.append(document_format_thread) document_format_thread.start() for thread in threads: @@ -76,9 +81,8 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: return all_qa_documents def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: - # check file type - if not file.filename.endswith('.csv'): + if not file.filename.endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") try: @@ -86,7 +90,7 @@ def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: df = pd.read_csv(file) text_docs = [] for index, row in df.iterrows(): - data = Document(page_content=row[0], metadata={'answer': row[1]}) + data = Document(page_content=row[0], metadata={"answer": row[1]}) text_docs.append(data) if len(text_docs) == 0: raise ValueError("The CSV file is empty.") @@ -96,7 +100,7 @@ def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: return text_docs def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) @@ -107,17 +111,29 @@ def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: else: vector.delete() - def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, - score_threshold: float, reranking_model: dict): + def retrieve( + self, + retrieval_method: str, + query: str, + dataset: Dataset, + top_k: int, + score_threshold: float, + reranking_model: dict, + ): # Set search parameters. - results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query, - top_k=top_k, score_threshold=score_threshold, - reranking_model=reranking_model) + results = RetrievalService.retrieve( + retrieval_method=retrieval_method, + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + ) # Organize results. docs = [] for result in results: metadata = result.metadata - metadata['score'] = result.score + metadata["score"] = result.score if result.score > score_threshold: doc = Document(page_content=result.page_content, metadata=metadata) docs.append(doc) @@ -134,12 +150,12 @@ def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, a document_qa_list = self._format_split_text(response) qa_documents = [] for result in document_qa_list: - qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy()) + qa_document = Document(page_content=result["question"], metadata=document_node.metadata.copy()) doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(result['question']) - qa_document.metadata['answer'] = result['answer'] - qa_document.metadata['doc_id'] = doc_id - qa_document.metadata['doc_hash'] = hash + hash = helper.generate_text_hash(result["question"]) + qa_document.metadata["answer"] = result["answer"] + qa_document.metadata["doc_id"] = doc_id + qa_document.metadata["doc_hash"] = hash qa_documents.append(qa_document) format_documents.extend(qa_documents) except Exception as e: @@ -151,10 +167,4 @@ def _format_split_text(self, text): regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" matches = re.findall(regex, text, re.UNICODE) - return [ - { - "question": q, - "answer": re.sub(r"\n\s*", "\n", a.strip()) - } - for q, a in matches if q and a - ] + return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a] diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 6f3c1c5d343977..1e9aaa24f04c98 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -17,6 +17,8 @@ class Document(BaseModel): """ metadata: Optional[dict] = Field(default_factory=dict) + provider: Optional[str] = "dify" + class BaseDocumentTransformer(ABC): """Abstract base class for document transformation systems. @@ -55,9 +57,7 @@ async def atransform_documents( """ @abstractmethod - def transform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: + def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Transform a list of documents. Args: @@ -68,9 +68,7 @@ def transform_documents( """ @abstractmethod - async def atransform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: + async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Asynchronously transform a list of documents. Args: diff --git a/api/core/rag/rerank/constants/rerank_mode.py b/api/core/rag/rerank/constants/rerank_mode.py deleted file mode 100644 index afbb9fd89d406d..00000000000000 --- a/api/core/rag/rerank/constants/rerank_mode.py +++ /dev/null @@ -1,8 +0,0 @@ -from enum import Enum - - -class RerankMode(Enum): - - RERANKING_MODEL = 'reranking_model' - WEIGHTED_SCORE = 'weighted_score' - diff --git a/api/core/rag/rerank/rerank_base.py b/api/core/rag/rerank/rerank_base.py new file mode 100644 index 00000000000000..818b04b2ffc196 --- /dev/null +++ b/api/core/rag/rerank/rerank_base.py @@ -0,0 +1,26 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from core.rag.models.document import Document + + +class BaseRerankRunner(ABC): + @abstractmethod + def run( + self, + query: str, + documents: list[Document], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> list[Document]: + """ + Run rerank model + :param query: search query + :param documents: documents for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id if needed + :return: + """ + raise NotImplementedError diff --git a/api/core/rag/rerank/rerank_factory.py b/api/core/rag/rerank/rerank_factory.py new file mode 100644 index 00000000000000..1a3cf8573631f2 --- /dev/null +++ b/api/core/rag/rerank/rerank_factory.py @@ -0,0 +1,16 @@ +from core.rag.rerank.rerank_base import BaseRerankRunner +from core.rag.rerank.rerank_model import RerankModelRunner +from core.rag.rerank.rerank_type import RerankMode +from core.rag.rerank.weight_rerank import WeightRerankRunner + + +class RerankRunnerFactory: + @staticmethod + def create_rerank_runner(runner_type: str, *args, **kwargs) -> BaseRerankRunner: + match runner_type: + case RerankMode.RERANKING_MODEL.value: + return RerankModelRunner(*args, **kwargs) + case RerankMode.WEIGHTED_SCORE.value: + return WeightRerankRunner(*args, **kwargs) + case _: + raise ValueError(f"Unknown runner type: {runner_type}") diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index d9067da2880fec..fc82b2080b2b3c 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -2,14 +2,21 @@ from core.model_manager import ModelInstance from core.rag.models.document import Document +from core.rag.rerank.rerank_base import BaseRerankRunner -class RerankModelRunner: +class RerankModelRunner(BaseRerankRunner): def __init__(self, rerank_model_instance: ModelInstance) -> None: self.rerank_model_instance = rerank_model_instance - def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]: + def run( + self, + query: str, + documents: list[Document], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> list[Document]: """ Run rerank model :param query: search query @@ -20,22 +27,22 @@ def run(self, query: str, documents: list[Document], score_threshold: Optional[f :return: """ docs = [] - doc_id = [] + doc_id = set() unique_documents = [] for document in documents: - if document.metadata['doc_id'] not in doc_id: - doc_id.append(document.metadata['doc_id']) + if document.provider == "dify" and document.metadata["doc_id"] not in doc_id: + doc_id.add(document.metadata["doc_id"]) docs.append(document.page_content) unique_documents.append(document) + elif document.provider == "external": + if document not in unique_documents: + docs.append(document.page_content) + unique_documents.append(document) documents = unique_documents rerank_result = self.rerank_model_instance.invoke_rerank( - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - user=user + query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user ) rerank_documents = [] @@ -44,14 +51,10 @@ def run(self, query: str, documents: list[Document], score_threshold: Optional[f # format document rerank_document = Document( page_content=result.text, - metadata={ - "doc_id": documents[result.index].metadata['doc_id'], - "doc_hash": documents[result.index].metadata['doc_hash'], - "document_id": documents[result.index].metadata['document_id'], - "dataset_id": documents[result.index].metadata['dataset_id'], - 'score': result.score - } + metadata=documents[result.index].metadata, + provider=documents[result.index].provider, ) + rerank_document.metadata["score"] = result.score rerank_documents.append(rerank_document) return rerank_documents diff --git a/api/core/rag/rerank/rerank_type.py b/api/core/rag/rerank/rerank_type.py new file mode 100644 index 00000000000000..d71eb2daa8f920 --- /dev/null +++ b/api/core/rag/rerank/rerank_type.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class RerankMode(str, Enum): + RERANKING_MODEL = "reranking_model" + WEIGHTED_SCORE = "weighted_score" diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index d8a78739826a31..2e3fbe04e27452 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -4,22 +4,28 @@ import numpy as np -from core.embedding.cached_embedding import CacheEmbedding from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler +from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights +from core.rag.rerank.rerank_base import BaseRerankRunner -class WeightRerankRunner: - +class WeightRerankRunner(BaseRerankRunner): def __init__(self, tenant_id: str, weights: Weights) -> None: self.tenant_id = tenant_id self.weights = weights - def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]: + def run( + self, + query: str, + documents: list[Document], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> list[Document]: """ Run rerank model :param query: search query @@ -34,8 +40,8 @@ def run(self, query: str, documents: list[Document], score_threshold: Optional[f doc_id = [] unique_documents = [] for document in documents: - if document.metadata['doc_id'] not in doc_id: - doc_id.append(document.metadata['doc_id']) + if document.metadata["doc_id"] not in doc_id: + doc_id.append(document.metadata["doc_id"]) docs.append(document.page_content) unique_documents.append(document) @@ -47,13 +53,15 @@ def run(self, query: str, documents: list[Document], score_threshold: Optional[f query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting) for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores): # format document - score = self.weights.vector_setting.vector_weight * query_vector_score + \ - self.weights.keyword_setting.keyword_weight * query_score + score = ( + self.weights.vector_setting.vector_weight * query_vector_score + + self.weights.keyword_setting.keyword_weight * query_score + ) if score_threshold and score < score_threshold: continue - document.metadata['score'] = score + document.metadata["score"] = score rerank_documents.append(document) - rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata['score'], reverse=True) + rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata["score"], reverse=True) return rerank_documents[:top_n] if top_n else rerank_documents def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]: @@ -70,7 +78,7 @@ def _calculate_keyword_score(self, query: str, documents: list[Document]) -> lis for document in documents: # get the document keywords document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) - document.metadata['keywords'] = document_keywords + document.metadata["keywords"] = document_keywords documents_keywords.append(document_keywords) # Counter query keywords(TF) @@ -113,8 +121,8 @@ def cosine_similarity(vec1, vec2): intersection = set(vec1.keys()) & set(vec2.keys()) numerator = sum(vec1[x] * vec2[x] for x in intersection) - sum1 = sum(vec1[x] ** 2 for x in vec1.keys()) - sum2 = sum(vec2[x] ** 2 for x in vec2.keys()) + sum1 = sum(vec1[x] ** 2 for x in vec1) + sum2 = sum(vec2[x] ** 2 for x in vec2) denominator = math.sqrt(sum1) * math.sqrt(sum2) if not denominator: @@ -132,8 +140,9 @@ def cosine_similarity(vec1, vec2): return similarities - def _calculate_cosine(self, tenant_id: str, query: str, documents: list[Document], - vector_setting: VectorSetting) -> list[float]: + def _calculate_cosine( + self, tenant_id: str, query: str, documents: list[Document], vector_setting: VectorSetting + ) -> list[float]: """ Calculate Cosine scores :param query: search query @@ -149,15 +158,14 @@ def _calculate_cosine(self, tenant_id: str, query: str, documents: list[Document tenant_id=tenant_id, provider=vector_setting.embedding_provider_name, model_type=ModelType.TEXT_EMBEDDING, - model=vector_setting.embedding_model_name - + model=vector_setting.embedding_model_name, ) cache_embedding = CacheEmbedding(embedding_model) query_vector = cache_embedding.embed_query(query) for document in documents: # calculate cosine similarity - if 'score' in document.metadata: - query_vector_scores.append(document.metadata['score']) + if "score" in document.metadata: + query_vector_scores.append(document.metadata["score"]) else: # transform to NumPy vec1 = np.array(query_vector) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index e9453647969a97..7a5bf39fa63f48 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -20,8 +20,10 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.entities.context_entities import DocumentContext from core.rag.models.document import Document -from core.rag.retrieval.retrival_methods import RetrievalMethod +from core.rag.rerank.rerank_type import RerankMode +from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool @@ -30,16 +32,14 @@ from extensions.ext_database import db from models.dataset import Dataset, DatasetQuery, DocumentSegment from models.dataset import Document as DatasetDocument +from services.external_knowledge_service import ExternalDatasetService default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } @@ -48,15 +48,18 @@ def __init__(self, application_generate_entity=None): self.application_generate_entity = application_generate_entity def retrieve( - self, app_id: str, user_id: str, tenant_id: str, - model_config: ModelConfigWithCredentialsEntity, - config: DatasetEntity, - query: str, - invoke_from: InvokeFrom, - show_retrieve_source: bool, - hit_callback: DatasetIndexToolCallbackHandler, - message_id: str, - memory: Optional[TokenBufferMemory] = None, + self, + app_id: str, + user_id: str, + tenant_id: str, + model_config: ModelConfigWithCredentialsEntity, + config: DatasetEntity, + query: str, + invoke_from: InvokeFrom, + show_retrieve_source: bool, + hit_callback: DatasetIndexToolCallbackHandler, + message_id: str, + memory: Optional[TokenBufferMemory] = None, ) -> Optional[str]: """ Retrieve dataset. @@ -84,16 +87,12 @@ def retrieve( model_manager = ModelManager() model_instance = model_manager.get_model_instance( - tenant_id=tenant_id, - model_type=ModelType.LLM, - provider=model_config.provider, - model=model_config.model + tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model ) # get model schema model_schema = model_type_instance.get_model_schema( - model=model_config.model, - credentials=model_config.credentials + model=model_config.model, credentials=model_config.credentials ) if not model_schema: @@ -102,39 +101,46 @@ def retrieve( planning_strategy = PlanningStrategy.REACT_ROUTER features = model_schema.features if features: - if ModelFeature.TOOL_CALL in features \ - or ModelFeature.MULTI_TOOL_CALL in features: + if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features: planning_strategy = PlanningStrategy.ROUTER available_datasets = [] for dataset_id in dataset_ids: # get dataset from dataset id - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() # pass if dataset is not available if not dataset: continue # pass if dataset is not available - if (dataset and dataset.available_document_count == 0 - and dataset.available_document_count == 0): + if dataset and dataset.available_document_count == 0 and dataset.provider != "external": continue available_datasets.append(dataset) all_documents = [] - user_from = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user' + user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: all_documents = self.single_retrieve( - app_id, tenant_id, user_id, user_from, available_datasets, query, + app_id, + tenant_id, + user_id, + user_from, + available_datasets, + query, model_instance, - model_config, planning_strategy, message_id + model_config, + planning_strategy, + message_id, ) elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: all_documents = self.multiple_retrieve( - app_id, tenant_id, user_id, user_from, - available_datasets, query, retrieve_config.top_k, + app_id, + tenant_id, + user_id, + user_from, + available_datasets, + query, + retrieve_config.top_k, retrieve_config.score_threshold, retrieve_config.rerank_mode, retrieve_config.reranking_model, @@ -143,91 +149,118 @@ def retrieve( message_id, ) - document_score_list = {} - for item in all_documents: - if item.metadata.get('score'): - document_score_list[item.metadata['doc_id']] = item.metadata['score'] - + dify_documents = [item for item in all_documents if item.provider == "dify"] + external_documents = [item for item in all_documents if item.provider == "external"] document_context_list = [] - index_node_ids = [document.metadata['doc_id'] for document in all_documents] - segments = DocumentSegment.query.filter( - DocumentSegment.dataset_id.in_(dataset_ids), - DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids) - ).all() - - if segments: - index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted(segments, - key=lambda segment: index_node_id_to_position.get(segment.index_node_id, - float('inf'))) - for segment in sorted_segments: - if segment.answer: - document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}') - else: - document_context_list.append(segment.get_sign_content()) - if show_retrieve_source: - context_list = [] - resource_number = 1 + retrieval_resource_list = [] + # deal with external documents + for item in external_documents: + document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score"))) + source = { + "dataset_id": item.metadata.get("dataset_id"), + "dataset_name": item.metadata.get("dataset_name"), + "document_name": item.metadata.get("title"), + "data_source_type": "external", + "retriever_from": invoke_from.to_source(), + "score": item.metadata.get("score"), + "content": item.page_content, + } + retrieval_resource_list.append(source) + document_score_list = {} + # deal with dify documents + if dify_documents: + for item in dify_documents: + if item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] + + index_node_ids = [document.metadata["doc_id"] for document in dify_documents] + segments = DocumentSegment.query.filter( + DocumentSegment.dataset_id.in_(dataset_ids), + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids), + ).all() + + if segments: + index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) for segment in sorted_segments: - dataset = Dataset.query.filter_by( - id=segment.dataset_id - ).first() - document = DatasetDocument.query.filter(DatasetDocument.id == segment.document_id, - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).first() - if dataset and document: - source = { - 'position': resource_number, - 'dataset_id': dataset.id, - 'dataset_name': dataset.name, - 'document_id': document.id, - 'document_name': document.name, - 'data_source_type': document.data_source_type, - 'segment_id': segment.id, - 'retriever_from': invoke_from.to_source(), - 'score': document_score_list.get(segment.index_node_id, None) - } - - if invoke_from.to_source() == 'dev': - source['hit_count'] = segment.hit_count - source['word_count'] = segment.word_count - source['segment_position'] = segment.position - source['index_node_hash'] = segment.index_node_hash - if segment.answer: - source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' - else: - source['content'] = segment.content - context_list.append(source) - resource_number += 1 - if hit_callback: - hit_callback.return_retriever_resource_info(context_list) - - return str("\n".join(document_context_list)) - return '' + if segment.answer: + document_context_list.append( + DocumentContext( + content=f"question:{segment.get_sign_content()} answer:{segment.answer}", + score=document_score_list.get(segment.index_node_id, None), + ) + ) + else: + document_context_list.append( + DocumentContext( + content=segment.get_sign_content(), + score=document_score_list.get(segment.index_node_id, None), + ) + ) + if show_retrieve_source: + for segment in sorted_segments: + dataset = Dataset.query.filter_by(id=segment.dataset_id).first() + document = DatasetDocument.query.filter( + DatasetDocument.id == segment.document_id, + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ).first() + if dataset and document: + source = { + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "data_source_type": document.data_source_type, + "segment_id": segment.id, + "retriever_from": invoke_from.to_source(), + "score": document_score_list.get(segment.index_node_id, 0.0), + } + + if invoke_from.to_source() == "dev": + source["hit_count"] = segment.hit_count + source["word_count"] = segment.word_count + source["segment_position"] = segment.position + source["index_node_hash"] = segment.index_node_hash + if segment.answer: + source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" + else: + source["content"] = segment.content + retrieval_resource_list.append(source) + if hit_callback and retrieval_resource_list: + retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score") or 0.0, reverse=True) + for position, item in enumerate(retrieval_resource_list, start=1): + item["position"] = position + hit_callback.return_retriever_resource_info(retrieval_resource_list) + if document_context_list: + document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) + return str("\n".join([document_context.content for document_context in document_context_list])) + return "" def single_retrieve( - self, app_id: str, - tenant_id: str, - user_id: str, - user_from: str, - available_datasets: list, - query: str, - model_instance: ModelInstance, - model_config: ModelConfigWithCredentialsEntity, - planning_strategy: PlanningStrategy, - message_id: Optional[str] = None, + self, + app_id: str, + tenant_id: str, + user_id: str, + user_from: str, + available_datasets: list, + query: str, + model_instance: ModelInstance, + model_config: ModelConfigWithCredentialsEntity, + planning_strategy: PlanningStrategy, + message_id: Optional[str] = None, ): tools = [] for dataset in available_datasets: description = dataset.description if not description: - description = 'useful for when you want to answer queries about the ' + dataset.name + description = "useful for when you want to answer queries about the " + dataset.name - description = description.replace('\n', '').replace('\r', '') + description = description.replace("\n", "").replace("\r", "") message_tool = PromptMessageTool( name=dataset.id, description=description, @@ -235,14 +268,15 @@ def single_retrieve( "type": "object", "properties": {}, "required": [], - } + }, ) tools.append(message_tool) dataset_id = None if planning_strategy == PlanningStrategy.REACT_ROUTER: react_multi_dataset_router = ReactMultiDatasetRouter() - dataset_id = react_multi_dataset_router.invoke(query, tools, model_config, model_instance, - user_id, tenant_id) + dataset_id = react_multi_dataset_router.invoke( + query, tools, model_config, model_instance, user_id, tenant_id + ) elif planning_strategy == PlanningStrategy.ROUTER: function_call_router = FunctionCallMultiDatasetRouter() @@ -250,75 +284,129 @@ def single_retrieve( if dataset_id: # get retrieval model config - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if dataset: - retrieval_model_config = dataset.retrieval_model \ - if dataset.retrieval_model else default_retrieval_model - - # get top k - top_k = retrieval_model_config['top_k'] - # get retrieval method - if dataset.indexing_technique == "economy": - retrival_method = 'keyword_search' - else: - retrival_method = retrieval_model_config['search_method'] - # get reranking model - reranking_model = retrieval_model_config['reranking_model'] \ - if retrieval_model_config['reranking_enable'] else None - # get score threshold - score_threshold = .0 - score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") - if score_threshold_enabled: - score_threshold = retrieval_model_config.get("score_threshold") - - with measure_time() as timer: - results = RetrievalService.retrieve( - retrival_method=retrival_method, dataset_id=dataset.id, + results = [] + if dataset.provider == "external": + external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( + tenant_id=dataset.tenant_id, + dataset_id=dataset_id, query=query, - top_k=top_k, score_threshold=score_threshold, - reranking_model=reranking_model, - reranking_mode=retrieval_model_config.get('reranking_mode', 'reranking_model'), - weights=retrieval_model_config.get('weights', None), + external_retrieval_parameters=dataset.retrieval_model, ) + for external_document in external_documents: + document = Document( + page_content=external_document.get("content"), + metadata=external_document.get("metadata"), + provider="external", + ) + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset_id + document.metadata["dataset_name"] = dataset.name + results.append(document) + else: + retrieval_model_config = dataset.retrieval_model or default_retrieval_model + + # get top k + top_k = retrieval_model_config["top_k"] + # get retrieval method + if dataset.indexing_technique == "economy": + retrieval_method = "keyword_search" + else: + retrieval_method = retrieval_model_config["search_method"] + # get reranking model + reranking_model = ( + retrieval_model_config["reranking_model"] + if retrieval_model_config["reranking_enable"] + else None + ) + # get score threshold + score_threshold = 0.0 + score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") + if score_threshold_enabled: + score_threshold = retrieval_model_config.get("score_threshold") + + with measure_time() as timer: + results = RetrievalService.retrieve( + retrieval_method=retrieval_method, + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"), + weights=retrieval_model_config.get("weights", None), + ) self._on_query(query, [dataset_id], app_id, user_from, user_id) if results: - self._on_retrival_end(results, message_id, timer) + self._on_retrieval_end(results, message_id, timer) return results return [] def multiple_retrieve( - self, - app_id: str, - tenant_id: str, - user_id: str, - user_from: str, - available_datasets: list, - query: str, - top_k: int, - score_threshold: float, - reranking_mode: str, - reranking_model: Optional[dict] = None, - weights: Optional[dict] = None, - reranking_enable: bool = True, - message_id: Optional[str] = None, + self, + app_id: str, + tenant_id: str, + user_id: str, + user_from: str, + available_datasets: list, + query: str, + top_k: int, + score_threshold: float, + reranking_mode: str, + reranking_model: Optional[dict] = None, + weights: Optional[dict] = None, + reranking_enable: bool = True, + message_id: Optional[str] = None, ): + if not available_datasets: + return [] threads = [] all_documents = [] dataset_ids = [dataset.id for dataset in available_datasets] - index_type = None + index_type_check = all( + item.indexing_technique == available_datasets[0].indexing_technique for item in available_datasets + ) + if not index_type_check and (not reranking_enable or reranking_mode != RerankMode.RERANKING_MODEL): + raise ValueError( + "The configured knowledge base list have different indexing technique, please set reranking model." + ) + index_type = available_datasets[0].indexing_technique + if index_type == "high_quality": + embedding_model_check = all( + item.embedding_model == available_datasets[0].embedding_model for item in available_datasets + ) + embedding_model_provider_check = all( + item.embedding_model_provider == available_datasets[0].embedding_model_provider + for item in available_datasets + ) + if ( + reranking_enable + and reranking_mode == "weighted_score" + and (not embedding_model_check or not embedding_model_provider_check) + ): + raise ValueError( + "The configured knowledge base list have different embedding model, please set reranking model." + ) + if reranking_enable and reranking_mode == RerankMode.WEIGHTED_SCORE: + weights["vector_setting"]["embedding_provider_name"] = available_datasets[0].embedding_model_provider + weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model + for dataset in available_datasets: index_type = dataset.indexing_technique - retrieval_thread = threading.Thread(target=self._retriever, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset.id, - 'query': query, - 'top_k': top_k, - 'all_documents': all_documents, - }) + retrieval_thread = threading.Thread( + target=self._retriever, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset.id, + "query": query, + "top_k": top_k, + "all_documents": all_documents, + }, + ) threads.append(retrieval_thread) retrieval_thread.start() for thread in threads: @@ -327,16 +415,10 @@ def multiple_retrieve( with measure_time() as timer: if reranking_enable: # do rerank for searched documents - data_post_processor = DataPostProcessor( - tenant_id, reranking_mode, - reranking_model, weights, False - ) + data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False) all_documents = data_post_processor.invoke( - query=query, - documents=all_documents, - score_threshold=score_threshold, - top_n=top_k + query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k ) else: if index_type == "economy": @@ -347,40 +429,37 @@ def multiple_retrieve( self._on_query(query, dataset_ids, app_id, user_from, user_id) if all_documents: - self._on_retrival_end(all_documents, message_id, timer) + self._on_retrieval_end(all_documents, message_id, timer) return all_documents - def _on_retrival_end( + def _on_retrieval_end( self, documents: list[Document], message_id: Optional[str] = None, timer: Optional[dict] = None ) -> None: - """Handle retrival end.""" - for document in documents: + """Handle retrieval end.""" + dify_documents = [document for document in documents if document.provider == "dify"] + for document in dify_documents: query = db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == document.metadata['doc_id'] + DocumentSegment.index_node_id == document.metadata["doc_id"] ) # if 'dataset_id' in document.metadata: - if 'dataset_id' in document.metadata: - query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id']) + if "dataset_id" in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) # add hit count to document segment - query.update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, - synchronize_session=False - ) + query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) db.session.commit() # get tracing instance - trace_manager: TraceQueueManager = self.application_generate_entity.trace_manager if self.application_generate_entity else None + trace_manager: TraceQueueManager = ( + self.application_generate_entity.trace_manager if self.application_generate_entity else None + ) if trace_manager: trace_manager.add_trace_task( TraceTask( - TraceTaskName.DATASET_RETRIEVAL_TRACE, - message_id=message_id, - documents=documents, - timer=timer + TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer ) ) @@ -395,10 +474,10 @@ def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: dataset_query = DatasetQuery( dataset_id=dataset_id, content=query, - source='app', + source="app", source_app_id=app_id, created_by_role=user_from, - created_by=user_id + created_by=user_id, ) dataset_queries.append(dataset_query) if dataset_queries: @@ -407,50 +486,69 @@ def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list): with flask_app.app_context(): - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: return [] - # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model - - if dataset.indexing_technique == "economy": - # use keyword table query - documents = RetrievalService.retrieve(retrival_method='keyword_search', - dataset_id=dataset.id, - query=query, - top_k=top_k - ) - if documents: - all_documents.extend(documents) + if dataset.provider == "external": + external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( + tenant_id=dataset.tenant_id, + dataset_id=dataset_id, + query=query, + external_retrieval_parameters=dataset.retrieval_model, + ) + for external_document in external_documents: + document = Document( + page_content=external_document.get("content"), + metadata=external_document.get("metadata"), + provider="external", + ) + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset_id + document.metadata["dataset_name"] = dataset.name + all_documents.append(document) else: - if top_k > 0: - # retrieval source - documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], - dataset_id=dataset.id, - query=query, - top_k=top_k, - score_threshold=retrieval_model.get('score_threshold', .0) - if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model.get('reranking_model', None) - if retrieval_model['reranking_enable'] else None, - reranking_mode=retrieval_model.get('reranking_mode') - if retrieval_model.get('reranking_mode') else 'reranking_model', - weights=retrieval_model.get('weights', None), - ) - - all_documents.extend(documents) - - def to_dataset_retriever_tool(self, tenant_id: str, - dataset_ids: list[str], - retrieve_config: DatasetRetrieveConfigEntity, - return_resource: bool, - invoke_from: InvokeFrom, - hit_callback: DatasetIndexToolCallbackHandler) \ - -> Optional[list[DatasetRetrieverBaseTool]]: + # get retrieval model , if the model is not setting , using default + retrieval_model = dataset.retrieval_model or default_retrieval_model + + if dataset.indexing_technique == "economy": + # use keyword table query + documents = RetrievalService.retrieve( + retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=top_k + ) + if documents: + all_documents.extend(documents) + else: + if top_k > 0: + # retrieval source + documents = RetrievalService.retrieve( + retrieval_method=retrieval_model["search_method"], + dataset_id=dataset.id, + query=query, + top_k=retrieval_model.get("top_k") or 2, + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else 0.0, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", + weights=retrieval_model.get("weights", None), + ) + + all_documents.extend(documents) + + def to_dataset_retriever_tool( + self, + tenant_id: str, + dataset_ids: list[str], + retrieve_config: DatasetRetrieveConfigEntity, + return_resource: bool, + invoke_from: InvokeFrom, + hit_callback: DatasetIndexToolCallbackHandler, + ) -> Optional[list[DatasetRetrieverBaseTool]]: """ A dataset tool is a tool that can be used to retrieve information from a dataset :param tenant_id: tenant id @@ -464,18 +562,14 @@ def to_dataset_retriever_tool(self, tenant_id: str, available_datasets = [] for dataset_id in dataset_ids: # get dataset from dataset id - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() # pass if dataset is not available if not dataset: continue # pass if dataset is not available - if (dataset and dataset.available_document_count == 0 - and dataset.available_document_count == 0): + if dataset and dataset.provider != "external" and dataset.available_document_count == 0: continue available_datasets.append(dataset) @@ -483,22 +577,18 @@ def to_dataset_retriever_tool(self, tenant_id: str, if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: # get retrieval model config default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } for dataset in available_datasets: - retrieval_model_config = dataset.retrieval_model \ - if dataset.retrieval_model else default_retrieval_model + retrieval_model_config = dataset.retrieval_model or default_retrieval_model # get top k - top_k = retrieval_model_config['top_k'] + top_k = retrieval_model_config["top_k"] # get score threshold score_threshold = None @@ -512,7 +602,7 @@ def to_dataset_retriever_tool(self, tenant_id: str, score_threshold=score_threshold, hit_callbacks=[hit_callback], return_resource=return_resource, - retriever_from=invoke_from.to_source() + retriever_from=invoke_from.to_source(), ) tools.append(tool) @@ -525,8 +615,8 @@ def to_dataset_retriever_tool(self, tenant_id: str, hit_callbacks=[hit_callback], return_resource=return_resource, retriever_from=invoke_from.to_source(), - reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'), - reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name') + reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"), + reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"), ) tools.append(tool) @@ -547,7 +637,7 @@ def calculate_keyword_score(self, query: str, documents: list[Document], top_k: for document in documents: # get the document keywords document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) - document.metadata['keywords'] = document_keywords + document.metadata["keywords"] = document_keywords documents_keywords.append(document_keywords) # Counter query keywords(TF) @@ -590,8 +680,8 @@ def cosine_similarity(vec1, vec2): intersection = set(vec1.keys()) & set(vec2.keys()) numerator = sum(vec1[x] * vec2[x] for x in intersection) - sum1 = sum(vec1[x] ** 2 for x in vec1.keys()) - sum2 = sum(vec2[x] ** 2 for x in vec2.keys()) + sum1 = sum(vec1[x] ** 2 for x in vec1) + sum2 = sum(vec2[x] ** 2 for x in vec2) denominator = math.sqrt(sum1) * math.sqrt(sum2) if not denominator: @@ -606,20 +696,19 @@ def cosine_similarity(vec1, vec2): for document, score in zip(documents, similarities): # format document - document.metadata['score'] = score - documents = sorted(documents, key=lambda x: x.metadata['score'], reverse=True) + document.metadata["score"] = score + documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True) return documents[:top_k] if top_k else documents - def calculate_vector_score(self, all_documents: list[Document], - top_k: int, score_threshold: float) -> list[Document]: + def calculate_vector_score( + self, all_documents: list[Document], top_k: int, score_threshold: float + ) -> list[Document]: filter_documents = [] for document in all_documents: - if score_threshold and document.metadata['score'] >= score_threshold: + if score_threshold is None or document.metadata["score"] >= score_threshold: filter_documents.append(document) + if not filter_documents: return [] - filter_documents = sorted(filter_documents, key=lambda x: x.metadata['score'], reverse=True) + filter_documents = sorted(filter_documents, key=lambda x: x.metadata["score"], reverse=True) return filter_documents[:top_k] if top_k else filter_documents - - - diff --git a/api/core/rag/retrieval/output_parser/structured_chat.py b/api/core/rag/retrieval/output_parser/structured_chat.py index 60770bd4c6e06a..7fc78bce8357da 100644 --- a/api/core/rag/retrieval/output_parser/structured_chat.py +++ b/api/core/rag/retrieval/output_parser/structured_chat.py @@ -16,9 +16,7 @@ def parse(self, text: str) -> Union[ReactAction, ReactFinish]: if response["action"] == "Final Answer": return ReactFinish({"output": response["action_input"]}, text) else: - return ReactAction( - response["action"], response.get("action_input", {}), text - ) + return ReactAction(response["action"], response.get("action_input", {}), text) else: return ReactFinish({"output": text}, text) except Exception as e: diff --git a/api/core/rag/retrieval/retrieval_methods.py b/api/core/rag/retrieval/retrieval_methods.py new file mode 100644 index 00000000000000..eaa00bca884a7c --- /dev/null +++ b/api/core/rag/retrieval/retrieval_methods.py @@ -0,0 +1,15 @@ +from enum import Enum + + +class RetrievalMethod(Enum): + SEMANTIC_SEARCH = "semantic_search" + FULL_TEXT_SEARCH = "full_text_search" + HYBRID_SEARCH = "hybrid_search" + + @staticmethod + def is_support_semantic_search(retrieval_method: str) -> bool: + return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value} + + @staticmethod + def is_support_fulltext_search(retrieval_method: str) -> bool: + return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value} diff --git a/api/core/rag/retrieval/retrival_methods.py b/api/core/rag/retrieval/retrival_methods.py deleted file mode 100644 index 12aa28a51c1d98..00000000000000 --- a/api/core/rag/retrieval/retrival_methods.py +++ /dev/null @@ -1,15 +0,0 @@ -from enum import Enum - - -class RetrievalMethod(Enum): - SEMANTIC_SEARCH = 'semantic_search' - FULL_TEXT_SEARCH = 'full_text_search' - HYBRID_SEARCH = 'hybrid_search' - - @staticmethod - def is_support_semantic_search(retrieval_method: str) -> bool: - return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value} - - @staticmethod - def is_support_fulltext_search(retrieval_method: str) -> bool: - return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value} diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index 84e53952acbf12..06147fe7b56544 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -6,14 +6,12 @@ class FunctionCallMultiDatasetRouter: - def invoke( - self, - query: str, - dataset_tools: list[PromptMessageTool], - model_config: ModelConfigWithCredentialsEntity, - model_instance: ModelInstance, - + self, + query: str, + dataset_tools: list[PromptMessageTool], + model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, ) -> Union[str, None]: """Given input, decided what to do. Returns: @@ -26,22 +24,18 @@ def invoke( try: prompt_messages = [ - SystemPromptMessage(content='You are a helpful AI assistant.'), - UserPromptMessage(content=query) + SystemPromptMessage(content="You are a helpful AI assistant."), + UserPromptMessage(content=query), ] result = model_instance.invoke_llm( prompt_messages=prompt_messages, tools=dataset_tools, stream=False, - model_parameters={ - 'temperature': 0.2, - 'top_p': 0.3, - 'max_tokens': 1500 - } + model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, ) if result.message.tool_calls: # get retrieval model config return result.message.tool_calls[0].function.name return None except Exception as e: - return None \ No newline at end of file + return None diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index 92f24277c1a3cc..68fab0c127a253 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -9,12 +9,12 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.rag.retrieval.output_parser.react_output import ReactAction from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser -from core.workflow.nodes.llm.llm_node import LLMNode +from core.workflow.nodes.llm import LLMNode PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" SUFFIX = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. -Thought:""" +Thought:""" # noqa: E501 FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. @@ -46,20 +46,18 @@ "action": "Final Answer", "action_input": "Final response to human" }} -```""" +```""" # noqa: E501 class ReactMultiDatasetRouter: - def invoke( - self, - query: str, - dataset_tools: list[PromptMessageTool], - model_config: ModelConfigWithCredentialsEntity, - model_instance: ModelInstance, - user_id: str, - tenant_id: str - + self, + query: str, + dataset_tools: list[PromptMessageTool], + model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, + user_id: str, + tenant_id: str, ) -> Union[str, None]: """Given input, decided what to do. Returns: @@ -71,23 +69,28 @@ def invoke( return dataset_tools[0].name try: - return self._react_invoke(query=query, model_config=model_config, - model_instance=model_instance, - tools=dataset_tools, user_id=user_id, tenant_id=tenant_id) + return self._react_invoke( + query=query, + model_config=model_config, + model_instance=model_instance, + tools=dataset_tools, + user_id=user_id, + tenant_id=tenant_id, + ) except Exception as e: return None def _react_invoke( - self, - query: str, - model_config: ModelConfigWithCredentialsEntity, - model_instance: ModelInstance, - tools: Sequence[PromptMessageTool], - user_id: str, - tenant_id: str, - prefix: str = PREFIX, - suffix: str = SUFFIX, - format_instructions: str = FORMAT_INSTRUCTIONS, + self, + query: str, + model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, + tools: Sequence[PromptMessageTool], + user_id: str, + tenant_id: str, + prefix: str = PREFIX, + suffix: str = SUFFIX, + format_instructions: str = FORMAT_INSTRUCTIONS, ) -> Union[str, None]: if model_config.mode == "chat": prompt = self.create_chat_prompt( @@ -103,18 +106,18 @@ def _react_invoke( prefix=prefix, format_instructions=format_instructions, ) - stop = ['Observation:'] + stop = ["Observation:"] # handle invoke result prompt_transform = AdvancedPromptTransform() prompt_messages = prompt_transform.get_prompt( prompt_template=prompt, inputs={}, - query='', + query="", files=[], - context='', + context="", memory_config=None, memory=None, - model_config=model_config + model_config=model_config, ) result_text, usage = self._invoke_llm( completion_param=model_config.parameters, @@ -122,7 +125,7 @@ def _react_invoke( prompt_messages=prompt_messages, stop=stop, user_id=user_id, - tenant_id=tenant_id + tenant_id=tenant_id, ) output_parser = StructuredChatOutputParser() react_decision = output_parser.parse(result_text) @@ -130,17 +133,21 @@ def _react_invoke( return react_decision.tool return None - def _invoke_llm(self, completion_param: dict, - model_instance: ModelInstance, - prompt_messages: list[PromptMessage], - stop: list[str], user_id: str, tenant_id: str - ) -> tuple[str, LLMUsage]: + def _invoke_llm( + self, + completion_param: dict, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + stop: list[str], + user_id: str, + tenant_id: str, + ) -> tuple[str, LLMUsage]: """ - Invoke large language model - :param model_instance: model instance - :param prompt_messages: prompt messages - :param stop: stop - :return: + Invoke large language model + :param model_instance: model instance + :param prompt_messages: prompt messages + :param stop: stop + :return: """ invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, @@ -151,9 +158,7 @@ def _invoke_llm(self, completion_param: dict, ) # handle invoke result - text, usage = self._handle_invoke_result( - invoke_result=invoke_result - ) + text, usage = self._handle_invoke_result(invoke_result=invoke_result) # deduct quota LLMNode.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) @@ -168,7 +173,7 @@ def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage """ model = None prompt_messages = [] - full_text = '' + full_text = "" usage = None for result in invoke_result: text = result.delta.message.content @@ -189,40 +194,36 @@ def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage return full_text, usage def create_chat_prompt( - self, - query: str, - tools: Sequence[PromptMessageTool], - prefix: str = PREFIX, - suffix: str = SUFFIX, - format_instructions: str = FORMAT_INSTRUCTIONS, + self, + query: str, + tools: Sequence[PromptMessageTool], + prefix: str = PREFIX, + suffix: str = SUFFIX, + format_instructions: str = FORMAT_INSTRUCTIONS, ) -> list[ChatModelMessage]: tool_strings = [] for tool in tools: tool_strings.append( - f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query', 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}") + f"{tool.name}: {tool.description}, args: {{'query': {{'title': 'Query'," + f" 'description': 'Query for the dataset to be used to retrieve the dataset.', 'type': 'string'}}}}" + ) formatted_tools = "\n".join(tool_strings) unique_tool_names = {tool.name for tool in tools} tool_names = ", ".join('"' + name + '"' for name in unique_tool_names) format_instructions = format_instructions.format(tool_names=tool_names) template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) prompt_messages = [] - system_prompt_messages = ChatModelMessage( - role=PromptMessageRole.SYSTEM, - text=template - ) + system_prompt_messages = ChatModelMessage(role=PromptMessageRole.SYSTEM, text=template) prompt_messages.append(system_prompt_messages) - user_prompt_message = ChatModelMessage( - role=PromptMessageRole.USER, - text=query - ) + user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=query) prompt_messages.append(user_prompt_message) return prompt_messages def create_completion_prompt( - self, - tools: Sequence[PromptMessageTool], - prefix: str = PREFIX, - format_instructions: str = FORMAT_INSTRUCTIONS, + self, + tools: Sequence[PromptMessageTool], + prefix: str = PREFIX, + format_instructions: str = FORMAT_INSTRUCTIONS, ) -> CompletionModelPromptTemplate: """Create prompt in the style of the zero shot agent. @@ -236,7 +237,7 @@ def create_completion_prompt( suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. Question: {input} Thought: {agent_scratchpad} -""" +""" # noqa: E501 tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) tool_names = ", ".join([tool.name for tool in tools]) diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 6a0804f890db39..53032b34d570c7 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -1,4 +1,5 @@ """Functionality for splitting text.""" + from __future__ import annotations from typing import Any, Optional @@ -18,31 +19,29 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): """ - This class is used to implement from_gpt2_encoder, to prevent using of tiktoken + This class is used to implement from_gpt2_encoder, to prevent using of tiktoken """ @classmethod def from_encoder( - cls: type[TS], - embedding_model_instance: Optional[ModelInstance], - allowed_special: Union[Literal[all], Set[str]] = set(), - disallowed_special: Union[Literal[all], Collection[str]] = "all", - **kwargs: Any, + cls: type[TS], + embedding_model_instance: Optional[ModelInstance], + allowed_special: Union[Literal[all], Set[str]] = set(), + disallowed_special: Union[Literal[all], Collection[str]] = "all", + **kwargs: Any, ): def _token_encoder(text: str) -> int: if not text: return 0 if embedding_model_instance: - return embedding_model_instance.get_text_embedding_num_tokens( - texts=[text] - ) + return embedding_model_instance.get_text_embedding_num_tokens(texts=[text]) else: return GPT2Tokenizer.get_num_tokens(text) if issubclass(cls, TokenTextSplitter): extra_kwargs = { - "model_name": embedding_model_instance.model if embedding_model_instance else 'gpt2', + "model_name": embedding_model_instance.model if embedding_model_instance else "gpt2", "allowed_special": allowed_special, "disallowed_special": disallowed_special, } @@ -93,17 +92,21 @@ def recursive_split_text(self, text: str) -> list[str]: splits = list(text) # Now go merging things, recursively splitting longer texts. _good_splits = [] + _good_splits_lengths = [] # cache the lengths of the splits for s in splits: - if self._length_function(s) < self._chunk_size: + s_len = self._length_function(s) + if s_len < self._chunk_size: _good_splits.append(s) + _good_splits_lengths.append(s_len) else: if _good_splits: - merged_text = self._merge_splits(_good_splits, separator) + merged_text = self._merge_splits(_good_splits, separator, _good_splits_lengths) final_chunks.extend(merged_text) _good_splits = [] + _good_splits_lengths = [] other_info = self.recursive_split_text(s) final_chunks.extend(other_info) if _good_splits: - merged_text = self._merge_splits(_good_splits, separator) + merged_text = self._merge_splits(_good_splits, separator, _good_splits_lengths) final_chunks.extend(merged_text) return final_chunks diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index b3adcedc76c9f9..7dd62f8de18a15 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -22,35 +22,32 @@ TS = TypeVar("TS", bound="TextSplitter") -def _split_text_with_regex( - text: str, separator: str, keep_separator: bool -) -> list[str]: +def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> list[str]: # Now that we have the separator, split the text if separator: if keep_separator: # The parentheses in the pattern keep the delimiters in the result. _splits = re.split(f"({re.escape(separator)})", text) - splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)] - if len(_splits) % 2 == 0: + splits = [_splits[i - 1] + _splits[i] for i in range(1, len(_splits), 2)] + if len(_splits) % 2 != 0: splits += _splits[-1:] - splits = [_splits[0]] + splits else: splits = re.split(separator, text) else: splits = list(text) - return [s for s in splits if s != ""] + return [s for s in splits if (s not in {"", "\n"})] class TextSplitter(BaseDocumentTransformer, ABC): """Interface for splitting text into chunks.""" def __init__( - self, - chunk_size: int = 4000, - chunk_overlap: int = 200, - length_function: Callable[[str], int] = len, - keep_separator: bool = False, - add_start_index: bool = False, + self, + chunk_size: int = 4000, + chunk_overlap: int = 200, + length_function: Callable[[str], int] = len, + keep_separator: bool = False, + add_start_index: bool = False, ) -> None: """Create a new TextSplitter. @@ -63,8 +60,7 @@ def __init__( """ if chunk_overlap > chunk_size: raise ValueError( - f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " - f"({chunk_size}), should be smaller." + f"Got a larger chunk overlap ({chunk_overlap}) than chunk size ({chunk_size}), should be smaller." ) self._chunk_size = chunk_size self._chunk_overlap = chunk_overlap @@ -76,9 +72,7 @@ def __init__( def split_text(self, text: str) -> list[str]: """Split text into multiple components.""" - def create_documents( - self, texts: list[str], metadatas: Optional[list[dict]] = None - ) -> list[Document]: + def create_documents(self, texts: list[str], metadatas: Optional[list[dict]] = None) -> list[Document]: """Create documents from a list of texts.""" _metadatas = metadatas or [{}] * len(texts) documents = [] @@ -109,7 +103,7 @@ def _join_docs(self, docs: list[str], separator: str) -> Optional[str]: else: return text - def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]: + def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]: # We now want to combine these smaller pieces into medium size # chunks to send to the LLM. separator_len = self._length_function(separator) @@ -117,16 +111,13 @@ def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]: docs = [] current_doc: list[str] = [] total = 0 + index = 0 for d in splits: - _len = self._length_function(d) - if ( - total + _len + (separator_len if len(current_doc) > 0 else 0) - > self._chunk_size - ): + _len = lengths[index] + if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size: if total > self._chunk_size: logger.warning( - f"Created a chunk of size {total}, " - f"which is longer than the specified {self._chunk_size}" + f"Created a chunk of size {total}, which is longer than the specified {self._chunk_size}" ) if len(current_doc) > 0: doc = self._join_docs(current_doc, separator) @@ -136,16 +127,13 @@ def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]: # - we have a larger chunk than in the chunk overlap # - or if we still have any chunks and the length is long while total > self._chunk_overlap or ( - total + _len + (separator_len if len(current_doc) > 0 else 0) - > self._chunk_size - and total > 0 + total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0 ): - total -= self._length_function(current_doc[0]) + ( - separator_len if len(current_doc) > 1 else 0 - ) + total -= self._length_function(current_doc[0]) + (separator_len if len(current_doc) > 1 else 0) current_doc = current_doc[1:] current_doc.append(d) total += _len + (separator_len if len(current_doc) > 1 else 0) + index += 1 doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc) @@ -158,28 +146,25 @@ def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitt from transformers import PreTrainedTokenizerBase if not isinstance(tokenizer, PreTrainedTokenizerBase): - raise ValueError( - "Tokenizer received was not an instance of PreTrainedTokenizerBase" - ) + raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase") def _huggingface_tokenizer_length(text: str) -> int: return len(tokenizer.encode(text)) except ImportError: raise ValueError( - "Could not import transformers python package. " - "Please install it with `pip install transformers`." + "Could not import transformers python package. Please install it with `pip install transformers`." ) return cls(length_function=_huggingface_tokenizer_length, **kwargs) @classmethod def from_tiktoken_encoder( - cls: type[TS], - encoding_name: str = "gpt2", - model_name: Optional[str] = None, - allowed_special: Union[Literal["all"], Set[str]] = set(), - disallowed_special: Union[Literal["all"], Collection[str]] = "all", - **kwargs: Any, + cls: type[TS], + encoding_name: str = "gpt2", + model_name: Optional[str] = None, + allowed_special: Union[Literal["all"], Set[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, ) -> TS: """Text splitter that uses tiktoken encoder to count length.""" try: @@ -216,15 +201,11 @@ def _tiktoken_encoder(text: str) -> int: return cls(length_function=_tiktoken_encoder, **kwargs) - def transform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: + def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Transform sequence of documents by splitting them.""" return self.split_documents(list(documents)) - async def atransform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: + async def atransform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Asynchronously transform a sequence of documents by splitting them.""" raise NotImplementedError @@ -242,7 +223,10 @@ def split_text(self, text: str) -> list[str]: # First we naively split the large input into a bunch of smaller ones. splits = _split_text_with_regex(text, self._separator, self._keep_separator) _separator = "" if self._keep_separator else self._separator - return self._merge_splits(splits, _separator) + _good_splits_lengths = [] # cache the lengths of the splits + for split in splits: + _good_splits_lengths.append(self._length_function(split)) + return self._merge_splits(splits, _separator, _good_splits_lengths) class LineType(TypedDict): @@ -263,9 +247,7 @@ class HeaderType(TypedDict): class MarkdownHeaderTextSplitter: """Splitting markdown files based on specified headers.""" - def __init__( - self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False - ): + def __init__(self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False): """Create a new MarkdownHeaderTextSplitter. Args: @@ -276,9 +258,7 @@ def __init__( self.return_each_line = return_each_line # Given the headers we want to split on, # (e.g., "#, ##, etc") order by length - self.headers_to_split_on = sorted( - headers_to_split_on, key=lambda split: len(split[0]), reverse=True - ) + self.headers_to_split_on = sorted(headers_to_split_on, key=lambda split: len(split[0]), reverse=True) def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]: """Combine lines with common metadata into chunks @@ -288,10 +268,7 @@ def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]: aggregated_chunks: list[LineType] = [] for line in lines: - if ( - aggregated_chunks - and aggregated_chunks[-1]["metadata"] == line["metadata"] - ): + if aggregated_chunks and aggregated_chunks[-1]["metadata"] == line["metadata"]: # If the last line in the aggregated list # has the same metadata as the current line, # append the current content to the last lines's content @@ -300,10 +277,7 @@ def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]: # Otherwise, append the current line to the aggregated list aggregated_chunks.append(line) - return [ - Document(page_content=chunk["content"], metadata=chunk["metadata"]) - for chunk in aggregated_chunks - ] + return [Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in aggregated_chunks] def split_text(self, text: str) -> list[Document]: """Split markdown file @@ -328,10 +302,9 @@ def split_text(self, text: str) -> list[Document]: for sep, name in self.headers_to_split_on: # Check if line starts with a header that we intend to split on if stripped_line.startswith(sep) and ( - # Header with no text OR header is followed by space - # Both are valid conditions that sep is being used a header - len(stripped_line) == len(sep) - or stripped_line[len(sep)] == " " + # Header with no text OR header is followed by space + # Both are valid conditions that sep is being used a header + len(stripped_line) == len(sep) or stripped_line[len(sep)] == " " ): # Ensure we are tracking the header as metadata if name is not None: @@ -339,10 +312,7 @@ def split_text(self, text: str) -> list[Document]: current_header_level = sep.count("#") # Pop out headers of lower or same level from the stack - while ( - header_stack - and header_stack[-1]["level"] >= current_header_level - ): + while header_stack and header_stack[-1]["level"] >= current_header_level: # We have encountered a new header # at the same or higher level popped_header = header_stack.pop() @@ -355,7 +325,7 @@ def split_text(self, text: str) -> list[Document]: header: HeaderType = { "level": current_header_level, "name": name, - "data": stripped_line[len(sep):].strip(), + "data": stripped_line[len(sep) :].strip(), } header_stack.append(header) # Update initial_metadata with the current header @@ -388,9 +358,7 @@ def split_text(self, text: str) -> list[Document]: current_metadata = initial_metadata.copy() if current_content: - lines_with_metadata.append( - {"content": "\n".join(current_content), "metadata": current_metadata} - ) + lines_with_metadata.append({"content": "\n".join(current_content), "metadata": current_metadata}) # lines_with_metadata has each line with associated header metadata # aggregate these into chunks based on common metadata @@ -398,8 +366,7 @@ def split_text(self, text: str) -> list[Document]: return self.aggregate_lines_to_chunks(lines_with_metadata) else: return [ - Document(page_content=chunk["content"], metadata=chunk["metadata"]) - for chunk in lines_with_metadata + Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in lines_with_metadata ] @@ -432,12 +399,12 @@ class TokenTextSplitter(TextSplitter): """Splitting text to tokens using model tokenizer.""" def __init__( - self, - encoding_name: str = "gpt2", - model_name: Optional[str] = None, - allowed_special: Union[Literal["all"], Set[str]] = set(), - disallowed_special: Union[Literal["all"], Collection[str]] = "all", - **kwargs: Any, + self, + encoding_name: str = "gpt2", + model_name: Optional[str] = None, + allowed_special: Union[Literal["all"], Set[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__(**kwargs) @@ -484,50 +451,55 @@ class RecursiveCharacterTextSplitter(TextSplitter): """ def __init__( - self, - separators: Optional[list[str]] = None, - keep_separator: bool = True, - **kwargs: Any, + self, + separators: Optional[list[str]] = None, + keep_separator: bool = True, + **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__(keep_separator=keep_separator, **kwargs) self._separators = separators or ["\n\n", "\n", " ", ""] def _split_text(self, text: str, separators: list[str]) -> list[str]: - """Split incoming text and return chunks.""" final_chunks = [] - # Get appropriate separator to use separator = separators[-1] new_separators = [] + for i, _s in enumerate(separators): if _s == "": separator = _s break if re.search(_s, text): separator = _s - new_separators = separators[i + 1:] + new_separators = separators[i + 1 :] break splits = _split_text_with_regex(text, separator, self._keep_separator) - # Now go merging things, recursively splitting longer texts. _good_splits = [] + _good_splits_lengths = [] # cache the lengths of the splits _separator = "" if self._keep_separator else separator + for s in splits: - if self._length_function(s) < self._chunk_size: + s_len = self._length_function(s) + if s_len < self._chunk_size: _good_splits.append(s) + _good_splits_lengths.append(s_len) else: if _good_splits: - merged_text = self._merge_splits(_good_splits, _separator) + merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths) final_chunks.extend(merged_text) _good_splits = [] + _good_splits_lengths = [] if not new_separators: final_chunks.append(s) else: other_info = self._split_text(s, new_separators) final_chunks.extend(other_info) + if _good_splits: - merged_text = self._merge_splits(_good_splits, _separator) + merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths) final_chunks.extend(merged_text) + return final_chunks def split_text(self, text: str) -> list[str]: diff --git a/api/core/tools/README.md b/api/core/tools/README.md index c7ee81422efbd1..b5d0a30d348a9d 100644 --- a/api/core/tools/README.md +++ b/api/core/tools/README.md @@ -9,10 +9,10 @@ The tools provided for Agents and Workflows are currently divided into two categ - `Api-Based Tools` leverage third-party APIs for implementation. You don't need to code to integrate these -- simply provide interface definitions in formats like `OpenAPI` , `Swagger`, or the `OpenAI-plugin` on the front-end. ### Built-in Tool Providers -![Alt text](docs/zh_Hans/images/index/image.png) +![Alt text](docs/images/index/image.png) ### API Tool Providers -![Alt text](docs/zh_Hans/images/index/image-1.png) +![Alt text](docs/images/index/image-1.png) ## Tool Integration diff --git a/api/core/tools/README_CN.md b/api/core/tools/README_CN.md index fda5d0630ca8cf..7e18441131fd75 100644 --- a/api/core/tools/README_CN.md +++ b/api/core/tools/README_CN.md @@ -12,10 +12,10 @@ - `Api-Based Tools` 基于API的工具,即通过调用第三方API实现的工具,`Api-Based Tool`不需要再额外定义,只需提供`OpenAPI` `Swagger` `OpenAI plugin`等接口文档即可。 ### 内置工具供应商 -![Alt text](docs/zh_Hans/images/index/image.png) +![Alt text](docs/images/index/image.png) ### API工具供应商 -![Alt text](docs/zh_Hans/images/index/image-1.png) +![Alt text](docs/images/index/image-1.png) ## 工具接入 为了实现更灵活更强大的功能,Tools提供了一系列的接口,帮助开发者快速构建想要的工具,本文作为开发者的入门指南,将会以[快速接入](./docs/zh_Hans/tool_scale_out.md)和[高级接入](./docs/zh_Hans/advanced_scale_out.md)两部分介绍如何接入工具。 diff --git a/api/core/tools/README_JP.md b/api/core/tools/README_JP.md new file mode 100644 index 00000000000000..39d0bf1762ad0e --- /dev/null +++ b/api/core/tools/README_JP.md @@ -0,0 +1,31 @@ +# Tools + +このモジュールは、Difyのエージェントアシスタントやワークフローで使用される組み込みツールを実装しています。このモジュールでは、フロントエンドのロジックを変更することなく、独自のツールを定義し表示することができます。この分離により、Difyの機能を容易に水平方向にスケールアウトできます。 + +## 機能紹介 + +エージェントとワークフロー向けに提供されるツールは、現在2つのカテゴリーに分類されています。 + +- `Built-in Tools`はDify内部で実装され、エージェントとワークフローで使用するためにハードコードされています。 +- `Api-Based Tools`はサードパーティのAPIを利用して実装されています。これらを統合するためのコーディングは不要で、フロントエンドで + `OpenAPI`, `Swagger`または`OpenAI-plugin`などの形式でインターフェース定義を提供するだけです。 + +### 組み込みツールプロバイダー + +![Alt text](docs/images/index/image.png) + +### APIツールプロバイダー + +![Alt text](docs/images/index/image-1.png) + +## ツールの統合 + +開発者が柔軟で強力なツールを構築できるよう、2つのガイドを提供しています。 + +### [クイック統合 👈🏻](./docs/ja_JP/tool_scale_out.md) + +クイック統合は、Google検索ツールの例を通じて、ツール統合の基本をすばやく理解できるようにすることを目的としています。 + +### [高度な統合 👈🏻](./docs/ja_JP/advanced_scale_out.md) + +高度な統合では、モジュールインターフェースについてより深く掘り下げ、画像生成、複数ツールの組み合わせ、異なるツール間でのパラメーター、画像、ファイルのフロー管理など、より複雑な機能の実装方法を説明します。 \ No newline at end of file diff --git a/api/core/tools/docs/en_US/tool_scale_out.md b/api/core/tools/docs/en_US/tool_scale_out.md index 121b7a5a76d221..1deaf04a47539b 100644 --- a/api/core/tools/docs/en_US/tool_scale_out.md +++ b/api/core/tools/docs/en_US/tool_scale_out.md @@ -245,4 +245,4 @@ After the above steps are completed, we can see this tool on the frontend, and i Of course, because google_search needs a credential, before using it, you also need to input your credentials on the frontend. -![Alt text](../zh_Hans/images/index/image-2.png) +![Alt text](../images/index/image-2.png) diff --git a/api/core/tools/docs/zh_Hans/images/index/image-1.png b/api/core/tools/docs/images/index/image-1.png similarity index 100% rename from api/core/tools/docs/zh_Hans/images/index/image-1.png rename to api/core/tools/docs/images/index/image-1.png diff --git a/api/core/tools/docs/zh_Hans/images/index/image-2.png b/api/core/tools/docs/images/index/image-2.png similarity index 100% rename from api/core/tools/docs/zh_Hans/images/index/image-2.png rename to api/core/tools/docs/images/index/image-2.png diff --git a/api/core/tools/docs/zh_Hans/images/index/image.png b/api/core/tools/docs/images/index/image.png similarity index 100% rename from api/core/tools/docs/zh_Hans/images/index/image.png rename to api/core/tools/docs/images/index/image.png diff --git a/api/core/tools/docs/ja_JP/advanced_scale_out.md b/api/core/tools/docs/ja_JP/advanced_scale_out.md new file mode 100644 index 00000000000000..96f843354f91b5 --- /dev/null +++ b/api/core/tools/docs/ja_JP/advanced_scale_out.md @@ -0,0 +1,283 @@ +# 高度なツール統合 + +このガイドを始める前に、Difyのツール統合プロセスの基本を理解していることを確認してください。簡単な概要については[クイック統合](./tool_scale_out.md)をご覧ください。 + +## ツールインターフェース + +より複雑なツールを迅速に構築するのを支援するため、`Tool`クラスに一連のヘルパーメソッドを定義しています。 + +### メッセージの返却 + +Difyは`テキスト`、`リンク`、`画像`、`ファイルBLOB`、`JSON`などの様々なメッセージタイプをサポートしています。以下のインターフェースを通じて、異なるタイプのメッセージをLLMとユーザーに返すことができます。 + +注意:以下のインターフェースの一部のパラメータについては、後のセクションで説明します。 + +#### 画像URL +画像のURLを渡すだけで、Difyが自動的に画像をダウンロードしてユーザーに返します。 + +```python + def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage: + """ + create an image message + + :param image: the url of the image + :param save_as: save as + :return: the image message + """ +``` + +#### リンク +リンクを返す必要がある場合は、以下のインターフェースを使用できます。 + +```python + def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage: + """ + create a link message + + :param link: the url of the link + :param save_as: save as + :return: the link message + """ +``` + +#### テキスト +テキストメッセージを返す必要がある場合は、以下のインターフェースを使用できます。 + +```python + def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage: + """ + create a text message + + :param text: the text of the message + :param save_as: save as + :return: the text message + """ +``` + +#### ファイルBLOB +画像、音声、動画、PPT、Word、Excelなどのファイルの生データを返す必要がある場合は、以下のインターフェースを使用できます。 + +- `blob` ファイルの生データ(bytes型) +- `meta` ファイルのメタデータ。ファイルの種類が分かっている場合は、`mime_type`を渡すことをお勧めします。そうでない場合、Difyはデフォルトタイプとして`octet/stream`を使用します。 + +```python + def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage: + """ + create a blob message + + :param blob: the blob + :param meta: meta + :param save_as: save as + :return: the blob message + """ +``` + +#### JSON +フォーマットされたJSONを返す必要がある場合は、以下のインターフェースを使用できます。これは通常、ワークフロー内のノード間のデータ伝送に使用されますが、エージェントモードでは、ほとんどの大規模言語モデルもJSONを読み取り、理解することができます。 + +- `object` Pythonの辞書オブジェクトで、自動的にJSONにシリアライズされます。 + +```python + def create_json_message(self, object: dict) -> ToolInvokeMessage: + """ + create a json message + """ +``` + +### ショートカットツール + +大規模モデルアプリケーションでは、以下の2つの一般的なニーズがあります: +- まず長いテキストを事前に要約し、その要約内容をLLMに渡すことで、元のテキストが長すぎてLLMが処理できない問題を防ぐ +- ツールが取得したコンテンツがリンクである場合、Webページ情報をクロールしてからLLMに返す必要がある + +開発者がこれら2つのニーズを迅速に実装できるよう、以下の2つのショートカットツールを提供しています。 + +#### テキスト要約ツール + +このツールはuser_idと要約するテキストを入力として受け取り、要約されたテキストを返します。Difyは現在のワークスペースのデフォルトモデルを使用して長文を要約します。 + +```python + def summary(self, user_id: str, content: str) -> str: + """ + summary the content + + :param user_id: the user id + :param content: the content + :return: the summary + """ +``` + +#### Webページクローリングツール + +このツールはクロールするWebページのリンクとユーザーエージェント(空でも可)を入力として受け取り、そのWebページの情報を含む文字列を返します。`user_agent`はオプションのパラメータで、ツールを識別するために使用できます。渡さない場合、Difyはデフォルトの`user_agent`を使用します。 + +```python + def get_url(self, url: str, user_agent: str = None) -> str: + """ + get url from the crawled result + """ +``` + +### 変数プール + +`Tool`内に変数プールを導入し、ツールの実行中に生成された変数やファイルなどを保存します。これらの変数は、ツールの実行中に他のツールが使用することができます。 + +次に、`DallE3`と`Vectorizer.AI`を例に、変数プールの使用方法を紹介します。 + +- `DallE3`は画像生成ツールで、テキストに基づいて画像を生成できます。ここでは、`DallE3`にカフェのロゴを生成させます。 +- `Vectorizer.AI`はベクター画像変換ツールで、画像をベクター画像に変換できるため、画像を無限に拡大しても品質が損なわれません。ここでは、`DallE3`が生成したPNGアイコンをベクター画像に変換し、デザイナーが実際に使用できるようにします。 + +#### DallE3 +まず、DallE3を使用します。画像を作成した後、その画像を変数プールに保存します。コードは以下の通りです: + +```python +from typing import Any, Dict, List, Union +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +from base64 import b64decode + +from openai import OpenAI + +class DallE3Tool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_parameters: Dict[str, Any], + ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + invoke tools + """ + client = OpenAI( + api_key=self.runtime.credentials['openai_api_key'], + ) + + # prompt + prompt = tool_parameters.get('prompt', '') + if not prompt: + return self.create_text_message('Please input prompt') + + # call openapi dalle3 + response = client.images.generate( + prompt=prompt, model='dall-e-3', + size='1024x1024', n=1, style='vivid', quality='standard', + response_format='b64_json' + ) + + result = [] + for image in response.data: + # Save all images to the variable pool through the save_as parameter. The variable name is self.VARIABLE_KEY.IMAGE.value. If new images are generated later, they will overwrite the previous images. + result.append(self.create_blob_message(blob=b64decode(image.b64_json), + meta={ 'mime_type': 'image/png' }, + save_as=self.VARIABLE_KEY.IMAGE.value)) + + return result +``` + +ここでは画像の変数名として`self.VARIABLE_KEY.IMAGE.value`を使用していることに注意してください。開発者のツールが互いに連携できるよう、この`KEY`を定義しました。自由に使用することも、この`KEY`を使用しないこともできます。カスタムのKEYを渡すこともできます。 + +#### Vectorizer.AI +次に、Vectorizer.AIを使用して、DallE3が生成したPNGアイコンをベクター画像に変換します。ここで定義した関数を見てみましょう。コードは以下の通りです: + +```python +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from core.tools.errors import ToolProviderCredentialValidationError + +from typing import Any, Dict, List, Union +from httpx import post +from base64 import b64decode + +class VectorizerTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) + -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + Tool invocation, the image variable name needs to be passed in from here, so that we can get the image from the variable pool + """ + + + def get_runtime_parameters(self) -> List[ToolParameter]: + """ + Override the tool parameter list, we can dynamically generate the parameter list based on the actual situation in the current variable pool, so that the LLM can generate the form based on the parameter list + """ + + + def is_tool_available(self) -> bool: + """ + Whether the current tool is available, if there is no image in the current variable pool, then we don't need to display this tool, just return False here + """ +``` + +次に、これら3つの関数を実装します: + +```python +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from core.tools.errors import ToolProviderCredentialValidationError + +from typing import Any, Dict, List, Union +from httpx import post +from base64 import b64decode + +class VectorizerTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) + -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + invoke tools + """ + api_key_name = self.runtime.credentials.get('api_key_name', None) + api_key_value = self.runtime.credentials.get('api_key_value', None) + + if not api_key_name or not api_key_value: + raise ToolProviderCredentialValidationError('Please input api key name and value') + + # Get image_id, the definition of image_id can be found in get_runtime_parameters + image_id = tool_parameters.get('image_id', '') + if not image_id: + return self.create_text_message('Please input image id') + + # Get the image generated by DallE from the variable pool + image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE) + if not image_binary: + return self.create_text_message('Image not found, please request user to generate image firstly.') + + # Generate vector image + response = post( + 'https://vectorizer.ai/api/v1/vectorize', + files={ 'image': image_binary }, + data={ 'mode': 'test' }, + auth=(api_key_name, api_key_value), + timeout=30 + ) + + if response.status_code != 200: + raise Exception(response.text) + + return [ + self.create_text_message('the vectorized svg is saved as an image.'), + self.create_blob_message(blob=response.content, + meta={'mime_type': 'image/svg+xml'}) + ] + + def get_runtime_parameters(self) -> List[ToolParameter]: + """ + override the runtime parameters + """ + # Here, we override the tool parameter list, define the image_id, and set its option list to all images in the current variable pool. The configuration here is consistent with the configuration in yaml. + return [ + ToolParameter.get_simple_instance( + name='image_id', + llm_description=f'the image id that you want to vectorize, \ + and the image id should be specified in \ + {[i.name for i in self.list_default_image_variables()]}', + type=ToolParameter.ToolParameterType.SELECT, + required=True, + options=[i.name for i in self.list_default_image_variables()] + ) + ] + + def is_tool_available(self) -> bool: + # Only when there are images in the variable pool, the LLM needs to use this tool + return len(self.list_default_image_variables()) > 0 +``` + +ここで注目すべきは、実際には`image_id`を使用していないことです。このツールを呼び出す際には、デフォルトの変数プールに必ず画像があると仮定し、直接`image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)`を使用して画像を取得しています。モデルの能力が弱い場合、開発者にもこの方法を推奨します。これにより、エラー許容度を効果的に向上させ、モデルが誤ったパラメータを渡すのを防ぐことができます。 \ No newline at end of file diff --git a/api/core/tools/docs/ja_JP/tool_scale_out.md b/api/core/tools/docs/ja_JP/tool_scale_out.md new file mode 100644 index 00000000000000..a721023d00bdda --- /dev/null +++ b/api/core/tools/docs/ja_JP/tool_scale_out.md @@ -0,0 +1,240 @@ +# ツールの迅速な統合 + +ここでは、GoogleSearchを例にツールを迅速に統合する方法を紹介します。 + +## 1. ツールプロバイダーのyamlを準備する + +### 概要 + +このyamlファイルには、プロバイダー名、アイコン、作者などの詳細情報が含まれ、フロントエンドでの柔軟な表示を可能にします。 + +### 例 + +`core/tools/provider/builtin`の下に`google`モジュール(フォルダ)を作成し、`google.yaml`を作成します。名前はモジュール名と一致している必要があります。 + +以降、このツールに関するすべての操作はこのモジュール内で行います。 + +```yaml +identity: # ツールプロバイダーの基本情報 + author: Dify # 作者 + name: google # 名前(一意、他のプロバイダーと重複不可) + label: # フロントエンド表示用のラベル + en_US: Google # 英語ラベル + zh_Hans: Google # 中国語ラベル + description: # フロントエンド表示用の説明 + en_US: Google # 英語説明 + zh_Hans: Google # 中国語説明 + icon: icon.svg # アイコン(現在のモジュールの_assetsフォルダに配置) + tags: # タグ(フロントエンド表示用) + - search +``` + +- `identity`フィールドは必須で、ツールプロバイダーの基本情報(作者、名前、ラベル、説明、アイコンなど)が含まれます。 + - アイコンは現在のモジュールの`_assets`フォルダに配置する必要があります。[こちら](../../provider/builtin/google/_assets/icon.svg)を参照してください。 + - タグはフロントエンドでの表示に使用され、ユーザーがこのツールプロバイダーを素早く見つけるのに役立ちます。現在サポートされているすべてのタグは以下の通りです: + ```python + class ToolLabelEnum(Enum): + SEARCH = 'search' + IMAGE = 'image' + VIDEOS = 'videos' + WEATHER = 'weather' + FINANCE = 'finance' + DESIGN = 'design' + TRAVEL = 'travel' + SOCIAL = 'social' + NEWS = 'news' + MEDICAL = 'medical' + PRODUCTIVITY = 'productivity' + EDUCATION = 'education' + BUSINESS = 'business' + ENTERTAINMENT = 'entertainment' + UTILITIES = 'utilities' + OTHER = 'other' + ``` + +## 2. プロバイダーの認証情報を準備する + +GoogleはSerpApiが提供するAPIを使用するサードパーティツールであり、SerpApiを使用するにはAPI Keyが必要です。つまり、このツールを使用するには認証情報が必要です。一方、`wikipedia`のようなツールでは認証情報フィールドを記入する必要はありません。[こちら](../../provider/builtin/wikipedia/wikipedia.yaml)を参照してください。 + +認証情報フィールドを設定すると、以下のようになります: + +```yaml +identity: + author: Dify + name: google + label: + en_US: Google + zh_Hans: Google + description: + en_US: Google + zh_Hans: Google + icon: icon.svg +credentials_for_provider: # 認証情報フィールド + serpapi_api_key: # 認証情報フィールド名 + type: secret-input # 認証情報フィールドタイプ + required: true # 必須かどうか + label: # 認証情報フィールドラベル + en_US: SerpApi API key # 英語ラベル + zh_Hans: SerpApi API key # 中国語ラベル + placeholder: # 認証情報フィールドプレースホルダー + en_US: Please input your SerpApi API key # 英語プレースホルダー + zh_Hans: 请输入你的 SerpApi API key # 中国語プレースホルダー + help: # 認証情報フィールドヘルプテキスト + en_US: Get your SerpApi API key from SerpApi # 英語ヘルプテキスト + zh_Hans: 从 SerpApi 获取您的 SerpApi API key # 中国語ヘルプテキスト + url: https://serpapi.com/manage-api-key # 認証情報フィールドヘルプリンク +``` + +- `type`:認証情報フィールドタイプ。現在、`secret-input`、`text-input`、`select`の3種類をサポートしており、それぞれパスワード入力ボックス、テキスト入力ボックス、ドロップダウンボックスに対応します。`secret-input`の場合、フロントエンドで入力内容が隠され、バックエンドで入力内容が暗号化されます。 + +## 3. ツールのyamlを準備する + +1つのプロバイダーの下に複数のツールを持つことができ、各ツールにはyamlファイルが必要です。このファイルにはツールの基本情報、パラメータ、出力などが含まれます。 + +引き続きGoogleSearchを例に、`google`モジュールの下に`tools`モジュールを作成し、`tools/google_search.yaml`を作成します。内容は以下の通りです: + +```yaml +identity: # ツールの基本情報 + name: google_search # ツール名(一意、他のツールと重複不可) + author: Dify # 作者 + label: # フロントエンド表示用のラベル + en_US: GoogleSearch # 英語ラベル + zh_Hans: 谷歌搜索 # 中国語ラベル +description: # フロントエンド表示用の説明 + human: # フロントエンド表示用の紹介(多言語対応) + en_US: A tool for performing a Google SERP search and extracting snippets and webpages. Input should be a search query. + zh_Hans: 一个用于执行 Google SERP 搜索并提取片段和网页的工具。输入应该是一个搜索查询。 + llm: A tool for performing a Google SERP search and extracting snippets and webpages. Input should be a search query. # LLMに渡す紹介文。LLMがこのツールをより理解できるよう、できるだけ詳細な情報を記述することをお勧めします。 +parameters: # パラメータリスト + - name: query # パラメータ名 + type: string # パラメータタイプ + required: true # 必須かどうか + label: # パラメータラベル + en_US: Query string # 英語ラベル + zh_Hans: 查询语句 # 中国語ラベル + human_description: # フロントエンド表示用の紹介(多言語対応) + en_US: used for searching + zh_Hans: 用于搜索网页内容 + llm_description: key words for searching # LLMに渡す紹介文。LLMがこのパラメータをより理解できるよう、できるだけ詳細な情報を記述することをお勧めします。 + form: llm # フォームタイプ。llmはこのパラメータがAgentによって推論される必要があることを意味し、フロントエンドではこのパラメータは表示されません。 + - name: result_type + type: select # パラメータタイプ + required: true + options: # ドロップダウンボックスのオプション + - value: text + label: + en_US: text + zh_Hans: 文本 + - value: link + label: + en_US: link + zh_Hans: 链接 + default: link + label: + en_US: Result type + zh_Hans: 结果类型 + human_description: + en_US: used for selecting the result type, text or link + zh_Hans: 用于选择结果类型,使用文本还是链接进行展示 + form: form # フォームタイプ。formはこのパラメータが対話開始前にフロントエンドでユーザーによって入力される必要があることを意味します。 +``` + +- `identity`フィールドは必須で、ツールの基本情報(名前、作者、ラベル、説明など)が含まれます。 +- `parameters` パラメータリスト + - `name`(必須)パラメータ名。一意で、他のパラメータと重複しないようにしてください。 + - `type`(必須)パラメータタイプ。現在、`string`、`number`、`boolean`、`select`、`secret-input`の5種類をサポートしており、それぞれ文字列、数値、ブール値、ドロップダウンボックス、暗号化入力ボックスに対応します。機密情報には`secret-input`タイプの使用をお勧めします。 + - `label`(必須)パラメータラベル。フロントエンド表示用です。 + - `form`(必須)フォームタイプ。現在、`llm`と`form`の2種類をサポートしています。 + - エージェントアプリケーションでは、`llm`はこのパラメータがLLM自身によって推論されることを示し、`form`はこのツールを使用するために事前に設定できるパラメータであることを示します。 + - ワークフローアプリケーションでは、`llm`と`form`の両方がフロントエンドで入力する必要がありますが、`llm`のパラメータはツールノードの入力変数として使用されます。 + - `required` パラメータが必須かどうかを示します。 + - `llm`モードでは、パラメータが必須の場合、Agentはこのパラメータを推論する必要があります。 + - `form`モードでは、パラメータが必須の場合、ユーザーは対話開始前にフロントエンドでこのパラメータを入力する必要があります。 + - `options` パラメータオプション + - `llm`モードでは、DifyはすべてのオプションをLLMに渡し、LLMはこれらのオプションに基づいて推論できます。 + - `form`モードで、`type`が`select`の場合、フロントエンドはこれらのオプションを表示します。 + - `default` デフォルト値 + - `min` 最小値。パラメータタイプが`number`の場合に設定できます。 + - `max` 最大値。パラメータタイプが`number`の場合に設定できます。 + - `human_description` フロントエンド表示用の紹介。多言語対応です。 + - `placeholder` 入力ボックスのプロンプトテキスト。フォームタイプが`form`で、パラメータタイプが`string`、`number`、`secret-input`の場合に設定できます。多言語対応です。 + - `llm_description` LLMに渡す紹介文。LLMがこのパラメータをより理解できるよう、できるだけ詳細な情報を記述することをお勧めします。 + +## 4. ツールコードを準備する + +ツールの設定が完了したら、ツールのロジックを実装するコードを作成します。 + +`google/tools`モジュールの下に`google_search.py`を作成し、内容は以下の通りです: + +```python +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage + +from typing import Any, Dict, List, Union + +class GoogleSearchTool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_parameters: Dict[str, Any], + ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + ツールを呼び出す + """ + query = tool_parameters['query'] + result_type = tool_parameters['result_type'] + api_key = self.runtime.credentials['serpapi_api_key'] + result = SerpAPI(api_key).run(query, result_type=result_type) + + if result_type == 'text': + return self.create_text_message(text=result) + return self.create_link_message(link=result) +``` + +### パラメータ +ツールの全体的なロジックは`_invoke`メソッドにあります。このメソッドは2つのパラメータ(`user_id`とtool_parameters`)を受け取り、それぞれユーザーIDとツールパラメータを表します。 + +### 戻り値 +ツールの戻り値として、1つのメッセージまたは複数のメッセージを選択できます。ここでは1つのメッセージを返しています。`create_text_message`と`create_link_message`を使用して、テキストメッセージまたはリンクメッセージを作成できます。複数のメッセージを返す場合は、リストを構築できます(例:`[self.create_text_message('msg1'), self.create_text_message('msg2')]`)。 + +## 5. プロバイダーコードを準備する + +最後に、プロバイダーモジュールの下にプロバイダークラスを作成し、プロバイダーの認証情報検証ロジックを実装する必要があります。認証情報の検証が失敗した場合、`ToolProviderCredentialValidationError`例外が発生します。 + +`google`モジュールの下に`google.py`を作成し、内容は以下の通りです: + +```python +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.errors import ToolProviderCredentialValidationError + +from core.tools.provider.builtin.google.tools.google_search import GoogleSearchTool + +from typing import Any, Dict + +class GoogleProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: Dict[str, Any]) -> None: + try: + # 1. ここでGoogleSearchTool()を使ってGoogleSearchToolをインスタンス化する必要があります。これによりGoogleSearchToolのyaml設定が自動的に読み込まれますが、この時点では認証情報は含まれていません + # 2. 次に、fork_tool_runtimeメソッドを使用して、現在の認証情報をGoogleSearchToolに渡す必要があります + # 3. 最後に、invokeを呼び出します。パラメータはGoogleSearchToolのyamlで設定されたパラメータルールに従って渡す必要があります + GoogleSearchTool().fork_tool_runtime( + meta={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_parameters={ + "query": "test", + "result_type": "link" + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) +``` + +## 完了 + +以上のステップが完了すると、このツールをフロントエンドで確認し、Agentで使用することができるようになります。 + +もちろん、google_searchには認証情報が必要なため、使用する前にフロントエンドで認証情報を入力する必要があります。 + +![Alt text](../images/index/image-2.png) \ No newline at end of file diff --git a/api/core/tools/docs/zh_Hans/tool_scale_out.md b/api/core/tools/docs/zh_Hans/tool_scale_out.md index 06a8d9a4f9a9d8..ec61e4677bae76 100644 --- a/api/core/tools/docs/zh_Hans/tool_scale_out.md +++ b/api/core/tools/docs/zh_Hans/tool_scale_out.md @@ -234,4 +234,4 @@ class GoogleProvider(BuiltinToolProviderController): 当然,因为google_search需要一个凭据,在使用之前,还需要在前端配置它的凭据。 -![Alt text](images/index/image-2.png) +![Alt text](../images/index/image-2.png) diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 2b01b8fd8e89c9..b1db5594414470 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -10,56 +10,57 @@ class UserTool(BaseModel): author: str - name: str # identifier - label: I18nObject # label + name: str # identifier + label: I18nObject # label description: I18nObject parameters: Optional[list[ToolParameter]] = None labels: list[str] = None -UserToolProviderTypeLiteral = Optional[Literal[ - 'builtin', 'api', 'workflow' -]] + +UserToolProviderTypeLiteral = Optional[Literal["builtin", "api", "workflow"]] + class UserToolProvider(BaseModel): id: str author: str - name: str # identifier + name: str # identifier description: I18nObject icon: str - label: I18nObject # label + label: I18nObject # label type: ToolProviderType masked_credentials: Optional[dict] = None original_credentials: Optional[dict] = None is_team_authorization: bool = False allow_delete: bool = True - tools: list[UserTool] = None - labels: list[str] = None + tools: list[UserTool] | None = None + labels: list[str] | None = None def to_dict(self) -> dict: # ------------- # overwrite tool parameter types for temp fix tools = jsonable_encoder(self.tools) for tool in tools: - if tool.get('parameters'): - for parameter in tool.get('parameters'): - if parameter.get('type') == ToolParameter.ToolParameterType.FILE.value: - parameter['type'] = 'files' + if tool.get("parameters"): + for parameter in tool.get("parameters"): + if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value: + parameter["type"] = "files" # ------------- return { - 'id': self.id, - 'author': self.author, - 'name': self.name, - 'description': self.description.to_dict(), - 'icon': self.icon, - 'label': self.label.to_dict(), - 'type': self.type.value, - 'team_credentials': self.masked_credentials, - 'is_team_authorization': self.is_team_authorization, - 'allow_delete': self.allow_delete, - 'tools': tools, - 'labels': self.labels, + "id": self.id, + "author": self.author, + "name": self.name, + "description": self.description.to_dict(), + "icon": self.icon, + "label": self.label.to_dict(), + "type": self.type.value, + "team_credentials": self.masked_credentials, + "is_team_authorization": self.is_team_authorization, + "allow_delete": self.allow_delete, + "tools": tools, + "labels": self.labels, } + class UserToolProviderCredentials(BaseModel): - credentials: dict[str, ToolProviderCredentials] \ No newline at end of file + credentials: dict[str, ToolProviderCredentials] diff --git a/api/core/tools/entities/common_entities.py b/api/core/tools/entities/common_entities.py index 55e31e8c35e5b7..924e6fc0cf9f17 100644 --- a/api/core/tools/entities/common_entities.py +++ b/api/core/tools/entities/common_entities.py @@ -1,26 +1,23 @@ from typing import Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field class I18nObject(BaseModel): """ Model class for i18n object. """ - zh_Hans: Optional[str] = None - pt_BR: Optional[str] = None + en_US: str + zh_Hans: Optional[str] = Field(default=None) + pt_BR: Optional[str] = Field(default=None) + ja_JP: Optional[str] = Field(default=None) def __init__(self, **data): super().__init__(**data) - if not self.zh_Hans: - self.zh_Hans = self.en_US - if not self.pt_BR: - self.pt_BR = self.en_US + self.zh_Hans = self.zh_Hans or self.en_US + self.pt_BR = self.pt_BR or self.en_US + self.ja_JP = self.ja_JP or self.en_US def to_dict(self) -> dict: - return { - 'zh_Hans': self.zh_Hans, - 'en_US': self.en_US, - 'pt_BR': self.pt_BR - } + return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP} diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py index d18d27fb02beee..0c15b2a3711f11 100644 --- a/api/core/tools/entities/tool_bundle.py +++ b/api/core/tools/entities/tool_bundle.py @@ -7,8 +7,10 @@ class ApiToolBundle(BaseModel): """ - This class is used to store the schema information of an api based tool. such as the url, the method, the parameters, etc. + This class is used to store the schema information of an api based tool. + such as the url, the method, the parameters, etc. """ + # server_url server_url: str # method diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 2e4433d9f6d2b9..d8637fd2cb0c52 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -7,27 +7,29 @@ class ToolLabelEnum(Enum): - SEARCH = 'search' - IMAGE = 'image' - VIDEOS = 'videos' - WEATHER = 'weather' - FINANCE = 'finance' - DESIGN = 'design' - TRAVEL = 'travel' - SOCIAL = 'social' - NEWS = 'news' - MEDICAL = 'medical' - PRODUCTIVITY = 'productivity' - EDUCATION = 'education' - BUSINESS = 'business' - ENTERTAINMENT = 'entertainment' - UTILITIES = 'utilities' - OTHER = 'other' + SEARCH = "search" + IMAGE = "image" + VIDEOS = "videos" + WEATHER = "weather" + FINANCE = "finance" + DESIGN = "design" + TRAVEL = "travel" + SOCIAL = "social" + NEWS = "news" + MEDICAL = "medical" + PRODUCTIVITY = "productivity" + EDUCATION = "education" + BUSINESS = "business" + ENTERTAINMENT = "entertainment" + UTILITIES = "utilities" + OTHER = "other" + class ToolProviderType(Enum): """ - Enum class for tool provider + Enum class for tool provider """ + BUILT_IN = "builtin" WORKFLOW = "workflow" API = "api" @@ -35,7 +37,7 @@ class ToolProviderType(Enum): DATASET_RETRIEVAL = "dataset-retrieval" @classmethod - def value_of(cls, value: str) -> 'ToolProviderType': + def value_of(cls, value: str) -> "ToolProviderType": """ Get value of given mode. @@ -45,19 +47,21 @@ def value_of(cls, value: str) -> 'ToolProviderType': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") + class ApiProviderSchemaType(Enum): """ Enum class for api provider schema type. """ + OPENAPI = "openapi" SWAGGER = "swagger" OPENAI_PLUGIN = "openai_plugin" OPENAI_ACTIONS = "openai_actions" @classmethod - def value_of(cls, value: str) -> 'ApiProviderSchemaType': + def value_of(cls, value: str) -> "ApiProviderSchemaType": """ Get value of given mode. @@ -67,17 +71,19 @@ def value_of(cls, value: str) -> 'ApiProviderSchemaType': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") + class ApiProviderAuthType(Enum): """ Enum class for api provider auth type. """ + NONE = "none" API_KEY = "api_key" @classmethod - def value_of(cls, value: str) -> 'ApiProviderAuthType': + def value_of(cls, value: str) -> "ApiProviderAuthType": """ Get value of given mode. @@ -87,7 +93,8 @@ def value_of(cls, value: str) -> 'ApiProviderAuthType': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") + class ToolInvokeMessage(BaseModel): class MessageType(Enum): @@ -97,27 +104,30 @@ class MessageType(Enum): BLOB = "blob" JSON = "json" IMAGE_LINK = "image_link" - FILE_VAR = "file_var" + FILE = "file" type: MessageType = MessageType.TEXT """ plain text, image url or link url """ message: str | bytes | dict | None = None - meta: dict[str, Any] | None = None - save_as: str = '' + # TODO: Use a BaseModel for meta + meta: dict[str, Any] = Field(default_factory=dict) + save_as: str = "" + class ToolInvokeMessageBinary(BaseModel): mimetype: str = Field(..., description="The mimetype of the binary") url: str = Field(..., description="The url of the binary") - save_as: str = '' + save_as: str = "" file_var: Optional[dict[str, Any]] = None + class ToolParameterOption(BaseModel): value: str = Field(..., description="The value of the option") label: I18nObject = Field(..., description="The label of the option") - @field_validator('value', mode='before') + @field_validator("value", mode="before") @classmethod def transform_id_to_str(cls, value) -> str: if not isinstance(value, str): @@ -134,11 +144,72 @@ class ToolParameterType(str, Enum): SELECT = "select" SECRET_INPUT = "secret-input" FILE = "file" + FILES = "files" + + # deprecated, should not use. + SYSTEM_FILES = "systme-files" + + def as_normal_type(self): + if self in { + ToolParameter.ToolParameterType.SECRET_INPUT, + ToolParameter.ToolParameterType.SELECT, + }: + return "string" + return self.value + + def cast_value(self, value: Any, /): + try: + match self: + case ( + ToolParameter.ToolParameterType.STRING + | ToolParameter.ToolParameterType.SECRET_INPUT + | ToolParameter.ToolParameterType.SELECT + ): + if value is None: + return "" + else: + return value if isinstance(value, str) else str(value) + + case ToolParameter.ToolParameterType.BOOLEAN: + if value is None: + return False + elif isinstance(value, str): + # Allowed YAML boolean value strings: https://yaml.org/type/bool.html + # and also '0' for False and '1' for True + match value.lower(): + case "true" | "yes" | "y" | "1": + return True + case "false" | "no" | "n" | "0": + return False + case _: + return bool(value) + else: + return value if isinstance(value, bool) else bool(value) + + case ToolParameter.ToolParameterType.NUMBER: + if isinstance(value, int | float): + return value + elif isinstance(value, str) and value: + if "." in value: + return float(value) + else: + return int(value) + case ( + ToolParameter.ToolParameterType.SYSTEM_FILES + | ToolParameter.ToolParameterType.FILE + | ToolParameter.ToolParameterType.FILES + ): + return value + case _: + return str(value) + + except Exception: + raise ValueError(f"The tool parameter value {value} is not in correct type.") class ToolParameterForm(Enum): - SCHEMA = "schema" # should be set while adding tool - FORM = "form" # should be set before invoking tool - LLM = "llm" # will be set by LLM + SCHEMA = "schema" # should be set while adding tool + FORM = "form" # should be set before invoking tool + LLM = "llm" # will be set by LLM name: str = Field(..., description="The name of the parameter") label: I18nObject = Field(..., description="The label presented to the user") @@ -148,31 +219,38 @@ class ToolParameterForm(Enum): form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm") llm_description: Optional[str] = None required: Optional[bool] = False - default: Optional[Union[int, str]] = None + default: Optional[Union[float, int, str]] = None min: Optional[Union[float, int]] = None max: Optional[Union[float, int]] = None options: Optional[list[ToolParameterOption]] = None @classmethod - def get_simple_instance(cls, - name: str, llm_description: str, type: ToolParameterType, - required: bool, options: Optional[list[str]] = None) -> 'ToolParameter': + def get_simple_instance( + cls, + name: str, + llm_description: str, + type: ToolParameterType, + required: bool, + options: Optional[list[str]] = None, + ) -> "ToolParameter": """ - get a simple tool parameter + get a simple tool parameter - :param name: the name of the parameter - :param llm_description: the description presented to the LLM - :param type: the type of the parameter - :param required: if the parameter is required - :param options: the options of the parameter + :param name: the name of the parameter + :param llm_description: the description presented to the LLM + :param type: the type of the parameter + :param required: if the parameter is required + :param options: the options of the parameter """ # convert options to ToolParameterOption if options: - options = [ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options] + options = [ + ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options + ] return cls( name=name, - label=I18nObject(en_US='', zh_Hans=''), - human_description=I18nObject(en_US='', zh_Hans=''), + label=I18nObject(en_US="", zh_Hans=""), + human_description=I18nObject(en_US="", zh_Hans=""), type=type, form=cls.ToolParameterForm.LLM, llm_description=llm_description, @@ -180,18 +258,24 @@ def get_simple_instance(cls, options=options, ) + class ToolProviderIdentity(BaseModel): author: str = Field(..., description="The author of the tool") name: str = Field(..., description="The name of the tool") description: I18nObject = Field(..., description="The description of the tool") icon: str = Field(..., description="The icon of the tool") label: I18nObject = Field(..., description="The label of the tool") - tags: Optional[list[ToolLabelEnum]] = Field(default=[], description="The tags of the tool", ) + tags: Optional[list[ToolLabelEnum]] = Field( + default=[], + description="The tags of the tool", + ) + class ToolDescription(BaseModel): human: I18nObject = Field(..., description="The description presented to the user") llm: str = Field(..., description="The description presented to the LLM") + class ToolIdentity(BaseModel): author: str = Field(..., description="The author of the tool") name: str = Field(..., description="The name of the tool") @@ -199,10 +283,12 @@ class ToolIdentity(BaseModel): provider: str = Field(..., description="The provider of the tool") icon: Optional[str] = None + class ToolCredentialsOption(BaseModel): value: str = Field(..., description="The value of the option") label: I18nObject = Field(..., description="The label of the option") + class ToolProviderCredentials(BaseModel): class CredentialsType(Enum): SECRET_INPUT = "secret-input" @@ -221,7 +307,7 @@ def value_of(cls, value: str) -> "ToolProviderCredentials.CredentialsType": for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") @staticmethod def default(value: str) -> str: @@ -239,33 +325,38 @@ def default(value: str) -> str: def to_dict(self) -> dict: return { - 'name': self.name, - 'type': self.type.value, - 'required': self.required, - 'default': self.default, - 'options': self.options, - 'help': self.help.to_dict() if self.help else None, - 'label': self.label.to_dict(), - 'url': self.url, - 'placeholder': self.placeholder.to_dict() if self.placeholder else None, + "name": self.name, + "type": self.type.value, + "required": self.required, + "default": self.default, + "options": self.options, + "help": self.help.to_dict() if self.help else None, + "label": self.label.to_dict(), + "url": self.url, + "placeholder": self.placeholder.to_dict() if self.placeholder else None, } + class ToolRuntimeVariableType(Enum): TEXT = "text" IMAGE = "image" + class ToolRuntimeVariable(BaseModel): type: ToolRuntimeVariableType = Field(..., description="The type of the variable") name: str = Field(..., description="The name of the variable") position: int = Field(..., description="The position of the variable") tool_name: str = Field(..., description="The name of the tool") + class ToolRuntimeTextVariable(ToolRuntimeVariable): value: str = Field(..., description="The value of the variable") + class ToolRuntimeImageVariable(ToolRuntimeVariable): value: str = Field(..., description="The path of the image") + class ToolRuntimeVariablePool(BaseModel): conversation_id: str = Field(..., description="The conversation id") user_id: str = Field(..., description="The user id") @@ -274,26 +365,26 @@ class ToolRuntimeVariablePool(BaseModel): pool: list[ToolRuntimeVariable] = Field(..., description="The pool of variables") def __init__(self, **data: Any): - pool = data.get('pool', []) + pool = data.get("pool", []) # convert pool into correct type for index, variable in enumerate(pool): - if variable['type'] == ToolRuntimeVariableType.TEXT.value: + if variable["type"] == ToolRuntimeVariableType.TEXT.value: pool[index] = ToolRuntimeTextVariable(**variable) - elif variable['type'] == ToolRuntimeVariableType.IMAGE.value: + elif variable["type"] == ToolRuntimeVariableType.IMAGE.value: pool[index] = ToolRuntimeImageVariable(**variable) super().__init__(**data) def dict(self) -> dict: return { - 'conversation_id': self.conversation_id, - 'user_id': self.user_id, - 'tenant_id': self.tenant_id, - 'pool': [variable.model_dump() for variable in self.pool], + "conversation_id": self.conversation_id, + "user_id": self.user_id, + "tenant_id": self.tenant_id, + "pool": [variable.model_dump() for variable in self.pool], } def set_text(self, tool_name: str, name: str, value: str) -> None: """ - set a text variable + set a text variable """ for variable in self.pool: if variable.name == name: @@ -312,12 +403,12 @@ def set_text(self, tool_name: str, name: str, value: str) -> None: self.pool.append(variable) - def set_file(self, tool_name: str, value: str, name: str = None) -> None: + def set_file(self, tool_name: str, value: str, name: Optional[str] = None) -> None: """ - set an image variable + set an image variable - :param tool_name: the name of the tool - :param value: the id of the file + :param tool_name: the name of the tool + :param value: the id of the file """ # check how many image variables are there image_variable_count = 0 @@ -345,22 +436,27 @@ def set_file(self, tool_name: str, value: str, name: str = None) -> None: self.pool.append(variable) + class ModelToolPropertyKey(Enum): IMAGE_PARAMETER_NAME = "image_parameter_name" + class ModelToolConfiguration(BaseModel): """ Model tool configuration """ + type: str = Field(..., description="The type of the model tool") model: str = Field(..., description="The model") label: I18nObject = Field(..., description="The label of the model tool") properties: dict[ModelToolPropertyKey, Any] = Field(..., description="The properties of the model tool") + class ModelToolProviderConfiguration(BaseModel): """ Model tool provider configuration """ + provider: str = Field(..., description="The provider of the model tool") models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool") label: I18nObject = Field(..., description="The label of the model tool") @@ -370,27 +466,30 @@ class WorkflowToolParameterConfiguration(BaseModel): """ Workflow tool configuration """ + name: str = Field(..., description="The name of the parameter") description: str = Field(..., description="The description of the parameter") form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter") + class ToolInvokeMeta(BaseModel): """ Tool invoke meta """ + time_cost: float = Field(..., description="The time cost of the tool invoke") error: Optional[str] = None tool_config: Optional[dict] = None @classmethod - def empty(cls) -> 'ToolInvokeMeta': + def empty(cls) -> "ToolInvokeMeta": """ Get an empty instance of ToolInvokeMeta """ return cls(time_cost=0.0, error=None, tool_config={}) @classmethod - def error_instance(cls, error: str) -> 'ToolInvokeMeta': + def error_instance(cls, error: str) -> "ToolInvokeMeta": """ Get an instance of ToolInvokeMeta with error """ @@ -398,22 +497,26 @@ def error_instance(cls, error: str) -> 'ToolInvokeMeta': def to_dict(self) -> dict: return { - 'time_cost': self.time_cost, - 'error': self.error, - 'tool_config': self.tool_config, + "time_cost": self.time_cost, + "error": self.error, + "tool_config": self.tool_config, } + class ToolLabel(BaseModel): """ Tool label """ + name: str = Field(..., description="The name of the tool") label: I18nObject = Field(..., description="The label of the tool") icon: str = Field(..., description="The icon of the tool") + class ToolInvokeFrom(Enum): """ Enum class for tool invoke """ + WORKFLOW = "workflow" AGENT = "agent" diff --git a/api/core/tools/entities/values.py b/api/core/tools/entities/values.py index d0be5e93557fe6..f460df7e25c916 100644 --- a/api/core/tools/entities/values.py +++ b/api/core/tools/entities/values.py @@ -2,73 +2,109 @@ from core.tools.entities.tool_entities import ToolLabel, ToolLabelEnum ICONS = { - ToolLabelEnum.SEARCH: ''' + ToolLabelEnum.SEARCH: """ -''', - ToolLabelEnum.IMAGE: ''' +""", # noqa: E501 + ToolLabelEnum.IMAGE: """ -''', - ToolLabelEnum.VIDEOS: ''' +""", # noqa: E501 + ToolLabelEnum.VIDEOS: """ -''', - ToolLabelEnum.WEATHER: ''' +""", # noqa: E501 + ToolLabelEnum.WEATHER: """ -''', - ToolLabelEnum.FINANCE: ''' +""", # noqa: E501 + ToolLabelEnum.FINANCE: """ -''', - ToolLabelEnum.DESIGN: ''' +""", # noqa: E501 + ToolLabelEnum.DESIGN: """ -''', - ToolLabelEnum.TRAVEL: ''' +""", # noqa: E501 + ToolLabelEnum.TRAVEL: """ -''', - ToolLabelEnum.SOCIAL: ''' +""", # noqa: E501 + ToolLabelEnum.SOCIAL: """ -''', - ToolLabelEnum.NEWS: ''' +""", # noqa: E501 + ToolLabelEnum.NEWS: """ -''', - ToolLabelEnum.MEDICAL: ''' +""", # noqa: E501 + ToolLabelEnum.MEDICAL: """ -''', - ToolLabelEnum.PRODUCTIVITY: ''' +""", # noqa: E501 + ToolLabelEnum.PRODUCTIVITY: """ -''', - ToolLabelEnum.EDUCATION: ''' +""", # noqa: E501 + ToolLabelEnum.EDUCATION: """ -''', - ToolLabelEnum.BUSINESS: ''' +""", # noqa: E501 + ToolLabelEnum.BUSINESS: """ -''', - ToolLabelEnum.ENTERTAINMENT: ''' +""", # noqa: E501 + ToolLabelEnum.ENTERTAINMENT: """ -''', - ToolLabelEnum.UTILITIES: ''' +""", # noqa: E501 + ToolLabelEnum.UTILITIES: """ -''', - ToolLabelEnum.OTHER: ''' +""", # noqa: E501 + ToolLabelEnum.OTHER: """ -''' +""", # noqa: E501 } default_tool_label_dict = { - ToolLabelEnum.SEARCH: ToolLabel(name='search', label=I18nObject(en_US='Search', zh_Hans='搜索'), icon=ICONS[ToolLabelEnum.SEARCH]), - ToolLabelEnum.IMAGE: ToolLabel(name='image', label=I18nObject(en_US='Image', zh_Hans='图片'), icon=ICONS[ToolLabelEnum.IMAGE]), - ToolLabelEnum.VIDEOS: ToolLabel(name='videos', label=I18nObject(en_US='Videos', zh_Hans='视频'), icon=ICONS[ToolLabelEnum.VIDEOS]), - ToolLabelEnum.WEATHER: ToolLabel(name='weather', label=I18nObject(en_US='Weather', zh_Hans='天气'), icon=ICONS[ToolLabelEnum.WEATHER]), - ToolLabelEnum.FINANCE: ToolLabel(name='finance', label=I18nObject(en_US='Finance', zh_Hans='金融'), icon=ICONS[ToolLabelEnum.FINANCE]), - ToolLabelEnum.DESIGN: ToolLabel(name='design', label=I18nObject(en_US='Design', zh_Hans='设计'), icon=ICONS[ToolLabelEnum.DESIGN]), - ToolLabelEnum.TRAVEL: ToolLabel(name='travel', label=I18nObject(en_US='Travel', zh_Hans='旅行'), icon=ICONS[ToolLabelEnum.TRAVEL]), - ToolLabelEnum.SOCIAL: ToolLabel(name='social', label=I18nObject(en_US='Social', zh_Hans='社交'), icon=ICONS[ToolLabelEnum.SOCIAL]), - ToolLabelEnum.NEWS: ToolLabel(name='news', label=I18nObject(en_US='News', zh_Hans='新闻'), icon=ICONS[ToolLabelEnum.NEWS]), - ToolLabelEnum.MEDICAL: ToolLabel(name='medical', label=I18nObject(en_US='Medical', zh_Hans='医疗'), icon=ICONS[ToolLabelEnum.MEDICAL]), - ToolLabelEnum.PRODUCTIVITY: ToolLabel(name='productivity', label=I18nObject(en_US='Productivity', zh_Hans='生产力'), icon=ICONS[ToolLabelEnum.PRODUCTIVITY]), - ToolLabelEnum.EDUCATION: ToolLabel(name='education', label=I18nObject(en_US='Education', zh_Hans='教育'), icon=ICONS[ToolLabelEnum.EDUCATION]), - ToolLabelEnum.BUSINESS: ToolLabel(name='business', label=I18nObject(en_US='Business', zh_Hans='商业'), icon=ICONS[ToolLabelEnum.BUSINESS]), - ToolLabelEnum.ENTERTAINMENT: ToolLabel(name='entertainment', label=I18nObject(en_US='Entertainment', zh_Hans='娱乐'), icon=ICONS[ToolLabelEnum.ENTERTAINMENT]), - ToolLabelEnum.UTILITIES: ToolLabel(name='utilities', label=I18nObject(en_US='Utilities', zh_Hans='工具'), icon=ICONS[ToolLabelEnum.UTILITIES]), - ToolLabelEnum.OTHER: ToolLabel(name='other', label=I18nObject(en_US='Other', zh_Hans='其他'), icon=ICONS[ToolLabelEnum.OTHER]), + ToolLabelEnum.SEARCH: ToolLabel( + name="search", label=I18nObject(en_US="Search", zh_Hans="搜索"), icon=ICONS[ToolLabelEnum.SEARCH] + ), + ToolLabelEnum.IMAGE: ToolLabel( + name="image", label=I18nObject(en_US="Image", zh_Hans="图片"), icon=ICONS[ToolLabelEnum.IMAGE] + ), + ToolLabelEnum.VIDEOS: ToolLabel( + name="videos", label=I18nObject(en_US="Videos", zh_Hans="视频"), icon=ICONS[ToolLabelEnum.VIDEOS] + ), + ToolLabelEnum.WEATHER: ToolLabel( + name="weather", label=I18nObject(en_US="Weather", zh_Hans="天气"), icon=ICONS[ToolLabelEnum.WEATHER] + ), + ToolLabelEnum.FINANCE: ToolLabel( + name="finance", label=I18nObject(en_US="Finance", zh_Hans="金融"), icon=ICONS[ToolLabelEnum.FINANCE] + ), + ToolLabelEnum.DESIGN: ToolLabel( + name="design", label=I18nObject(en_US="Design", zh_Hans="设计"), icon=ICONS[ToolLabelEnum.DESIGN] + ), + ToolLabelEnum.TRAVEL: ToolLabel( + name="travel", label=I18nObject(en_US="Travel", zh_Hans="旅行"), icon=ICONS[ToolLabelEnum.TRAVEL] + ), + ToolLabelEnum.SOCIAL: ToolLabel( + name="social", label=I18nObject(en_US="Social", zh_Hans="社交"), icon=ICONS[ToolLabelEnum.SOCIAL] + ), + ToolLabelEnum.NEWS: ToolLabel( + name="news", label=I18nObject(en_US="News", zh_Hans="新闻"), icon=ICONS[ToolLabelEnum.NEWS] + ), + ToolLabelEnum.MEDICAL: ToolLabel( + name="medical", label=I18nObject(en_US="Medical", zh_Hans="医疗"), icon=ICONS[ToolLabelEnum.MEDICAL] + ), + ToolLabelEnum.PRODUCTIVITY: ToolLabel( + name="productivity", + label=I18nObject(en_US="Productivity", zh_Hans="生产力"), + icon=ICONS[ToolLabelEnum.PRODUCTIVITY], + ), + ToolLabelEnum.EDUCATION: ToolLabel( + name="education", label=I18nObject(en_US="Education", zh_Hans="教育"), icon=ICONS[ToolLabelEnum.EDUCATION] + ), + ToolLabelEnum.BUSINESS: ToolLabel( + name="business", label=I18nObject(en_US="Business", zh_Hans="商业"), icon=ICONS[ToolLabelEnum.BUSINESS] + ), + ToolLabelEnum.ENTERTAINMENT: ToolLabel( + name="entertainment", + label=I18nObject(en_US="Entertainment", zh_Hans="娱乐"), + icon=ICONS[ToolLabelEnum.ENTERTAINMENT], + ), + ToolLabelEnum.UTILITIES: ToolLabel( + name="utilities", label=I18nObject(en_US="Utilities", zh_Hans="工具"), icon=ICONS[ToolLabelEnum.UTILITIES] + ), + ToolLabelEnum.OTHER: ToolLabel( + name="other", label=I18nObject(en_US="Other", zh_Hans="其他"), icon=ICONS[ToolLabelEnum.OTHER] + ), } default_tool_labels = [v for k, v in default_tool_label_dict.items()] diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index 9fd8322db13741..6febf137b000f9 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -4,23 +4,30 @@ class ToolProviderNotFoundError(ValueError): pass + class ToolNotFoundError(ValueError): pass + class ToolParameterValidationError(ValueError): pass + class ToolProviderCredentialValidationError(ValueError): pass + class ToolNotSupportedError(ValueError): pass + class ToolInvokeError(ValueError): pass + class ToolApiSchemaError(ValueError): pass + class ToolEngineInvokeError(Exception): - meta: ToolInvokeMeta \ No newline at end of file + meta: ToolInvokeMeta diff --git a/api/core/tools/provider/_position.yaml b/api/core/tools/provider/_position.yaml index 25d9f403a0fbbe..d80974486dfcfd 100644 --- a/api/core/tools/provider/_position.yaml +++ b/api/core/tools/provider/_position.yaml @@ -1,34 +1,80 @@ - google - bing +- perplexity - duckduckgo - searchapi - serper - searxng +- websearch +- tavily +- stackexchange +- pubmed +- arxiv +- aws +- nominatim +- devdocs +- spider +- firecrawl +- brave +- crossref +- jina +- webscraper - dalle - azuredalle - stability -- wikipedia -- nominatim -- yahoo -- arxiv -- pubmed - stablediffusion -- webscraper -- jina +- cogview +- comfyui +- getimgai +- siliconflow +- spark +- stepfun +- xinference +- alphavantage +- yahoo +- openweather +- gaode - aippt +- chart - youtube +- did +- dingtalk +- discord +- feishu +- feishu_base +- feishu_document +- feishu_message +- feishu_wiki +- feishu_task +- feishu_calendar +- feishu_spreadsheet +- lark_base +- lark_document +- lark_message_and_group +- lark_wiki +- lark_task +- lark_calendar +- lark_spreadsheet +- slack +- twilio +- wecom +- wikipedia - code - wolframalpha - maths - github -- chart +- gitlab - time - vectorizer -- gaode -- wecom - qrcode -- dingtalk -- feishu -- feishu_base -- slack - tianditu +- aliyuque +- google_translate +- hap +- json_process +- judge0ce +- novitaai +- onebot +- regex +- trello +- vanna diff --git a/api/core/tools/provider/api_tool_provider.py b/api/core/tools/provider/api_tool_provider.py index ae80ad2114cce0..d99314e33a3204 100644 --- a/api/core/tools/provider/api_tool_provider.py +++ b/api/core/tools/provider/api_tool_provider.py @@ -1,4 +1,3 @@ - from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( @@ -18,85 +17,69 @@ class ApiToolProviderController(ToolProviderController): provider_id: str @staticmethod - def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController': + def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController": credentials_schema = { - 'auth_type': ToolProviderCredentials( - name='auth_type', + "auth_type": ToolProviderCredentials( + name="auth_type", required=True, type=ToolProviderCredentials.CredentialsType.SELECT, options=[ - ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='无')), - ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key')) + ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")), + ToolCredentialsOption(value="api_key", label=I18nObject(en_US="api_key", zh_Hans="api_key")), ], - default='none', - help=I18nObject( - en_US='The auth type of the api provider', - zh_Hans='api provider 的认证类型' - ) + default="none", + help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"), ) } if auth_type == ApiProviderAuthType.API_KEY: credentials_schema = { **credentials_schema, - 'api_key_header': ToolProviderCredentials( - name='api_key_header', + "api_key_header": ToolProviderCredentials( + name="api_key_header", required=False, - default='api_key', + default="api_key", type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, - help=I18nObject( - en_US='The header name of the api key', - zh_Hans='携带 api key 的 header 名称' - ) + help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"), ), - 'api_key_value': ToolProviderCredentials( - name='api_key_value', + "api_key_value": ToolProviderCredentials( + name="api_key_value", required=True, type=ToolProviderCredentials.CredentialsType.SECRET_INPUT, - help=I18nObject( - en_US='The api key', - zh_Hans='api key的值' - ) + help=I18nObject(en_US="The api key", zh_Hans="api key的值"), ), - 'api_key_header_prefix': ToolProviderCredentials( - name='api_key_header_prefix', + "api_key_header_prefix": ToolProviderCredentials( + name="api_key_header_prefix", required=False, - default='basic', + default="basic", type=ToolProviderCredentials.CredentialsType.SELECT, - help=I18nObject( - en_US='The prefix of the api key header', - zh_Hans='api key header 的前缀' - ), + help=I18nObject(en_US="The prefix of the api key header", zh_Hans="api key header 的前缀"), options=[ - ToolCredentialsOption(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')), - ToolCredentialsOption(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')), - ToolCredentialsOption(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom')) - ] - ) + ToolCredentialsOption(value="basic", label=I18nObject(en_US="Basic", zh_Hans="Basic")), + ToolCredentialsOption(value="bearer", label=I18nObject(en_US="Bearer", zh_Hans="Bearer")), + ToolCredentialsOption(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")), + ], + ), } elif auth_type == ApiProviderAuthType.NONE: pass else: - raise ValueError(f'invalid auth type {auth_type}') - - user_name = db_provider.user.name if db_provider.user_id else '' - - return ApiToolProviderController(**{ - 'identity': { - 'author': user_name, - 'name': db_provider.name, - 'label': { - 'en_US': db_provider.name, - 'zh_Hans': db_provider.name - }, - 'description': { - 'en_US': db_provider.description, - 'zh_Hans': db_provider.description + raise ValueError(f"invalid auth type {auth_type}") + + user_name = db_provider.user.name if db_provider.user_id else "" + + return ApiToolProviderController( + **{ + "identity": { + "author": user_name, + "name": db_provider.name, + "label": {"en_US": db_provider.name, "zh_Hans": db_provider.name}, + "description": {"en_US": db_provider.description, "zh_Hans": db_provider.description}, + "icon": db_provider.icon, }, - 'icon': db_provider.icon, - }, - 'credentials_schema': credentials_schema, - 'provider_id': db_provider.id or '', - }) + "credentials_schema": credentials_schema, + "provider_id": db_provider.id or "", + } + ) @property def provider_type(self) -> ToolProviderType: @@ -104,39 +87,35 @@ def provider_type(self) -> ToolProviderType: def _parse_tool_bundle(self, tool_bundle: ApiToolBundle) -> ApiTool: """ - parse tool bundle to tool + parse tool bundle to tool - :param tool_bundle: the tool bundle - :return: the tool + :param tool_bundle: the tool bundle + :return: the tool """ - return ApiTool(**{ - 'api_bundle': tool_bundle, - 'identity' : { - 'author': tool_bundle.author, - 'name': tool_bundle.operation_id, - 'label': { - 'en_US': tool_bundle.operation_id, - 'zh_Hans': tool_bundle.operation_id + return ApiTool( + **{ + "api_bundle": tool_bundle, + "identity": { + "author": tool_bundle.author, + "name": tool_bundle.operation_id, + "label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id}, + "icon": self.identity.icon, + "provider": self.provider_id, }, - 'icon': self.identity.icon, - 'provider': self.provider_id, - }, - 'description': { - 'human': { - 'en_US': tool_bundle.summary or '', - 'zh_Hans': tool_bundle.summary or '' + "description": { + "human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""}, + "llm": tool_bundle.summary or "", }, - 'llm': tool_bundle.summary or '' - }, - 'parameters' : tool_bundle.parameters if tool_bundle.parameters else [], - }) + "parameters": tool_bundle.parameters or [], + } + ) def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]: """ - load bundled tools + load bundled tools - :param tools: the bundled tools - :return: the tools + :param tools: the bundled tools + :return: the tools """ self.tools = [self._parse_tool_bundle(tool) for tool in tools] @@ -144,22 +123,23 @@ def load_bundled_tools(self, tools: list[ApiToolBundle]) -> list[ApiTool]: def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]: """ - fetch tools from database + fetch tools from database - :param user_id: the user id - :param tenant_id: the tenant id - :return: the tools + :param user_id: the user id + :param tenant_id: the tenant id + :return: the tools """ if self.tools is not None: return self.tools - + tools: list[Tool] = [] # get tenant api providers - db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == self.identity.name - ).all() + db_providers: list[ApiToolProvider] = ( + db.session.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.identity.name) + .all() + ) if db_providers and len(db_providers) != 0: for db_provider in db_providers: @@ -167,16 +147,16 @@ def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]: assistant_tool = self._parse_tool_bundle(tool) assistant_tool.is_team_authorization = True tools.append(assistant_tool) - + self.tools = tools return tools - + def get_tool(self, tool_name: str) -> ApiTool: """ - get tool by name + get tool by name - :param tool_name: the name of the tool - :return: the tool + :param tool_name: the name of the tool + :return: the tool """ if self.tools is None: self.get_tools() @@ -185,4 +165,4 @@ def get_tool(self, tool_name: str) -> ApiTool: if tool.identity.name == tool_name: return tool - raise ValueError(f'tool {tool_name} not found') \ No newline at end of file + raise ValueError(f"tool {tool_name} not found") diff --git a/api/core/tools/provider/app_tool_provider.py b/api/core/tools/provider/app_tool_provider.py index 2d472e0a93c866..09f328cd1fe65f 100644 --- a/api/core/tools/provider/app_tool_provider.py +++ b/api/core/tools/provider/app_tool_provider.py @@ -11,11 +11,12 @@ logger = logging.getLogger(__name__) + class AppToolProviderEntity(ToolProviderController): @property def provider_type(self) -> ToolProviderType: return ToolProviderType.APP - + def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None: pass @@ -23,9 +24,13 @@ def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) - pass def get_tools(self, user_id: str) -> list[Tool]: - db_tools: list[PublishedAppTool] = db.session.query(PublishedAppTool).filter( - PublishedAppTool.user_id == user_id, - ).all() + db_tools: list[PublishedAppTool] = ( + db.session.query(PublishedAppTool) + .filter( + PublishedAppTool.user_id == user_id, + ) + .all() + ) if not db_tools or len(db_tools) == 0: return [] @@ -34,23 +39,17 @@ def get_tools(self, user_id: str) -> list[Tool]: for db_tool in db_tools: tool = { - 'identity': { - 'author': db_tool.author, - 'name': db_tool.tool_name, - 'label': { - 'en_US': db_tool.tool_name, - 'zh_Hans': db_tool.tool_name - }, - 'icon': '' + "identity": { + "author": db_tool.author, + "name": db_tool.tool_name, + "label": {"en_US": db_tool.tool_name, "zh_Hans": db_tool.tool_name}, + "icon": "", }, - 'description': { - 'human': { - 'en_US': db_tool.description_i18n.en_US, - 'zh_Hans': db_tool.description_i18n.zh_Hans - }, - 'llm': db_tool.llm_description + "description": { + "human": {"en_US": db_tool.description_i18n.en_US, "zh_Hans": db_tool.description_i18n.zh_Hans}, + "llm": db_tool.llm_description, }, - 'parameters': [] + "parameters": [], } # get app from db app: App = db_tool.app @@ -64,52 +63,41 @@ def get_tools(self, user_id: str) -> list[Tool]: for input_form in user_input_form_list: # get type form_type = input_form.keys()[0] - default = input_form[form_type]['default'] - required = input_form[form_type]['required'] - label = input_form[form_type]['label'] - variable_name = input_form[form_type]['variable_name'] - options = input_form[form_type].get('options', []) - if form_type == 'paragraph' or form_type == 'text-input': - tool['parameters'].append(ToolParameter( - name=variable_name, - label=I18nObject( - en_US=label, - zh_Hans=label - ), - human_description=I18nObject( - en_US=label, - zh_Hans=label - ), - llm_description=label, - form=ToolParameter.ToolParameterForm.FORM, - type=ToolParameter.ToolParameterType.STRING, - required=required, - default=default - )) - elif form_type == 'select': - tool['parameters'].append(ToolParameter( - name=variable_name, - label=I18nObject( - en_US=label, - zh_Hans=label - ), - human_description=I18nObject( - en_US=label, - zh_Hans=label - ), - llm_description=label, - form=ToolParameter.ToolParameterForm.FORM, - type=ToolParameter.ToolParameterType.SELECT, - required=required, - default=default, - options=[ToolParameterOption( - value=option, - label=I18nObject( - en_US=option, - zh_Hans=option - ) - ) for option in options] - )) + default = input_form[form_type]["default"] + required = input_form[form_type]["required"] + label = input_form[form_type]["label"] + variable_name = input_form[form_type]["variable_name"] + options = input_form[form_type].get("options", []) + if form_type in {"paragraph", "text-input"}: + tool["parameters"].append( + ToolParameter( + name=variable_name, + label=I18nObject(en_US=label, zh_Hans=label), + human_description=I18nObject(en_US=label, zh_Hans=label), + llm_description=label, + form=ToolParameter.ToolParameterForm.FORM, + type=ToolParameter.ToolParameterType.STRING, + required=required, + default=default, + ) + ) + elif form_type == "select": + tool["parameters"].append( + ToolParameter( + name=variable_name, + label=I18nObject(en_US=label, zh_Hans=label), + human_description=I18nObject(en_US=label, zh_Hans=label), + llm_description=label, + form=ToolParameter.ToolParameterForm.FORM, + type=ToolParameter.ToolParameterType.SELECT, + required=required, + default=default, + options=[ + ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) + for option in options + ], + ) + ) tools.append(Tool(**tool)) - return tools \ No newline at end of file + return tools diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index ae806eaff4a032..5c10f72fdaed01 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -1,6 +1,6 @@ import os.path -from core.helper.position_helper import get_position_map, sort_by_position_map +from core.helper.position_helper import get_tool_position_map, sort_by_position_map from core.tools.entities.api_entities import UserToolProvider @@ -10,11 +10,11 @@ class BuiltinToolProviderSort: @classmethod def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: if not cls._position: - cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..')) + cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), "..")) def name_func(provider: UserToolProvider) -> str: return provider.name sorted_providers = sort_by_position_map(cls._position, providers, name_func) - return sorted_providers \ No newline at end of file + return sorted_providers diff --git a/api/core/tools/provider/builtin/aippt/aippt.py b/api/core/tools/provider/builtin/aippt/aippt.py index 25133c51df4ff3..e0cbbd2992a515 100644 --- a/api/core/tools/provider/builtin/aippt/aippt.py +++ b/api/core/tools/provider/builtin/aippt/aippt.py @@ -6,6 +6,6 @@ class AIPPTProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: - AIPPTGenerateTool._get_api_token(credentials, user_id='__dify_system__') + AIPPTGenerateTool._get_api_token(credentials, user_id="__dify_system__") except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/aippt/tools/aippt.py b/api/core/tools/provider/builtin/aippt/tools/aippt.py index 8d6883a3b114ac..38123f125ae974 100644 --- a/api/core/tools/provider/builtin/aippt/tools/aippt.py +++ b/api/core/tools/provider/builtin/aippt/tools/aippt.py @@ -4,7 +4,7 @@ from json import loads as json_loads from threading import Lock from time import sleep, time -from typing import Any, Optional +from typing import Any from httpx import get, post from requests import get as requests_get @@ -15,27 +15,27 @@ from core.tools.tool.builtin_tool import BuiltinTool -class AIPPTGenerateTool(BuiltinTool): +class AIPPTGenerateToolAdapter: """ A tool for generating a ppt """ - _api_base_url = URL('https://co.aippt.cn/api') + _api_base_url = URL("https://co.aippt.cn/api") _api_token_cache = {} - _api_token_cache_lock:Optional[Lock] = None _style_cache = {} - _style_cache_lock:Optional[Lock] = None + + _api_token_cache_lock = Lock() + _style_cache_lock = Lock() _task = {} _task_type_map = { - 'auto': 1, - 'markdown': 7, + "auto": 1, + "markdown": 7, } + _tool: BuiltinTool - def __init__(self, **kwargs: Any): - super().__init__(**kwargs) - self._api_token_cache_lock = Lock() - self._style_cache_lock = Lock() + def __init__(self, tool: BuiltinTool = None): + self._tool = tool def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: """ @@ -46,67 +46,58 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe tool_parameters (dict[str, Any]): The parameters for the tool Returns: - ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages. + ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, + which can be a single message or a list of messages. """ - title = tool_parameters.get('title', '') + title = tool_parameters.get("title", "") if not title: - return self.create_text_message('Please provide a title for the ppt') - - model = tool_parameters.get('model', 'aippt') + return self._tool.create_text_message("Please provide a title for the ppt") + + model = tool_parameters.get("model", "aippt") if not model: - return self.create_text_message('Please provide a model for the ppt') - - outline = tool_parameters.get('outline', '') + return self._tool.create_text_message("Please provide a model for the ppt") + + outline = tool_parameters.get("outline", "") # create task task_id = self._create_task( - type=self._task_type_map['auto' if not outline else 'markdown'], + type=self._task_type_map["auto" if not outline else "markdown"], title=title, content=outline, - user_id=user_id + user_id=user_id, ) # get suit - color = tool_parameters.get('color') - style = tool_parameters.get('style') + color: str = tool_parameters.get("color") + style: str = tool_parameters.get("style") - if color == '__default__': - color_id = '' + if color == "__default__": + color_id = "" else: - color_id = int(color.split('-')[1]) + color_id = int(color.split("-")[1]) - if style == '__default__': - style_id = '' + if style == "__default__": + style_id = "" else: - style_id = int(style.split('-')[1]) + style_id = int(style.split("-")[1]) suit_id = self._get_suit(style_id=style_id, colour_id=color_id) # generate outline if not outline: - self._generate_outline( - task_id=task_id, - model=model, - user_id=user_id - ) + self._generate_outline(task_id=task_id, model=model, user_id=user_id) # generate content - self._generate_content( - task_id=task_id, - model=model, - user_id=user_id - ) + self._generate_content(task_id=task_id, model=model, user_id=user_id) # generate ppt - _, ppt_url = self._generate_ppt( - task_id=task_id, - suit_id=suit_id, - user_id=user_id - ) + _, ppt_url = self._generate_ppt(task_id=task_id, suit_id=suit_id, user_id=user_id) - return self.create_text_message('''the ppt has been created successfully,''' - f'''the ppt url is {ppt_url}''' - '''please give the ppt url to user and direct user to download it.''') + return self._tool.create_text_message( + """the ppt has been created successfully,""" + f"""the ppt url is {ppt_url} .""" + """please give the ppt url to user and direct user to download it.""" + ) def _create_task(self, type: int, title: str, content: str, user_id: str) -> str: """ @@ -119,129 +110,121 @@ def _create_task(self, type: int, title: str, content: str, user_id: str) -> str :return: the task ID """ headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-channel": "", + "x-api-key": self._tool.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id), } response = post( - str(self._api_base_url / 'ai' / 'chat' / 'v2' / 'task'), + str(self._api_base_url / "ai" / "chat" / "v2" / "task"), headers=headers, - files={ - 'type': ('', str(type)), - 'title': ('', title), - 'content': ('', content) - } + files={"type": ("", str(type)), "title": ("", title), "content": ("", content)}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to create task: {response.get("msg")}') - return response.get('data', {}).get('id') - + return response.get("data", {}).get("id") + def _generate_outline(self, task_id: str, model: str, user_id: str) -> str: - api_url = self._api_base_url / 'ai' / 'chat' / 'outline' if model == 'aippt' else \ - self._api_base_url / 'ai' / 'chat' / 'wx' / 'outline' - api_url %= {'task_id': task_id} + api_url = ( + self._api_base_url / "ai" / "chat" / "outline" + if model == "aippt" + else self._api_base_url / "ai" / "chat" / "wx" / "outline" + ) + api_url %= {"task_id": task_id} headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-channel": "", + "x-api-key": self._tool.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id), } - response = requests_get( - url=api_url, - headers=headers, - stream=True, - timeout=(10, 60) - ) + response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60)) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - - outline = '' - for chunk in response.iter_lines(delimiter=b'\n\n'): + raise Exception(f"Failed to connect to aippt: {response.text}") + + outline = "" + for chunk in response.iter_lines(delimiter=b"\n\n"): if not chunk: continue - - event = '' - lines = chunk.decode('utf-8').split('\n') + + event = "" + lines = chunk.decode("utf-8").split("\n") for line in lines: - if line.startswith('event:'): + if line.startswith("event:"): event = line[6:] - elif line.startswith('data:'): + elif line.startswith("data:"): data = line[5:] - if event == 'message': + if event == "message": try: data = json_loads(data) - outline += data.get('content', '') + outline += data.get("content", "") except Exception as e: pass - elif event == 'close': + elif event == "close": break - elif event == 'error' or event == 'filter': - raise Exception(f'Failed to generate outline: {data}') - + elif event in {"error", "filter"}: + raise Exception(f"Failed to generate outline: {data}") + return outline - + def _generate_content(self, task_id: str, model: str, user_id: str) -> str: - api_url = self._api_base_url / 'ai' / 'chat' / 'content' if model == 'aippt' else \ - self._api_base_url / 'ai' / 'chat' / 'wx' / 'content' - api_url %= {'task_id': task_id} + api_url = ( + self._api_base_url / "ai" / "chat" / "content" + if model == "aippt" + else self._api_base_url / "ai" / "chat" / "wx" / "content" + ) + api_url %= {"task_id": task_id} headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-channel": "", + "x-api-key": self._tool.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id), } - response = requests_get( - url=api_url, - headers=headers, - stream=True, - timeout=(10, 60) - ) + response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60)) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - - if model == 'aippt': - content = '' - for chunk in response.iter_lines(delimiter=b'\n\n'): + raise Exception(f"Failed to connect to aippt: {response.text}") + + if model == "aippt": + content = "" + for chunk in response.iter_lines(delimiter=b"\n\n"): if not chunk: continue - - event = '' - lines = chunk.decode('utf-8').split('\n') + + event = "" + lines = chunk.decode("utf-8").split("\n") for line in lines: - if line.startswith('event:'): + if line.startswith("event:"): event = line[6:] - elif line.startswith('data:'): + elif line.startswith("data:"): data = line[5:] - if event == 'message': + if event == "message": try: data = json_loads(data) - content += data.get('content', '') + content += data.get("content", "") except Exception as e: pass - elif event == 'close': + elif event == "close": break - elif event == 'error' or event == 'filter': - raise Exception(f'Failed to generate content: {data}') - + elif event in {"error", "filter"}: + raise Exception(f"Failed to generate content: {data}") + return content - elif model == 'wenxin': + elif model == "wenxin": response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to generate content: {response.get("msg")}') - - return response.get('data', '') - - return '' + + return response.get("data", "") + + return "" def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]: """ @@ -252,83 +235,74 @@ def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]: :return: the cover url of the ppt and the ppt url """ headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-channel": "", + "x-api-key": self._tool.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id), } response = post( - str(self._api_base_url / 'design' / 'v2' / 'save'), + str(self._api_base_url / "design" / "v2" / "save"), headers=headers, - data={ - 'task_id': task_id, - 'template_id': suit_id - } + data={"task_id": task_id, "template_id": suit_id}, + timeout=(10, 60), ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to generate ppt: {response.get("msg")}') - - id = response.get('data', {}).get('id') - cover_url = response.get('data', {}).get('cover_url') + + id = response.get("data", {}).get("id") + cover_url = response.get("data", {}).get("cover_url") response = post( - str(self._api_base_url / 'download' / 'export' / 'file'), + str(self._api_base_url / "download" / "export" / "file"), headers=headers, - data={ - 'id': id, - 'format': 'ppt', - 'files_to_zip': False, - 'edit': True - } + data={"id": id, "format": "ppt", "files_to_zip": False, "edit": True}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to generate ppt: {response.get("msg")}') - - export_code = response.get('data') + + export_code = response.get("data") if not export_code: - raise Exception('Failed to generate ppt, the export code is empty') - + raise Exception("Failed to generate ppt, the export code is empty") + current_iteration = 0 while current_iteration < 50: # get ppt url response = post( - str(self._api_base_url / 'download' / 'export' / 'file' / 'result'), + str(self._api_base_url / "download" / "export" / "file" / "result"), headers=headers, - data={ - 'task_key': export_code - } + data={"task_key": export_code}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to generate ppt: {response.get("msg")}') - - if response.get('msg') == '导出中': + + if response.get("msg") == "导出中": current_iteration += 1 sleep(2) continue - - ppt_url = response.get('data', []) + + ppt_url = response.get("data", []) if len(ppt_url) == 0: - raise Exception('Failed to generate ppt, the ppt url is empty') - + raise Exception("Failed to generate ppt, the ppt url is empty") + return cover_url, ppt_url[0] - - raise Exception('Failed to generate ppt, the export is timeout') - + + raise Exception("Failed to generate ppt, the export is timeout") + @classmethod def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str: """ @@ -337,65 +311,55 @@ def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str: :param credentials: the credentials :return: the API token """ - access_key = credentials['aippt_access_key'] - secret_key = credentials['aippt_secret_key'] + access_key = credentials["aippt_access_key"] + secret_key = credentials["aippt_secret_key"] - cache_key = f'{access_key}#@#{user_id}' + cache_key = f"{access_key}#@#{user_id}" with cls._api_token_cache_lock: # clear expired tokens now = time() for key in list(cls._api_token_cache.keys()): - if cls._api_token_cache[key]['expire'] < now: + if cls._api_token_cache[key]["expire"] < now: del cls._api_token_cache[key] if cache_key in cls._api_token_cache: - return cls._api_token_cache[cache_key]['token'] - + return cls._api_token_cache[cache_key]["token"] + # get token headers = { - 'x-api-key': access_key, - 'x-timestamp': str(int(now)), - 'x-signature': cls._calculate_sign(access_key, secret_key, int(now)) + "x-api-key": access_key, + "x-timestamp": str(int(now)), + "x-signature": cls._calculate_sign(access_key, secret_key, int(now)), } - param = { - 'uid': user_id, - 'channel': '' - } + param = {"uid": user_id, "channel": ""} - response = get( - str(cls._api_base_url / 'grant' / 'token'), - params=param, - headers=headers - ) + response = get(str(cls._api_base_url / "grant" / "token"), params=param, headers=headers) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') + raise Exception(f"Failed to connect to aippt: {response.text}") response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to connect to aippt: {response.get("msg")}') - - token = response.get('data', {}).get('token') - expire = response.get('data', {}).get('time_expire') + + token = response.get("data", {}).get("token") + expire = response.get("data", {}).get("time_expire") with cls._api_token_cache_lock: - cls._api_token_cache[cache_key] = { - 'token': token, - 'expire': now + expire - } + cls._api_token_cache[cache_key] = {"token": token, "expire": now + expire} return token - @classmethod - def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str: + @staticmethod + def _calculate_sign(access_key: str, secret_key: str, timestamp: int) -> str: return b64encode( hmac_new( - key=secret_key.encode('utf-8'), - msg=f'GET@/api/grant/token/@{timestamp}'.encode(), - digestmod=sha1 + key=secret_key.encode("utf-8"), + msg=f"GET@/api/grant/token/@{timestamp}".encode(), + digestmod=sha1, ).digest() - ).decode('utf-8') + ).decode("utf-8") @classmethod def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]: @@ -408,47 +372,46 @@ def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[di # clear expired styles now = time() for key in list(cls._style_cache.keys()): - if cls._style_cache[key]['expire'] < now: + if cls._style_cache[key]["expire"] < now: del cls._style_cache[key] key = f'{credentials["aippt_access_key"]}#@#{user_id}' if key in cls._style_cache: - return cls._style_cache[key]['colors'], cls._style_cache[key]['styles'] + return cls._style_cache[key]["colors"], cls._style_cache[key]["styles"] headers = { - 'x-channel': '', - 'x-api-key': credentials['aippt_access_key'], - 'x-token': cls._get_api_token(credentials=credentials, user_id=user_id) + "x-channel": "", + "x-api-key": credentials["aippt_access_key"], + "x-token": cls._get_api_token(credentials=credentials, user_id=user_id), } - response = get( - str(cls._api_base_url / 'template_component' / 'suit' / 'select'), - headers=headers - ) + response = get(str(cls._api_base_url / "template_component" / "suit" / "select"), headers=headers) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to connect to aippt: {response.get("msg")}') - - colors = [{ - 'id': f'id-{item.get("id")}', - 'name': item.get('name'), - 'en_name': item.get('en_name', item.get('name')), - } for item in response.get('data', {}).get('colour') or []] - styles = [{ - 'id': f'id-{item.get("id")}', - 'name': item.get('title'), - } for item in response.get('data', {}).get('suit_style') or []] - with cls._style_cache_lock: - cls._style_cache[key] = { - 'colors': colors, - 'styles': styles, - 'expire': now + 60 * 60 + colors = [ + { + "id": f'id-{item.get("id")}', + "name": item.get("name"), + "en_name": item.get("en_name", item.get("name")), + } + for item in response.get("data", {}).get("colour") or [] + ] + styles = [ + { + "id": f'id-{item.get("id")}', + "name": item.get("title"), } + for item in response.get("data", {}).get("suit_style") or [] + ] + + with cls._style_cache_lock: + cls._style_cache[key] = {"colors": colors, "styles": styles, "expire": now + 60 * 60} return colors, styles @@ -459,44 +422,41 @@ def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]: :param credentials: the credentials :return: Tuple[list[dict[id, color]], list[dict[id, style]] """ - if not self.runtime.credentials.get('aippt_access_key') or not self.runtime.credentials.get('aippt_secret_key'): - raise Exception('Please provide aippt credentials') + if not self._tool.runtime.credentials.get("aippt_access_key") or not self._tool.runtime.credentials.get( + "aippt_secret_key" + ): + raise Exception("Please provide aippt credentials") + + return self._get_styles(credentials=self._tool.runtime.credentials, user_id=user_id) - return self._get_styles(credentials=self.runtime.credentials, user_id=user_id) - def _get_suit(self, style_id: int, colour_id: int) -> int: """ Get suit """ headers = { - 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id='__dify_system__') + "x-channel": "", + "x-api-key": self._tool.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id="__dify_system__"), } response = get( - str(self._api_base_url / 'template_component' / 'suit' / 'search'), + str(self._api_base_url / "template_component" / "suit" / "search"), headers=headers, - params={ - 'style_id': style_id, - 'colour_id': colour_id, - 'page': 1, - 'page_size': 1 - } + params={"style_id": style_id, "colour_id": colour_id, "page": 1, "page_size": 1}, ) if response.status_code != 200: - raise Exception(f'Failed to connect to aippt: {response.text}') - + raise Exception(f"Failed to connect to aippt: {response.text}") + response = response.json() - if response.get('code') != 0: + if response.get("code") != 0: raise Exception(f'Failed to connect to aippt: {response.get("msg")}') - - if len(response.get('data', {}).get('list') or []) > 0: - return response.get('data', {}).get('list')[0].get('id') - - raise Exception('Failed to get suit, the suit does not exist, please check the style and color') - + + if len(response.get("data", {}).get("list") or []) > 0: + return response.get("data", {}).get("list")[0].get("id") + + raise Exception("Failed to get suit, the suit does not exist, please check the style and color") + def get_runtime_parameters(self) -> list[ToolParameter]: """ Get runtime parameters @@ -504,43 +464,55 @@ def get_runtime_parameters(self) -> list[ToolParameter]: Override this method to add runtime parameters to the tool. """ try: - colors, styles = self.get_styles(user_id='__dify_system__') + colors, styles = self.get_styles(user_id="__dify_system__") except Exception as e: - colors, styles = [ - {'id': '-1', 'name': '__default__', 'en_name': '__default__'} - ], [ - {'id': '-1', 'name': '__default__', 'en_name': '__default__'} - ] + colors, styles = ( + [{"id": "-1", "name": "__default__", "en_name": "__default__"}], + [{"id": "-1", "name": "__default__", "en_name": "__default__"}], + ) return [ ToolParameter( - name='color', - label=I18nObject(zh_Hans='颜色', en_US='Color'), - human_description=I18nObject(zh_Hans='颜色', en_US='Color'), + name="color", + label=I18nObject(zh_Hans="颜色", en_US="Color"), + human_description=I18nObject(zh_Hans="颜色", en_US="Color"), type=ToolParameter.ToolParameterType.SELECT, form=ToolParameter.ToolParameterForm.FORM, required=False, - default=colors[0]['id'], + default=colors[0]["id"], options=[ ToolParameterOption( - value=color['id'], - label=I18nObject(zh_Hans=color['name'], en_US=color['en_name']) - ) for color in colors - ] + value=color["id"], label=I18nObject(zh_Hans=color["name"], en_US=color["en_name"]) + ) + for color in colors + ], ), ToolParameter( - name='style', - label=I18nObject(zh_Hans='风格', en_US='Style'), - human_description=I18nObject(zh_Hans='风格', en_US='Style'), + name="style", + label=I18nObject(zh_Hans="风格", en_US="Style"), + human_description=I18nObject(zh_Hans="风格", en_US="Style"), type=ToolParameter.ToolParameterType.SELECT, form=ToolParameter.ToolParameterForm.FORM, required=False, - default=styles[0]['id'], + default=styles[0]["id"], options=[ - ToolParameterOption( - value=style['id'], - label=I18nObject(zh_Hans=style['name'], en_US=style['name']) - ) for style in styles - ] + ToolParameterOption(value=style["id"], label=I18nObject(zh_Hans=style["name"], en_US=style["name"])) + for style in styles + ], ), - ] \ No newline at end of file + ] + + +class AIPPTGenerateTool(BuiltinTool): + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + return AIPPTGenerateToolAdapter(self)._invoke(user_id, tool_parameters) + + def get_runtime_parameters(self) -> list[ToolParameter]: + return AIPPTGenerateToolAdapter(self).get_runtime_parameters() + + @classmethod + def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str: + return AIPPTGenerateToolAdapter()._get_api_token(credentials, user_id) diff --git a/api/core/tools/provider/builtin/aliyuque/_assets/icon.svg b/api/core/tools/provider/builtin/aliyuque/_assets/icon.svg new file mode 100644 index 00000000000000..82b23ebbc66e68 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/_assets/icon.svg @@ -0,0 +1,32 @@ + + 绿 lgo + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/aliyuque/aliyuque.py b/api/core/tools/provider/builtin/aliyuque/aliyuque.py new file mode 100644 index 00000000000000..56eac1a4b570cf --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/aliyuque.py @@ -0,0 +1,19 @@ +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class AliYuqueProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + token = credentials.get("token") + if not token: + raise ToolProviderCredentialValidationError("token is required") + + try: + resp = AliYuqueTool.auth(token) + if resp and resp.get("data", {}).get("id"): + return + + raise ToolProviderCredentialValidationError(resp) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/aliyuque/aliyuque.yaml b/api/core/tools/provider/builtin/aliyuque/aliyuque.yaml new file mode 100644 index 00000000000000..73d39aa96cfd17 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/aliyuque.yaml @@ -0,0 +1,29 @@ +identity: + author: 佐井 + name: aliyuque + label: + en_US: yuque + zh_Hans: 语雀 + pt_BR: yuque + description: + en_US: Yuque, https://www.yuque.com. + zh_Hans: 语雀,https://www.yuque.com。 + pt_BR: Yuque, https://www.yuque.com. + icon: icon.svg + tags: + - productivity + - search +credentials_for_provider: + token: + type: secret-input + required: true + label: + en_US: Yuque Team Token + zh_Hans: 语雀团队Token + placeholder: + en_US: Please input your Yuque team token + zh_Hans: 请输入你的语雀团队Token + help: + en_US: Get Alibaba Yuque team token + zh_Hans: 先获取语雀团队Token + url: https://www.yuque.com/settings/tokens diff --git a/api/core/tools/provider/builtin/aliyuque/tools/base.py b/api/core/tools/provider/builtin/aliyuque/tools/base.py new file mode 100644 index 00000000000000..edfb9fea8ec453 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/base.py @@ -0,0 +1,42 @@ +from typing import Any + +import requests + + +class AliYuqueTool: + # yuque service url + server_url = "https://www.yuque.com" + + @staticmethod + def auth(token): + session = requests.Session() + session.headers.update({"Accept": "application/json", "X-Auth-Token": token}) + login = session.request("GET", AliYuqueTool.server_url + "/api/v2/user") + login.raise_for_status() + resp = login.json() + return resp + + def request(self, method: str, token, tool_parameters: dict[str, Any], path: str) -> str: + if not token: + raise Exception("token is required") + session = requests.Session() + session.headers.update({"accept": "application/json", "X-Auth-Token": token}) + new_params = {**tool_parameters} + + replacements = {k: v for k, v in new_params.items() if f"{{{k}}}" in path} + + for key, value in replacements.items(): + path = path.replace(f"{{{key}}}", str(value)) + del new_params[key] + + if method.upper() in {"POST", "PUT"}: + session.headers.update( + { + "Content-Type": "application/json", + } + ) + response = session.request(method.upper(), self.server_url + path, json=new_params) + else: + response = session.request(method, self.server_url + path, params=new_params) + response.raise_for_status() + return response.text diff --git a/api/core/tools/provider/builtin/aliyuque/tools/create_document.py b/api/core/tools/provider/builtin/aliyuque/tools/create_document.py new file mode 100644 index 00000000000000..01080fd1d57f4d --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/create_document.py @@ -0,0 +1,15 @@ +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class AliYuqueCreateDocumentTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + return self.create_text_message(self.request("POST", token, tool_parameters, "/api/v2/repos/{book_id}/docs")) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/create_document.yaml b/api/core/tools/provider/builtin/aliyuque/tools/create_document.yaml new file mode 100644 index 00000000000000..6ac8ae6696f330 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/create_document.yaml @@ -0,0 +1,99 @@ +identity: + name: aliyuque_create_document + author: 佐井 + label: + en_US: Create Document + zh_Hans: 创建文档 + icon: icon.svg +description: + human: + en_US: Creates a new document within a knowledge base without automatic addition to the table of contents. Requires a subsequent call to the "knowledge base directory update API". Supports setting visibility, format, and content. # 接口英文描述 + zh_Hans: 在知识库中创建新文档,但不会自动加入目录,需额外调用“知识库目录更新接口”。允许设置公开性、格式及正文内容。 + llm: Creates docs in a KB. + +parameters: + - name: book_id + type: string + required: true + form: llm + label: + en_US: Knowledge Base ID + zh_Hans: 知识库ID + human_description: + en_US: The unique identifier of the knowledge base where the document will be created. + zh_Hans: 文档将被创建的知识库的唯一标识。 + llm_description: ID of the target knowledge base. + + - name: title + type: string + required: false + form: llm + label: + en_US: Title + zh_Hans: 标题 + human_description: + en_US: The title of the document, defaults to 'Untitled' if not provided. + zh_Hans: 文档标题,默认为'无标题'如未提供。 + llm_description: Title of the document, defaults to 'Untitled'. + + - name: public + type: select + required: false + form: llm + options: + - value: 0 + label: + en_US: Private + zh_Hans: 私密 + - value: 1 + label: + en_US: Public + zh_Hans: 公开 + - value: 2 + label: + en_US: Enterprise-only + zh_Hans: 企业内公开 + label: + en_US: Visibility + zh_Hans: 公开性 + human_description: + en_US: Document visibility (0 Private, 1 Public, 2 Enterprise-only). + zh_Hans: 文档可见性(0 私密, 1 公开, 2 企业内公开)。 + llm_description: Doc visibility options, 0-private, 1-public, 2-enterprise. + + - name: format + type: select + required: false + form: llm + options: + - value: markdown + label: + en_US: markdown + zh_Hans: markdown + - value: html + label: + en_US: html + zh_Hans: html + - value: lake + label: + en_US: lake + zh_Hans: lake + label: + en_US: Content Format + zh_Hans: 内容格式 + human_description: + en_US: Format of the document content (markdown, HTML, Lake). + zh_Hans: 文档内容格式(markdown, HTML, Lake)。 + llm_description: Content format choices, markdown, HTML, Lake. + + - name: body + type: string + required: true + form: llm + label: + en_US: Body Content + zh_Hans: 正文内容 + human_description: + en_US: The actual content of the document. + zh_Hans: 文档的实际内容。 + llm_description: Content of the document. diff --git a/api/core/tools/provider/builtin/aliyuque/tools/delete_document.py b/api/core/tools/provider/builtin/aliyuque/tools/delete_document.py new file mode 100644 index 00000000000000..84237cec30c563 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/delete_document.py @@ -0,0 +1,17 @@ +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class AliYuqueDeleteDocumentTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + return self.create_text_message( + self.request("DELETE", token, tool_parameters, "/api/v2/repos/{book_id}/docs/{id}") + ) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/delete_document.yaml b/api/core/tools/provider/builtin/aliyuque/tools/delete_document.yaml new file mode 100644 index 00000000000000..dddd62d3048c35 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/delete_document.yaml @@ -0,0 +1,37 @@ +identity: + name: aliyuque_delete_document + author: 佐井 + label: + en_US: Delete Document + zh_Hans: 删除文档 + icon: icon.svg +description: + human: + en_US: Delete Document + zh_Hans: 根据id删除文档 + llm: Delete document. + +parameters: + - name: book_id + type: string + required: true + form: llm + label: + en_US: Knowledge Base ID + zh_Hans: 知识库ID + human_description: + en_US: The unique identifier of the knowledge base where the document will be created. + zh_Hans: 文档将被创建的知识库的唯一标识。 + llm_description: ID of the target knowledge base. + + - name: id + type: string + required: true + form: llm + label: + en_US: Document ID or Path + zh_Hans: 文档 ID or 路径 + human_description: + en_US: Document ID or path. + zh_Hans: 文档 ID or 路径。 + llm_description: Document ID or path. diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.py b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.py new file mode 100644 index 00000000000000..c23d30059a8424 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.py @@ -0,0 +1,17 @@ +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class AliYuqueDescribeBookIndexPageTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + return self.create_text_message( + self.request("GET", token, tool_parameters, "/api/v2/repos/{group_login}/{book_slug}/index_page") + ) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.yaml b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.yaml new file mode 100644 index 00000000000000..5e490725d18882 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.yaml @@ -0,0 +1,38 @@ +identity: + name: aliyuque_describe_book_index_page + author: 佐井 + label: + en_US: Get Repo Index Page + zh_Hans: 获取知识库首页 + icon: icon.svg + +description: + human: + en_US: Retrieves the homepage of a knowledge base within a group, supporting both book ID and group login with book slug access. + zh_Hans: 获取团队中知识库的首页信息,可通过书籍ID或团队登录名与书籍路径访问。 + llm: Fetches the knowledge base homepage using group and book identifiers with support for alternate access paths. + +parameters: + - name: group_login + type: string + required: true + form: llm + label: + en_US: Group Login + zh_Hans: 团队登录名 + human_description: + en_US: The login name of the group that owns the knowledge base. + zh_Hans: 拥有该知识库的团队登录名。 + llm_description: Team login identifier for the knowledge base owner. + + - name: book_slug + type: string + required: true + form: llm + label: + en_US: Book Slug + zh_Hans: 知识库路径 + human_description: + en_US: The unique slug representing the path of the knowledge base. + zh_Hans: 知识库的唯一路径标识。 + llm_description: Unique path identifier for the knowledge base. diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.py b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.py new file mode 100644 index 00000000000000..36f8c10d6fd79d --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.py @@ -0,0 +1,15 @@ +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class YuqueDescribeBookTableOfContentsTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> (Union)[ToolInvokeMessage, list[ToolInvokeMessage]]: + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + return self.create_text_message(self.request("GET", token, tool_parameters, "/api/v2/repos/{book_id}/toc")) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.yaml b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.yaml new file mode 100644 index 00000000000000..0a481b59ebedad --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.yaml @@ -0,0 +1,25 @@ +identity: + name: aliyuque_describe_book_table_of_contents + author: 佐井 + label: + en_US: Get Book's Table of Contents + zh_Hans: 获取知识库的目录 + icon: icon.svg +description: + human: + en_US: Get Book's Table of Contents. + zh_Hans: 获取知识库的目录。 + llm: Get Book's Table of Contents. + +parameters: + - name: book_id + type: string + required: true + form: llm + label: + en_US: Book ID + zh_Hans: 知识库 ID + human_description: + en_US: Book ID. + zh_Hans: 知识库 ID。 + llm_description: Book ID. diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.py b/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.py new file mode 100644 index 00000000000000..a69bf121f7e5ae --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.py @@ -0,0 +1,53 @@ +import json +from typing import Any, Union +from urllib.parse import urlparse + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class AliYuqueDescribeDocumentContentTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + new_params = {**tool_parameters} + token = new_params.pop("token") + if not token or token.lower() == "none": + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + new_params = {**tool_parameters} + url = new_params.pop("url") + if not url or not url.startswith("http"): + raise Exception("url is not valid") + + parsed_url = urlparse(url) + path_parts = parsed_url.path.strip("/").split("/") + if len(path_parts) < 3: + raise Exception("url is not correct") + doc_id = path_parts[-1] + book_slug = path_parts[-2] + group_id = path_parts[-3] + + new_params["group_login"] = group_id + new_params["book_slug"] = book_slug + index_page = json.loads( + self.request("GET", token, new_params, "/api/v2/repos/{group_login}/{book_slug}/index_page") + ) + book_id = index_page.get("data", {}).get("book", {}).get("id") + if not book_id: + raise Exception(f"can not parse book_id from {index_page}") + + new_params["book_id"] = book_id + new_params["id"] = doc_id + data = self.request("GET", token, new_params, "/api/v2/repos/{book_id}/docs/{id}") + data = json.loads(data) + body_only = tool_parameters.get("body_only") or "" + if body_only.lower() == "true": + return self.create_text_message(data.get("data").get("body")) + else: + raw = data.get("data") + del raw["body_lake"] + del raw["body_html"] + return self.create_text_message(json.dumps(data)) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.yaml b/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.yaml new file mode 100644 index 00000000000000..6116886a96b790 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.yaml @@ -0,0 +1,50 @@ +identity: + name: aliyuque_describe_document_content + author: 佐井 + label: + en_US: Fetch Document Content + zh_Hans: 获取文档内容 + icon: icon.svg + +description: + human: + en_US: Retrieves document content from Yuque based on the provided document URL, which can be a normal or shared link. + zh_Hans: 根据提供的语雀文档地址(支持正常链接或分享链接)获取文档内容。 + llm: Fetches Yuque document content given a URL. + +parameters: + - name: url + type: string + required: true + form: llm + label: + en_US: Document URL + zh_Hans: 文档地址 + human_description: + en_US: The URL of the document to retrieve content from, can be normal or shared. + zh_Hans: 需要获取内容的文档地址,可以是正常链接或分享链接。 + llm_description: URL of the Yuque document to fetch content. + + - name: body_only + type: string + required: false + form: llm + label: + en_US: return body content only + zh_Hans: 仅返回body内容 + human_description: + en_US: true:Body content only, false:Full response with metadata. + zh_Hans: true:仅返回body内容,不返回其他元数据,false:返回所有元数据。 + llm_description: true:Body content only, false:Full response with metadata. + + - name: token + type: secret-input + required: false + form: llm + label: + en_US: Yuque API Token + zh_Hans: 语雀接口Token + human_description: + en_US: The token for calling the Yuque API defaults to the Yuque token bound to the current tool if not provided. + zh_Hans: 调用语雀接口的token,如果不传则默认为当前工具绑定的语雀Token。 + llm_description: If the token for calling the Yuque API is not provided, it will default to the Yuque token bound to the current tool. diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.py b/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.py new file mode 100644 index 00000000000000..7a45684bed0498 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.py @@ -0,0 +1,17 @@ +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class AliYuqueDescribeDocumentsTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + return self.create_text_message( + self.request("GET", token, tool_parameters, "/api/v2/repos/{book_id}/docs/{id}") + ) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.yaml b/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.yaml new file mode 100644 index 00000000000000..0b14c1afba684d --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.yaml @@ -0,0 +1,38 @@ +identity: + name: aliyuque_describe_documents + author: 佐井 + label: + en_US: Get Doc Detail + zh_Hans: 获取文档详情 + icon: icon.svg + +description: + human: + en_US: Retrieves detailed information of a specific document identified by its ID or path within a knowledge base. + zh_Hans: 根据知识库ID和文档ID或路径获取文档详细信息。 + llm: Fetches detailed doc info using ID/path from a knowledge base; supports doc lookup in Yuque. + +parameters: + - name: book_id + type: string + required: true + form: llm + label: + en_US: Knowledge Base ID + zh_Hans: 知识库 ID + human_description: + en_US: Identifier for the knowledge base where the document resides. + zh_Hans: 文档所属知识库的唯一标识。 + llm_description: ID of the knowledge base holding the document. + + - name: id + type: string + required: true + form: llm + label: + en_US: Document ID or Path + zh_Hans: 文档 ID 或路径 + human_description: + en_US: The unique identifier or path of the document to retrieve. + zh_Hans: 需要获取的文档的ID或其在知识库中的路径。 + llm_description: Unique doc ID or its path for retrieval. diff --git a/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.py b/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.py new file mode 100644 index 00000000000000..ca0a3909f80709 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.py @@ -0,0 +1,21 @@ +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class YuqueDescribeBookTableOfContentsTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> (Union)[ToolInvokeMessage, list[ToolInvokeMessage]]: + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + + doc_ids = tool_parameters.get("doc_ids") + if doc_ids: + doc_ids = [int(doc_id.strip()) for doc_id in doc_ids.split(",")] + tool_parameters["doc_ids"] = doc_ids + + return self.create_text_message(self.request("PUT", token, tool_parameters, "/api/v2/repos/{book_id}/toc")) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.yaml b/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.yaml new file mode 100644 index 00000000000000..f85970348b1f13 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.yaml @@ -0,0 +1,222 @@ +identity: + name: aliyuque_update_book_table_of_contents + author: 佐井 + label: + en_US: Update Book's Table of Contents + zh_Hans: 更新知识库目录 + icon: icon.svg +description: + human: + en_US: Update Book's Table of Contents. + zh_Hans: 更新知识库目录。 + llm: Update Book's Table of Contents. + +parameters: + - name: book_id + type: string + required: true + form: llm + label: + en_US: Book ID + zh_Hans: 知识库 ID + human_description: + en_US: Book ID. + zh_Hans: 知识库 ID。 + llm_description: Book ID. + + - name: action + type: select + required: true + form: llm + options: + - value: appendNode + label: + en_US: appendNode + zh_Hans: appendNode + pt_BR: appendNode + - value: prependNode + label: + en_US: prependNode + zh_Hans: prependNode + pt_BR: prependNode + - value: editNode + label: + en_US: editNode + zh_Hans: editNode + pt_BR: editNode + - value: editNode + label: + en_US: removeNode + zh_Hans: removeNode + pt_BR: removeNode + label: + en_US: Action Type + zh_Hans: 操作 + human_description: + en_US: In the operation scenario, sibling node prepending is not supported, deleting a node doesn't remove associated documents, and node deletion has two modes, 'sibling' (delete current node) and 'child' (delete current node and its children). + zh_Hans: 操作,创建场景下不支持同级头插 prependNode,删除节点不会删除关联文档,删除节点时action_mode=sibling (删除当前节点), action_mode=child (删除当前节点及子节点) + llm_description: In the operation scenario, sibling node prepending is not supported, deleting a node doesn't remove associated documents, and node deletion has two modes, 'sibling' (delete current node) and 'child' (delete current node and its children). + + + - name: action_mode + type: select + required: false + form: llm + options: + - value: sibling + label: + en_US: sibling + zh_Hans: 同级 + pt_BR: sibling + - value: child + label: + en_US: child + zh_Hans: 子集 + pt_BR: child + label: + en_US: Action Type + zh_Hans: 操作 + human_description: + en_US: Operation mode (sibling:same level, child:child level). + zh_Hans: 操作模式 (sibling:同级, child:子级)。 + llm_description: Operation mode (sibling:same level, child:child level). + + - name: target_uuid + type: string + required: false + form: llm + label: + en_US: Target node UUID + zh_Hans: 目标节点 UUID + human_description: + en_US: Target node UUID, defaults to root node if left empty. + zh_Hans: 目标节点 UUID, 不填默认为根节点。 + llm_description: Target node UUID, defaults to root node if left empty. + + - name: node_uuid + type: string + required: false + form: llm + label: + en_US: Node UUID + zh_Hans: 操作节点 UUID + human_description: + en_US: Operation node UUID [required for move/update/delete]. + zh_Hans: 操作节点 UUID [移动/更新/删除必填]。 + llm_description: Operation node UUID [required for move/update/delete]. + + - name: doc_ids + type: string + required: false + form: llm + label: + en_US: Document IDs + zh_Hans: 文档id列表 + human_description: + en_US: Document IDs [required for creating documents], separate multiple IDs with ','. + zh_Hans: 文档 IDs [创建文档必填],多个用','分隔。 + llm_description: Document IDs [required for creating documents], separate multiple IDs with ','. + + + - name: type + type: select + required: false + form: llm + default: DOC + options: + - value: DOC + label: + en_US: DOC + zh_Hans: 文档 + pt_BR: DOC + - value: LINK + label: + en_US: LINK + zh_Hans: 链接 + pt_BR: LINK + - value: TITLE + label: + en_US: TITLE + zh_Hans: 分组 + pt_BR: TITLE + label: + en_US: Node type + zh_Hans: 操节点类型 + human_description: + en_US: Node type [required for creation] (DOC:document, LINK:external link, TITLE:group). + zh_Hans: 操节点类型 [创建必填] (DOC:文档, LINK:外链, TITLE:分组)。 + llm_description: Node type [required for creation] (DOC:document, LINK:external link, TITLE:group). + + - name: title + type: string + required: false + form: llm + label: + en_US: Node Name + zh_Hans: 节点名称 + human_description: + en_US: Node name [required for creating groups/external links]. + zh_Hans: 节点名称 [创建分组/外链必填]。 + llm_description: Node name [required for creating groups/external links]. + + - name: url + type: string + required: false + form: llm + label: + en_US: Node URL + zh_Hans: 节点URL + human_description: + en_US: Node URL [required for creating external links]. + zh_Hans: 节点 URL [创建外链必填]。 + llm_description: Node URL [required for creating external links]. + + + - name: open_window + type: select + required: false + form: llm + default: 0 + options: + - value: 0 + label: + en_US: DOC + zh_Hans: Current Page + pt_BR: DOC + - value: 1 + label: + en_US: LINK + zh_Hans: New Page + pt_BR: LINK + label: + en_US: Open in new window + zh_Hans: 是否新窗口打开 + human_description: + en_US: Open in new window [optional for external links] (0:open in current page, 1:open in new window). + zh_Hans: 是否新窗口打开 [外链选填] (0:当前页打开, 1:新窗口打开)。 + llm_description: Open in new window [optional for external links] (0:open in current page, 1:open in new window). + + + - name: visible + type: select + required: false + form: llm + default: 1 + options: + - value: 0 + label: + en_US: Invisible + zh_Hans: 隐藏 + pt_BR: Invisible + - value: 1 + label: + en_US: Visible + zh_Hans: 可见 + pt_BR: Visible + label: + en_US: Visibility + zh_Hans: 是否可见 + human_description: + en_US: Visibility (0:invisible, 1:visible). + zh_Hans: 是否可见 (0:不可见, 1:可见)。 + llm_description: Visibility (0:invisible, 1:visible). diff --git a/api/core/tools/provider/builtin/aliyuque/tools/update_document.py b/api/core/tools/provider/builtin/aliyuque/tools/update_document.py new file mode 100644 index 00000000000000..d7eba46ad968dd --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/update_document.py @@ -0,0 +1,17 @@ +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class AliYuqueUpdateDocumentTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + return self.create_text_message( + self.request("PUT", token, tool_parameters, "/api/v2/repos/{book_id}/docs/{id}") + ) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/update_document.yaml b/api/core/tools/provider/builtin/aliyuque/tools/update_document.yaml new file mode 100644 index 00000000000000..c2da6b179acdd9 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/update_document.yaml @@ -0,0 +1,87 @@ +identity: + name: aliyuque_update_document + author: 佐井 + label: + en_US: Update Document + zh_Hans: 更新文档 + icon: icon.svg +description: + human: + en_US: Update an existing document within a specified knowledge base by providing the document ID or path. + zh_Hans: 通过提供文档ID或路径,更新指定知识库中的现有文档。 + llm: Update doc in a knowledge base via ID/path. +parameters: + - name: book_id + type: string + required: true + form: llm + label: + en_US: Knowledge Base ID + zh_Hans: 知识库 ID + human_description: + en_US: The unique identifier of the knowledge base where the document resides. + zh_Hans: 文档所属知识库的ID。 + llm_description: ID of the knowledge base holding the doc. + - name: id + type: string + required: true + form: llm + label: + en_US: Document ID or Path + zh_Hans: 文档 ID 或 路径 + human_description: + en_US: The unique identifier or the path of the document to be updated. + zh_Hans: 要更新的文档的唯一ID或路径。 + llm_description: Doc's ID or path for update. + + - name: title + type: string + required: false + form: llm + label: + en_US: Title + zh_Hans: 标题 + human_description: + en_US: The title of the document, defaults to 'Untitled' if not provided. + zh_Hans: 文档标题,默认为'无标题'如未提供。 + llm_description: Title of the document, defaults to 'Untitled'. + + - name: format + type: select + required: false + form: llm + options: + - value: markdown + label: + en_US: markdown + zh_Hans: markdown + pt_BR: markdown + - value: html + label: + en_US: html + zh_Hans: html + pt_BR: html + - value: lake + label: + en_US: lake + zh_Hans: lake + pt_BR: lake + label: + en_US: Content Format + zh_Hans: 内容格式 + human_description: + en_US: Format of the document content (markdown, HTML, Lake). + zh_Hans: 文档内容格式(markdown, HTML, Lake)。 + llm_description: Content format choices, markdown, HTML, Lake. + + - name: body + type: string + required: true + form: llm + label: + en_US: Body Content + zh_Hans: 正文内容 + human_description: + en_US: The actual content of the document. + zh_Hans: 文档的实际内容。 + llm_description: Content of the document. diff --git a/api/core/tools/provider/builtin/alphavantage/_assets/icon.svg b/api/core/tools/provider/builtin/alphavantage/_assets/icon.svg new file mode 100644 index 00000000000000..785432943bc148 --- /dev/null +++ b/api/core/tools/provider/builtin/alphavantage/_assets/icon.svg @@ -0,0 +1,7 @@ + + + 形状结合 + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/alphavantage/alphavantage.py b/api/core/tools/provider/builtin/alphavantage/alphavantage.py new file mode 100644 index 00000000000000..a84630e5aa990a --- /dev/null +++ b/api/core/tools/provider/builtin/alphavantage/alphavantage.py @@ -0,0 +1,22 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.alphavantage.tools.query_stock import QueryStockTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class AlphaVantageProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + QueryStockTool().fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ).invoke( + user_id="", + tool_parameters={ + "code": "AAPL", # Apple Inc. + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/alphavantage/alphavantage.yaml b/api/core/tools/provider/builtin/alphavantage/alphavantage.yaml new file mode 100644 index 00000000000000..710510cfd8ed4a --- /dev/null +++ b/api/core/tools/provider/builtin/alphavantage/alphavantage.yaml @@ -0,0 +1,31 @@ +identity: + author: zhuhao + name: alphavantage + label: + en_US: AlphaVantage + zh_Hans: AlphaVantage + pt_BR: AlphaVantage + description: + en_US: AlphaVantage is an online platform that provides financial market data and APIs, making it convenient for individual investors and developers to access stock quotes, technical indicators, and stock analysis. + zh_Hans: AlphaVantage是一个在线平台,它提供金融市场数据和API,便于个人投资者和开发者获取股票报价、技术指标和股票分析。 + pt_BR: AlphaVantage is an online platform that provides financial market data and APIs, making it convenient for individual investors and developers to access stock quotes, technical indicators, and stock analysis. + icon: icon.svg + tags: + - finance +credentials_for_provider: + api_key: + type: secret-input + required: true + label: + en_US: AlphaVantage API key + zh_Hans: AlphaVantage API key + pt_BR: AlphaVantage API key + placeholder: + en_US: Please input your AlphaVantage API key + zh_Hans: 请输入你的 AlphaVantage API key + pt_BR: Please input your AlphaVantage API key + help: + en_US: Get your AlphaVantage API key from AlphaVantage + zh_Hans: 从 AlphaVantage 获取您的 AlphaVantage API key + pt_BR: Get your AlphaVantage API key from AlphaVantage + url: https://www.alphavantage.co/support/#api-key diff --git a/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py b/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py new file mode 100644 index 00000000000000..d06611acd05d1d --- /dev/null +++ b/api/core/tools/provider/builtin/alphavantage/tools/query_stock.py @@ -0,0 +1,48 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +ALPHAVANTAGE_API_URL = "https://www.alphavantage.co/query" + + +class QueryStockTool(BuiltinTool): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + stock_code = tool_parameters.get("code", "") + if not stock_code: + return self.create_text_message("Please tell me your stock code") + + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): + return self.create_text_message("Alpha Vantage API key is required.") + + params = { + "function": "TIME_SERIES_DAILY", + "symbol": stock_code, + "outputsize": "compact", + "datatype": "json", + "apikey": self.runtime.credentials["api_key"], + } + response = requests.get(url=ALPHAVANTAGE_API_URL, params=params) + response.raise_for_status() + result = self._handle_response(response.json()) + return self.create_json_message(result) + + def _handle_response(self, response: dict[str, Any]) -> dict[str, Any]: + result = response.get("Time Series (Daily)", {}) + if not result: + return {} + stock_result = {} + for k, v in result.items(): + stock_result[k] = {} + stock_result[k]["open"] = v.get("1. open") + stock_result[k]["high"] = v.get("2. high") + stock_result[k]["low"] = v.get("3. low") + stock_result[k]["close"] = v.get("4. close") + stock_result[k]["volume"] = v.get("5. volume") + return stock_result diff --git a/api/core/tools/provider/builtin/alphavantage/tools/query_stock.yaml b/api/core/tools/provider/builtin/alphavantage/tools/query_stock.yaml new file mode 100644 index 00000000000000..d89f34e373f9fa --- /dev/null +++ b/api/core/tools/provider/builtin/alphavantage/tools/query_stock.yaml @@ -0,0 +1,27 @@ +identity: + name: query_stock + author: zhuhao + label: + en_US: query_stock + zh_Hans: query_stock + pt_BR: query_stock +description: + human: + en_US: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol. + zh_Hans: 获取指定股票代码的每日开盘价、每日最高价、每日最低价、每日收盘价和每日交易量等信息。 + pt_BR: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol + llm: Retrieve information such as daily opening price, daily highest price, daily lowest price, daily closing price, and daily trading volume for a specified stock symbol +parameters: + - name: code + type: string + required: true + label: + en_US: stock code + zh_Hans: 股票代码 + pt_BR: stock code + human_description: + en_US: stock code + zh_Hans: 股票代码 + pt_BR: stock code + llm_description: stock code for query from alphavantage + form: llm diff --git a/api/core/tools/provider/builtin/arxiv/arxiv.py b/api/core/tools/provider/builtin/arxiv/arxiv.py index 707fc69be30cee..ebb2d1a8c47be9 100644 --- a/api/core/tools/provider/builtin/arxiv/arxiv.py +++ b/api/core/tools/provider/builtin/arxiv/arxiv.py @@ -11,11 +11,10 @@ def _validate_credentials(self, credentials: dict) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "John Doe", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/arxiv/arxiv.yaml b/api/core/tools/provider/builtin/arxiv/arxiv.yaml index d26993b3364ea1..25aec97bb795e3 100644 --- a/api/core/tools/provider/builtin/arxiv/arxiv.yaml +++ b/api/core/tools/provider/builtin/arxiv/arxiv.yaml @@ -4,9 +4,11 @@ identity: label: en_US: ArXiv zh_Hans: ArXiv + ja_JP: ArXiv description: en_US: Access to a vast repository of scientific papers and articles in various fields of research. zh_Hans: 访问各个研究领域大量科学论文和文章的存储库。 + ja_JP: 多様な研究分野の科学論文や記事の膨大なリポジトリへのアクセス。 icon: icon.svg tags: - search diff --git a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py index ce28373880ba18..2d65ba2d6f4389 100644 --- a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py +++ b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py @@ -8,6 +8,8 @@ from core.tools.tool.builtin_tool import BuiltinTool logger = logging.getLogger(__name__) + + class ArxivAPIWrapper(BaseModel): """Wrapper around ArxivAPI. @@ -86,11 +88,13 @@ def run(self, query: str) -> str: class ArxivSearchInput(BaseModel): query: str = Field(..., description="Search query.") - + + class ArxivSearchTool(BuiltinTool): """ A tool for searching articles on Arxiv. """ + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: """ Invokes the Arxiv search tool with the given user ID and tool parameters. @@ -100,15 +104,16 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe tool_parameters (dict[str, Any]): The parameters for the tool, including the 'query' parameter. Returns: - ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages. + ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, + which can be a single message or a list of messages. """ - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input query') - + return self.create_text_message("Please input query") + arxiv = ArxivAPIWrapper() - + response = arxiv.run(query) - + return self.create_text_message(self.summary(user_id=user_id, content=response)) diff --git a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.yaml b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.yaml index 7439a48658c6a0..afc1925df3b45a 100644 --- a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.yaml +++ b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.yaml @@ -4,10 +4,12 @@ identity: label: en_US: Arxiv Search zh_Hans: Arxiv 搜索 + ja_JP: Arxiv 検索 description: human: en_US: A tool for searching scientific papers and articles from the Arxiv repository. Input can be an Arxiv ID or an author's name. zh_Hans: 一个用于从Arxiv存储库搜索科学论文和文章的工具。 输入可以是Arxiv ID或作者姓名。 + ja_JP: Arxivリポジトリから科学論文や記事を検索するためのツールです。入力はArxiv IDまたは著者名にすることができます。 llm: A tool for searching scientific papers and articles from the Arxiv repository. Input can be an Arxiv ID or an author's name. parameters: - name: query @@ -16,8 +18,10 @@ parameters: label: en_US: Query string zh_Hans: 查询字符串 + ja_JP: クエリ文字列 human_description: en_US: The Arxiv ID or author's name used for searching. zh_Hans: 用于搜索的Arxiv ID或作者姓名。 + ja_JP: 検索に使用されるArxiv IDまたは著者名。 llm_description: The Arxiv ID or author's name used for searching. form: llm diff --git a/api/core/tools/provider/builtin/aws/aws.py b/api/core/tools/provider/builtin/aws/aws.py index 13ede9601509f5..f81b5dbd27d17c 100644 --- a/api/core/tools/provider/builtin/aws/aws.py +++ b/api/core/tools/provider/builtin/aws/aws.py @@ -11,15 +11,14 @@ def _validate_credentials(self, credentials: dict) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ - "sagemaker_endpoint" : "", + "sagemaker_endpoint": "", "query": "misaka mikoto", - "candidate_texts" : "hello$$$hello world", - "topk" : 5, - "aws_region" : "" + "candidate_texts": "hello$$$hello world", + "topk": 5, + "aws_region": "", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py index 9c006733bdd95d..a04f5c0fe9f1af 100644 --- a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py +++ b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py @@ -3,6 +3,7 @@ from typing import Any, Union import boto3 +from botocore.exceptions import BotoCoreError from pydantic import BaseModel, Field from core.tools.entities.tool_entities import ToolInvokeMessage @@ -11,40 +12,43 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + class GuardrailParameters(BaseModel): guardrail_id: str = Field(..., description="The identifier of the guardrail") guardrail_version: str = Field(..., description="The version of the guardrail") source: str = Field(..., description="The source of the content") text: str = Field(..., description="The text to apply the guardrail to") - aws_region: str = Field(default="us-east-1", description="AWS region for the Bedrock client") + aws_region: str = Field(..., description="AWS region for the Bedrock client") + class ApplyGuardrailTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the ApplyGuardrail tool """ try: # Validate and parse input parameters params = GuardrailParameters(**tool_parameters) - + # Initialize AWS client - bedrock_client = boto3.client('bedrock-runtime', region_name=params.aws_region) + bedrock_client = boto3.client("bedrock-runtime", region_name=params.aws_region) # Apply guardrail response = bedrock_client.apply_guardrail( guardrailIdentifier=params.guardrail_id, guardrailVersion=params.guardrail_version, source=params.source, - content=[{"text": {"text": params.text}}] + content=[{"text": {"text": params.text}}], ) + logger.info(f"Raw response from AWS: {json.dumps(response, indent=2)}") + # Check for empty response if not response: return self.create_text_message(text="Received empty response from AWS Bedrock.") - + # Process the result action = response.get("action", "No action specified") outputs = response.get("outputs", []) @@ -55,9 +59,12 @@ def _invoke(self, formatted_assessments = [] for assessment in assessments: for policy_type, policy_data in assessment.items(): - if isinstance(policy_data, dict) and 'topics' in policy_data: - for topic in policy_data['topics']: - formatted_assessments.append(f"Policy: {policy_type}, Topic: {topic['name']}, Type: {topic['type']}, Action: {topic['action']}") + if isinstance(policy_data, dict) and "topics" in policy_data: + for topic in policy_data["topics"]: + formatted_assessments.append( + f"Policy: {policy_type}, Topic: {topic['name']}, Type: {topic['type']}," + f" Action: {topic['action']}" + ) else: formatted_assessments.append(f"Policy: {policy_type}, Data: {policy_data}") @@ -65,19 +72,19 @@ def _invoke(self, result += f"Output: {output}\n " if formatted_assessments: result += "Assessments:\n " + "\n ".join(formatted_assessments) + "\n " -# result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}" + # result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}" return self.create_text_message(text=result) - except boto3.exceptions.BotoCoreError as e: - error_message = f'AWS service error: {str(e)}' + except BotoCoreError as e: + error_message = f"AWS service error: {str(e)}" logger.error(error_message, exc_info=True) return self.create_text_message(text=error_message) except json.JSONDecodeError as e: - error_message = f'JSON parsing error: {str(e)}' + error_message = f"JSON parsing error: {str(e)}" logger.error(error_message, exc_info=True) return self.create_text_message(text=error_message) except Exception as e: - error_message = f'An unexpected error occurred: {str(e)}' + error_message = f"An unexpected error occurred: {str(e)}" logger.error(error_message, exc_info=True) return self.create_text_message(text=error_message) diff --git a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.yaml b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.yaml index 2b7c8abb442f77..66044e4ea84fe1 100644 --- a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.yaml +++ b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.yaml @@ -54,3 +54,14 @@ parameters: zh_Hans: 用于请求护栏审查的内容,可以是用户输入或 LLM 输出。 llm_description: The content used for requesting guardrail review, which can be either user input or LLM output. form: llm + - name: aws_region + type: string + required: true + label: + en_US: AWS Region + zh_Hans: AWS 区域 + human_description: + en_US: Please enter the AWS region for the Bedrock client, for example 'us-east-1'. + zh_Hans: 请输入 Bedrock 客户端的 AWS 区域,例如 'us-east-1'。 + llm_description: Please enter the AWS region for the Bedrock client, for example 'us-east-1'. + form: form diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py index 005ba3deb53311..48755753ace7c1 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py +++ b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py @@ -11,78 +11,81 @@ class LambdaTranslateUtilsTool(BuiltinTool): lambda_client: Any = None def _invoke_lambda(self, text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name): - msg = { - "src_content":text_content, - "src_lang": src_lang, - "dest_lang":dest_lang, + msg = { + "src_content": text_content, + "src_lang": src_lang, + "dest_lang": dest_lang, "dictionary_id": dictionary_name, - "request_type" : request_type, - "model_id" : model_id + "request_type": request_type, + "model_id": model_id, } - invoke_response = self.lambda_client.invoke(FunctionName=lambda_name, - InvocationType='RequestResponse', - Payload=json.dumps(msg)) - response_body = invoke_response['Payload'] + invoke_response = self.lambda_client.invoke( + FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg) + ) + response_body = invoke_response["Payload"] response_str = response_body.read().decode("unicode_escape") return response_str - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ line = 0 try: if not self.lambda_client: - aws_region = tool_parameters.get('aws_region') + aws_region = tool_parameters.get("aws_region") if aws_region: self.lambda_client = boto3.client("lambda", region_name=aws_region) else: self.lambda_client = boto3.client("lambda") line = 1 - text_content = tool_parameters.get('text_content', '') + text_content = tool_parameters.get("text_content", "") if not text_content: - return self.create_text_message('Please input text_content') - + return self.create_text_message("Please input text_content") + line = 2 - src_lang = tool_parameters.get('src_lang', '') + src_lang = tool_parameters.get("src_lang", "") if not src_lang: - return self.create_text_message('Please input src_lang') - + return self.create_text_message("Please input src_lang") + line = 3 - dest_lang = tool_parameters.get('dest_lang', '') + dest_lang = tool_parameters.get("dest_lang", "") if not dest_lang: - return self.create_text_message('Please input dest_lang') - + return self.create_text_message("Please input dest_lang") + line = 4 - lambda_name = tool_parameters.get('lambda_name', '') + lambda_name = tool_parameters.get("lambda_name", "") if not lambda_name: - return self.create_text_message('Please input lambda_name') - + return self.create_text_message("Please input lambda_name") + line = 5 - request_type = tool_parameters.get('request_type', '') + request_type = tool_parameters.get("request_type", "") if not request_type: - return self.create_text_message('Please input request_type') - + return self.create_text_message("Please input request_type") + line = 6 - model_id = tool_parameters.get('model_id', '') + model_id = tool_parameters.get("model_id", "") if not model_id: - return self.create_text_message('Please input model_id') + return self.create_text_message("Please input model_id") line = 7 - dictionary_name = tool_parameters.get('dictionary_name', '') + dictionary_name = tool_parameters.get("dictionary_name", "") if not dictionary_name: - return self.create_text_message('Please input dictionary_name') - - result = self._invoke_lambda(text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name) + return self.create_text_message("Please input dictionary_name") + + result = self._invoke_lambda( + text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name + ) return self.create_text_message(text=result) except Exception as e: - return self.create_text_message(f'Exception {str(e)}, line : {line}') + return self.create_text_message(f"Exception {str(e)}, line : {line}") diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml index a35c9f49fb9720..3bb133c7ec8d16 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml +++ b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml @@ -10,7 +10,7 @@ description: human: en_US: A util tools for LLM translation, extra deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag zh_Hans: 大语言模型翻译工具(专词映射获取),需要在AWS上进行额外部署,可参考Github Repo - https://github.com/ybalbert001/dynamodb-rag - pt_BR: A util tools for LLM translation, specfic Lambda Function deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag + pt_BR: A util tools for LLM translation, specific Lambda Function deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag llm: A util tools for translation. parameters: - name: text_content diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py new file mode 100644 index 00000000000000..f43f3b6fe05694 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.py @@ -0,0 +1,70 @@ +import json +import logging +from typing import Any, Union + +import boto3 + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +console_handler = logging.StreamHandler() +logger.addHandler(console_handler) + + +class LambdaYamlToJsonTool(BuiltinTool): + lambda_client: Any = None + + def _invoke_lambda(self, lambda_name: str, yaml_content: str) -> str: + msg = {"body": yaml_content} + logger.info(json.dumps(msg)) + + invoke_response = self.lambda_client.invoke( + FunctionName=lambda_name, InvocationType="RequestResponse", Payload=json.dumps(msg) + ) + response_body = invoke_response["Payload"] + + response_str = response_body.read().decode("utf-8") + resp_json = json.loads(response_str) + + logger.info(resp_json) + if resp_json["statusCode"] != 200: + raise Exception(f"Invalid status code: {response_str}") + + return resp_json["body"] + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + try: + if not self.lambda_client: + aws_region = tool_parameters.get("aws_region") # todo: move aws_region out, and update client region + if aws_region: + self.lambda_client = boto3.client("lambda", region_name=aws_region) + else: + self.lambda_client = boto3.client("lambda") + + yaml_content = tool_parameters.get("yaml_content", "") + if not yaml_content: + return self.create_text_message("Please input yaml_content") + + lambda_name = tool_parameters.get("lambda_name", "") + if not lambda_name: + return self.create_text_message("Please input lambda_name") + logger.debug(f"{json.dumps(tool_parameters, indent=2, ensure_ascii=False)}") + + result = self._invoke_lambda(lambda_name, yaml_content) + logger.debug(result) + + return self.create_text_message(result) + except Exception as e: + return self.create_text_message(f"Exception: {str(e)}") + + console_handler.flush() diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.yaml b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.yaml new file mode 100644 index 00000000000000..919c285348df83 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/lambda_yaml_to_json.yaml @@ -0,0 +1,53 @@ +identity: + name: lambda_yaml_to_json + author: AWS + label: + en_US: LambdaYamlToJson + zh_Hans: LambdaYamlToJson + pt_BR: LambdaYamlToJson + icon: icon.svg +description: + human: + en_US: A tool to convert yaml to json using AWS Lambda. + zh_Hans: 将 YAML 转为 JSON 的工具(通过AWS Lambda)。 + pt_BR: A tool to convert yaml to json using AWS Lambda. + llm: A tool to convert yaml to json. +parameters: + - name: yaml_content + type: string + required: true + label: + en_US: YAML content to convert for + zh_Hans: YAML 内容 + pt_BR: YAML content to convert for + human_description: + en_US: YAML content to convert for + zh_Hans: YAML 内容 + pt_BR: YAML content to convert for + llm_description: YAML content to convert for + form: llm + - name: aws_region + type: string + required: false + label: + en_US: region of lambda + zh_Hans: Lambda 所在的region + pt_BR: region of lambda + human_description: + en_US: region of lambda + zh_Hans: Lambda 所在的region + pt_BR: region of lambda + llm_description: region of lambda + form: form + - name: lambda_name + type: string + required: false + label: + en_US: name of lambda + zh_Hans: Lambda 名称 + pt_BR: name of lambda + human_description: + en_US: name of lambda + zh_Hans: Lambda 名称 + pt_BR: name of lambda + form: form diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py index d4bc446e5b13d8..bffcd058b509bf 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py @@ -1,4 +1,5 @@ import json +import operator from typing import Any, Union import boto3 @@ -9,37 +10,33 @@ class SageMakerReRankTool(BuiltinTool): sagemaker_client: Any = None - sagemaker_endpoint:str = None - topk:int = None + sagemaker_endpoint: str = None + topk: int = None - def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str): - inputs = [query_input]*len(docs) + def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str): + inputs = [query_input] * len(docs) response_model = self.sagemaker_client.invoke_endpoint( EndpointName=rerank_endpoint, - Body=json.dumps( - { - "inputs": inputs, - "docs": docs - } - ), + Body=json.dumps({"inputs": inputs, "docs": docs}), ContentType="application/json", ) - json_str = response_model['Body'].read().decode('utf8') + json_str = response_model["Body"].read().decode("utf8") json_obj = json.loads(json_str) - scores = json_obj['scores'] + scores = json_obj["scores"] return scores if isinstance(scores, list) else [scores] - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ line = 0 try: if not self.sagemaker_client: - aws_region = tool_parameters.get('aws_region') + aws_region = tool_parameters.get("aws_region") if aws_region: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) else: @@ -47,25 +44,25 @@ def _invoke(self, line = 1 if not self.sagemaker_endpoint: - self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint') + self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint") line = 2 if not self.topk: - self.topk = tool_parameters.get('topk', 5) + self.topk = tool_parameters.get("topk", 5) line = 3 - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input query') - + return self.create_text_message("Please input query") + line = 4 - candidate_texts = tool_parameters.get('candidate_texts') + candidate_texts = tool_parameters.get("candidate_texts") if not candidate_texts: - return self.create_text_message('Please input candidate_texts') - + return self.create_text_message("Please input candidate_texts") + line = 5 candidate_docs = json.loads(candidate_texts) - docs = [ item.get('content') for item in candidate_docs ] + docs = [item.get("content") for item in candidate_docs] line = 6 scores = self._sagemaker_rerank(query_input=query, docs=docs, rerank_endpoint=self.sagemaker_endpoint) @@ -75,12 +72,10 @@ def _invoke(self, candidate_docs[idx]["score"] = scores[idx] line = 8 - sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x['score'], reverse=True) + sorted_candidate_docs = sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True) line = 9 - results_str = json.dumps(sorted_candidate_docs[:self.topk], ensure_ascii=False) - return self.create_text_message(text=results_str) - + return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]] + except Exception as e: - return self.create_text_message(f'Exception {str(e)}, line : {line}') - \ No newline at end of file + return self.create_text_message(f"Exception {str(e)}, line : {line}") diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py new file mode 100644 index 00000000000000..1fafe09b4d96bf --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py @@ -0,0 +1,101 @@ +import json +from enum import Enum +from typing import Any, Optional, Union + +import boto3 + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class TTSModelType(Enum): + PresetVoice = "PresetVoice" + CloneVoice = "CloneVoice" + CloneVoice_CrossLingual = "CloneVoice_CrossLingual" + InstructVoice = "InstructVoice" + + +class SageMakerTTSTool(BuiltinTool): + sagemaker_client: Any = None + sagemaker_endpoint: str = None + s3_client: Any = None + comprehend_client: Any = None + + def _detect_lang_code(self, content: str, map_dict: Optional[dict] = None): + map_dict = {"zh": "<|zh|>", "en": "<|en|>", "ja": "<|jp|>", "zh-TW": "<|yue|>", "ko": "<|ko|>"} + + response = self.comprehend_client.detect_dominant_language(Text=content) + language_code = response["Languages"][0]["LanguageCode"] + return map_dict.get(language_code, "<|zh|>") + + def _build_tts_payload( + self, + model_type: str, + content_text: str, + model_role: str, + prompt_text: str, + prompt_audio: str, + instruct_text: str, + ): + if model_type == TTSModelType.PresetVoice.value and model_role: + return {"tts_text": content_text, "role": model_role} + if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio: + return {"tts_text": content_text, "prompt_text": prompt_text, "prompt_audio": prompt_audio} + if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio: + lang_tag = self._detect_lang_code(content_text) + return {"tts_text": f"{content_text}", "prompt_audio": prompt_audio, "lang_tag": lang_tag} + if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role: + return {"tts_text": content_text, "role": model_role, "instruct_text": instruct_text} + + raise RuntimeError(f"Invalid params for {model_type}") + + def _invoke_sagemaker(self, payload: dict, endpoint: str): + response_model = self.sagemaker_client.invoke_endpoint( + EndpointName=endpoint, + Body=json.dumps(payload), + ContentType="application/json", + ) + json_str = response_model["Body"].read().decode("utf8") + json_obj = json.loads(json_str) + return json_obj + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + try: + if not self.sagemaker_client: + aws_region = tool_parameters.get("aws_region") + if aws_region: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + self.s3_client = boto3.client("s3", region_name=aws_region) + self.comprehend_client = boto3.client("comprehend", region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + self.s3_client = boto3.client("s3") + self.comprehend_client = boto3.client("comprehend") + + if not self.sagemaker_endpoint: + self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint") + + tts_text = tool_parameters.get("tts_text") + tts_infer_type = tool_parameters.get("tts_infer_type") + + voice = tool_parameters.get("voice") + mock_voice_audio = tool_parameters.get("mock_voice_audio") + mock_voice_text = tool_parameters.get("mock_voice_text") + voice_instruct_prompt = tool_parameters.get("voice_instruct_prompt") + payload = self._build_tts_payload( + tts_infer_type, tts_text, voice, mock_voice_text, mock_voice_audio, voice_instruct_prompt + ) + + result = self._invoke_sagemaker(payload, self.sagemaker_endpoint) + + return self.create_text_message(text=result["s3_presign_url"]) + + except Exception as e: + return self.create_text_message(f"Exception {str(e)}") diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.yaml b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.yaml new file mode 100644 index 00000000000000..a6a61dd4aa519a --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_tts.yaml @@ -0,0 +1,149 @@ +identity: + name: sagemaker_tts + author: AWS + label: + en_US: SagemakerTTS + zh_Hans: Sagemaker语音合成 + pt_BR: SagemakerTTS + icon: icon.svg +description: + human: + en_US: A tool for Speech synthesis - https://github.com/aws-samples/dify-aws-tool + zh_Hans: Sagemaker语音合成工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署脚本 + pt_BR: A tool for Speech synthesis. + llm: A tool for Speech synthesis. You can find deploy notebook on Github Repo - https://github.com/aws-samples/dify-aws-tool +parameters: + - name: sagemaker_endpoint + type: string + required: true + label: + en_US: sagemaker endpoint for tts + zh_Hans: 语音生成的SageMaker端点 + pt_BR: sagemaker endpoint for tts + human_description: + en_US: sagemaker endpoint for tts + zh_Hans: 语音生成的SageMaker端点 + pt_BR: sagemaker endpoint for tts + llm_description: sagemaker endpoint for tts + form: form + - name: tts_text + type: string + required: true + label: + en_US: tts text + zh_Hans: 语音合成原文 + pt_BR: tts text + human_description: + en_US: tts text + zh_Hans: 语音合成原文 + pt_BR: tts text + llm_description: tts text + form: llm + - name: tts_infer_type + type: select + required: false + label: + en_US: tts infer type + zh_Hans: 合成方式 + pt_BR: tts infer type + human_description: + en_US: tts infer type + zh_Hans: 合成方式 + pt_BR: tts infer type + llm_description: tts infer type + options: + - value: PresetVoice + label: + en_US: preset voice + zh_Hans: 预置音色 + - value: CloneVoice + label: + en_US: clone voice + zh_Hans: 克隆音色 + - value: CloneVoice_CrossLingual + label: + en_US: clone crossLingual voice + zh_Hans: 克隆音色(跨语言) + - value: InstructVoice + label: + en_US: instruct voice + zh_Hans: 指令音色 + form: form + - name: voice + type: select + required: false + label: + en_US: preset voice + zh_Hans: 预置音色 + pt_BR: preset voice + human_description: + en_US: preset voice + zh_Hans: 预置音色 + pt_BR: preset voice + llm_description: preset voice + options: + - value: 中文男 + label: + en_US: zh-cn male + zh_Hans: 中文男 + - value: 中文女 + label: + en_US: zh-cn female + zh_Hans: 中文女 + - value: 粤语女 + label: + en_US: zh-TW female + zh_Hans: 粤语女 + form: form + - name: mock_voice_audio + type: string + required: false + label: + en_US: clone voice link + zh_Hans: 克隆音频链接 + pt_BR: clone voice link + human_description: + en_US: clone voice link + zh_Hans: 克隆音频链接 + pt_BR: clone voice link + llm_description: clone voice link + form: llm + - name: mock_voice_text + type: string + required: false + label: + en_US: text of clone voice + zh_Hans: 克隆音频对应文本 + pt_BR: text of clone voice + human_description: + en_US: text of clone voice + zh_Hans: 克隆音频对应文本 + pt_BR: text of clone voice + llm_description: text of clone voice + form: llm + - name: voice_instruct_prompt + type: string + required: false + label: + en_US: instruct prompt for voice + zh_Hans: 音色指令文本 + pt_BR: instruct prompt for voice + human_description: + en_US: instruct prompt for voice + zh_Hans: 音色指令文本 + pt_BR: instruct prompt for voice + llm_description: instruct prompt for voice + form: llm + - name: aws_region + type: string + required: false + label: + en_US: region of sagemaker endpoint + zh_Hans: SageMaker 端点所在的region + pt_BR: region of sagemaker endpoint + human_description: + en_US: region of sagemaker endpoint + zh_Hans: SageMaker 端点所在的region + pt_BR: region of sagemaker endpoint + llm_description: region of sagemaker endpoint + form: form diff --git a/api/core/tools/provider/builtin/azuredalle/azuredalle.py b/api/core/tools/provider/builtin/azuredalle/azuredalle.py index 2981a54d3c5716..1fab0d03a28ff3 100644 --- a/api/core/tools/provider/builtin/azuredalle/azuredalle.py +++ b/api/core/tools/provider/builtin/azuredalle/azuredalle.py @@ -13,12 +13,8 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "prompt": "cute girl, blue eyes, white hair, anime style", - "size": "square", - "n": 1 - }, + user_id="", + tool_parameters={"prompt": "cute girl, blue eyes, white hair, anime style", "size": "square", "n": 1}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py index 2ffdd38b72bc22..cfa3cfb092803a 100644 --- a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py @@ -9,47 +9,48 @@ class DallE3Tool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ client = AzureOpenAI( - api_version=self.runtime.credentials['azure_openai_api_version'], - azure_endpoint=self.runtime.credentials['azure_openai_base_url'], - api_key=self.runtime.credentials['azure_openai_api_key'], + api_version=self.runtime.credentials["azure_openai_api_version"], + azure_endpoint=self.runtime.credentials["azure_openai_base_url"], + api_key=self.runtime.credentials["azure_openai_api_key"], ) SIZE_MAPPING = { - 'square': '1024x1024', - 'vertical': '1024x1792', - 'horizontal': '1792x1024', + "square": "1024x1024", + "vertical": "1024x1792", + "horizontal": "1792x1024", } # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') + return self.create_text_message("Please input prompt") # get size - size = SIZE_MAPPING[tool_parameters.get('size', 'square')] + size = SIZE_MAPPING[tool_parameters.get("size", "square")] # get n - n = tool_parameters.get('n', 1) + n = tool_parameters.get("n", 1) # get quality - quality = tool_parameters.get('quality', 'standard') - if quality not in ['standard', 'hd']: - return self.create_text_message('Invalid quality') + quality = tool_parameters.get("quality", "standard") + if quality not in {"standard", "hd"}: + return self.create_text_message("Invalid quality") # get style - style = tool_parameters.get('style', 'vivid') - if style not in ['natural', 'vivid']: - return self.create_text_message('Invalid style') + style = tool_parameters.get("style", "vivid") + if style not in {"natural", "vivid"}: + return self.create_text_message("Invalid style") # set extra body - seed_id = tool_parameters.get('seed_id', self._generate_random_id(8)) - extra_body = {'seed': seed_id} + seed_id = tool_parameters.get("seed_id", self._generate_random_id(8)) + extra_body = {"seed": seed_id} # call openapi dalle3 - model = self.runtime.credentials['azure_openai_api_model_name'] + model = self.runtime.credentials["azure_openai_api_model_name"] response = client.images.generate( prompt=prompt, model=model, @@ -58,21 +59,25 @@ def _invoke(self, extra_body=extra_body, style=style, quality=quality, - response_format='b64_json' + response_format="b64_json", ) result = [] for image in response.data: - result.append(self.create_blob_message(blob=b64decode(image.b64_json), - meta={'mime_type': 'image/png'}, - save_as=self.VARIABLE_KEY.IMAGE.value)) - result.append(self.create_text_message(f'\nGenerate image source to Seed ID: {seed_id}')) + result.append( + self.create_blob_message( + blob=b64decode(image.b64_json), + meta={"mime_type": "image/png"}, + save_as=self.VariableKey.IMAGE.value, + ) + ) + result.append(self.create_text_message(f"\nGenerate image source to Seed ID: {seed_id}")) return result @staticmethod def _generate_random_id(length=8): - characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' - random_id = ''.join(random.choices(characters, k=length)) + characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + random_id = "".join(random.choices(characters, k=length)) return random_id diff --git a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.yaml b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.yaml index 63a8c99d97f90c..e256748e8f7188 100644 --- a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.yaml +++ b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.yaml @@ -25,7 +25,7 @@ parameters: pt_BR: Prompt human_description: en_US: Image prompt, you can check the official documentation of DallE 3 - zh_Hans: 图像提示词,您可以查看DallE 3 的官方文档 + zh_Hans: 图像提示词,您可以查看 DallE 3 的官方文档 pt_BR: Imagem prompt, você pode verificar a documentação oficial do DallE 3 llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed form: llm diff --git a/api/core/tools/provider/builtin/baidu_translate/_assets/icon.png b/api/core/tools/provider/builtin/baidu_translate/_assets/icon.png new file mode 100644 index 00000000000000..8eb8f21513ba7d Binary files /dev/null and b/api/core/tools/provider/builtin/baidu_translate/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/baidu_translate/_baidu_translate_tool_base.py b/api/core/tools/provider/builtin/baidu_translate/_baidu_translate_tool_base.py new file mode 100644 index 00000000000000..ce907c3c616e07 --- /dev/null +++ b/api/core/tools/provider/builtin/baidu_translate/_baidu_translate_tool_base.py @@ -0,0 +1,11 @@ +from hashlib import md5 + + +class BaiduTranslateToolBase: + def _get_sign(self, appid, secret, salt, query): + """ + get baidu translate sign + """ + # concatenate the string in the order of appid+q+salt+secret + str = appid + query + salt + secret + return md5(str.encode("utf-8")).hexdigest() diff --git a/api/core/tools/provider/builtin/baidu_translate/baidu_translate.py b/api/core/tools/provider/builtin/baidu_translate/baidu_translate.py new file mode 100644 index 00000000000000..cccd2f8c8fc478 --- /dev/null +++ b/api/core/tools/provider/builtin/baidu_translate/baidu_translate.py @@ -0,0 +1,17 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.baidu_translate.tools.translate import BaiduTranslateTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class BaiduTranslateProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + BaiduTranslateTool().fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ).invoke(user_id="", tool_parameters={"q": "这是一段测试文本", "from": "auto", "to": "en"}) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/baidu_translate/baidu_translate.yaml b/api/core/tools/provider/builtin/baidu_translate/baidu_translate.yaml new file mode 100644 index 00000000000000..06dadeeefc9cde --- /dev/null +++ b/api/core/tools/provider/builtin/baidu_translate/baidu_translate.yaml @@ -0,0 +1,39 @@ +identity: + author: Xiao Ley + name: baidu_translate + label: + en_US: Baidu Translate + zh_Hans: 百度翻译 + description: + en_US: Translate text using Baidu + zh_Hans: 使用百度进行翻译 + icon: icon.png + tags: + - utilities +credentials_for_provider: + appid: + type: secret-input + required: true + label: + en_US: Baidu translate appid + zh_Hans: Baidu translate appid + placeholder: + en_US: Please input your Baidu translate appid + zh_Hans: 请输入你的百度翻译 appid + help: + en_US: Get your Baidu translate appid from Baidu translate + zh_Hans: 从百度翻译开放平台获取你的 appid + url: https://api.fanyi.baidu.com + secret: + type: secret-input + required: true + label: + en_US: Baidu translate secret + zh_Hans: Baidu translate secret + placeholder: + en_US: Please input your Baidu translate secret + zh_Hans: 请输入你的百度翻译 secret + help: + en_US: Get your Baidu translate secret from Baidu translate + zh_Hans: 从百度翻译开放平台获取你的 secret + url: https://api.fanyi.baidu.com diff --git a/api/core/tools/provider/builtin/baidu_translate/tools/fieldtranslate.py b/api/core/tools/provider/builtin/baidu_translate/tools/fieldtranslate.py new file mode 100644 index 00000000000000..bce259f31d772e --- /dev/null +++ b/api/core/tools/provider/builtin/baidu_translate/tools/fieldtranslate.py @@ -0,0 +1,78 @@ +import random +from hashlib import md5 +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.baidu_translate._baidu_translate_tool_base import BaiduTranslateToolBase +from core.tools.tool.builtin_tool import BuiltinTool + + +class BaiduFieldTranslateTool(BuiltinTool, BaiduTranslateToolBase): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + BAIDU_FIELD_TRANSLATE_URL = "https://fanyi-api.baidu.com/api/trans/vip/fieldtranslate" + + appid = self.runtime.credentials.get("appid", "") + if not appid: + raise ValueError("invalid baidu translate appid") + + secret = self.runtime.credentials.get("secret", "") + if not secret: + raise ValueError("invalid baidu translate secret") + + q = tool_parameters.get("q", "") + if not q: + raise ValueError("Please input text to translate") + + from_ = tool_parameters.get("from", "") + if not from_: + raise ValueError("Please select source language") + + to = tool_parameters.get("to", "") + if not to: + raise ValueError("Please select destination language") + + domain = tool_parameters.get("domain", "") + if not domain: + raise ValueError("Please select domain") + + salt = str(random.randint(32768, 16777215)) + sign = self._get_sign(appid, secret, salt, q, domain) + + headers = {"Content-Type": "application/x-www-form-urlencoded"} + params = { + "q": q, + "from": from_, + "to": to, + "appid": appid, + "salt": salt, + "domain": domain, + "sign": sign, + "needIntervene": 1, + } + try: + response = requests.post(BAIDU_FIELD_TRANSLATE_URL, headers=headers, data=params) + result = response.json() + + if "trans_result" in result: + result_text = result["trans_result"][0]["dst"] + else: + result_text = f'{result["error_code"]}: {result["error_msg"]}' + + return self.create_text_message(str(result_text)) + except requests.RequestException as e: + raise ValueError(f"Translation service error: {e}") + except Exception: + raise ValueError("Translation service error, please check the network") + + def _get_sign(self, appid, secret, salt, query, domain): + str = appid + query + salt + domain + secret + return md5(str.encode("utf-8")).hexdigest() diff --git a/api/core/tools/provider/builtin/baidu_translate/tools/fieldtranslate.yaml b/api/core/tools/provider/builtin/baidu_translate/tools/fieldtranslate.yaml new file mode 100644 index 00000000000000..de51fddbaea422 --- /dev/null +++ b/api/core/tools/provider/builtin/baidu_translate/tools/fieldtranslate.yaml @@ -0,0 +1,123 @@ +identity: + name: field_translate + author: Xiao Ley + label: + en_US: Field translate + zh_Hans: 百度领域翻译 +description: + human: + en_US: A tool for Baidu Field translate (Currently, the fields of "novel" and "wiki" only support Chinese to English translation. If the language direction is set to English to Chinese, the default output will be a universal translation result). + zh_Hans: 百度领域翻译,提供多种领域的文本翻译(目前“网络文学领域”和“人文社科领域”仅支持中到英,如设置语言方向为英到中,则默认输出通用翻译结果) + llm: A tool for Baidu Field translate +parameters: + - name: q + type: string + required: true + label: + en_US: Text content + zh_Hans: 文本内容 + human_description: + en_US: Text content to be translated + zh_Hans: 需要翻译的文本内容 + llm_description: Text content to be translated + form: llm + - name: from + type: select + required: true + label: + en_US: source language + zh_Hans: 源语言 + human_description: + en_US: The source language of the input text + zh_Hans: 输入的文本的源语言 + default: auto + form: form + options: + - value: auto + label: + en_US: auto + zh_Hans: 自动检测 + - value: zh + label: + en_US: Chinese + zh_Hans: 中文 + - value: en + label: + en_US: English + zh_Hans: 英语 + - name: to + type: select + required: true + label: + en_US: destination language + zh_Hans: 目标语言 + human_description: + en_US: The destination language of the input text + zh_Hans: 输入文本的目标语言 + default: en + form: form + options: + - value: zh + label: + en_US: Chinese + zh_Hans: 中文 + - value: en + label: + en_US: English + zh_Hans: 英语 + - name: domain + type: select + required: true + label: + en_US: domain + zh_Hans: 领域 + human_description: + en_US: The domain of the input text + zh_Hans: 输入文本的领域 + default: novel + form: form + options: + - value: it + label: + en_US: it + zh_Hans: 信息技术领域 + - value: finance + label: + en_US: finance + zh_Hans: 金融财经领域 + - value: machinery + label: + en_US: machinery + zh_Hans: 机械制造领域 + - value: senimed + label: + en_US: senimed + zh_Hans: 生物医药领域 + - value: novel + label: + en_US: novel (only support Chinese to English translation) + zh_Hans: 网络文学领域(仅支持中到英) + - value: academic + label: + en_US: academic + zh_Hans: 学术论文领域 + - value: aerospace + label: + en_US: aerospace + zh_Hans: 航空航天领域 + - value: wiki + label: + en_US: wiki (only support Chinese to English translation) + zh_Hans: 人文社科领域(仅支持中到英) + - value: news + label: + en_US: news + zh_Hans: 新闻咨询领域 + - value: law + label: + en_US: law + zh_Hans: 法律法规领域 + - value: contract + label: + en_US: contract + zh_Hans: 合同领域 diff --git a/api/core/tools/provider/builtin/baidu_translate/tools/language.py b/api/core/tools/provider/builtin/baidu_translate/tools/language.py new file mode 100644 index 00000000000000..3bbaee88b3adf1 --- /dev/null +++ b/api/core/tools/provider/builtin/baidu_translate/tools/language.py @@ -0,0 +1,95 @@ +import random +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.baidu_translate._baidu_translate_tool_base import BaiduTranslateToolBase +from core.tools.tool.builtin_tool import BuiltinTool + + +class BaiduLanguageTool(BuiltinTool, BaiduTranslateToolBase): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + BAIDU_LANGUAGE_URL = "https://fanyi-api.baidu.com/api/trans/vip/language" + + appid = self.runtime.credentials.get("appid", "") + if not appid: + raise ValueError("invalid baidu translate appid") + + secret = self.runtime.credentials.get("secret", "") + if not secret: + raise ValueError("invalid baidu translate secret") + + q = tool_parameters.get("q", "") + if not q: + raise ValueError("Please input text to translate") + + description_language = tool_parameters.get("description_language", "English") + + salt = str(random.randint(32768, 16777215)) + sign = self._get_sign(appid, secret, salt, q) + + headers = {"Content-Type": "application/x-www-form-urlencoded"} + params = { + "q": q, + "appid": appid, + "salt": salt, + "sign": sign, + } + + try: + response = requests.post(BAIDU_LANGUAGE_URL, params=params, headers=headers) + result = response.json() + if "error_code" not in result: + raise ValueError("Translation service error, please check the network") + + result_text = "" + if result["error_code"] != 0: + result_text = f'{result["error_code"]}: {result["error_msg"]}' + else: + result_text = result["data"]["src"] + result_text = self.mapping_result(description_language, result_text) + + return self.create_text_message(result_text) + except requests.RequestException as e: + raise ValueError(f"Translation service error: {e}") + except Exception: + raise ValueError("Translation service error, please check the network") + + def mapping_result(self, description_language: str, result: str) -> str: + """ + mapping result + """ + mapping = { + "English": { + "zh": "Chinese", + "en": "English", + "jp": "Japanese", + "kor": "Korean", + "th": "Thai", + "vie": "Vietnamese", + "ru": "Russian", + }, + "Chinese": { + "zh": "中文", + "en": "英文", + "jp": "日文", + "kor": "韩文", + "th": "泰语", + "vie": "越南语", + "ru": "俄语", + }, + } + + language_mapping = mapping.get(description_language) + if not language_mapping: + return result + + return language_mapping.get(result, result) diff --git a/api/core/tools/provider/builtin/baidu_translate/tools/language.yaml b/api/core/tools/provider/builtin/baidu_translate/tools/language.yaml new file mode 100644 index 00000000000000..60cca2e288a622 --- /dev/null +++ b/api/core/tools/provider/builtin/baidu_translate/tools/language.yaml @@ -0,0 +1,43 @@ +identity: + name: language + author: Xiao Ley + label: + en_US: Baidu Language + zh_Hans: 百度语种识别 +description: + human: + en_US: A tool for Baidu Language, support Chinese, English, Japanese, Korean, Thai, Vietnamese and Russian + zh_Hans: 使用百度进行语种识别,支持的语种:中文、英语、日语、韩语、泰语、越南语和俄语 + llm: A tool for Baidu Language +parameters: + - name: q + type: string + required: true + label: + en_US: Text content + zh_Hans: 文本内容 + human_description: + en_US: Text content to be recognized + zh_Hans: 需要识别语言的文本内容 + llm_description: Text content to be recognized + form: llm + - name: description_language + type: select + required: true + label: + en_US: Description language + zh_Hans: 描述语言 + human_description: + en_US: Describe the language used to identify the results + zh_Hans: 描述识别结果所用的语言 + default: Chinese + form: form + options: + - value: Chinese + label: + en_US: Chinese + zh_Hans: 中文 + - value: English + label: + en_US: English + zh_Hans: 英语 diff --git a/api/core/tools/provider/builtin/baidu_translate/tools/translate.py b/api/core/tools/provider/builtin/baidu_translate/tools/translate.py new file mode 100644 index 00000000000000..7cd816a3bcd4da --- /dev/null +++ b/api/core/tools/provider/builtin/baidu_translate/tools/translate.py @@ -0,0 +1,67 @@ +import random +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.baidu_translate._baidu_translate_tool_base import BaiduTranslateToolBase +from core.tools.tool.builtin_tool import BuiltinTool + + +class BaiduTranslateTool(BuiltinTool, BaiduTranslateToolBase): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + BAIDU_TRANSLATE_URL = "https://fanyi-api.baidu.com/api/trans/vip/translate" + + appid = self.runtime.credentials.get("appid", "") + if not appid: + raise ValueError("invalid baidu translate appid") + + secret = self.runtime.credentials.get("secret", "") + if not secret: + raise ValueError("invalid baidu translate secret") + + q = tool_parameters.get("q", "") + if not q: + raise ValueError("Please input text to translate") + + from_ = tool_parameters.get("from", "") + if not from_: + raise ValueError("Please select source language") + + to = tool_parameters.get("to", "") + if not to: + raise ValueError("Please select destination language") + + salt = str(random.randint(32768, 16777215)) + sign = self._get_sign(appid, secret, salt, q) + + headers = {"Content-Type": "application/x-www-form-urlencoded"} + params = { + "q": q, + "from": from_, + "to": to, + "appid": appid, + "salt": salt, + "sign": sign, + } + try: + response = requests.post(BAIDU_TRANSLATE_URL, params=params, headers=headers) + result = response.json() + + if "trans_result" in result: + result_text = result["trans_result"][0]["dst"] + else: + result_text = f'{result["error_code"]}: {result["error_msg"]}' + + return self.create_text_message(str(result_text)) + except requests.RequestException as e: + raise ValueError(f"Translation service error: {e}") + except Exception: + raise ValueError("Translation service error, please check the network") diff --git a/api/core/tools/provider/builtin/baidu_translate/tools/translate.yaml b/api/core/tools/provider/builtin/baidu_translate/tools/translate.yaml new file mode 100644 index 00000000000000..c8ff32cb6bb1f1 --- /dev/null +++ b/api/core/tools/provider/builtin/baidu_translate/tools/translate.yaml @@ -0,0 +1,275 @@ +identity: + name: translate + author: Xiao Ley + label: + en_US: Translate + zh_Hans: 百度翻译 +description: + human: + en_US: A tool for Baidu Translate + zh_Hans: 百度翻译 + llm: A tool for Baidu Translate +parameters: + - name: q + type: string + required: true + label: + en_US: Text content + zh_Hans: 文本内容 + human_description: + en_US: Text content to be translated + zh_Hans: 需要翻译的文本内容 + llm_description: Text content to be translated + form: llm + - name: from + type: select + required: true + label: + en_US: source language + zh_Hans: 源语言 + human_description: + en_US: The source language of the input text + zh_Hans: 输入的文本的源语言 + default: auto + form: form + options: + - value: auto + label: + en_US: auto + zh_Hans: 自动检测 + - value: zh + label: + en_US: Chinese + zh_Hans: 中文 + - value: en + label: + en_US: English + zh_Hans: 英语 + - value: cht + label: + en_US: Traditional Chinese + zh_Hans: 繁体中文 + - value: yue + label: + en_US: Yue + zh_Hans: 粤语 + - value: wyw + label: + en_US: Wyw + zh_Hans: 文言文 + - value: jp + label: + en_US: Japanese + zh_Hans: 日语 + - value: kor + label: + en_US: Korean + zh_Hans: 韩语 + - value: fra + label: + en_US: French + zh_Hans: 法语 + - value: spa + label: + en_US: Spanish + zh_Hans: 西班牙语 + - value: th + label: + en_US: Thai + zh_Hans: 泰语 + - value: ara + label: + en_US: Arabic + zh_Hans: 阿拉伯语 + - value: ru + label: + en_US: Russian + zh_Hans: 俄语 + - value: pt + label: + en_US: Portuguese + zh_Hans: 葡萄牙语 + - value: de + label: + en_US: German + zh_Hans: 德语 + - value: it + label: + en_US: Italian + zh_Hans: 意大利语 + - value: el + label: + en_US: Greek + zh_Hans: 希腊语 + - value: nl + label: + en_US: Dutch + zh_Hans: 荷兰语 + - value: pl + label: + en_US: Polish + zh_Hans: 波兰语 + - value: bul + label: + en_US: Bulgarian + zh_Hans: 保加利亚语 + - value: est + label: + en_US: Estonian + zh_Hans: 爱沙尼亚语 + - value: dan + label: + en_US: Danish + zh_Hans: 丹麦语 + - value: fin + label: + en_US: Finnish + zh_Hans: 芬兰语 + - value: cs + label: + en_US: Czech + zh_Hans: 捷克语 + - value: rom + label: + en_US: Romanian + zh_Hans: 罗马尼亚语 + - value: slo + label: + en_US: Slovak + zh_Hans: 斯洛文尼亚语 + - value: swe + label: + en_US: Swedish + zh_Hans: 瑞典语 + - value: hu + label: + en_US: Hungarian + zh_Hans: 匈牙利语 + - value: vie + label: + en_US: Vietnamese + zh_Hans: 越南语 + - name: to + type: select + required: true + label: + en_US: destination language + zh_Hans: 目标语言 + human_description: + en_US: The destination language of the input text + zh_Hans: 输入文本的目标语言 + default: en + form: form + options: + - value: zh + label: + en_US: Chinese + zh_Hans: 中文 + - value: en + label: + en_US: English + zh_Hans: 英语 + - value: cht + label: + en_US: Traditional Chinese + zh_Hans: 繁体中文 + - value: yue + label: + en_US: Yue + zh_Hans: 粤语 + - value: wyw + label: + en_US: Wyw + zh_Hans: 文言文 + - value: jp + label: + en_US: Japanese + zh_Hans: 日语 + - value: kor + label: + en_US: Korean + zh_Hans: 韩语 + - value: fra + label: + en_US: French + zh_Hans: 法语 + - value: spa + label: + en_US: Spanish + zh_Hans: 西班牙语 + - value: th + label: + en_US: Thai + zh_Hans: 泰语 + - value: ara + label: + en_US: Arabic + zh_Hans: 阿拉伯语 + - value: ru + label: + en_US: Russian + zh_Hans: 俄语 + - value: pt + label: + en_US: Portuguese + zh_Hans: 葡萄牙语 + - value: de + label: + en_US: German + zh_Hans: 德语 + - value: it + label: + en_US: Italian + zh_Hans: 意大利语 + - value: el + label: + en_US: Greek + zh_Hans: 希腊语 + - value: nl + label: + en_US: Dutch + zh_Hans: 荷兰语 + - value: pl + label: + en_US: Polish + zh_Hans: 波兰语 + - value: bul + label: + en_US: Bulgarian + zh_Hans: 保加利亚语 + - value: est + label: + en_US: Estonian + zh_Hans: 爱沙尼亚语 + - value: dan + label: + en_US: Danish + zh_Hans: 丹麦语 + - value: fin + label: + en_US: Finnish + zh_Hans: 芬兰语 + - value: cs + label: + en_US: Czech + zh_Hans: 捷克语 + - value: rom + label: + en_US: Romanian + zh_Hans: 罗马尼亚语 + - value: slo + label: + en_US: Slovak + zh_Hans: 斯洛文尼亚语 + - value: swe + label: + en_US: Swedish + zh_Hans: 瑞典语 + - value: hu + label: + en_US: Hungarian + zh_Hans: 匈牙利语 + - value: vie + label: + en_US: Vietnamese + zh_Hans: 越南语 diff --git a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py index f85a5ed4722523..8bed2c556cf879 100644 --- a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py +++ b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py @@ -8,142 +8,135 @@ class BingSearchTool(BuiltinTool): - url: str = 'https://api.bing.microsoft.com/v7.0/search' - - def _invoke_bing(self, - user_id: str, - server_url: str, - subscription_key: str, query: str, limit: int, - result_type: str, market: str, lang: str, - filters: list[str]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + url: str = "https://api.bing.microsoft.com/v7.0/search" + + def _invoke_bing( + self, + user_id: str, + server_url: str, + subscription_key: str, + query: str, + limit: int, + result_type: str, + market: str, + lang: str, + filters: list[str], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke bing search + invoke bing search """ - market_code = f'{lang}-{market}' - accept_language = f'{lang},{market_code};q=0.9' - headers = { - 'Ocp-Apim-Subscription-Key': subscription_key, - 'Accept-Language': accept_language - } + market_code = f"{lang}-{market}" + accept_language = f"{lang},{market_code};q=0.9" + headers = {"Ocp-Apim-Subscription-Key": subscription_key, "Accept-Language": accept_language} query = quote(query) server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}' response = get(server_url, headers=headers) if response.status_code != 200: - raise Exception(f'Error {response.status_code}: {response.text}') - + raise Exception(f"Error {response.status_code}: {response.text}") + response = response.json() - search_results = response['webPages']['value'][:limit] if 'webPages' in response else [] - related_searches = response['relatedSearches']['value'] if 'relatedSearches' in response else [] - entities = response['entities']['value'] if 'entities' in response else [] - news = response['news']['value'] if 'news' in response else [] - computation = response['computation']['value'] if 'computation' in response else None + search_results = response["webPages"]["value"][:limit] if "webPages" in response else [] + related_searches = response["relatedSearches"]["value"] if "relatedSearches" in response else [] + entities = response["entities"]["value"] if "entities" in response else [] + news = response["news"]["value"] if "news" in response else [] + computation = response["computation"]["value"] if "computation" in response else None - if result_type == 'link': + if result_type == "link": results = [] if search_results: for result in search_results: url = f': {result["url"]}' if "url" in result else "" - results.append(self.create_text_message( - text=f'{result["name"]}{url}' - )) - + results.append(self.create_text_message(text=f'{result["name"]}{url}')) if entities: for entity in entities: url = f': {entity["url"]}' if "url" in entity else "" - results.append(self.create_text_message( - text=f'{entity.get("name", "")}{url}' - )) + results.append(self.create_text_message(text=f'{entity.get("name", "")}{url}')) if news: for news_item in news: url = f': {news_item["url"]}' if "url" in news_item else "" - results.append(self.create_text_message( - text=f'{news_item.get("name", "")}{url}' - )) + results.append(self.create_text_message(text=f'{news_item.get("name", "")}{url}')) if related_searches: for related in related_searches: url = f': {related["displayText"]}' if "displayText" in related else "" - results.append(self.create_text_message( - text=f'{related.get("displayText", "")}{url}' - )) - + results.append(self.create_text_message(text=f'{related.get("displayText", "")}{url}')) + return results else: # construct text - text = '' + text = "" if search_results: for i, result in enumerate(search_results): - text += f'{i+1}: {result.get("name", "")} - {result.get("snippet", "")}\n' + text += f'{i + 1}: {result.get("name", "")} - {result.get("snippet", "")}\n' - if computation and 'expression' in computation and 'value' in computation: - text += '\nComputation:\n' + if computation and "expression" in computation and "value" in computation: + text += "\nComputation:\n" text += f'{computation["expression"]} = {computation["value"]}\n' if entities: - text += '\nEntities:\n' + text += "\nEntities:\n" for entity in entities: url = f'- {entity["url"]}' if "url" in entity else "" text += f'{entity.get("name", "")}{url}\n' if news: - text += '\nNews:\n' + text += "\nNews:\n" for news_item in news: url = f'- {news_item["url"]}' if "url" in news_item else "" text += f'{news_item.get("name", "")}{url}\n' if related_searches: - text += '\n\nRelated Searches:\n' + text += "\n\nRelated Searches:\n" for related in related_searches: url = f'- {related["webSearchUrl"]}' if "webSearchUrl" in related else "" text += f'{related.get("displayText", "")}{url}\n' return self.create_text_message(text=self.summary(user_id=user_id, content=text)) - def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dict[str, Any]) -> None: - key = credentials.get('subscription_key') + key = credentials.get("subscription_key") if not key: - raise Exception('subscription_key is required') - - server_url = credentials.get('server_url') + raise Exception("subscription_key is required") + + server_url = credentials.get("server_url") if not server_url: server_url = self.url - query = tool_parameters.get('query') + query = tool_parameters.get("query") if not query: - raise Exception('query is required') - - limit = min(tool_parameters.get('limit', 5), 10) - result_type = tool_parameters.get('result_type', 'text') or 'text' + raise Exception("query is required") - market = tool_parameters.get('market', 'US') - lang = tool_parameters.get('language', 'en') + limit = min(tool_parameters.get("limit", 5), 10) + result_type = tool_parameters.get("result_type", "text") or "text" + + market = tool_parameters.get("market", "US") + lang = tool_parameters.get("language", "en") filter = [] - if credentials.get('allow_entities', False): - filter.append('Entities') + if credentials.get("allow_entities", False): + filter.append("Entities") - if credentials.get('allow_computation', False): - filter.append('Computation') + if credentials.get("allow_computation", False): + filter.append("Computation") - if credentials.get('allow_news', False): - filter.append('News') + if credentials.get("allow_news", False): + filter.append("News") - if credentials.get('allow_related_searches', False): - filter.append('RelatedSearches') + if credentials.get("allow_related_searches", False): + filter.append("RelatedSearches") - if credentials.get('allow_web_pages', False): - filter.append('WebPages') + if credentials.get("allow_web_pages", False): + filter.append("WebPages") if not filter: - raise Exception('At least one filter is required') - + raise Exception("At least one filter is required") + self._invoke_bing( - user_id='test', + user_id="test", server_url=server_url, subscription_key=key, query=query, @@ -151,50 +144,51 @@ def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dic result_type=result_type, market=market, lang=lang, - filters=filter + filters=filter, ) - - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - key = self.runtime.credentials.get('subscription_key', None) + key = self.runtime.credentials.get("subscription_key", None) if not key: - raise Exception('subscription_key is required') - - server_url = self.runtime.credentials.get('server_url', None) + raise Exception("subscription_key is required") + + server_url = self.runtime.credentials.get("server_url", None) if not server_url: server_url = self.url - - query = tool_parameters.get('query') + + query = tool_parameters.get("query") if not query: - raise Exception('query is required') - - limit = min(tool_parameters.get('limit', 5), 10) - result_type = tool_parameters.get('result_type', 'text') or 'text' - - market = tool_parameters.get('market', 'US') - lang = tool_parameters.get('language', 'en') + raise Exception("query is required") + + limit = min(tool_parameters.get("limit", 5), 10) + result_type = tool_parameters.get("result_type", "text") or "text" + + market = tool_parameters.get("market", "US") + lang = tool_parameters.get("language", "en") filter = [] - if tool_parameters.get('enable_computation', False): - filter.append('Computation') - if tool_parameters.get('enable_entities', False): - filter.append('Entities') - if tool_parameters.get('enable_news', False): - filter.append('News') - if tool_parameters.get('enable_related_search', False): - filter.append('RelatedSearches') - if tool_parameters.get('enable_webpages', False): - filter.append('WebPages') + if tool_parameters.get("enable_computation", False): + filter.append("Computation") + if tool_parameters.get("enable_entities", False): + filter.append("Entities") + if tool_parameters.get("enable_news", False): + filter.append("News") + if tool_parameters.get("enable_related_search", False): + filter.append("RelatedSearches") + if tool_parameters.get("enable_webpages", False): + filter.append("WebPages") if not filter: - raise Exception('At least one filter is required') - + raise Exception("At least one filter is required") + return self._invoke_bing( user_id=user_id, server_url=server_url, @@ -204,5 +198,5 @@ def _invoke(self, result_type=result_type, market=market, lang=lang, - filters=filter - ) \ No newline at end of file + filters=filter, + ) diff --git a/api/core/tools/provider/builtin/brave/brave.py b/api/core/tools/provider/builtin/brave/brave.py index e5eada80ee430f..c24ee67334083b 100644 --- a/api/core/tools/provider/builtin/brave/brave.py +++ b/api/core/tools/provider/builtin/brave/brave.py @@ -13,11 +13,10 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "Sachin Tendulkar", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/brave/brave.yaml b/api/core/tools/provider/builtin/brave/brave.yaml index 93d315f8390ea9..2b0dcc0188caf8 100644 --- a/api/core/tools/provider/builtin/brave/brave.yaml +++ b/api/core/tools/provider/builtin/brave/brave.yaml @@ -29,3 +29,11 @@ credentials_for_provider: zh_Hans: 从 Brave 获取您的 Brave Search API key pt_BR: Get your Brave Search API key from Brave url: https://brave.com/search/api/ + base_url: + type: text-input + required: false + label: + en_US: Brave server's Base URL + zh_Hans: Brave服务器的API URL + placeholder: + en_US: https://api.search.brave.com/res/v1/web/search diff --git a/api/core/tools/provider/builtin/brave/tools/brave_search.py b/api/core/tools/provider/builtin/brave/tools/brave_search.py index 21cbf2c7dae11d..c34362ae52ecac 100644 --- a/api/core/tools/provider/builtin/brave/tools/brave_search.py +++ b/api/core/tools/provider/builtin/brave/tools/brave_search.py @@ -7,6 +7,8 @@ from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool +BRAVE_BASE_URL = "https://api.search.brave.com/res/v1/web/search" + class BraveSearchWrapper(BaseModel): """Wrapper around the Brave search engine.""" @@ -15,8 +17,10 @@ class BraveSearchWrapper(BaseModel): """The API key to use for the Brave search engine.""" search_kwargs: dict = Field(default_factory=dict) """Additional keyword arguments to pass to the search request.""" - base_url: str = "https://api.search.brave.com/res/v1/web/search" + base_url: str = BRAVE_BASE_URL """The base URL for the Brave search engine.""" + ensure_ascii: bool = True + """Ensure the JSON output is ASCII encoded.""" def run(self, query: str) -> str: """Query the Brave search engine and return the results as a JSON string. @@ -36,8 +40,8 @@ def run(self, query: str) -> str: } for item in web_search_results ] - return json.dumps(final_results) - + return json.dumps(final_results, ensure_ascii=self.ensure_ascii) + def _search_request(self, query: str) -> list[dict]: headers = { "X-Subscription-Token": self.api_key, @@ -55,6 +59,7 @@ def _search_request(self, query: str) -> list[dict]: return response.json().get("web", {}).get("results", []) + class BraveSearch(BaseModel): """Tool that queries the BraveSearch.""" @@ -68,7 +73,7 @@ class BraveSearch(BaseModel): @classmethod def from_api_key( - cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any + cls, api_key: str, base_url: str, search_kwargs: Optional[dict] = None, ensure_ascii: bool = True, **kwargs: Any ) -> "BraveSearch": """Create a tool from an api key. @@ -80,7 +85,9 @@ def from_api_key( Returns: A tool. """ - wrapper = BraveSearchWrapper(api_key=api_key, search_kwargs=search_kwargs or {}) + wrapper = BraveSearchWrapper( + api_key=api_key, base_url=base_url, search_kwargs=search_kwargs or {}, ensure_ascii=ensure_ascii + ) return cls(search_wrapper=wrapper, **kwargs) def _run( @@ -90,6 +97,7 @@ def _run( """Use the tool.""" return self.search_wrapper.run(query) + class BraveSearchTool(BuiltinTool): """ Tool for performing a search using Brave search engine. @@ -106,14 +114,21 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe Returns: ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation. """ - query = tool_parameters.get('query', '') - count = tool_parameters.get('count', 3) - api_key = self.runtime.credentials['brave_search_api_key'] + query = tool_parameters.get("query", "") + count = tool_parameters.get("count", 3) + api_key = self.runtime.credentials["brave_search_api_key"] + base_url = self.runtime.credentials.get("base_url", BRAVE_BASE_URL) + ensure_ascii = tool_parameters.get("ensure_ascii", True) + + if len(base_url) == 0: + base_url = BRAVE_BASE_URL if not query: - return self.create_text_message('Please input query') + return self.create_text_message("Please input query") - tool = BraveSearch.from_api_key(api_key=api_key, search_kwargs={"count": count}) + tool = BraveSearch.from_api_key( + api_key=api_key, base_url=base_url, search_kwargs={"count": count}, ensure_ascii=ensure_ascii + ) results = tool._run(query) @@ -121,4 +136,3 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe return self.create_text_message(f"No results found for '{query}' in Tavily") else: return self.create_text_message(text=results) - diff --git a/api/core/tools/provider/builtin/brave/tools/brave_search.yaml b/api/core/tools/provider/builtin/brave/tools/brave_search.yaml index b2a734c12d5f1c..5222a375f84cee 100644 --- a/api/core/tools/provider/builtin/brave/tools/brave_search.yaml +++ b/api/core/tools/provider/builtin/brave/tools/brave_search.yaml @@ -39,3 +39,15 @@ parameters: pt_BR: O número de resultados de pesquisa a serem retornados, permitindo que os usuários controlem a amplitude de sua saída de pesquisa. llm_description: Specifies the amount of search results to be displayed, offering users the ability to adjust the scope of their search findings. form: llm + - name: ensure_ascii + type: boolean + default: true + label: + en_US: Ensure ASCII + zh_Hans: 确保 ASCII + pt_BR: Ensure ASCII + human_description: + en_US: Ensure the JSON output is ASCII encoded + zh_Hans: 确保输出的 JSON 是 ASCII 编码 + pt_BR: Ensure the JSON output is ASCII encoded + form: form diff --git a/api/core/tools/provider/builtin/chart/chart.py b/api/core/tools/provider/builtin/chart/chart.py index 0865bc700ac91c..dfa3fbea6aaeb9 100644 --- a/api/core/tools/provider/builtin/chart/chart.py +++ b/api/core/tools/provider/builtin/chart/chart.py @@ -1,58 +1,37 @@ import matplotlib.pyplot as plt -from fontTools.ttLib import TTFont -from matplotlib.font_manager import findSystemFonts +from matplotlib.font_manager import FontProperties, fontManager -from core.tools.errors import ToolProviderCredentialValidationError -from core.tools.provider.builtin.chart.tools.line import LinearChartTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController -# use a business theme -plt.style.use('seaborn-v0_8-darkgrid') -plt.rcParams['axes.unicode_minus'] = False - -def init_fonts(): - fonts = findSystemFonts() - popular_unicode_fonts = [ - 'Arial Unicode MS', 'DejaVu Sans', 'DejaVu Sans Mono', 'DejaVu Serif', 'FreeMono', 'FreeSans', 'FreeSerif', - 'Liberation Mono', 'Liberation Sans', 'Liberation Serif', 'Noto Mono', 'Noto Sans', 'Noto Serif', 'Open Sans', - 'Roboto', 'Source Code Pro', 'Source Sans Pro', 'Source Serif Pro', 'Ubuntu', 'Ubuntu Mono' +def set_chinese_font(): + font_list = [ + "PingFang SC", + "SimHei", + "Microsoft YaHei", + "STSong", + "SimSun", + "Arial Unicode MS", + "Noto Sans CJK SC", + "Noto Sans CJK JP", ] - supported_fonts = [] - - for font_path in fonts: - try: - font = TTFont(font_path) - # get family name - family_name = font['name'].getName(1, 3, 1).toUnicode() - if family_name in popular_unicode_fonts: - supported_fonts.append(family_name) - except: - pass - - plt.rcParams['font.family'] = 'sans-serif' - # sort by order of popular_unicode_fonts - for font in popular_unicode_fonts: - if font in supported_fonts: - plt.rcParams['font.sans-serif'] = font - break - -init_fonts() + for font in font_list: + if font in fontManager.ttflist: + chinese_font = FontProperties(font) + if chinese_font.get_name() == font: + return chinese_font + + return FontProperties() + + +# use a business theme +plt.style.use("seaborn-v0_8-darkgrid") +plt.rcParams["axes.unicode_minus"] = False +font_properties = set_chinese_font() +plt.rcParams["font.family"] = font_properties.get_name() + class ChartProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: - try: - LinearChartTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).invoke( - user_id='', - tool_parameters={ - "data": "1,3,5,7,9,2,4,6,8,10", - }, - ) - except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file + pass diff --git a/api/core/tools/provider/builtin/chart/tools/bar.py b/api/core/tools/provider/builtin/chart/tools/bar.py index 749ec761c692ec..20ce5e138b5bfe 100644 --- a/api/core/tools/provider/builtin/chart/tools/bar.py +++ b/api/core/tools/provider/builtin/chart/tools/bar.py @@ -8,12 +8,13 @@ class BarChartTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - data = tool_parameters.get('data', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + data = tool_parameters.get("data", "") if not data: - return self.create_text_message('Please input data') - data = data.split(';') + return self.create_text_message("Please input data") + data = data.split(";") # if all data is int, convert to int if all(i.isdigit() for i in data): @@ -21,29 +22,29 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ else: data = [float(i) for i in data] - axis = tool_parameters.get('x_axis') or None + axis = tool_parameters.get("x_axis") or None if axis: - axis = axis.split(';') + axis = axis.split(";") if len(axis) != len(data): axis = None flg, ax = plt.subplots(figsize=(10, 8)) if axis: - axis = [label[:10] + '...' if len(label) > 10 else label for label in axis] - ax.set_xticklabels(axis, rotation=45, ha='right') - ax.bar(axis, data) + axis = [label[:10] + "..." if len(label) > 10 else label for label in axis] + ax.set_xticklabels(axis, rotation=45, ha="right") + # ensure all labels, including duplicates, are correctly displayed + ax.bar(range(len(data)), data) + ax.set_xticks(range(len(data))) else: ax.bar(range(len(data)), data) buf = io.BytesIO() - flg.savefig(buf, format='png') + flg.savefig(buf, format="png") buf.seek(0) plt.close(flg) return [ - self.create_text_message('the bar chart is saved as an image.'), - self.create_blob_message(blob=buf.read(), - meta={'mime_type': 'image/png'}) + self.create_text_message("the bar chart is saved as an image."), + self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}), ] - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/chart/tools/line.py b/api/core/tools/provider/builtin/chart/tools/line.py index 608bd6623cf71c..39e8caac7ef609 100644 --- a/api/core/tools/provider/builtin/chart/tools/line.py +++ b/api/core/tools/provider/builtin/chart/tools/line.py @@ -8,18 +8,19 @@ class LinearChartTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - data = tool_parameters.get('data', '') + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + data = tool_parameters.get("data", "") if not data: - return self.create_text_message('Please input data') - data = data.split(';') + return self.create_text_message("Please input data") + data = data.split(";") - axis = tool_parameters.get('x_axis') or None + axis = tool_parameters.get("x_axis") or None if axis: - axis = axis.split(';') + axis = axis.split(";") if len(axis) != len(data): axis = None @@ -32,20 +33,18 @@ def _invoke(self, flg, ax = plt.subplots(figsize=(10, 8)) if axis: - axis = [label[:10] + '...' if len(label) > 10 else label for label in axis] - ax.set_xticklabels(axis, rotation=45, ha='right') + axis = [label[:10] + "..." if len(label) > 10 else label for label in axis] + ax.set_xticklabels(axis, rotation=45, ha="right") ax.plot(axis, data) else: ax.plot(data) buf = io.BytesIO() - flg.savefig(buf, format='png') + flg.savefig(buf, format="png") buf.seek(0) plt.close(flg) return [ - self.create_text_message('the linear chart is saved as an image.'), - self.create_blob_message(blob=buf.read(), - meta={'mime_type': 'image/png'}) + self.create_text_message("the linear chart is saved as an image."), + self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}), ] - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/chart/tools/pie.py b/api/core/tools/provider/builtin/chart/tools/pie.py index 4c551229e98565..2c3b8a733eac9a 100644 --- a/api/core/tools/provider/builtin/chart/tools/pie.py +++ b/api/core/tools/provider/builtin/chart/tools/pie.py @@ -8,15 +8,16 @@ class PieChartTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - data = tool_parameters.get('data', '') + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + data = tool_parameters.get("data", "") if not data: - return self.create_text_message('Please input data') - data = data.split(';') - categories = tool_parameters.get('categories') or None + return self.create_text_message("Please input data") + data = data.split(";") + categories = tool_parameters.get("categories") or None # if all data is int, convert to int if all(i.isdigit() for i in data): @@ -27,7 +28,7 @@ def _invoke(self, flg, ax = plt.subplots() if categories: - categories = categories.split(';') + categories = categories.split(";") if len(categories) != len(data): categories = None @@ -37,12 +38,11 @@ def _invoke(self, ax.pie(data) buf = io.BytesIO() - flg.savefig(buf, format='png') + flg.savefig(buf, format="png") buf.seek(0) plt.close(flg) return [ - self.create_text_message('the pie chart is saved as an image.'), - self.create_blob_message(blob=buf.read(), - meta={'mime_type': 'image/png'}) - ] \ No newline at end of file + self.create_text_message("the pie chart is saved as an image."), + self.create_blob_message(blob=buf.read(), meta={"mime_type": "image/png"}), + ] diff --git a/api/core/tools/provider/builtin/code/tools/simple_code.py b/api/core/tools/provider/builtin/code/tools/simple_code.py index 37645bf0d0189d..632c9fc7f1451b 100644 --- a/api/core/tools/provider/builtin/code/tools/simple_code.py +++ b/api/core/tools/provider/builtin/code/tools/simple_code.py @@ -8,15 +8,15 @@ class SimpleCode(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: """ - invoke simple code + invoke simple code """ - language = tool_parameters.get('language', CodeLanguage.PYTHON3) - code = tool_parameters.get('code', '') + language = tool_parameters.get("language", CodeLanguage.PYTHON3) + code = tool_parameters.get("code", "") - if language not in [CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]: - raise ValueError(f'Only python3 and javascript are supported, not {language}') - - result = CodeExecutor.execute_code(language, '', code) + if language not in {CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT}: + raise ValueError(f"Only python3 and javascript are supported, not {language}") - return self.create_text_message(result) \ No newline at end of file + result = CodeExecutor.execute_code(language, "", code) + + return self.create_text_message(result) diff --git a/api/core/tools/provider/builtin/cogview/cogview.py b/api/core/tools/provider/builtin/cogview/cogview.py index 801817ec06ed36..6941ce86495693 100644 --- a/api/core/tools/provider/builtin/cogview/cogview.py +++ b/api/core/tools/provider/builtin/cogview/cogview.py @@ -1,4 +1,5 @@ -""" Provide the input parameters type for the cogview provider class """ +"""Provide the input parameters type for the cogview provider class""" + from typing import Any from core.tools.errors import ToolProviderCredentialValidationError @@ -7,7 +8,8 @@ class COGVIEWProvider(BuiltinToolProviderController): - """ cogview provider """ + """cogview provider""" + def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: CogView3Tool().fork_tool_runtime( @@ -15,13 +17,12 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "prompt": "一个城市在水晶瓶中欢快生活的场景,水彩画风格,展现出微观与珠宝般的美丽。", "size": "square", - "n": 1 + "n": 1, }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) from e - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/cogview/tools/cogvideo.py b/api/core/tools/provider/builtin/cogview/tools/cogvideo.py new file mode 100644 index 00000000000000..7f69e833cb9046 --- /dev/null +++ b/api/core/tools/provider/builtin/cogview/tools/cogvideo.py @@ -0,0 +1,24 @@ +from typing import Any, Union + +from zhipuai import ZhipuAI + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class CogVideoTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + client = ZhipuAI( + base_url=self.runtime.credentials["zhipuai_base_url"], + api_key=self.runtime.credentials["zhipuai_api_key"], + ) + if not tool_parameters.get("prompt") and not tool_parameters.get("image_url"): + return self.create_text_message("require at least one of prompt and image_url") + + response = client.videos.generations( + model="cogvideox", prompt=tool_parameters.get("prompt"), image_url=tool_parameters.get("image_url") + ) + + return self.create_json_message(response.dict()) diff --git a/api/core/tools/provider/builtin/cogview/tools/cogvideo.yaml b/api/core/tools/provider/builtin/cogview/tools/cogvideo.yaml new file mode 100644 index 00000000000000..3df0cfcea938fa --- /dev/null +++ b/api/core/tools/provider/builtin/cogview/tools/cogvideo.yaml @@ -0,0 +1,32 @@ +identity: + name: cogvideo + author: hjlarry + label: + en_US: CogVideo + zh_Hans: CogVideo 视频生成 +description: + human: + en_US: Use the CogVideox model provided by ZhipuAI to generate videos based on user prompts and images. + zh_Hans: 使用智谱cogvideox模型,根据用户输入的提示词和图片,生成视频。 + llm: A tool for generating videos. The input is user's prompt or image url or both of them, the output is a task id. You can use another tool with this task id to check the status and get the video. +parameters: + - name: prompt + type: string + label: + en_US: prompt + zh_Hans: 提示词 + human_description: + en_US: The prompt text used to generate video. + zh_Hans: 用于生成视频的提示词。 + llm_description: The prompt text used to generate video. Optional. + form: llm + - name: image_url + type: string + label: + en_US: image url + zh_Hans: 图片链接 + human_description: + en_US: The image url used to generate video. + zh_Hans: 输入一个图片链接,生成的视频将基于该图片和提示词。 + llm_description: The image url used to generate video. Optional. + form: llm diff --git a/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py b/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py new file mode 100644 index 00000000000000..a521f1c28a41b6 --- /dev/null +++ b/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.py @@ -0,0 +1,30 @@ +from typing import Any, Union + +import httpx +from zhipuai import ZhipuAI + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class CogVideoJobTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + client = ZhipuAI( + api_key=self.runtime.credentials["zhipuai_api_key"], + base_url=self.runtime.credentials["zhipuai_base_url"], + ) + + response = client.videos.retrieve_videos_result(id=tool_parameters.get("id")) + result = [self.create_json_message(response.dict())] + if response.task_status == "SUCCESS": + for item in response.video_result: + video_cover_image = self.create_image_message(item.cover_image_url) + result.append(video_cover_image) + video = self.create_blob_message( + blob=httpx.get(item.url).content, meta={"mime_type": "video/mp4"}, save_as=self.VariableKey.VIDEO + ) + result.append(video) + + return result diff --git a/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.yaml b/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.yaml new file mode 100644 index 00000000000000..fb2eb3ab130b81 --- /dev/null +++ b/api/core/tools/provider/builtin/cogview/tools/cogvideo_job.yaml @@ -0,0 +1,21 @@ +identity: + name: cogvideo_job + author: hjlarry + label: + en_US: CogVideo Result + zh_Hans: CogVideo 结果获取 +description: + human: + en_US: Get the result of CogVideo tool generation. + zh_Hans: 根据 CogVideo 工具返回的 id 获取视频生成结果。 + llm: Get the result of CogVideo tool generation. The input is the id which is returned by the CogVideo tool. The output is the url of video and video cover image. +parameters: + - name: id + type: string + label: + en_US: id + human_description: + en_US: The id returned by the CogVideo. + zh_Hans: CogVideo 工具返回的 id。 + llm_description: The id returned by the cogvideo. + form: llm diff --git a/api/core/tools/provider/builtin/cogview/tools/cogview3.py b/api/core/tools/provider/builtin/cogview/tools/cogview3.py index 89ffcf3347878a..12b4173fa40270 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogview3.py +++ b/api/core/tools/provider/builtin/cogview/tools/cogview3.py @@ -1,69 +1,93 @@ import random from typing import Any, Union -from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI +from zhipuai import ZhipuAI + from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool class CogView3Tool(BuiltinTool): - """ CogView3 Tool """ + """CogView3 Tool""" - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke CogView3 tool """ client = ZhipuAI( - base_url=self.runtime.credentials['zhipuai_base_url'], - api_key=self.runtime.credentials['zhipuai_api_key'], + base_url=self.runtime.credentials["zhipuai_base_url"], + api_key=self.runtime.credentials["zhipuai_api_key"], ) size_mapping = { - 'square': '1024x1024', - 'vertical': '1024x1792', - 'horizontal': '1792x1024', + "square": "1024x1024", + "vertical_768": "768x1344", + "vertical_864": "864x1152", + "horizontal_1344": "1344x768", + "horizontal_1152": "1152x864", + "widescreen_1440": "1440x720", + "tallscreen_720": "720x1440", } # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') - # get size - size = size_mapping[tool_parameters.get('size', 'square')] + return self.create_text_message("Please input prompt") + # get size key + size_key = tool_parameters.get("size", "square") + # cogview-3-plus get size + if size_key != "cogview_3": + size = size_mapping[size_key] # get n - n = tool_parameters.get('n', 1) + n = tool_parameters.get("n", 1) # get quality - quality = tool_parameters.get('quality', 'standard') - if quality not in ['standard', 'hd']: - return self.create_text_message('Invalid quality') + quality = tool_parameters.get("quality", "standard") + if quality not in {"standard", "hd"}: + return self.create_text_message("Invalid quality") # get style - style = tool_parameters.get('style', 'vivid') - if style not in ['natural', 'vivid']: - return self.create_text_message('Invalid style') + style = tool_parameters.get("style", "vivid") + if style not in {"natural", "vivid"}: + return self.create_text_message("Invalid style") # set extra body - seed_id = tool_parameters.get('seed_id', self._generate_random_id(8)) - extra_body = {'seed': seed_id} - response = client.images.generations( - prompt=prompt, - model="cogview-3", - size=size, - n=n, - extra_body=extra_body, - style=style, - quality=quality, - response_format='b64_json' - ) + seed_id = tool_parameters.get("seed_id", self._generate_random_id(8)) + extra_body = {"seed": seed_id} + # cogview-3-plus + if size_key != "cogview_3": + response = client.images.generations( + prompt=prompt, + model="cogview-3-plus", + size=size, + n=n, + extra_body=extra_body, + style=style, + quality=quality, + response_format="b64_json", + ) + # cogview-3 + else: + response = client.images.generations( + prompt=prompt, + model="cogview-3", + n=n, + extra_body=extra_body, + style=style, + quality=quality, + response_format="b64_json", + ) result = [] for image in response.data: result.append(self.create_image_message(image=image.url)) - result.append(self.create_json_message({ - "url": image.url, - })) + result.append( + self.create_json_message( + { + "url": image.url, + } + ) + ) return result @staticmethod def _generate_random_id(length=8): - characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' - random_id = ''.join(random.choices(characters, k=length)) + characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + random_id = "".join(random.choices(characters, k=length)) return random_id diff --git a/api/core/tools/provider/builtin/cogview/tools/cogview3.yaml b/api/core/tools/provider/builtin/cogview/tools/cogview3.yaml index ba0b271a1c716c..9ab5c2729bf7a9 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogview3.yaml +++ b/api/core/tools/provider/builtin/cogview/tools/cogview3.yaml @@ -25,7 +25,7 @@ parameters: pt_BR: Prompt human_description: en_US: Image prompt, you can check the official documentation of CogView 3 - zh_Hans: 图像提示词,您可以查看CogView 3 的官方文档 + zh_Hans: 图像提示词,您可以查看 CogView 3 的官方文档 pt_BR: Image prompt, you can check the official documentation of CogView 3 llm_description: Image prompt of CogView 3, you should describe the image you want to generate as a list of words as possible as detailed form: llm @@ -42,21 +42,46 @@ parameters: pt_BR: Image size form: form options: + - value: cogview_3 + label: + en_US: Square_cogview_3(1024x1024) + zh_Hans: 方_cogview_3(1024x1024) + pt_BR: Square_cogview_3(1024x1024) - value: square label: - en_US: Squre(1024x1024) + en_US: Square(1024x1024) zh_Hans: 方(1024x1024) - pt_BR: Squre(1024x1024) - - value: vertical + pt_BR: Square(1024x1024) + - value: vertical_768 + label: + en_US: Vertical(768x1344) + zh_Hans: 竖屏(768x1344) + pt_BR: Vertical(768x1344) + - value: vertical_864 + label: + en_US: Vertical(864x1152) + zh_Hans: 竖屏(864x1152) + pt_BR: Vertical(864x1152) + - value: horizontal_1344 + label: + en_US: Horizontal(1344x768) + zh_Hans: 横屏(1344x768) + pt_BR: Horizontal(1344x768) + - value: horizontal_1152 + label: + en_US: Horizontal(1152x864) + zh_Hans: 横屏(1152x864) + pt_BR: Horizontal(1152x864) + - value: widescreen_1440 label: - en_US: Vertical(1024x1792) - zh_Hans: 竖屏(1024x1792) - pt_BR: Vertical(1024x1792) - - value: horizontal + en_US: Widescreen(1440x720) + zh_Hans: 宽屏(1440x720) + pt_BR: Widescreen(1440x720) + - value: tallscreen_720 label: - en_US: Horizontal(1792x1024) - zh_Hans: 横屏(1792x1024) - pt_BR: Horizontal(1792x1024) + en_US: Tallscreen(720x1440) + zh_Hans: 高屏(720x1440) + pt_BR: Tallscreen(720x1440) default: square - name: n type: number diff --git a/api/core/tools/provider/builtin/comfyui/_assets/icon.png b/api/core/tools/provider/builtin/comfyui/_assets/icon.png new file mode 100644 index 00000000000000..958ec5c5cfe296 Binary files /dev/null and b/api/core/tools/provider/builtin/comfyui/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/comfyui/comfyui.py b/api/core/tools/provider/builtin/comfyui/comfyui.py new file mode 100644 index 00000000000000..bab690af8292b7 --- /dev/null +++ b/api/core/tools/provider/builtin/comfyui/comfyui.py @@ -0,0 +1,21 @@ +from typing import Any + +import websocket +from yarl import URL + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class ComfyUIProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + ws = websocket.WebSocket() + base_url = URL(credentials.get("base_url")) + ws_address = f"ws://{base_url.authority}/ws?clientId=test123" + + try: + ws.connect(ws_address) + except Exception as e: + raise ToolProviderCredentialValidationError(f"can not connect to {ws_address}") + finally: + ws.close() diff --git a/api/core/tools/provider/builtin/comfyui/comfyui.yaml b/api/core/tools/provider/builtin/comfyui/comfyui.yaml new file mode 100644 index 00000000000000..24ae43cd44051e --- /dev/null +++ b/api/core/tools/provider/builtin/comfyui/comfyui.yaml @@ -0,0 +1,23 @@ +identity: + author: Qun + name: comfyui + label: + en_US: ComfyUI + zh_Hans: ComfyUI + description: + en_US: ComfyUI is a tool for generating images which can be deployed locally. + zh_Hans: ComfyUI 是一个可以在本地部署的图片生成的工具。 + icon: icon.png + tags: + - image +credentials_for_provider: + base_url: + type: text-input + required: true + label: + en_US: The URL of ComfyUI Server + zh_Hans: ComfyUI服务器的URL + placeholder: + en_US: Please input your ComfyUI server's Base URL + zh_Hans: 请输入你的 ComfyUI 服务器的 Base URL + url: https://docs.dify.ai/guides/tools/tool-configuration/comfyui diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py new file mode 100644 index 00000000000000..bed9cd1882fa29 --- /dev/null +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py @@ -0,0 +1,128 @@ +import json +import random +import uuid + +import httpx +from websocket import WebSocket +from yarl import URL + +from core.file.file_manager import download +from core.file.models import File + + +class ComfyUiClient: + def __init__(self, base_url: str): + self.base_url = URL(base_url) + + def get_history(self, prompt_id: str) -> dict: + res = httpx.get(str(self.base_url / "history"), params={"prompt_id": prompt_id}) + history = res.json()[prompt_id] + return history + + def get_image(self, filename: str, subfolder: str, folder_type: str) -> bytes: + response = httpx.get( + str(self.base_url / "view"), + params={"filename": filename, "subfolder": subfolder, "type": folder_type}, + ) + return response.content + + def upload_image(self, image_file: File) -> dict: + file = download(image_file) + files = {"image": (image_file.filename, file, image_file.mime_type), "overwrite": "true"} + res = httpx.post(str(self.base_url / "upload/image"), files=files) + return res.json() + + def queue_prompt(self, client_id: str, prompt: dict) -> str: + res = httpx.post(str(self.base_url / "prompt"), json={"client_id": client_id, "prompt": prompt}) + prompt_id = res.json()["prompt_id"] + return prompt_id + + def open_websocket_connection(self) -> tuple[WebSocket, str]: + client_id = str(uuid.uuid4()) + ws = WebSocket() + ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}" + ws.connect(ws_address) + return ws, client_id + + def set_prompt_by_ksampler(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "") -> dict: + prompt = origin_prompt.copy() + id_to_class_type = {id: details["class_type"] for id, details in prompt.items()} + k_sampler = [key for key, value in id_to_class_type.items() if value == "KSampler"][0] + positive_input_id = prompt.get(k_sampler)["inputs"]["positive"][0] + prompt.get(positive_input_id)["inputs"]["text"] = positive_prompt + + if negative_prompt != "": + negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0] + prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt + + return prompt + + def set_prompt_images_by_ids(self, origin_prompt: dict, image_names: list[str], image_ids: list[str]) -> dict: + prompt = origin_prompt.copy() + for index, image_node_id in enumerate(image_ids): + prompt[image_node_id]["inputs"]["image"] = image_names[index] + return prompt + + def set_prompt_images_by_default(self, origin_prompt: dict, image_names: list[str]) -> dict: + prompt = origin_prompt.copy() + id_to_class_type = {id: details["class_type"] for id, details in prompt.items()} + load_image_nodes = [key for key, value in id_to_class_type.items() if value == "LoadImage"] + for load_image, image_name in zip(load_image_nodes, image_names): + prompt.get(load_image)["inputs"]["image"] = image_name + return prompt + + def set_prompt_seed_by_id(self, origin_prompt: dict, seed_id: str) -> dict: + prompt = origin_prompt.copy() + if seed_id not in prompt: + raise Exception("Not a valid seed node") + if "seed" in prompt[seed_id]["inputs"]: + prompt[seed_id]["inputs"]["seed"] = random.randint(10**14, 10**15 - 1) + elif "noise_seed" in prompt[seed_id]["inputs"]: + prompt[seed_id]["inputs"]["noise_seed"] = random.randint(10**14, 10**15 - 1) + else: + raise Exception("Not a valid seed node") + return prompt + + def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str): + node_ids = list(prompt.keys()) + finished_nodes = [] + + while True: + out = ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message["type"] == "progress": + data = message["data"] + current_step = data["value"] + print("In K-Sampler -> Step: ", current_step, " of: ", data["max"]) + if message["type"] == "execution_cached": + data = message["data"] + for itm in data["nodes"]: + if itm not in finished_nodes: + finished_nodes.append(itm) + print("Progress: ", len(finished_nodes), "/", len(node_ids), " Tasks done") + if message["type"] == "executing": + data = message["data"] + if data["node"] not in finished_nodes: + finished_nodes.append(data["node"]) + print("Progress: ", len(finished_nodes), "/", len(node_ids), " Tasks done") + + if data["node"] is None and data["prompt_id"] == prompt_id: + break # Execution is done + else: + continue + + def generate_image_by_prompt(self, prompt: dict) -> list[bytes]: + try: + ws, client_id = self.open_websocket_connection() + prompt_id = self.queue_prompt(client_id, prompt) + self.track_progress(prompt, ws, prompt_id) + history = self.get_history(prompt_id) + images = [] + for output in history["outputs"].values(): + for img in output.get("images", []): + image_data = self.get_image(img["filename"], img["subfolder"], img["type"]) + images.append(image_data) + return images + finally: + ws.close() diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.py new file mode 100644 index 00000000000000..eaa4b0d0275568 --- /dev/null +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.py @@ -0,0 +1,475 @@ +import json +import os +import random +import uuid +from copy import deepcopy +from enum import Enum +from typing import Any, Union + +import websocket +from httpx import get, post +from yarl import URL + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.tool.builtin_tool import BuiltinTool + +SD_TXT2IMG_OPTIONS = {} +LORA_NODE = { + "inputs": {"lora_name": "", "strength_model": 1, "strength_clip": 1, "model": ["11", 0], "clip": ["11", 1]}, + "class_type": "LoraLoader", + "_meta": {"title": "Load LoRA"}, +} +FluxGuidanceNode = { + "inputs": {"guidance": 3.5, "conditioning": ["6", 0]}, + "class_type": "FluxGuidance", + "_meta": {"title": "FluxGuidance"}, +} + + +class ModelType(Enum): + SD15 = 1 + SDXL = 2 + SD3 = 3 + FLUX = 4 + + +class ComfyuiStableDiffusionTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + # base url + base_url = self.runtime.credentials.get("base_url", "") + if not base_url: + return self.create_text_message("Please input base_url") + + if tool_parameters.get("model"): + self.runtime.credentials["model"] = tool_parameters["model"] + + model = self.runtime.credentials.get("model", None) + if not model: + return self.create_text_message("Please input model") + + # prompt + prompt = tool_parameters.get("prompt", "") + if not prompt: + return self.create_text_message("Please input prompt") + + # get negative prompt + negative_prompt = tool_parameters.get("negative_prompt", "") + + # get size + width = tool_parameters.get("width", 1024) + height = tool_parameters.get("height", 1024) + + # get steps + steps = tool_parameters.get("steps", 1) + + # get sampler_name + sampler_name = tool_parameters.get("sampler_name", "euler") + + # scheduler + scheduler = tool_parameters.get("scheduler", "normal") + + # get cfg + cfg = tool_parameters.get("cfg", 7.0) + + # get model type + model_type = tool_parameters.get("model_type", ModelType.SD15.name) + + # get lora + # supports up to 3 loras + lora_list = [] + lora_strength_list = [] + if tool_parameters.get("lora_1"): + lora_list.append(tool_parameters["lora_1"]) + lora_strength_list.append(tool_parameters.get("lora_strength_1", 1)) + if tool_parameters.get("lora_2"): + lora_list.append(tool_parameters["lora_2"]) + lora_strength_list.append(tool_parameters.get("lora_strength_2", 1)) + if tool_parameters.get("lora_3"): + lora_list.append(tool_parameters["lora_3"]) + lora_strength_list.append(tool_parameters.get("lora_strength_3", 1)) + + return self.text2img( + base_url=base_url, + model=model, + model_type=model_type, + prompt=prompt, + negative_prompt=negative_prompt, + width=width, + height=height, + steps=steps, + sampler_name=sampler_name, + scheduler=scheduler, + cfg=cfg, + lora_list=lora_list, + lora_strength_list=lora_strength_list, + ) + + def get_checkpoints(self) -> list[str]: + """ + get checkpoints + """ + try: + base_url = self.runtime.credentials.get("base_url", None) + if not base_url: + return [] + api_url = str(URL(base_url) / "models" / "checkpoints") + response = get(url=api_url, timeout=(2, 10)) + if response.status_code != 200: + return [] + else: + return response.json() + except Exception as e: + return [] + + def get_loras(self) -> list[str]: + """ + get loras + """ + try: + base_url = self.runtime.credentials.get("base_url", None) + if not base_url: + return [] + api_url = str(URL(base_url) / "models" / "loras") + response = get(url=api_url, timeout=(2, 10)) + if response.status_code != 200: + return [] + else: + return response.json() + except Exception as e: + return [] + + def get_sample_methods(self) -> tuple[list[str], list[str]]: + """ + get sample method + """ + try: + base_url = self.runtime.credentials.get("base_url", None) + if not base_url: + return [], [] + api_url = str(URL(base_url) / "object_info" / "KSampler") + response = get(url=api_url, timeout=(2, 10)) + if response.status_code != 200: + return [], [] + else: + data = response.json()["KSampler"]["input"]["required"] + return data["sampler_name"][0], data["scheduler"][0] + except Exception as e: + return [], [] + + def validate_models(self) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + validate models + """ + try: + base_url = self.runtime.credentials.get("base_url", None) + if not base_url: + raise ToolProviderCredentialValidationError("Please input base_url") + model = self.runtime.credentials.get("model", None) + if not model: + raise ToolProviderCredentialValidationError("Please input model") + + api_url = str(URL(base_url) / "models" / "checkpoints") + response = get(url=api_url, timeout=(2, 10)) + if response.status_code != 200: + raise ToolProviderCredentialValidationError("Failed to get models") + else: + models = response.json() + if len([d for d in models if d == model]) > 0: + return self.create_text_message(json.dumps(models)) + else: + raise ToolProviderCredentialValidationError(f"model {model} does not exist") + except Exception as e: + raise ToolProviderCredentialValidationError(f"Failed to get models, {e}") + + def get_history(self, base_url, prompt_id): + """ + get history + """ + url = str(URL(base_url) / "history") + respond = get(url, params={"prompt_id": prompt_id}, timeout=(2, 10)) + return respond.json() + + def download_image(self, base_url, filename, subfolder, folder_type): + """ + download image + """ + url = str(URL(base_url) / "view") + response = get(url, params={"filename": filename, "subfolder": subfolder, "type": folder_type}, timeout=(2, 10)) + return response.content + + def queue_prompt_image(self, base_url, client_id, prompt): + """ + send prompt task and rotate + """ + # initiate task execution + url = str(URL(base_url) / "prompt") + respond = post(url, data=json.dumps({"client_id": client_id, "prompt": prompt}), timeout=(2, 10)) + prompt_id = respond.json()["prompt_id"] + + ws = websocket.WebSocket() + if "https" in base_url: + ws_url = base_url.replace("https", "ws") + else: + ws_url = base_url.replace("http", "ws") + ws.connect(str(URL(f"{ws_url}") / "ws") + f"?clientId={client_id}", timeout=120) + + # websocket rotate execution status + output_images = {} + while True: + out = ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message["type"] == "executing": + data = message["data"] + if data["node"] is None and data["prompt_id"] == prompt_id: + break # Execution is done + elif message["type"] == "status": + data = message["data"] + if data["status"]["exec_info"]["queue_remaining"] == 0 and data.get("sid"): + break # Execution is done + else: + continue # previews are binary data + + # download image when execution finished + history = self.get_history(base_url, prompt_id)[prompt_id] + for o in history["outputs"]: + for node_id in history["outputs"]: + node_output = history["outputs"][node_id] + if "images" in node_output: + images_output = [] + for image in node_output["images"]: + image_data = self.download_image(base_url, image["filename"], image["subfolder"], image["type"]) + images_output.append(image_data) + output_images[node_id] = images_output + + ws.close() + + return output_images + + def text2img( + self, + base_url: str, + model: str, + model_type: str, + prompt: str, + negative_prompt: str, + width: int, + height: int, + steps: int, + sampler_name: str, + scheduler: str, + cfg: float, + lora_list: list, + lora_strength_list: list, + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + generate image + """ + if not SD_TXT2IMG_OPTIONS: + current_dir = os.path.dirname(os.path.realpath(__file__)) + with open(os.path.join(current_dir, "txt2img.json")) as file: + SD_TXT2IMG_OPTIONS.update(json.load(file)) + + draw_options = deepcopy(SD_TXT2IMG_OPTIONS) + draw_options["3"]["inputs"]["steps"] = steps + draw_options["3"]["inputs"]["sampler_name"] = sampler_name + draw_options["3"]["inputs"]["scheduler"] = scheduler + draw_options["3"]["inputs"]["cfg"] = cfg + # generate different image when using same prompt next time + draw_options["3"]["inputs"]["seed"] = random.randint(0, 100000000) + draw_options["4"]["inputs"]["ckpt_name"] = model + draw_options["5"]["inputs"]["width"] = width + draw_options["5"]["inputs"]["height"] = height + draw_options["6"]["inputs"]["text"] = prompt + draw_options["7"]["inputs"]["text"] = negative_prompt + # if the model is SD3 or FLUX series, the Latent class should be corresponding to SD3 Latent + if model_type in {ModelType.SD3.name, ModelType.FLUX.name}: + draw_options["5"]["class_type"] = "EmptySD3LatentImage" + + if lora_list: + # last Lora node link to KSampler node + draw_options["3"]["inputs"]["model"][0] = "10" + # last Lora node link to positive and negative Clip node + draw_options["6"]["inputs"]["clip"][0] = "10" + draw_options["7"]["inputs"]["clip"][0] = "10" + # every Lora node link to next Lora node, and Checkpoints node link to first Lora node + for i, (lora, strength) in enumerate(zip(lora_list, lora_strength_list), 10): + if i - 10 == len(lora_list) - 1: + next_node_id = "4" + else: + next_node_id = str(i + 1) + lora_node = deepcopy(LORA_NODE) + lora_node["inputs"]["lora_name"] = lora + lora_node["inputs"]["strength_model"] = strength + lora_node["inputs"]["strength_clip"] = strength + lora_node["inputs"]["model"][0] = next_node_id + lora_node["inputs"]["clip"][0] = next_node_id + draw_options[str(i)] = lora_node + + # FLUX need to add FluxGuidance Node + if model_type == ModelType.FLUX.name: + last_node_id = str(10 + len(lora_list)) + draw_options[last_node_id] = deepcopy(FluxGuidanceNode) + draw_options[last_node_id]["inputs"]["conditioning"][0] = "6" + draw_options["3"]["inputs"]["positive"][0] = last_node_id + + try: + client_id = str(uuid.uuid4()) + result = self.queue_prompt_image(base_url, client_id, prompt=draw_options) + + # get first image + image = b"" + for node in result: + for img in result[node]: + if img: + image = img + break + + return self.create_blob_message( + blob=image, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value + ) + + except Exception as e: + return self.create_text_message(f"Failed to generate image: {str(e)}") + + def get_runtime_parameters(self) -> list[ToolParameter]: + parameters = [ + ToolParameter( + name="prompt", + label=I18nObject(en_US="Prompt", zh_Hans="Prompt"), + human_description=I18nObject( + en_US="Image prompt, you can check the official documentation of Stable Diffusion", + zh_Hans="图像提示词,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Image prompt of Stable Diffusion, you should describe the image " + "you want to generate as a list of words as possible as detailed, " + "the prompt must be written in English.", + required=True, + ), + ] + if self.runtime.credentials: + try: + models = self.get_checkpoints() + if len(models) != 0: + parameters.append( + ToolParameter( + name="model", + label=I18nObject(en_US="Model", zh_Hans="Model"), + human_description=I18nObject( + en_US="Model of Stable Diffusion or FLUX, " + "you can check the official documentation of Stable Diffusion or FLUX", + zh_Hans="Stable Diffusion 或者 FLUX 的模型,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Model of Stable Diffusion or FLUX, " + "you can check the official documentation of Stable Diffusion or FLUX", + required=True, + default=models[0], + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in models + ], + ) + ) + loras = self.get_loras() + if len(loras) != 0: + for n in range(1, 4): + parameters.append( + ToolParameter( + name=f"lora_{n}", + label=I18nObject(en_US=f"Lora {n}", zh_Hans=f"Lora {n}"), + human_description=I18nObject( + en_US="Lora of Stable Diffusion, " + "you can check the official documentation of Stable Diffusion", + zh_Hans="Stable Diffusion 的 Lora 模型,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Lora of Stable Diffusion, " + "you can check the official documentation of " + "Stable Diffusion", + required=False, + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in loras + ], + ) + ) + sample_methods, schedulers = self.get_sample_methods() + if len(sample_methods) != 0: + parameters.append( + ToolParameter( + name="sampler_name", + label=I18nObject(en_US="Sampling method", zh_Hans="Sampling method"), + human_description=I18nObject( + en_US="Sampling method of Stable Diffusion, " + "you can check the official documentation of Stable Diffusion", + zh_Hans="Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Sampling method of Stable Diffusion, " + "you can check the official documentation of Stable Diffusion", + required=True, + default=sample_methods[0], + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) + for i in sample_methods + ], + ) + ) + if len(schedulers) != 0: + parameters.append( + ToolParameter( + name="scheduler", + label=I18nObject(en_US="Scheduler", zh_Hans="Scheduler"), + human_description=I18nObject( + en_US="Scheduler of Stable Diffusion, " + "you can check the official documentation of Stable Diffusion", + zh_Hans="Stable Diffusion 的Scheduler,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Scheduler of Stable Diffusion, " + "you can check the official documentation of Stable Diffusion", + required=True, + default=schedulers[0], + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in schedulers + ], + ) + ) + parameters.append( + ToolParameter( + name="model_type", + label=I18nObject(en_US="Model Type", zh_Hans="Model Type"), + human_description=I18nObject( + en_US="Model Type of Stable Diffusion or Flux, " + "you can check the official documentation of Stable Diffusion or Flux", + zh_Hans="Stable Diffusion 或 FLUX 的模型类型," + "您可以查看 Stable Diffusion 或 Flux 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Model Type of Stable Diffusion or Flux, " + "you can check the official documentation of Stable Diffusion or Flux", + required=True, + default=ModelType.SD15.name, + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) + for i in ModelType.__members__ + ], + ) + ) + except: + pass + + return parameters diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.yaml b/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.yaml new file mode 100644 index 00000000000000..75fe746965196a --- /dev/null +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.yaml @@ -0,0 +1,212 @@ +identity: + name: txt2img + author: Qun + label: + en_US: Txt2Img + zh_Hans: Txt2Img + pt_BR: Txt2Img +description: + human: + en_US: a pre-defined comfyui workflow that can use one model and up to 3 loras to generate images. Support SD1.5, SDXL, SD3 and FLUX which contain text encoders/clip, but does not support models that requires a triple clip loader. + zh_Hans: 一个预定义的 ComfyUI 工作流,可以使用一个模型和最多3个loras来生成图像。支持包含文本编码器/clip的SD1.5、SDXL、SD3和FLUX,但不支持需要clip加载器的模型。 + pt_BR: a pre-defined comfyui workflow that can use one model and up to 3 loras to generate images. Support SD1.5, SDXL, SD3 and FLUX which contain text encoders/clip, but does not support models that requires a triple clip loader. + llm: draw the image you want based on your prompt. +parameters: + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + pt_BR: Prompt + human_description: + en_US: Image prompt, you can check the official documentation of Stable Diffusion or FLUX + zh_Hans: 图像提示词,您可以查看 Stable Diffusion 或者 FLUX 的官方文档 + pt_BR: Image prompt, you can check the official documentation of Stable Diffusion or FLUX + llm_description: Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English. + form: llm + - name: model + type: string + required: true + label: + en_US: Model Name + zh_Hans: 模型名称 + pt_BR: Model Name + human_description: + en_US: Model Name + zh_Hans: 模型名称 + pt_BR: Model Name + form: form + - name: model_type + type: string + required: true + label: + en_US: Model Type + zh_Hans: 模型类型 + pt_BR: Model Type + human_description: + en_US: Model Type + zh_Hans: 模型类型 + pt_BR: Model Type + form: form + - name: lora_1 + type: string + required: false + label: + en_US: Lora 1 + zh_Hans: Lora 1 + pt_BR: Lora 1 + human_description: + en_US: Lora 1 + zh_Hans: Lora 1 + pt_BR: Lora 1 + form: form + - name: lora_strength_1 + type: number + required: false + label: + en_US: Lora Strength 1 + zh_Hans: Lora Strength 1 + pt_BR: Lora Strength 1 + human_description: + en_US: Lora Strength 1 + zh_Hans: Lora模型的权重 + pt_BR: Lora Strength 1 + form: form + - name: steps + type: number + required: false + label: + en_US: Steps + zh_Hans: Steps + pt_BR: Steps + human_description: + en_US: Steps + zh_Hans: Steps + pt_BR: Steps + form: form + default: 20 + - name: width + type: number + required: false + label: + en_US: Width + zh_Hans: Width + pt_BR: Width + human_description: + en_US: Width + zh_Hans: Width + pt_BR: Width + form: form + default: 1024 + - name: height + type: number + required: false + label: + en_US: Height + zh_Hans: Height + pt_BR: Height + human_description: + en_US: Height + zh_Hans: Height + pt_BR: Height + form: form + default: 1024 + - name: negative_prompt + type: string + required: false + label: + en_US: Negative prompt + zh_Hans: Negative prompt + pt_BR: Negative prompt + human_description: + en_US: Negative prompt + zh_Hans: Negative prompt + pt_BR: Negative prompt + form: form + default: bad art, ugly, deformed, watermark, duplicated, discontinuous lines + - name: cfg + type: number + required: false + label: + en_US: CFG Scale + zh_Hans: CFG Scale + pt_BR: CFG Scale + human_description: + en_US: CFG Scale + zh_Hans: 提示词相关性(CFG Scale) + pt_BR: CFG Scale + form: form + default: 7.0 + - name: sampler_name + type: string + required: false + label: + en_US: Sampling method + zh_Hans: Sampling method + pt_BR: Sampling method + human_description: + en_US: Sampling method + zh_Hans: Sampling method + pt_BR: Sampling method + form: form + - name: scheduler + type: string + required: false + label: + en_US: Scheduler + zh_Hans: Scheduler + pt_BR: Scheduler + human_description: + en_US: Scheduler + zh_Hans: Scheduler + pt_BR: Scheduler + form: form + - name: lora_2 + type: string + required: false + label: + en_US: Lora 2 + zh_Hans: Lora 2 + pt_BR: Lora 2 + human_description: + en_US: Lora 2 + zh_Hans: Lora 2 + pt_BR: Lora 2 + form: form + - name: lora_strength_2 + type: number + required: false + label: + en_US: Lora Strength 2 + zh_Hans: Lora Strength 2 + pt_BR: Lora Strength 2 + human_description: + en_US: Lora Strength 2 + zh_Hans: Lora模型的权重 + pt_BR: Lora Strength 2 + form: form + - name: lora_3 + type: string + required: false + label: + en_US: Lora 3 + zh_Hans: Lora 3 + pt_BR: Lora 3 + human_description: + en_US: Lora 3 + zh_Hans: Lora 3 + pt_BR: Lora 3 + form: form + - name: lora_strength_3 + type: number + required: false + label: + en_US: Lora Strength 3 + zh_Hans: Lora Strength 3 + pt_BR: Lora Strength 3 + human_description: + en_US: Lora Strength 3 + zh_Hans: Lora模型的权重 + pt_BR: Lora Strength 3 + form: form diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py new file mode 100644 index 00000000000000..87837362779baa --- /dev/null +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py @@ -0,0 +1,84 @@ +import json +from typing import Any + +from core.file import FileType +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.errors import ToolParameterValidationError +from core.tools.provider.builtin.comfyui.tools.comfyui_client import ComfyUiClient +from core.tools.tool.builtin_tool import BuiltinTool + + +def sanitize_json_string(s): + escape_dict = { + "\n": "\\n", + "\r": "\\r", + "\t": "\\t", + "\b": "\\b", + "\f": "\\f", + } + for char, escaped in escape_dict.items(): + s = s.replace(char, escaped) + + return s + + +class ComfyUIWorkflowTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + comfyui = ComfyUiClient(self.runtime.credentials["base_url"]) + + positive_prompt = tool_parameters.get("positive_prompt", "") + negative_prompt = tool_parameters.get("negative_prompt", "") + images = tool_parameters.get("images") or [] + workflow = tool_parameters.get("workflow_json") + image_names = [] + for image in images: + if image.type != FileType.IMAGE: + continue + image_name = comfyui.upload_image(image).get("name") + image_names.append(image_name) + + set_prompt_with_ksampler = True + if "{{positive_prompt}}" in workflow: + set_prompt_with_ksampler = False + workflow = workflow.replace("{{positive_prompt}}", positive_prompt.replace('"', "'")) + workflow = workflow.replace("{{negative_prompt}}", negative_prompt.replace('"', "'")) + + try: + prompt = json.loads(workflow) + except json.JSONDecodeError: + cleaned_string = sanitize_json_string(workflow) + try: + prompt = json.loads(cleaned_string) + except: + return self.create_text_message("the Workflow JSON is not correct") + + if set_prompt_with_ksampler: + try: + prompt = comfyui.set_prompt_by_ksampler(prompt, positive_prompt, negative_prompt) + except: + raise ToolParameterValidationError( + "Failed set prompt with KSampler, try replace prompt to {{positive_prompt}} in your workflow json" + ) + + if image_names: + if image_ids := tool_parameters.get("image_ids"): + image_ids = image_ids.split(",") + try: + prompt = comfyui.set_prompt_images_by_ids(prompt, image_names, image_ids) + except: + raise ToolParameterValidationError("the Image Node ID List not match your upload image files.") + else: + prompt = comfyui.set_prompt_images_by_default(prompt, image_names) + + if seed_id := tool_parameters.get("seed_id"): + prompt = comfyui.set_prompt_seed_by_id(prompt, seed_id) + + images = comfyui.generate_image_by_prompt(prompt) + result = [] + for img in images: + result.append( + self.create_blob_message( + blob=img, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value + ) + ) + return result diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml new file mode 100644 index 00000000000000..9428acbe943642 --- /dev/null +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml @@ -0,0 +1,63 @@ +identity: + name: workflow + author: hjlarry + label: + en_US: workflow + zh_Hans: 工作流 +description: + human: + en_US: Run ComfyUI workflow. + zh_Hans: 运行ComfyUI工作流。 + llm: Run ComfyUI workflow. +parameters: + - name: positive_prompt + type: string + label: + en_US: Prompt + zh_Hans: 提示词 + llm_description: Image prompt, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English. + form: llm + - name: negative_prompt + type: string + label: + en_US: Negative Prompt + zh_Hans: 负面提示词 + llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English. + form: llm + - name: images + type: files + label: + en_US: Input Images + zh_Hans: 输入的图片 + llm_description: The input images, used to transfer to the comfyui workflow to generate another image. + form: llm + - name: workflow_json + type: string + required: true + label: + en_US: Workflow JSON + human_description: + en_US: exported from ComfyUI workflow + zh_Hans: 从ComfyUI的工作流中导出 + form: form + - name: image_ids + type: string + label: + en_US: Image Node ID List + zh_Hans: 图片节点ID列表 + placeholder: + en_US: Use commas to separate multiple node ID + zh_Hans: 多个节点ID时使用半角逗号分隔 + human_description: + en_US: When the workflow has multiple image nodes, enter the ID list of these nodes, and the images will be passed to ComfyUI in the order of the list. + zh_Hans: 当工作流有多个图片节点时,输入这些节点的ID列表,图片将按列表顺序传给ComfyUI + form: form + - name: seed_id + type: string + label: + en_US: Seed Node Id + zh_Hans: 种子节点ID + human_description: + en_US: If you need to generate different images each time, you need to enter the ID of the seed node. + zh_Hans: 如果需要每次生成时使用不同的种子,需要输入包含种子的节点的ID + form: form diff --git a/api/core/tools/provider/builtin/comfyui/tools/txt2img.json b/api/core/tools/provider/builtin/comfyui/tools/txt2img.json new file mode 100644 index 00000000000000..8ea869ff106c38 --- /dev/null +++ b/api/core/tools/provider/builtin/comfyui/tools/txt2img.json @@ -0,0 +1,107 @@ +{ + "3": { + "inputs": { + "seed": 156680208700286, + "steps": 20, + "cfg": 8, + "sampler_name": "euler", + "scheduler": "normal", + "denoise": 1, + "model": [ + "4", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "7", + 0 + ], + "latent_image": [ + "5", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "4": { + "inputs": { + "ckpt_name": "3dAnimationDiffusion_v10.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "5": { + "inputs": { + "width": 512, + "height": 512, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + }, + "6": { + "inputs": { + "text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "7": { + "inputs": { + "text": "text, watermark", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "8": { + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "4", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "9": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "8", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + } +} \ No newline at end of file diff --git a/api/core/tools/provider/builtin/crossref/_assets/icon.svg b/api/core/tools/provider/builtin/crossref/_assets/icon.svg new file mode 100644 index 00000000000000..aa629de7cb1660 --- /dev/null +++ b/api/core/tools/provider/builtin/crossref/_assets/icon.svg @@ -0,0 +1,49 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/api/core/tools/provider/builtin/crossref/crossref.py b/api/core/tools/provider/builtin/crossref/crossref.py new file mode 100644 index 00000000000000..8ba3c1b48ae6d7 --- /dev/null +++ b/api/core/tools/provider/builtin/crossref/crossref.py @@ -0,0 +1,20 @@ +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.crossref.tools.query_doi import CrossRefQueryDOITool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class CrossRefProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + try: + CrossRefQueryDOITool().fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ).invoke( + user_id="", + tool_parameters={ + "doi": "10.1007/s00894-022-05373-8", + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/crossref/crossref.yaml b/api/core/tools/provider/builtin/crossref/crossref.yaml new file mode 100644 index 00000000000000..da67fbec3a480b --- /dev/null +++ b/api/core/tools/provider/builtin/crossref/crossref.yaml @@ -0,0 +1,29 @@ +identity: + author: Sakura4036 + name: crossref + label: + en_US: CrossRef + zh_Hans: CrossRef + description: + en_US: Crossref is a cross-publisher reference linking registration query system using DOI technology created in 2000. Crossref establishes cross-database links between the reference list and citation full text of papers, making it very convenient for readers to access the full text of papers. + zh_Hans: Crossref是于2000年创建的使用DOI技术的跨出版商参考文献链接注册查询系统。Crossref建立了在论文的参考文献列表和引文全文之间的跨数据库链接,使得读者能够非常便捷地获取文献全文。 + icon: icon.svg + tags: + - search +credentials_for_provider: + mailto: + type: text-input + required: true + label: + en_US: email address + zh_Hans: email地址 + pt_BR: email address + placeholder: + en_US: Please input your email address + zh_Hans: 请输入你的email地址 + pt_BR: Please input your email address + help: + en_US: According to the requirements of Crossref, an email address is required + zh_Hans: 根据Crossref的要求,需要提供一个邮箱地址 + pt_BR: According to the requirements of Crossref, an email address is required + url: https://api.crossref.org/swagger-ui/index.html diff --git a/api/core/tools/provider/builtin/crossref/tools/query_doi.py b/api/core/tools/provider/builtin/crossref/tools/query_doi.py new file mode 100644 index 00000000000000..746139dd69d27b --- /dev/null +++ b/api/core/tools/provider/builtin/crossref/tools/query_doi.py @@ -0,0 +1,28 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.errors import ToolParameterValidationError +from core.tools.tool.builtin_tool import BuiltinTool + + +class CrossRefQueryDOITool(BuiltinTool): + """ + Tool for querying the metadata of a publication using its DOI. + """ + + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + doi = tool_parameters.get("doi") + if not doi: + raise ToolParameterValidationError("doi is required.") + # doc: https://github.com/CrossRef/rest-api-doc + url = f"https://api.crossref.org/works/{doi}" + response = requests.get(url) + response.raise_for_status() + response = response.json() + message = response.get("message", {}) + + return self.create_json_message(message) diff --git a/api/core/tools/provider/builtin/crossref/tools/query_doi.yaml b/api/core/tools/provider/builtin/crossref/tools/query_doi.yaml new file mode 100644 index 00000000000000..9c16da25edf2b3 --- /dev/null +++ b/api/core/tools/provider/builtin/crossref/tools/query_doi.yaml @@ -0,0 +1,23 @@ +identity: + name: crossref_query_doi + author: Sakura4036 + label: + en_US: CrossRef Query DOI + zh_Hans: CrossRef DOI 查询 + pt_BR: CrossRef Query DOI +description: + human: + en_US: A tool for searching literature information using CrossRef by DOI. + zh_Hans: 一个使用CrossRef通过DOI获取文献信息的工具。 + pt_BR: A tool for searching literature information using CrossRef by DOI. + llm: A tool for searching literature information using CrossRef by DOI. +parameters: + - name: doi + type: string + required: true + label: + en_US: DOI + zh_Hans: DOI + pt_BR: DOI + llm_description: DOI for searching in CrossRef + form: llm diff --git a/api/core/tools/provider/builtin/crossref/tools/query_title.py b/api/core/tools/provider/builtin/crossref/tools/query_title.py new file mode 100644 index 00000000000000..e2452381832938 --- /dev/null +++ b/api/core/tools/provider/builtin/crossref/tools/query_title.py @@ -0,0 +1,143 @@ +import time +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +def convert_time_str_to_seconds(time_str: str) -> int: + """ + Convert a time string to seconds. + example: 1s -> 1, 1m30s -> 90, 1h30m -> 5400, 1h30m30s -> 5430 + """ + time_str = time_str.lower().strip().replace(" ", "") + seconds = 0 + if "h" in time_str: + hours, time_str = time_str.split("h") + seconds += int(hours) * 3600 + if "m" in time_str: + minutes, time_str = time_str.split("m") + seconds += int(minutes) * 60 + if "s" in time_str: + seconds += int(time_str.replace("s", "")) + return seconds + + +class CrossRefQueryTitleAPI: + """ + Tool for querying the metadata of a publication using its title. + Crossref API doc: https://github.com/CrossRef/rest-api-doc + """ + + query_url_template: str = "https://api.crossref.org/works?query.bibliographic={query}&rows={rows}&offset={offset}&sort={sort}&order={order}&mailto={mailto}" + rate_limit: int = 50 + rate_interval: float = 1 + max_limit: int = 1000 + + def __init__(self, mailto: str): + self.mailto = mailto + + def _query( + self, + query: str, + rows: int = 5, + offset: int = 0, + sort: str = "relevance", + order: str = "desc", + fuzzy_query: bool = False, + ) -> list[dict]: + """ + Query the metadata of a publication using its title. + :param query: the title of the publication + :param rows: the number of results to return + :param sort: the sort field + :param order: the sort order + :param fuzzy_query: whether to return all items that match the query + """ + url = self.query_url_template.format( + query=query, rows=rows, offset=offset, sort=sort, order=order, mailto=self.mailto + ) + response = requests.get(url) + response.raise_for_status() + rate_limit = int(response.headers["x-ratelimit-limit"]) + # convert time string to seconds + rate_interval = convert_time_str_to_seconds(response.headers["x-ratelimit-interval"]) + + self.rate_limit = rate_limit + self.rate_interval = rate_interval + + response = response.json() + if response["status"] != "ok": + return [] + + message = response["message"] + if fuzzy_query: + # fuzzy query return all items + return message["items"] + else: + for paper in message["items"]: + title = paper["title"][0] + if title.lower() != query.lower(): + continue + return [paper] + return [] + + def query( + self, query: str, rows: int = 5, sort: str = "relevance", order: str = "desc", fuzzy_query: bool = False + ) -> list[dict]: + """ + Query the metadata of a publication using its title. + :param query: the title of the publication + :param rows: the number of results to return + :param sort: the sort field + :param order: the sort order + :param fuzzy_query: whether to return all items that match the query + """ + rows = min(rows, self.max_limit) + if rows > self.rate_limit: + # query multiple times + query_times = rows // self.rate_limit + 1 + results = [] + + for i in range(query_times): + result = self._query( + query, + rows=self.rate_limit, + offset=i * self.rate_limit, + sort=sort, + order=order, + fuzzy_query=fuzzy_query, + ) + if fuzzy_query: + results.extend(result) + else: + # fuzzy_query=False, only one result + if result: + return result + time.sleep(self.rate_interval) + return results + else: + # query once + return self._query(query, rows, sort=sort, order=order, fuzzy_query=fuzzy_query) + + +class CrossRefQueryTitleTool(BuiltinTool): + """ + Tool for querying the metadata of a publication using its title. + """ + + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + query = tool_parameters.get("query") + fuzzy_query = tool_parameters.get("fuzzy_query", False) + rows = tool_parameters.get("rows", 3) + sort = tool_parameters.get("sort", "relevance") + order = tool_parameters.get("order", "desc") + mailto = self.runtime.credentials["mailto"] + + result = CrossRefQueryTitleAPI(mailto).query(query, rows, sort, order, fuzzy_query) + + return [self.create_json_message(r) for r in result] diff --git a/api/core/tools/provider/builtin/crossref/tools/query_title.yaml b/api/core/tools/provider/builtin/crossref/tools/query_title.yaml new file mode 100644 index 00000000000000..5579c77f5293d3 --- /dev/null +++ b/api/core/tools/provider/builtin/crossref/tools/query_title.yaml @@ -0,0 +1,105 @@ +identity: + name: crossref_query_title + author: Sakura4036 + label: + en_US: CrossRef Title Query + zh_Hans: CrossRef 标题查询 + pt_BR: CrossRef Title Query +description: + human: + en_US: A tool for querying literature information using CrossRef by title. + zh_Hans: 一个使用CrossRef通过标题搜索文献信息的工具。 + pt_BR: A tool for querying literature information using CrossRef by title. + llm: A tool for querying literature information using CrossRef by title. +parameters: + - name: query + type: string + required: true + label: + en_US: 标题 + zh_Hans: 查询语句 + pt_BR: 标题 + human_description: + en_US: Query bibliographic information, useful for citation look up. Includes titles, authors, ISSNs and publication years + zh_Hans: 用于搜索文献信息,有助于查找引用。包括标题,作者,ISSN和出版年份 + pt_BR: Query bibliographic information, useful for citation look up. Includes titles, authors, ISSNs and publication years + llm_description: key words for querying in Web of Science + form: llm + - name: fuzzy_query + type: boolean + default: false + label: + en_US: Whether to fuzzy search + zh_Hans: 是否模糊搜索 + pt_BR: Whether to fuzzy search + human_description: + en_US: used for selecting the query type, fuzzy query returns more results, precise query returns 1 or none + zh_Hans: 用于选择搜索类型,模糊搜索返回更多结果,精确搜索返回1条结果或无 + pt_BR: used for selecting the query type, fuzzy query returns more results, precise query returns 1 or none + form: form + - name: limit + type: number + required: false + label: + en_US: max query number + zh_Hans: 最大搜索数 + pt_BR: max query number + human_description: + en_US: max query number(fuzzy search returns the maximum number of results or precise search the maximum number of matches) + zh_Hans: 最大搜索数(模糊搜索返回的最大结果数或精确搜索最大匹配数) + pt_BR: max query number(fuzzy search returns the maximum number of results or precise search the maximum number of matches) + form: llm + default: 50 + - name: sort + type: select + required: true + options: + - value: relevance + label: + en_US: relevance + zh_Hans: 相关性 + pt_BR: relevance + - value: published + label: + en_US: publication date + zh_Hans: 出版日期 + pt_BR: publication date + - value: references-count + label: + en_US: references-count + zh_Hans: 引用次数 + pt_BR: references-count + default: relevance + label: + en_US: sorting field + zh_Hans: 排序字段 + pt_BR: sorting field + human_description: + en_US: Sorting of query results + zh_Hans: 检索结果的排序字段 + pt_BR: Sorting of query results + form: form + - name: order + type: select + required: true + options: + - value: desc + label: + en_US: descending + zh_Hans: 降序 + pt_BR: descending + - value: asc + label: + en_US: ascending + zh_Hans: 升序 + pt_BR: ascending + default: desc + label: + en_US: Order + zh_Hans: 排序 + pt_BR: Order + human_description: + en_US: Order of query results + zh_Hans: 检索结果的排序方式 + pt_BR: Order of query results + form: form diff --git a/api/core/tools/provider/builtin/dalle/dalle.py b/api/core/tools/provider/builtin/dalle/dalle.py index 1c8019364de9d2..5bd16e49e85e29 100644 --- a/api/core/tools/provider/builtin/dalle/dalle.py +++ b/api/core/tools/provider/builtin/dalle/dalle.py @@ -13,13 +13,8 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "prompt": "cute girl, blue eyes, white hair, anime style", - "size": "small", - "n": 1 - }, + user_id="", + tool_parameters={"prompt": "cute girl, blue eyes, white hair, anime style", "size": "small", "n": 1}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/dalle/dalle.yaml b/api/core/tools/provider/builtin/dalle/dalle.yaml index f09a9177f2225b..37cf93c28aae58 100644 --- a/api/core/tools/provider/builtin/dalle/dalle.yaml +++ b/api/core/tools/provider/builtin/dalle/dalle.yaml @@ -29,7 +29,7 @@ credentials_for_provider: en_US: Please input your OpenAI API key zh_Hans: 请输入你的 OpenAI API key pt_BR: Please input your OpenAI API key - openai_organizaion_id: + openai_organization_id: type: text-input required: false label: diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle2.py b/api/core/tools/provider/builtin/dalle/tools/dalle2.py index 450e78228135b5..fbd7397292155e 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle2.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle2.py @@ -9,59 +9,58 @@ class DallE2Tool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - openai_organization = self.runtime.credentials.get('openai_organizaion_id', None) + openai_organization = self.runtime.credentials.get("openai_organization_id", None) if not openai_organization: openai_organization = None - openai_base_url = self.runtime.credentials.get('openai_base_url', None) + openai_base_url = self.runtime.credentials.get("openai_base_url", None) if not openai_base_url: openai_base_url = None else: - openai_base_url = str(URL(openai_base_url) / 'v1') + openai_base_url = str(URL(openai_base_url) / "v1") client = OpenAI( - api_key=self.runtime.credentials['openai_api_key'], + api_key=self.runtime.credentials["openai_api_key"], base_url=openai_base_url, - organization=openai_organization + organization=openai_organization, ) SIZE_MAPPING = { - 'small': '256x256', - 'medium': '512x512', - 'large': '1024x1024', + "small": "256x256", + "medium": "512x512", + "large": "1024x1024", } # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') - + return self.create_text_message("Please input prompt") + # get size - size = SIZE_MAPPING[tool_parameters.get('size', 'large')] + size = SIZE_MAPPING[tool_parameters.get("size", "large")] # get n - n = tool_parameters.get('n', 1) + n = tool_parameters.get("n", 1) # call openapi dalle2 - response = client.images.generate( - prompt=prompt, - model='dall-e-2', - size=size, - n=n, - response_format='b64_json' - ) + response = client.images.generate(prompt=prompt, model="dall-e-2", size=size, n=n, response_format="b64_json") result = [] for image in response.data: - result.append(self.create_blob_message(blob=b64decode(image.b64_json), - meta={ 'mime_type': 'image/png' }, - save_as=self.VARIABLE_KEY.IMAGE.value)) + result.append( + self.create_blob_message( + blob=b64decode(image.b64_json), + meta={"mime_type": "image/png"}, + save_as=self.VariableKey.IMAGE.value, + ) + ) return result diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle2.yaml b/api/core/tools/provider/builtin/dalle/tools/dalle2.yaml index 90c73ecc57bb7b..e43e5df8cddd9b 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle2.yaml +++ b/api/core/tools/provider/builtin/dalle/tools/dalle2.yaml @@ -24,7 +24,7 @@ parameters: pt_BR: Prompt human_description: en_US: Image prompt, you can check the official documentation of DallE 2 - zh_Hans: 图像提示词,您可以查看DallE 2 的官方文档 + zh_Hans: 图像提示词,您可以查看 DallE 2 的官方文档 pt_BR: Image prompt, you can check the official documentation of DallE 2 llm_description: Image prompt of DallE 2, you should describe the image you want to generate as a list of words as possible as detailed form: llm diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.py b/api/core/tools/provider/builtin/dalle/tools/dalle3.py index f985deade55aba..af9aa6abb4bc3d 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.py @@ -10,69 +10,64 @@ class DallE3Tool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - openai_organization = self.runtime.credentials.get('openai_organizaion_id', None) + openai_organization = self.runtime.credentials.get("openai_organization_id", None) if not openai_organization: openai_organization = None - openai_base_url = self.runtime.credentials.get('openai_base_url', None) + openai_base_url = self.runtime.credentials.get("openai_base_url", None) if not openai_base_url: openai_base_url = None else: - openai_base_url = str(URL(openai_base_url) / 'v1') + openai_base_url = str(URL(openai_base_url) / "v1") client = OpenAI( - api_key=self.runtime.credentials['openai_api_key'], + api_key=self.runtime.credentials["openai_api_key"], base_url=openai_base_url, - organization=openai_organization + organization=openai_organization, ) SIZE_MAPPING = { - 'square': '1024x1024', - 'vertical': '1024x1792', - 'horizontal': '1792x1024', + "square": "1024x1024", + "vertical": "1024x1792", + "horizontal": "1792x1024", } # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') + return self.create_text_message("Please input prompt") # get size - size = SIZE_MAPPING[tool_parameters.get('size', 'square')] + size = SIZE_MAPPING[tool_parameters.get("size", "square")] # get n - n = tool_parameters.get('n', 1) + n = tool_parameters.get("n", 1) # get quality - quality = tool_parameters.get('quality', 'standard') - if quality not in ['standard', 'hd']: - return self.create_text_message('Invalid quality') + quality = tool_parameters.get("quality", "standard") + if quality not in {"standard", "hd"}: + return self.create_text_message("Invalid quality") # get style - style = tool_parameters.get('style', 'vivid') - if style not in ['natural', 'vivid']: - return self.create_text_message('Invalid style') + style = tool_parameters.get("style", "vivid") + if style not in {"natural", "vivid"}: + return self.create_text_message("Invalid style") # call openapi dalle3 response = client.images.generate( - prompt=prompt, - model='dall-e-3', - size=size, - n=n, - style=style, - quality=quality, - response_format='b64_json' + prompt=prompt, model="dall-e-3", size=size, n=n, style=style, quality=quality, response_format="b64_json" ) result = [] for image in response.data: mime_type, blob_image = DallE3Tool._decode_image(image.b64_json) - blob_message = self.create_blob_message(blob=blob_image, - meta={'mime_type': mime_type}, - save_as=self.VARIABLE_KEY.IMAGE.value) + blob_message = self.create_blob_message( + blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE + ) result.append(blob_message) return result @@ -86,7 +81,7 @@ def _decode_image(base64_image: str) -> tuple[str, bytes]: :return: A tuple containing the MIME type and the decoded image bytes """ if DallE3Tool._is_plain_base64(base64_image): - return 'image/png', base64.b64decode(base64_image) + return "image/png", base64.b64decode(base64_image) else: return DallE3Tool._extract_mime_and_data(base64_image) @@ -98,7 +93,7 @@ def _is_plain_base64(encoded_str: str) -> bool: :param encoded_str: Base64 encoded image string :return: True if the string is plain base64, False otherwise """ - return not encoded_str.startswith('data:image') + return not encoded_str.startswith("data:image") @staticmethod def _extract_mime_and_data(encoded_str: str) -> tuple[str, bytes]: @@ -108,13 +103,13 @@ def _extract_mime_and_data(encoded_str: str) -> tuple[str, bytes]: :param encoded_str: Base64 encoded image string with MIME type prefix :return: A tuple containing the MIME type and the decoded image bytes """ - mime_type = encoded_str.split(';')[0].split(':')[1] - image_data_base64 = encoded_str.split(',')[1] + mime_type = encoded_str.split(";")[0].split(":")[1] + image_data_base64 = encoded_str.split(",")[1] decoded_data = base64.b64decode(image_data_base64) return mime_type, decoded_data @staticmethod def _generate_random_id(length=8): - characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' - random_id = ''.join(random.choices(characters, k=length)) + characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + random_id = "".join(random.choices(characters, k=length)) return random_id diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.yaml b/api/core/tools/provider/builtin/dalle/tools/dalle3.yaml index 7ba5c56889c7e4..0cea8af761e1e5 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle3.yaml +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.yaml @@ -25,7 +25,7 @@ parameters: pt_BR: Prompt human_description: en_US: Image prompt, you can check the official documentation of DallE 3 - zh_Hans: 图像提示词,您可以查看DallE 3 的官方文档 + zh_Hans: 图像提示词,您可以查看 DallE 3 的官方文档 pt_BR: Image prompt, you can check the official documentation of DallE 3 llm_description: Image prompt of DallE 3, you should describe the image you want to generate as a list of words as possible as detailed form: llm diff --git a/api/core/tools/provider/builtin/devdocs/devdocs.py b/api/core/tools/provider/builtin/devdocs/devdocs.py index 95d7939d0d9539..446c1e548935c0 100644 --- a/api/core/tools/provider/builtin/devdocs/devdocs.py +++ b/api/core/tools/provider/builtin/devdocs/devdocs.py @@ -11,7 +11,7 @@ def _validate_credentials(self, credentials: dict) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "doc": "python~3.12", "topic": "library/code", @@ -19,4 +19,3 @@ def _validate_credentials(self, credentials: dict) -> None: ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/devdocs/tools/searchDevDocs.py b/api/core/tools/provider/builtin/devdocs/tools/searchDevDocs.py index 1a244c5db3f69a..57cf6d7a308dba 100644 --- a/api/core/tools/provider/builtin/devdocs/tools/searchDevDocs.py +++ b/api/core/tools/provider/builtin/devdocs/tools/searchDevDocs.py @@ -13,7 +13,9 @@ class SearchDevDocsInput(BaseModel): class SearchDevDocsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invokes the DevDocs search tool with the given user ID and tool parameters. @@ -22,15 +24,16 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolIn tool_parameters (dict[str, Any]): The parameters for the tool, including 'doc' and 'topic'. Returns: - ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages. + ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, + which can be a single message or a list of messages. """ - doc = tool_parameters.get('doc', '') - topic = tool_parameters.get('topic', '') + doc = tool_parameters.get("doc", "") + topic = tool_parameters.get("topic", "") if not doc: - return self.create_text_message('Please provide the documentation name.') + return self.create_text_message("Please provide the documentation name.") if not topic: - return self.create_text_message('Please provide the topic path.') + return self.create_text_message("Please provide the topic path.") url = f"https://documents.devdocs.io/{doc}/{topic}.html" response = requests.get(url) @@ -39,4 +42,6 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolIn content = response.text return self.create_text_message(self.summary(user_id=user_id, content=content)) else: - return self.create_text_message(f"Failed to retrieve the documentation. Status code: {response.status_code}") \ No newline at end of file + return self.create_text_message( + f"Failed to retrieve the documentation. Status code: {response.status_code}" + ) diff --git a/api/core/tools/provider/builtin/did/did.py b/api/core/tools/provider/builtin/did/did.py index b4bf172131448d..5af78794f625b7 100644 --- a/api/core/tools/provider/builtin/did/did.py +++ b/api/core/tools/provider/builtin/did/did.py @@ -7,15 +7,12 @@ class DIDProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: # Example validation using the D-ID talks tool - TalksTool().fork_tool_runtime( - runtime={"credentials": credentials} - ).invoke( - user_id='', + TalksTool().fork_tool_runtime(runtime={"credentials": credentials}).invoke( + user_id="", tool_parameters={ "source_url": "https://www.d-id.com/wp-content/uploads/2023/11/Hero-image-1.png", "text_input": "Hello, welcome to use D-ID tool in Dify", - } + }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/did/did_appx.py b/api/core/tools/provider/builtin/did/did_appx.py index 964e82b729319e..c68878630d67b0 100644 --- a/api/core/tools/provider/builtin/did/did_appx.py +++ b/api/core/tools/provider/builtin/did/did_appx.py @@ -12,14 +12,14 @@ class DIDApp: def __init__(self, api_key: str | None = None, base_url: str | None = None): self.api_key = api_key - self.base_url = base_url or 'https://api.d-id.com' + self.base_url = base_url or "https://api.d-id.com" if not self.api_key: - raise ValueError('API key is required') + raise ValueError("API key is required") def _prepare_headers(self, idempotency_key: str | None = None): - headers = {'Content-Type': 'application/json', 'Authorization': f'Basic {self.api_key}'} + headers = {"Content-Type": "application/json", "Authorization": f"Basic {self.api_key}"} if idempotency_key: - headers['Idempotency-Key'] = idempotency_key + headers["Idempotency-Key"] = idempotency_key return headers def _request( @@ -44,44 +44,44 @@ def _request( return None def talks(self, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs): - endpoint = f'{self.base_url}/talks' + endpoint = f"{self.base_url}/talks" headers = self._prepare_headers(idempotency_key) - data = kwargs['params'] - logger.debug(f'Send request to {endpoint=} body={data}') - response = self._request('POST', endpoint, data, headers) + data = kwargs["params"] + logger.debug(f"Send request to {endpoint=} body={data}") + response = self._request("POST", endpoint, data, headers) if response is None: - raise HTTPError('Failed to initiate D-ID talks after multiple retries') - id: str = response['id'] + raise HTTPError("Failed to initiate D-ID talks after multiple retries") + id: str = response["id"] if wait: - return self._monitor_job_status(id=id, target='talks', poll_interval=poll_interval) + return self._monitor_job_status(id=id, target="talks", poll_interval=poll_interval) return id def animations(self, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs): - endpoint = f'{self.base_url}/animations' + endpoint = f"{self.base_url}/animations" headers = self._prepare_headers(idempotency_key) - data = kwargs['params'] - logger.debug(f'Send request to {endpoint=} body={data}') - response = self._request('POST', endpoint, data, headers) + data = kwargs["params"] + logger.debug(f"Send request to {endpoint=} body={data}") + response = self._request("POST", endpoint, data, headers) if response is None: - raise HTTPError('Failed to initiate D-ID talks after multiple retries') - id: str = response['id'] + raise HTTPError("Failed to initiate D-ID talks after multiple retries") + id: str = response["id"] if wait: - return self._monitor_job_status(target='animations', id=id, poll_interval=poll_interval) + return self._monitor_job_status(target="animations", id=id, poll_interval=poll_interval) return id def check_did_status(self, target: str, id: str): - endpoint = f'{self.base_url}/{target}/{id}' + endpoint = f"{self.base_url}/{target}/{id}" headers = self._prepare_headers() - response = self._request('GET', endpoint, headers=headers) + response = self._request("GET", endpoint, headers=headers) if response is None: - raise HTTPError(f'Failed to check status for talks {id} after multiple retries') + raise HTTPError(f"Failed to check status for talks {id} after multiple retries") return response def _monitor_job_status(self, target: str, id: str, poll_interval: int): while True: status = self.check_did_status(target=target, id=id) - if status['status'] == 'done': + if status["status"] == "done": return status - elif status['status'] == 'error' or status['status'] == 'rejected': - raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error",{}).get("description")}') + elif status["status"] == "error" or status["status"] == "rejected": + raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error", {}).get("description")}') time.sleep(poll_interval) diff --git a/api/core/tools/provider/builtin/did/tools/animations.py b/api/core/tools/provider/builtin/did/tools/animations.py index e1d9de603fbb7a..bc9d17e40d2878 100644 --- a/api/core/tools/provider/builtin/did/tools/animations.py +++ b/api/core/tools/provider/builtin/did/tools/animations.py @@ -10,33 +10,33 @@ class AnimationsTool(BuiltinTool): def _invoke( self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - app = DIDApp(api_key=self.runtime.credentials['did_api_key'], base_url=self.runtime.credentials['base_url']) + app = DIDApp(api_key=self.runtime.credentials["did_api_key"], base_url=self.runtime.credentials["base_url"]) - driver_expressions_str = tool_parameters.get('driver_expressions') + driver_expressions_str = tool_parameters.get("driver_expressions") driver_expressions = json.loads(driver_expressions_str) if driver_expressions_str else None config = { - 'stitch': tool_parameters.get('stitch', True), - 'mute': tool_parameters.get('mute'), - 'result_format': tool_parameters.get('result_format') or 'mp4', + "stitch": tool_parameters.get("stitch", True), + "mute": tool_parameters.get("mute"), + "result_format": tool_parameters.get("result_format") or "mp4", } - config = {k: v for k, v in config.items() if v is not None and v != ''} + config = {k: v for k, v in config.items() if v is not None and v != ""} options = { - 'source_url': tool_parameters['source_url'], - 'driver_url': tool_parameters.get('driver_url'), - 'config': config, + "source_url": tool_parameters["source_url"], + "driver_url": tool_parameters.get("driver_url"), + "config": config, } - options = {k: v for k, v in options.items() if v is not None and v != ''} + options = {k: v for k, v in options.items() if v is not None and v != ""} - if not options.get('source_url'): - raise ValueError('Source URL is required') + if not options.get("source_url"): + raise ValueError("Source URL is required") - if config.get('logo_url'): - if not config.get('logo_x'): - raise ValueError('Logo X position is required when logo URL is provided') - if not config.get('logo_y'): - raise ValueError('Logo Y position is required when logo URL is provided') + if config.get("logo_url"): + if not config.get("logo_x"): + raise ValueError("Logo X position is required when logo URL is provided") + if not config.get("logo_y"): + raise ValueError("Logo Y position is required when logo URL is provided") animations_result = app.animations(params=options, wait=True) @@ -44,6 +44,6 @@ def _invoke( animations_result = json.dumps(animations_result, ensure_ascii=False, indent=4) if not animations_result: - return self.create_text_message('D-ID animations request failed.') + return self.create_text_message("D-ID animations request failed.") return self.create_text_message(animations_result) diff --git a/api/core/tools/provider/builtin/did/tools/talks.py b/api/core/tools/provider/builtin/did/tools/talks.py index 06b2c4cb2f6049..d6f0c7ff179793 100644 --- a/api/core/tools/provider/builtin/did/tools/talks.py +++ b/api/core/tools/provider/builtin/did/tools/talks.py @@ -10,49 +10,49 @@ class TalksTool(BuiltinTool): def _invoke( self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - app = DIDApp(api_key=self.runtime.credentials['did_api_key'], base_url=self.runtime.credentials['base_url']) + app = DIDApp(api_key=self.runtime.credentials["did_api_key"], base_url=self.runtime.credentials["base_url"]) - driver_expressions_str = tool_parameters.get('driver_expressions') + driver_expressions_str = tool_parameters.get("driver_expressions") driver_expressions = json.loads(driver_expressions_str) if driver_expressions_str else None script = { - 'type': tool_parameters.get('script_type') or 'text', - 'input': tool_parameters.get('text_input'), - 'audio_url': tool_parameters.get('audio_url'), - 'reduce_noise': tool_parameters.get('audio_reduce_noise', False), + "type": tool_parameters.get("script_type") or "text", + "input": tool_parameters.get("text_input"), + "audio_url": tool_parameters.get("audio_url"), + "reduce_noise": tool_parameters.get("audio_reduce_noise", False), } - script = {k: v for k, v in script.items() if v is not None and v != ''} + script = {k: v for k, v in script.items() if v is not None and v != ""} config = { - 'stitch': tool_parameters.get('stitch', True), - 'sharpen': tool_parameters.get('sharpen'), - 'fluent': tool_parameters.get('fluent'), - 'result_format': tool_parameters.get('result_format') or 'mp4', - 'pad_audio': tool_parameters.get('pad_audio'), - 'driver_expressions': driver_expressions, + "stitch": tool_parameters.get("stitch", True), + "sharpen": tool_parameters.get("sharpen"), + "fluent": tool_parameters.get("fluent"), + "result_format": tool_parameters.get("result_format") or "mp4", + "pad_audio": tool_parameters.get("pad_audio"), + "driver_expressions": driver_expressions, } - config = {k: v for k, v in config.items() if v is not None and v != ''} + config = {k: v for k, v in config.items() if v is not None and v != ""} options = { - 'source_url': tool_parameters['source_url'], - 'driver_url': tool_parameters.get('driver_url'), - 'script': script, - 'config': config, + "source_url": tool_parameters["source_url"], + "driver_url": tool_parameters.get("driver_url"), + "script": script, + "config": config, } - options = {k: v for k, v in options.items() if v is not None and v != ''} + options = {k: v for k, v in options.items() if v is not None and v != ""} - if not options.get('source_url'): - raise ValueError('Source URL is required') + if not options.get("source_url"): + raise ValueError("Source URL is required") - if script.get('type') == 'audio': - script.pop('input', None) - if not script.get('audio_url'): - raise ValueError('Audio URL is required for audio script type') + if script.get("type") == "audio": + script.pop("input", None) + if not script.get("audio_url"): + raise ValueError("Audio URL is required for audio script type") - if script.get('type') == 'text': - script.pop('audio_url', None) - script.pop('reduce_noise', None) - if not script.get('input'): - raise ValueError('Text input is required for text script type') + if script.get("type") == "text": + script.pop("audio_url", None) + script.pop("reduce_noise", None) + if not script.get("input"): + raise ValueError("Text input is required for text script type") talks_result = app.talks(params=options, wait=True) @@ -60,6 +60,6 @@ def _invoke( talks_result = json.dumps(talks_result, ensure_ascii=False, indent=4) if not talks_result: - return self.create_text_message('D-ID talks request failed.') + return self.create_text_message("D-ID talks request failed.") return self.create_text_message(talks_result) diff --git a/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py b/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py index c247c3bd6bcff0..f33ad5be59b403 100644 --- a/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py +++ b/api/core/tools/provider/builtin/dingtalk/tools/dingtalk_group_bot.py @@ -13,38 +13,43 @@ class DingTalkGroupBotTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools - Dingtalk custom group robot API docs: - https://open.dingtalk.com/document/orgapp/custom-robot-access + invoke tools + Dingtalk custom group robot API docs: + https://open.dingtalk.com/document/orgapp/custom-robot-access """ - content = tool_parameters.get('content') + content = tool_parameters.get("content") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - access_token = tool_parameters.get('access_token') + access_token = tool_parameters.get("access_token") if not access_token: - return self.create_text_message('Invalid parameter access_token. ' - 'Regarding information about security details,' - 'please refer to the DingTalk docs:' - 'https://open.dingtalk.com/document/robots/customize-robot-security-settings') + return self.create_text_message( + "Invalid parameter access_token. " + "Regarding information about security details," + "please refer to the DingTalk docs:" + "https://open.dingtalk.com/document/robots/customize-robot-security-settings" + ) - sign_secret = tool_parameters.get('sign_secret') + sign_secret = tool_parameters.get("sign_secret") if not sign_secret: - return self.create_text_message('Invalid parameter sign_secret. ' - 'Regarding information about security details,' - 'please refer to the DingTalk docs:' - 'https://open.dingtalk.com/document/robots/customize-robot-security-settings') + return self.create_text_message( + "Invalid parameter sign_secret. " + "Regarding information about security details," + "please refer to the DingTalk docs:" + "https://open.dingtalk.com/document/robots/customize-robot-security-settings" + ) - msgtype = 'text' - api_url = 'https://oapi.dingtalk.com/robot/send' + msgtype = "text" + api_url = "https://oapi.dingtalk.com/robot/send" headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = { - 'access_token': access_token, + "access_token": access_token, } self._apply_security_mechanism(params, sign_secret) @@ -53,7 +58,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any] "msgtype": msgtype, "text": { "content": content, - } + }, } try: @@ -62,7 +67,8 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any] return self.create_text_message("Text message sent successfully") else: return self.create_text_message( - f"Failed to send the text message, status code: {res.status_code}, response: {res.text}") + f"Failed to send the text message, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to send message to group chat bot. {}".format(e)) @@ -70,14 +76,14 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any] def _apply_security_mechanism(params: dict[str, Any], sign_secret: str): try: timestamp = str(round(time.time() * 1000)) - secret_enc = sign_secret.encode('utf-8') - string_to_sign = f'{timestamp}\n{sign_secret}' - string_to_sign_enc = string_to_sign.encode('utf-8') + secret_enc = sign_secret.encode("utf-8") + string_to_sign = f"{timestamp}\n{sign_secret}" + string_to_sign_enc = string_to_sign.encode("utf-8") hmac_code = hmac.new(secret_enc, string_to_sign_enc, digestmod=hashlib.sha256).digest() sign = urllib.parse.quote_plus(base64.b64encode(hmac_code)) - params['timestamp'] = timestamp - params['sign'] = sign + params["timestamp"] = timestamp + params["sign"] = sign except Exception: msg = "Failed to apply security mechanism to the request." logging.exception(msg) diff --git a/api/core/tools/provider/builtin/discord/_assets/icon.svg b/api/core/tools/provider/builtin/discord/_assets/icon.svg new file mode 100644 index 00000000000000..177a0591f9cb08 --- /dev/null +++ b/api/core/tools/provider/builtin/discord/_assets/icon.svg @@ -0,0 +1,7 @@ + + + + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/discord/discord.py b/api/core/tools/provider/builtin/discord/discord.py new file mode 100644 index 00000000000000..c94824b591cd95 --- /dev/null +++ b/api/core/tools/provider/builtin/discord/discord.py @@ -0,0 +1,9 @@ +from typing import Any + +from core.tools.provider.builtin.discord.tools.discord_webhook import DiscordWebhookTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class DiscordProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + DiscordWebhookTool() diff --git a/api/core/tools/provider/builtin/discord/discord.yaml b/api/core/tools/provider/builtin/discord/discord.yaml new file mode 100644 index 00000000000000..18b249b5229a0e --- /dev/null +++ b/api/core/tools/provider/builtin/discord/discord.yaml @@ -0,0 +1,16 @@ +identity: + author: Ice Yao + name: discord + label: + en_US: Discord + zh_Hans: Discord + pt_BR: Discord + description: + en_US: Discord Webhook + zh_Hans: Discord Webhook + pt_BR: Discord Webhook + icon: icon.svg + tags: + - social + - productivity +credentials_for_provider: diff --git a/api/core/tools/provider/builtin/discord/tools/discord_webhook.py b/api/core/tools/provider/builtin/discord/tools/discord_webhook.py new file mode 100644 index 00000000000000..c1834a1a265be2 --- /dev/null +++ b/api/core/tools/provider/builtin/discord/tools/discord_webhook.py @@ -0,0 +1,49 @@ +from typing import Any, Union + +import httpx + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class DiscordWebhookTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + Incoming Webhooks + API Document: + https://discord.com/developers/docs/resources/webhook#execute-webhook + """ + + content = tool_parameters.get("content", "") + if not content: + return self.create_text_message("Invalid parameter content") + + webhook_url = tool_parameters.get("webhook_url", "") + if not webhook_url.startswith("https://discord.com/api/webhooks/"): + return self.create_text_message( + f"Invalid parameter webhook_url ${webhook_url}, \ + not a valid Discord webhook URL" + ) + + headers = { + "Content-Type": "application/json", + } + payload = { + "username": tool_parameters.get("username") or user_id, + "content": content, + "avatar_url": tool_parameters.get("avatar_url") or None, + } + + try: + res = httpx.post(webhook_url, headers=headers, json=payload) + if res.is_success: + return self.create_text_message("Text message was sent successfully") + else: + return self.create_text_message( + f"Failed to send the text message, \ + status code: {res.status_code}, response: {res.text}" + ) + except Exception as e: + return self.create_text_message("Failed to send message through webhook. {}".format(e)) diff --git a/api/core/tools/provider/builtin/discord/tools/discord_webhook.yaml b/api/core/tools/provider/builtin/discord/tools/discord_webhook.yaml new file mode 100644 index 00000000000000..6847b973cabd19 --- /dev/null +++ b/api/core/tools/provider/builtin/discord/tools/discord_webhook.yaml @@ -0,0 +1,65 @@ +identity: + name: discord_webhook + author: Ice Yao + label: + en_US: Incoming Webhook to send message + zh_Hans: 通过入站Webhook发送消息 + pt_BR: Incoming Webhook to send message + icon: icon.svg +description: + human: + en_US: Sending a message on Discord via the Incoming Webhook + zh_Hans: 通过入站Webhook在Discord上发送消息 + pt_BR: Sending a message on Discord via the Incoming Webhook + llm: A tool for sending messages to a chat on Discord. +parameters: + - name: webhook_url + type: string + required: true + label: + en_US: Discord Incoming Webhook url + zh_Hans: Discord入站Webhook的url + pt_BR: Discord Incoming Webhook url + human_description: + en_US: Discord Incoming Webhook url + zh_Hans: Discord入站Webhook的url + pt_BR: Discord Incoming Webhook url + form: form + - name: content + type: string + required: true + label: + en_US: content + zh_Hans: 消息内容 + pt_BR: content + human_description: + en_US: Content to sent to the channel or person. + zh_Hans: 消息内容文本 + pt_BR: Content to sent to the channel or person. + llm_description: Content of the message + form: llm + - name: username + type: string + required: false + label: + en_US: Discord Webhook Username + zh_Hans: Discord Webhook用户名 + pt_BR: Discord Webhook Username + human_description: + en_US: Discord Webhook Username + zh_Hans: Discord Webhook用户名 + pt_BR: Discord Webhook Username + llm_description: Discord Webhook Username + form: llm + - name: avatar_url + type: string + required: false + label: + en_US: Discord Webhook Avatar + zh_Hans: Discord Webhook头像 + pt_BR: Discord Webhook Avatar + human_description: + en_US: Discord Webhook Avatar URL + zh_Hans: Discord Webhook头像地址 + pt_BR: Discord Webhook Avatar URL + form: form diff --git a/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py b/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py index 2292e89fa6ed13..8269167127b8e5 100644 --- a/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py +++ b/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py @@ -11,11 +11,10 @@ def _validate_credentials(self, credentials: dict) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "John Doe", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.py index 878b0d86453a2b..8bdd638f4a01d1 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.py @@ -13,8 +13,8 @@ class DuckDuckGoAITool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: query_dict = { - "keywords": tool_parameters.get('query'), - "model": tool_parameters.get('model'), + "keywords": tool_parameters.get("query"), + "model": tool_parameters.get("model"), } response = DDGS().chat(**query_dict) return self.create_text_message(text=response) diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.yaml b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.yaml index 21cbae6bd3e002..dd049d3b5a13d2 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.yaml +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_ai.yaml @@ -37,7 +37,7 @@ parameters: - value: mixtral-8x7b label: en_US: Mixtral - default: gpt-3.5 + default: gpt-4o-mini label: en_US: Choose Model zh_Hans: 选择模型 diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py index bca53f6b4bd745..54bb38755a5b5c 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py @@ -2,7 +2,6 @@ from duckduckgo_search import DDGS -from core.file.file_obj import FileTransferMethod from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool @@ -14,18 +13,15 @@ class DuckDuckGoImageSearchTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: query_dict = { - "keywords": tool_parameters.get('query'), - "timelimit": tool_parameters.get('timelimit'), - "size": tool_parameters.get('size'), - "max_results": tool_parameters.get('max_results'), + "keywords": tool_parameters.get("query"), + "timelimit": tool_parameters.get("timelimit"), + "size": tool_parameters.get("size"), + "max_results": tool_parameters.get("max_results"), } response = DDGS().images(**query_dict) - result = [] + markdown_result = "\n\n" + json_result = [] for res in response: - res['transfer_method'] = FileTransferMethod.REMOTE_URL - msg = ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=res.get('image'), - save_as='', - meta=res) - result.append(msg) - return result + markdown_result += f"![{res.get('title') or ''}]({res.get('image') or ''})" + json_result.append(self.create_json_message(res)) + return [self.create_text_message(markdown_result)] + json_result diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py index dfaeb734d8f667..cbd65d2e7756e0 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py @@ -21,10 +21,11 @@ class DuckDuckGoSearchTool(BuiltinTool): """ Tool for performing a search using DuckDuckGo search engine. """ + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: - query = tool_parameters.get('query') - max_results = tool_parameters.get('max_results', 5) - require_summary = tool_parameters.get('require_summary', False) + query = tool_parameters.get("query") + max_results = tool_parameters.get("max_results", 5) + require_summary = tool_parameters.get("require_summary", False) response = DDGS().text(query, max_results=max_results) if require_summary: results = "\n".join([res.get("body") for res in response]) @@ -34,7 +35,11 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe def summary_results(self, user_id: str, content: str, query: str) -> str: prompt = SUMMARY_PROMPT.format(query=query, content=content) - summary = self.invoke_model(user_id=user_id, prompt_messages=[ - SystemPromptMessage(content=prompt), - ], stop=[]) + summary = self.invoke_model( + user_id=user_id, + prompt_messages=[ + SystemPromptMessage(content=prompt), + ], + stop=[], + ) return summary.message.content diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_translate.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_translate.py index 9822b37cf0231d..396ce21b183afc 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_translate.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_translate.py @@ -13,8 +13,8 @@ class DuckDuckGoTranslateTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: query_dict = { - "keywords": tool_parameters.get('query'), - "to": tool_parameters.get('translate_to'), + "keywords": tool_parameters.get("query"), + "to": tool_parameters.get("translate_to"), } - response = DDGS().translate(**query_dict)[0].get('translated', 'Unable to translate!') + response = DDGS().translate(**query_dict)[0].get("translated", "Unable to translate!") return self.create_text_message(text=response) diff --git a/api/core/tools/provider/builtin/feishu/tools/feishu_group_bot.py b/api/core/tools/provider/builtin/feishu/tools/feishu_group_bot.py index e8ab02f55ea4ed..e82da8ca534b96 100644 --- a/api/core/tools/provider/builtin/feishu/tools/feishu_group_bot.py +++ b/api/core/tools/provider/builtin/feishu/tools/feishu_group_bot.py @@ -8,35 +8,35 @@ class FeishuGroupBotTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools - API document: https://open.feishu.cn/document/client-docs/bot-v3/add-custom-bot + invoke tools + API document: https://open.feishu.cn/document/client-docs/bot-v3/add-custom-bot """ url = "https://open.feishu.cn/open-apis/bot/v2/hook" - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - hook_key = tool_parameters.get('hook_key', '') + hook_key = tool_parameters.get("hook_key", "") if not is_valid_uuid(hook_key): - return self.create_text_message( - f'Invalid parameter hook_key ${hook_key}, not a valid UUID') + return self.create_text_message(f"Invalid parameter hook_key ${hook_key}, not a valid UUID") - msg_type = 'text' - api_url = f'{url}/{hook_key}' + msg_type = "text" + api_url = f"{url}/{hook_key}" headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = {} payload = { "msg_type": msg_type, "content": { "text": content, - } + }, } try: @@ -45,6 +45,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any] return self.create_text_message("Text message sent successfully") else: return self.create_text_message( - f"Failed to send the text message, status code: {res.status_code}, response: {res.text}") + f"Failed to send the text message, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: - return self.create_text_message("Failed to send message to group chat bot. {}".format(e)) \ No newline at end of file + return self.create_text_message("Failed to send message to group chat bot. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/_assets/icon.png b/api/core/tools/provider/builtin/feishu_base/_assets/icon.png new file mode 100644 index 00000000000000..787427e7218058 Binary files /dev/null and b/api/core/tools/provider/builtin/feishu_base/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/feishu_base/_assets/icon.svg b/api/core/tools/provider/builtin/feishu_base/_assets/icon.svg deleted file mode 100644 index 2663a0f59ee6a4..00000000000000 --- a/api/core/tools/provider/builtin/feishu_base/_assets/icon.svg +++ /dev/null @@ -1,47 +0,0 @@ - - - - diff --git a/api/core/tools/provider/builtin/feishu_base/feishu_base.py b/api/core/tools/provider/builtin/feishu_base/feishu_base.py index febb769ff83cc9..f301ec5355d48f 100644 --- a/api/core/tools/provider/builtin/feishu_base/feishu_base.py +++ b/api/core/tools/provider/builtin/feishu_base/feishu_base.py @@ -1,8 +1,7 @@ -from core.tools.provider.builtin.feishu_base.tools.get_tenant_access_token import GetTenantAccessTokenTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.utils.feishu_api_utils import auth class FeishuBaseProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: - GetTenantAccessTokenTool() - pass \ No newline at end of file + auth(credentials) diff --git a/api/core/tools/provider/builtin/feishu_base/feishu_base.yaml b/api/core/tools/provider/builtin/feishu_base/feishu_base.yaml index f3dcbb6136b3a3..456dd8c88fc348 100644 --- a/api/core/tools/provider/builtin/feishu_base/feishu_base.yaml +++ b/api/core/tools/provider/builtin/feishu_base/feishu_base.yaml @@ -5,10 +5,32 @@ identity: en_US: Feishu Base zh_Hans: 飞书多维表格 description: - en_US: Feishu Base - zh_Hans: 飞书多维表格 - icon: icon.svg + en_US: | + Feishu base, requires the following permissions: bitable:app. + zh_Hans: | + 飞书多维表格,需要开通以下权限: bitable:app。 + icon: icon.png tags: - social - productivity credentials_for_provider: + app_id: + type: text-input + required: true + label: + en_US: APP ID + placeholder: + en_US: Please input your feishu app id + zh_Hans: 请输入你的飞书 app id + help: + en_US: Get your app_id and app_secret from Feishu + zh_Hans: 从飞书获取您的 app_id 和 app_secret + url: https://open.larkoffice.com/app + app_secret: + type: secret-input + required: true + label: + en_US: APP Secret + placeholder: + en_US: Please input your app secret + zh_Hans: 请输入你的飞书 app secret diff --git a/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py b/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py deleted file mode 100644 index be43b43ce47337..00000000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.py +++ /dev/null @@ -1,52 +0,0 @@ -import json -from typing import Any, Union - -import httpx - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool - - -class AddBaseRecordTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records" - - access_token = tool_parameters.get('Authorization', '') - if not access_token: - return self.create_text_message('Invalid parameter access_token') - - app_token = tool_parameters.get('app_token', '') - if not app_token: - return self.create_text_message('Invalid parameter app_token') - - table_id = tool_parameters.get('table_id', '') - if not table_id: - return self.create_text_message('Invalid parameter table_id') - - fields = tool_parameters.get('fields', '') - if not fields: - return self.create_text_message('Invalid parameter fields') - - headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", - } - - params = {} - payload = { - "fields": json.loads(fields) - } - - try: - res = httpx.post(url.format(app_token=app_token, table_id=table_id), headers=headers, params=params, - json=payload, timeout=30) - res_json = res.json() - if res.is_success: - return self.create_text_message(text=json.dumps(res_json)) - else: - return self.create_text_message( - f"Failed to add base record, status code: {res.status_code}, response: {res.text}") - except Exception as e: - return self.create_text_message("Failed to add base record. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.yaml b/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.yaml deleted file mode 100644 index 3ce0154efd69dc..00000000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/add_base_record.yaml +++ /dev/null @@ -1,66 +0,0 @@ -identity: - name: add_base_record - author: Doug Lea - label: - en_US: Add Base Record - zh_Hans: 在多维表格数据表中新增一条记录 -description: - human: - en_US: Add Base Record - zh_Hans: | - 在多维表格数据表中新增一条记录,详细请参考:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/app-table-record/create - llm: Add a new record in the multidimensional table data table. -parameters: - - name: Authorization - type: string - required: true - label: - en_US: token - zh_Hans: 凭证 - human_description: - en_US: API access token parameter, tenant_access_token or user_access_token - zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token - llm_description: API access token parameter, tenant_access_token or user_access_token - form: llm - - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: 多维表格 - human_description: - en_US: bitable app token - zh_Hans: 多维表格的唯一标识符 app_token - llm_description: bitable app token - form: llm - - - name: table_id - type: string - required: true - label: - en_US: table_id - zh_Hans: 多维表格的数据表 - human_description: - en_US: bitable table id - zh_Hans: 多维表格数据表的唯一标识符 table_id - llm_description: bitable table id - form: llm - - - name: fields - type: string - required: true - label: - en_US: fields - zh_Hans: 数据表的列字段内容 - human_description: - en_US: The fields of the Base data table are the columns of the data table. - zh_Hans: | - 要增加一行多维表格记录,字段结构拼接如下:{"多行文本":"多行文本内容","单选":"选项1","多选":["选项1","选项2"],"复选框":true,"人员":[{"id":"ou_2910013f1e6456f16a0ce75ede950a0a"}],"群组":[{"id":"oc_cd07f55f14d6f4a4f1b51504e7e97f48"}],"电话号码":"13026162666"} - 当前接口支持的字段类型为:多行文本、单选、条码、多选、日期、人员、附件、复选框、超链接、数字、单向关联、双向关联、电话号码、地理位置。 - 不同类型字段的数据结构请参考数据结构概述:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure - llm_description: | - 要增加一行多维表格记录,字段结构拼接如下:{"多行文本":"多行文本内容","单选":"选项1","多选":["选项1","选项2"],"复选框":true,"人员":[{"id":"ou_2910013f1e6456f16a0ce75ede950a0a"}],"群组":[{"id":"oc_cd07f55f14d6f4a4f1b51504e7e97f48"}],"电话号码":"13026162666"} - 当前接口支持的字段类型为:多行文本、单选、条码、多选、日期、人员、附件、复选框、超链接、数字、单向关联、双向关联、电话号码、地理位置。 - 不同类型字段的数据结构请参考数据结构概述:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/add_records.py b/api/core/tools/provider/builtin/feishu_base/tools/add_records.py new file mode 100644 index 00000000000000..905f8b78806d05 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/add_records.py @@ -0,0 +1,21 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class AddRecordsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + table_id = tool_parameters.get("table_id") + table_name = tool_parameters.get("table_name") + records = tool_parameters.get("records") + user_id_type = tool_parameters.get("user_id_type", "open_id") + + res = client.add_records(app_token, table_id, table_name, records, user_id_type) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/add_records.yaml b/api/core/tools/provider/builtin/feishu_base/tools/add_records.yaml new file mode 100644 index 00000000000000..f2a93490dc0c31 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/add_records.yaml @@ -0,0 +1,91 @@ +identity: + name: add_records + author: Doug Lea + label: + en_US: Add Records + zh_Hans: 新增多条记录 +description: + human: + en_US: Add Multiple Records to Multidimensional Table + zh_Hans: 在多维表格数据表中新增多条记录 + llm: A tool for adding multiple records to a multidimensional table. (在多维表格数据表中新增多条记录) +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: table_id + type: string + required: false + label: + en_US: table_id + zh_Hans: table_id + human_description: + en_US: Unique identifier for the multidimensional table data, either table_id or table_name must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + form: llm + + - name: table_name + type: string + required: false + label: + en_US: table_name + zh_Hans: table_name + human_description: + en_US: Name of the multidimensional table data, either table_name or table_id must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + form: llm + + - name: records + type: string + required: true + label: + en_US: records + zh_Hans: 记录列表 + human_description: + en_US: | + List of records to be added in this request. Example value: [{"multi-line-text":"text content","single_select":"option 1","date":1674206443000}] + For supported field types, refer to the integration guide (https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification). For data structures of different field types, refer to the data structure overview (https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure). + zh_Hans: | + 本次请求将要新增的记录列表,示例值:[{"多行文本":"文本内容","单选":"选项 1","日期":1674206443000}]。 + 当前接口支持的字段类型请参考接入指南(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification),不同类型字段的数据结构请参考数据结构概述(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure)。 + llm_description: | + 本次请求将要新增的记录列表,示例值:[{"多行文本":"文本内容","单选":"选项 1","日期":1674206443000}]。 + 当前接口支持的字段类型请参考接入指南(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification),不同类型字段的数据结构请参考数据结构概述(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure)。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_base.py b/api/core/tools/provider/builtin/feishu_base/tools/create_base.py index 639644e7f0e3ea..f074acc5ff709e 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/create_base.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/create_base.py @@ -1,43 +1,18 @@ -import json -from typing import Any, Union - -import httpx +from typing import Any from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest class CreateBaseTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - url = "https://open.feishu.cn/open-apis/bitable/v1/apps" - - access_token = tool_parameters.get('Authorization', '') - if not access_token: - return self.create_text_message('Invalid parameter access_token') - - name = tool_parameters.get('name', '') - folder_token = tool_parameters.get('folder_token', '') - - headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", - } + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) - params = {} - payload = { - "name": name, - "folder_token": folder_token - } + name = tool_parameters.get("name") + folder_token = tool_parameters.get("folder_token") - try: - res = httpx.post(url, headers=headers, params=params, json=payload, timeout=30) - res_json = res.json() - if res.is_success: - return self.create_text_message(text=json.dumps(res_json)) - else: - return self.create_text_message( - f"Failed to create base, status code: {res.status_code}, response: {res.text}") - except Exception as e: - return self.create_text_message("Failed to create base. {}".format(e)) + res = client.create_base(name, folder_token) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_base.yaml b/api/core/tools/provider/builtin/feishu_base/tools/create_base.yaml index 76c76a916d4951..3ec91a90e7f0b6 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/create_base.yaml +++ b/api/core/tools/provider/builtin/feishu_base/tools/create_base.yaml @@ -6,32 +6,21 @@ identity: zh_Hans: 创建多维表格 description: human: - en_US: Create base + en_US: Create Multidimensional Table in Specified Directory zh_Hans: 在指定目录下创建多维表格 - llm: A tool for create a multidimensional table in the specified directory. + llm: A tool for creating a multidimensional table in a specified directory. (在指定目录下创建多维表格) parameters: - - name: Authorization - type: string - required: true - label: - en_US: token - zh_Hans: 凭证 - human_description: - en_US: API access token parameter, tenant_access_token or user_access_token - zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token - llm_description: API access token parameter, tenant_access_token or user_access_token - form: llm - - name: name type: string required: false label: en_US: name - zh_Hans: name + zh_Hans: 多维表格 App 名字 human_description: - en_US: Base App Name - zh_Hans: 多维表格App名字 - llm_description: Base App Name + en_US: | + Name of the multidimensional table App. Example value: "A new multidimensional table". + zh_Hans: 多维表格 App 名字,示例值:"一篇新的多维表格"。 + llm_description: 多维表格 App 名字,示例值:"一篇新的多维表格"。 form: llm - name: folder_token @@ -39,9 +28,15 @@ parameters: required: false label: en_US: folder_token - zh_Hans: 多维表格App归属文件夹 + zh_Hans: 多维表格 App 归属文件夹 human_description: - en_US: Base App home folder. The default is empty, indicating that Base will be created in the cloud space root directory. - zh_Hans: 多维表格App归属文件夹。默认为空,表示多维表格将被创建在云空间根目录。 - llm_description: Base App home folder. The default is empty, indicating that Base will be created in the cloud space root directory. + en_US: | + Folder where the multidimensional table App belongs. Default is empty, meaning the table will be created in the root directory of the cloud space. Example values: Fa3sfoAgDlMZCcdcJy1cDFg8nJc or https://svi136aogf123.feishu.cn/drive/folder/Fa3sfoAgDlMZCcdcJy1cDFg8nJc. + The folder_token must be an existing folder and supports inputting folder token or folder URL. + zh_Hans: | + 多维表格 App 归属文件夹。默认为空,表示多维表格将被创建在云空间根目录。示例值: Fa3sfoAgDlMZCcdcJy1cDFg8nJc 或者 https://svi136aogf123.feishu.cn/drive/folder/Fa3sfoAgDlMZCcdcJy1cDFg8nJc。 + folder_token 必须是已存在的文件夹,支持输入文件夹 token 或者文件夹 URL。 + llm_description: | + 多维表格 App 归属文件夹。默认为空,表示多维表格将被创建在云空间根目录。示例值: Fa3sfoAgDlMZCcdcJy1cDFg8nJc 或者 https://svi136aogf123.feishu.cn/drive/folder/Fa3sfoAgDlMZCcdcJy1cDFg8nJc。 + folder_token 必须是已存在的文件夹,支持输入文件夹 token 或者文件夹 URL。 form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py b/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py deleted file mode 100644 index e9062e8730f9ac..00000000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.py +++ /dev/null @@ -1,52 +0,0 @@ -import json -from typing import Any, Union - -import httpx - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool - - -class CreateBaseTableTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables" - - access_token = tool_parameters.get('Authorization', '') - if not access_token: - return self.create_text_message('Invalid parameter access_token') - - app_token = tool_parameters.get('app_token', '') - if not app_token: - return self.create_text_message('Invalid parameter app_token') - - name = tool_parameters.get('name', '') - - fields = tool_parameters.get('fields', '') - if not fields: - return self.create_text_message('Invalid parameter fields') - - headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", - } - - params = {} - payload = { - "table": { - "name": name, - "fields": json.loads(fields) - } - } - - try: - res = httpx.post(url.format(app_token=app_token), headers=headers, params=params, json=payload, timeout=30) - res_json = res.json() - if res.is_success: - return self.create_text_message(text=json.dumps(res_json)) - else: - return self.create_text_message( - f"Failed to create base table, status code: {res.status_code}, response: {res.text}") - except Exception as e: - return self.create_text_message("Failed to create base table. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.yaml b/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.yaml deleted file mode 100644 index 48c46bec14f448..00000000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/create_base_table.yaml +++ /dev/null @@ -1,106 +0,0 @@ -identity: - name: create_base_table - author: Doug Lea - label: - en_US: Create Base Table - zh_Hans: 多维表格新增一个数据表 -description: - human: - en_US: Create base table - zh_Hans: | - 多维表格新增一个数据表,详细请参考:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/app-table/create - llm: A tool for add a new data table to the multidimensional table. -parameters: - - name: Authorization - type: string - required: true - label: - en_US: token - zh_Hans: 凭证 - human_description: - en_US: API access token parameter, tenant_access_token or user_access_token - zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token - llm_description: API access token parameter, tenant_access_token or user_access_token - form: llm - - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: 多维表格 - human_description: - en_US: bitable app token - zh_Hans: 多维表格的唯一标识符 app_token - llm_description: bitable app token - form: llm - - - name: name - type: string - required: false - label: - en_US: name - zh_Hans: name - human_description: - en_US: Multidimensional table data table name - zh_Hans: 多维表格数据表名称 - llm_description: Multidimensional table data table name - form: llm - - - name: fields - type: string - required: true - label: - en_US: fields - zh_Hans: fields - human_description: - en_US: Initial fields of the data table - zh_Hans: | - 数据表的初始字段,格式为:[{"field_name":"多行文本","type":1},{"field_name":"数字","type":2},{"field_name":"单选","type":3},{"field_name":"多选","type":4},{"field_name":"日期","type":5}]。 - field_name:字段名; - type: 字段类型;可选值有 - 1:多行文本 - 2:数字 - 3:单选 - 4:多选 - 5:日期 - 7:复选框 - 11:人员 - 13:电话号码 - 15:超链接 - 17:附件 - 18:单向关联 - 20:公式 - 21:双向关联 - 22:地理位置 - 23:群组 - 1001:创建时间 - 1002:最后更新时间 - 1003:创建人 - 1004:修改人 - 1005:自动编号 - llm_description: | - 数据表的初始字段,格式为:[{"field_name":"多行文本","type":1},{"field_name":"数字","type":2},{"field_name":"单选","type":3},{"field_name":"多选","type":4},{"field_name":"日期","type":5}]。 - field_name:字段名; - type: 字段类型;可选值有 - 1:多行文本 - 2:数字 - 3:单选 - 4:多选 - 5:日期 - 7:复选框 - 11:人员 - 13:电话号码 - 15:超链接 - 17:附件 - 18:单向关联 - 20:公式 - 21:双向关联 - 22:地理位置 - 23:群组 - 1001:创建时间 - 1002:最后更新时间 - 1003:创建人 - 1004:修改人 - 1005:自动编号 - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_table.py b/api/core/tools/provider/builtin/feishu_base/tools/create_table.py new file mode 100644 index 00000000000000..81f2617545969b --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/create_table.py @@ -0,0 +1,20 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class CreateTableTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + table_name = tool_parameters.get("table_name") + default_view_name = tool_parameters.get("default_view_name") + fields = tool_parameters.get("fields") + + res = client.create_table(app_token, table_name, default_view_name, fields) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/create_table.yaml b/api/core/tools/provider/builtin/feishu_base/tools/create_table.yaml new file mode 100644 index 00000000000000..8b1007b9a53166 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/create_table.yaml @@ -0,0 +1,61 @@ +identity: + name: create_table + author: Doug Lea + label: + en_US: Create Table + zh_Hans: 新增数据表 +description: + human: + en_US: Add a Data Table to Multidimensional Table + zh_Hans: 在多维表格中新增一个数据表 + llm: A tool for adding a data table to a multidimensional table. (在多维表格中新增一个数据表) +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: table_name + type: string + required: true + label: + en_US: Table Name + zh_Hans: 数据表名称 + human_description: + en_US: | + The name of the data table, length range: 1 character to 100 characters. + zh_Hans: 数据表名称,长度范围:1 字符 ~ 100 字符。 + llm_description: 数据表名称,长度范围:1 字符 ~ 100 字符。 + form: llm + + - name: default_view_name + type: string + required: false + label: + en_US: Default View Name + zh_Hans: 默认表格视图的名称 + human_description: + en_US: The name of the default table view, defaults to "Table" if not filled. + zh_Hans: 默认表格视图的名称,不填则默认为"表格"。 + llm_description: 默认表格视图的名称,不填则默认为"表格"。 + form: llm + + - name: fields + type: string + required: true + label: + en_US: Initial Fields + zh_Hans: 初始字段 + human_description: + en_US: | + Initial fields of the data table, format: [ { "field_name": "Multi-line Text","type": 1 },{ "field_name": "Number","type": 2 },{ "field_name": "Single Select","type": 3 },{ "field_name": "Multiple Select","type": 4 },{ "field_name": "Date","type": 5 } ]. For field details, refer to: https://open.larkoffice.com/document/server-docs/docs/bitable-v1/app-table-field/guide + zh_Hans: 数据表的初始字段,格式为:[{"field_name":"多行文本","type":1},{"field_name":"数字","type":2},{"field_name":"单选","type":3},{"field_name":"多选","type":4},{"field_name":"日期","type":5}]。字段详情参考:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/app-table-field/guide + llm_description: 数据表的初始字段,格式为:[{"field_name":"多行文本","type":1},{"field_name":"数字","type":2},{"field_name":"单选","type":3},{"field_name":"多选","type":4},{"field_name":"日期","type":5}]。字段详情参考:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/app-table-field/guide + form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py deleted file mode 100644 index aa13aad6fac287..00000000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.py +++ /dev/null @@ -1,52 +0,0 @@ -import json -from typing import Any, Union - -import httpx - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool - - -class DeleteBaseRecordsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/batch_delete" - - access_token = tool_parameters.get('Authorization', '') - if not access_token: - return self.create_text_message('Invalid parameter access_token') - - app_token = tool_parameters.get('app_token', '') - if not app_token: - return self.create_text_message('Invalid parameter app_token') - - table_id = tool_parameters.get('table_id', '') - if not table_id: - return self.create_text_message('Invalid parameter table_id') - - record_ids = tool_parameters.get('record_ids', '') - if not record_ids: - return self.create_text_message('Invalid parameter record_ids') - - headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", - } - - params = {} - payload = { - "records": json.loads(record_ids) - } - - try: - res = httpx.post(url.format(app_token=app_token, table_id=table_id), headers=headers, params=params, - json=payload, timeout=30) - res_json = res.json() - if res.is_success: - return self.create_text_message(text=json.dumps(res_json)) - else: - return self.create_text_message( - f"Failed to delete base records, status code: {res.status_code}, response: {res.text}") - except Exception as e: - return self.create_text_message("Failed to delete base records. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.yaml b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.yaml deleted file mode 100644 index 595b2870298af9..00000000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_records.yaml +++ /dev/null @@ -1,60 +0,0 @@ -identity: - name: delete_base_records - author: Doug Lea - label: - en_US: Delete Base Records - zh_Hans: 在多维表格数据表中删除多条记录 -description: - human: - en_US: Delete base records - zh_Hans: | - 该接口用于删除多维表格数据表中的多条记录,单次调用中最多删除 500 条记录。 - llm: A tool for delete multiple records in a multidimensional table data table, up to 500 records can be deleted in a single call. -parameters: - - name: Authorization - type: string - required: true - label: - en_US: token - zh_Hans: 凭证 - human_description: - en_US: API access token parameter, tenant_access_token or user_access_token - zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token - llm_description: API access token parameter, tenant_access_token or user_access_token - form: llm - - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: 多维表格 - human_description: - en_US: bitable app token - zh_Hans: 多维表格的唯一标识符 app_token - llm_description: bitable app token - form: llm - - - name: table_id - type: string - required: true - label: - en_US: table_id - zh_Hans: 多维表格的数据表 - human_description: - en_US: bitable table id - zh_Hans: 多维表格数据表的唯一标识符 table_id - llm_description: bitable table id - form: llm - - - name: record_ids - type: string - required: true - label: - en_US: record_ids - zh_Hans: record_ids - human_description: - en_US: A list of multiple record IDs to be deleted, for example ["recwNXzPQv","recpCsf4ME"] - zh_Hans: 待删除的多条记录id列表,示例为 ["recwNXzPQv","recpCsf4ME"] - llm_description: A list of multiple record IDs to be deleted, for example ["recwNXzPQv","recpCsf4ME"] - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py deleted file mode 100644 index c4280ebc21eaed..00000000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.py +++ /dev/null @@ -1,47 +0,0 @@ -import json -from typing import Any, Union - -import httpx - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool - - -class DeleteBaseTablesTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/batch_delete" - - access_token = tool_parameters.get('Authorization', '') - if not access_token: - return self.create_text_message('Invalid parameter access_token') - - app_token = tool_parameters.get('app_token', '') - if not app_token: - return self.create_text_message('Invalid parameter app_token') - - table_ids = tool_parameters.get('table_ids', '') - if not table_ids: - return self.create_text_message('Invalid parameter table_ids') - - headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", - } - - params = {} - payload = { - "table_ids": json.loads(table_ids) - } - - try: - res = httpx.post(url.format(app_token=app_token), headers=headers, params=params, json=payload, timeout=30) - res_json = res.json() - if res.is_success: - return self.create_text_message(text=json.dumps(res_json)) - else: - return self.create_text_message( - f"Failed to delete base tables, status code: {res.status_code}, response: {res.text}") - except Exception as e: - return self.create_text_message("Failed to delete base tables. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.yaml b/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.yaml deleted file mode 100644 index 5d72814363d86f..00000000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/delete_base_tables.yaml +++ /dev/null @@ -1,48 +0,0 @@ -identity: - name: delete_base_tables - author: Doug Lea - label: - en_US: Delete Base Tables - zh_Hans: 删除多维表格中的数据表 -description: - human: - en_US: Delete base tables - zh_Hans: | - 删除多维表格中的数据表 - llm: A tool for deleting a data table in a multidimensional table -parameters: - - name: Authorization - type: string - required: true - label: - en_US: token - zh_Hans: 凭证 - human_description: - en_US: API access token parameter, tenant_access_token or user_access_token - zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token - llm_description: API access token parameter, tenant_access_token or user_access_token - form: llm - - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: 多维表格 - human_description: - en_US: bitable app token - zh_Hans: 多维表格的唯一标识符 app_token - llm_description: bitable app token - form: llm - - - name: table_ids - type: string - required: true - label: - en_US: table_ids - zh_Hans: table_ids - human_description: - en_US: The ID list of the data tables to be deleted. Currently, a maximum of 50 data tables can be deleted at a time. The example is ["tbl1TkhyTWDkSoZ3","tblsRc9GRRXKqhvW"] - zh_Hans: 待删除数据表的id列表,当前一次操作最多支持50个数据表,示例为 ["tbl1TkhyTWDkSoZ3","tblsRc9GRRXKqhvW"] - llm_description: The ID list of the data tables to be deleted. Currently, a maximum of 50 data tables can be deleted at a time. The example is ["tbl1TkhyTWDkSoZ3","tblsRc9GRRXKqhvW"] - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_records.py b/api/core/tools/provider/builtin/feishu_base/tools/delete_records.py new file mode 100644 index 00000000000000..c896a2c81b97f8 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/delete_records.py @@ -0,0 +1,20 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class DeleteRecordsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + table_id = tool_parameters.get("table_id") + table_name = tool_parameters.get("table_name") + record_ids = tool_parameters.get("record_ids") + + res = client.delete_records(app_token, table_id, table_name, record_ids) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_records.yaml b/api/core/tools/provider/builtin/feishu_base/tools/delete_records.yaml new file mode 100644 index 00000000000000..c30ebd630ce9d8 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/delete_records.yaml @@ -0,0 +1,86 @@ +identity: + name: delete_records + author: Doug Lea + label: + en_US: Delete Records + zh_Hans: 删除多条记录 +description: + human: + en_US: Delete Multiple Records from Multidimensional Table + zh_Hans: 删除多维表格数据表中的多条记录 + llm: A tool for deleting multiple records from a multidimensional table. (删除多维表格数据表中的多条记录) +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: table_id + type: string + required: false + label: + en_US: table_id + zh_Hans: table_id + human_description: + en_US: Unique identifier for the multidimensional table data, either table_id or table_name must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + form: llm + + - name: table_name + type: string + required: false + label: + en_US: table_name + zh_Hans: table_name + human_description: + en_US: Name of the multidimensional table data, either table_name or table_id must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + form: llm + + - name: record_ids + type: string + required: true + label: + en_US: Record IDs + zh_Hans: 记录 ID 列表 + human_description: + en_US: | + List of IDs for the records to be deleted, example value: ["recwNXzPQv"]. + zh_Hans: 删除的多条记录 ID 列表,示例值:["recwNXzPQv"]。 + llm_description: 删除的多条记录 ID 列表,示例值:["recwNXzPQv"]。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_tables.py b/api/core/tools/provider/builtin/feishu_base/tools/delete_tables.py new file mode 100644 index 00000000000000..f732a16da6f697 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/delete_tables.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class DeleteTablesTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + table_ids = tool_parameters.get("table_ids") + table_names = tool_parameters.get("table_names") + + res = client.delete_tables(app_token, table_ids, table_names) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/delete_tables.yaml b/api/core/tools/provider/builtin/feishu_base/tools/delete_tables.yaml new file mode 100644 index 00000000000000..498126eae53302 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/delete_tables.yaml @@ -0,0 +1,49 @@ +identity: + name: delete_tables + author: Doug Lea + label: + en_US: Delete Tables + zh_Hans: 删除数据表 +description: + human: + en_US: Batch Delete Data Tables from Multidimensional Table + zh_Hans: 批量删除多维表格中的数据表 + llm: A tool for batch deleting data tables from a multidimensional table. (批量删除多维表格中的数据表) +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: table_ids + type: string + required: false + label: + en_US: Table IDs + zh_Hans: 数据表 ID + human_description: + en_US: | + IDs of the tables to be deleted. Each operation supports deleting up to 50 tables. Example: ["tbl1TkhyTWDkSoZ3"]. Ensure that either table_ids or table_names is not empty. + zh_Hans: 待删除的数据表的 ID,每次操作最多支持删除 50 个数据表。示例值:["tbl1TkhyTWDkSoZ3"]。请确保 table_ids 和 table_names 至少有一个不为空。 + llm_description: 待删除的数据表的 ID,每次操作最多支持删除 50 个数据表。示例值:["tbl1TkhyTWDkSoZ3"]。请确保 table_ids 和 table_names 至少有一个不为空。 + form: llm + + - name: table_names + type: string + required: false + label: + en_US: Table Names + zh_Hans: 数据表名称 + human_description: + en_US: | + Names of the tables to be deleted. Each operation supports deleting up to 50 tables. Example: ["Table1", "Table2"]. Ensure that either table_names or table_ids is not empty. + zh_Hans: 待删除的数据表的名称,每次操作最多支持删除 50 个数据表。示例值:["数据表1", "数据表2"]。请确保 table_names 和 table_ids 至少有一个不为空。 + llm_description: 待删除的数据表的名称,每次操作最多支持删除 50 个数据表。示例值:["数据表1", "数据表2"]。请确保 table_names 和 table_ids 至少有一个不为空。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.py b/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.py index de70f2ed9359dc..a74e9be288bc17 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.py +++ b/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.py @@ -1,38 +1,17 @@ -import json -from typing import Any, Union - -import httpx +from typing import Any from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest class GetBaseInfoTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}" - - access_token = tool_parameters.get('Authorization', '') - if not access_token: - return self.create_text_message('Invalid parameter access_token') - - app_token = tool_parameters.get('app_token', '') - if not app_token: - return self.create_text_message('Invalid parameter app_token') + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) - headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", - } + app_token = tool_parameters.get("app_token") - try: - res = httpx.get(url.format(app_token=app_token), headers=headers, timeout=30) - res_json = res.json() - if res.is_success: - return self.create_text_message(text=json.dumps(res_json)) - else: - return self.create_text_message( - f"Failed to get base info, status code: {res.status_code}, response: {res.text}") - except Exception as e: - return self.create_text_message("Failed to get base info. {}".format(e)) + res = client.get_base_info(app_token) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.yaml b/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.yaml index de0868901834ee..eb0e7a26c06a55 100644 --- a/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.yaml +++ b/api/core/tools/provider/builtin/feishu_base/tools/get_base_info.yaml @@ -6,49 +6,18 @@ identity: zh_Hans: 获取多维表格元数据 description: human: - en_US: Get base info - zh_Hans: | - 获取多维表格元数据,响应体如下: - { - "code": 0, - "msg": "success", - "data": { - "app": { - "app_token": "appbcbWCzen6D8dezhoCH2RpMAh", - "name": "mybase", - "revision": 1, - "is_advanced": false, - "time_zone": "Asia/Beijing" - } - } - } - app_token: 多维表格的 app_token; - name: 多维表格的名字; - revision: 多维表格的版本号; - is_advanced: 多维表格是否开启了高级权限。取值包括:(true-表示开启了高级权限,false-表示关闭了高级权限); - time_zone: 文档时区; - llm: A tool to get Base Metadata, imported parameter is Unique Device Identifier app_token of Base, app_token is required. + en_US: Get Metadata Information of Specified Multidimensional Table + zh_Hans: 获取指定多维表格的元数据信息 + llm: A tool for getting metadata information of a specified multidimensional table. (获取指定多维表格的元数据信息) parameters: - - name: Authorization - type: string - required: true - label: - en_US: token - zh_Hans: 凭证 - human_description: - en_US: API access token parameter, tenant_access_token or user_access_token - zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token - llm_description: API access token parameter, tenant_access_token or user_access_token - form: llm - - name: app_token type: string required: true label: en_US: app_token - zh_Hans: 多维表格 + zh_Hans: app_token human_description: - en_US: bitable app token - zh_Hans: 多维表格的唯一标识符 app_token - llm_description: bitable app token + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py b/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py deleted file mode 100644 index 88507bda60090f..00000000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.py +++ /dev/null @@ -1,50 +0,0 @@ -import json -from typing import Any, Union - -import httpx - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool - - -class GetTenantAccessTokenTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal" - - app_id = tool_parameters.get('app_id', '') - if not app_id: - return self.create_text_message('Invalid parameter app_id') - - app_secret = tool_parameters.get('app_secret', '') - if not app_secret: - return self.create_text_message('Invalid parameter app_secret') - - headers = { - 'Content-Type': 'application/json', - } - params = {} - payload = { - "app_id": app_id, - "app_secret": app_secret - } - - """ - { - "code": 0, - "msg": "ok", - "tenant_access_token": "t-caecc734c2e3328a62489fe0648c4b98779515d3", - "expire": 7200 - } - """ - try: - res = httpx.post(url, headers=headers, params=params, json=payload, timeout=30) - res_json = res.json() - if res.is_success: - return self.create_text_message(text=json.dumps(res_json)) - else: - return self.create_text_message( - f"Failed to get tenant access token, status code: {res.status_code}, response: {res.text}") - except Exception as e: - return self.create_text_message("Failed to get tenant access token. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.yaml b/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.yaml deleted file mode 100644 index 88acc27e06eca1..00000000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/get_tenant_access_token.yaml +++ /dev/null @@ -1,39 +0,0 @@ -identity: - name: get_tenant_access_token - author: Doug Lea - label: - en_US: Get Tenant Access Token - zh_Hans: 获取飞书自建应用的 tenant_access_token -description: - human: - en_US: Get tenant access token - zh_Hans: | - 获取飞书自建应用的 tenant_access_token,响应体示例: - {"code":0,"msg":"ok","tenant_access_token":"t-caecc734c2e3328a62489fe0648c4b98779515d3","expire":7200} - tenant_access_token: 租户访问凭证; - expire: tenant_access_token 的过期时间,单位为秒; - llm: A tool for obtaining a tenant access token. The input parameters must include app_id and app_secret. -parameters: - - name: app_id - type: string - required: true - label: - en_US: app_id - zh_Hans: 应用唯一标识 - human_description: - en_US: app_id is the unique identifier of the Lark Open Platform application - zh_Hans: app_id 是飞书开放平台应用的唯一标识 - llm_description: app_id is the unique identifier of the Lark Open Platform application - form: llm - - - name: app_secret - type: secret-input - required: true - label: - en_US: app_secret - zh_Hans: 应用秘钥 - human_description: - en_US: app_secret is the secret key of the application - zh_Hans: app_secret 是应用的秘钥 - llm_description: app_secret is the secret key of the application - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py b/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py deleted file mode 100644 index 2a4229f137d7fd..00000000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.py +++ /dev/null @@ -1,61 +0,0 @@ -import json -from typing import Any, Union - -import httpx - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool - - -class ListBaseRecordsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/search" - - access_token = tool_parameters.get('Authorization', '') - if not access_token: - return self.create_text_message('Invalid parameter access_token') - - app_token = tool_parameters.get('app_token', '') - if not app_token: - return self.create_text_message('Invalid parameter app_token') - - table_id = tool_parameters.get('table_id', '') - if not table_id: - return self.create_text_message('Invalid parameter table_id') - - page_token = tool_parameters.get('page_token', '') - page_size = tool_parameters.get('page_size', '') - sort_condition = tool_parameters.get('sort_condition', '') - filter_condition = tool_parameters.get('filter_condition', '') - - headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", - } - - params = { - "page_token": page_token, - "page_size": page_size, - } - - payload = { - "automatic_fields": True - } - if sort_condition: - payload["sort"] = json.loads(sort_condition) - if filter_condition: - payload["filter"] = json.loads(filter_condition) - - try: - res = httpx.post(url.format(app_token=app_token, table_id=table_id), headers=headers, params=params, - json=payload, timeout=30) - res_json = res.json() - if res.is_success: - return self.create_text_message(text=json.dumps(res_json)) - else: - return self.create_text_message( - f"Failed to list base records, status code: {res.status_code}, response: {res.text}") - except Exception as e: - return self.create_text_message("Failed to list base records. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.yaml b/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.yaml deleted file mode 100644 index 8647c880a60024..00000000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/list_base_records.yaml +++ /dev/null @@ -1,108 +0,0 @@ -identity: - name: list_base_records - author: Doug Lea - label: - en_US: List Base Records - zh_Hans: 查询多维表格数据表中的现有记录 -description: - human: - en_US: List base records - zh_Hans: | - 查询多维表格数据表中的现有记录,单次最多查询 500 行记录,支持分页获取。 - llm: Query existing records in a multidimensional table data table. A maximum of 500 rows of records can be queried at a time, and paging retrieval is supported. -parameters: - - name: Authorization - type: string - required: true - label: - en_US: token - zh_Hans: 凭证 - human_description: - en_US: API access token parameter, tenant_access_token or user_access_token - zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token - llm_description: API access token parameter, tenant_access_token or user_access_token - form: llm - - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: 多维表格 - human_description: - en_US: bitable app token - zh_Hans: 多维表格的唯一标识符 app_token - llm_description: bitable app token - form: llm - - - name: table_id - type: string - required: true - label: - en_US: table_id - zh_Hans: 多维表格的数据表 - human_description: - en_US: bitable table id - zh_Hans: 多维表格数据表的唯一标识符 table_id - llm_description: bitable table id - form: llm - - - name: page_token - type: string - required: false - label: - en_US: page_token - zh_Hans: 分页标记 - human_description: - en_US: Pagination mark. If it is not filled in the first request, it means to traverse from the beginning. - zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历。 - llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 - form: llm - - - name: page_size - type: number - required: false - default: 20 - label: - en_US: page_size - zh_Hans: 分页大小 - human_description: - en_US: paging size - zh_Hans: 分页大小,默认值为 20,最大值为 100。 - llm_description: The default value of paging size is 20 and the maximum value is 100. - form: llm - - - name: sort_condition - type: string - required: false - label: - en_US: sort_condition - zh_Hans: 排序条件 - human_description: - en_US: sort condition - zh_Hans: | - 排序条件,格式为:[{"field_name":"多行文本","desc":true}]。 - field_name: 字段名称; - desc: 是否倒序排序; - llm_description: | - Sorting conditions, the format is: [{"field_name":"multi-line text","desc":true}]. - form: llm - - - name: filter_condition - type: string - required: false - label: - en_US: filter_condition - zh_Hans: 筛选条件 - human_description: - en_US: filter condition - zh_Hans: | - 筛选条件,格式为:{"conjunction":"and","conditions":[{"field_name":"字段1","operator":"is","value":["文本内容"]}]}。 - conjunction:条件逻辑连接词; - conditions:筛选条件集合; - field_name:筛选条件的左值,值为字段的名称; - operator:条件运算符; - value:目标值; - llm_description: | - The format of the filter condition is: {"conjunction":"and","conditions":[{"field_name":"Field 1","operator":"is","value":["text content"]}]}. - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py b/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py deleted file mode 100644 index 6d82490eb3235f..00000000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.py +++ /dev/null @@ -1,46 +0,0 @@ -import json -from typing import Any, Union - -import httpx - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool - - -class ListBaseTablesTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables" - - access_token = tool_parameters.get('Authorization', '') - if not access_token: - return self.create_text_message('Invalid parameter access_token') - - app_token = tool_parameters.get('app_token', '') - if not app_token: - return self.create_text_message('Invalid parameter app_token') - - page_token = tool_parameters.get('page_token', '') - page_size = tool_parameters.get('page_size', '') - - headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", - } - - params = { - "page_token": page_token, - "page_size": page_size, - } - - try: - res = httpx.get(url.format(app_token=app_token), headers=headers, params=params, timeout=30) - res_json = res.json() - if res.is_success: - return self.create_text_message(text=json.dumps(res_json)) - else: - return self.create_text_message( - f"Failed to list base tables, status code: {res.status_code}, response: {res.text}") - except Exception as e: - return self.create_text_message("Failed to list base tables. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.yaml b/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.yaml deleted file mode 100644 index 9887124a28823a..00000000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/list_base_tables.yaml +++ /dev/null @@ -1,65 +0,0 @@ -identity: - name: list_base_tables - author: Doug Lea - label: - en_US: List Base Tables - zh_Hans: 根据 app_token 获取多维表格下的所有数据表 -description: - human: - en_US: List base tables - zh_Hans: | - 根据 app_token 获取多维表格下的所有数据表 - llm: A tool for getting all data tables under a multidimensional table based on app_token. -parameters: - - name: Authorization - type: string - required: true - label: - en_US: token - zh_Hans: 凭证 - human_description: - en_US: API access token parameter, tenant_access_token or user_access_token - zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token - llm_description: API access token parameter, tenant_access_token or user_access_token - form: llm - - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: 多维表格 - human_description: - en_US: bitable app token - zh_Hans: 多维表格的唯一标识符 app_token - llm_description: bitable app token - form: llm - - - name: page_token - type: string - required: false - label: - en_US: page_token - zh_Hans: 分页标记 - human_description: - en_US: Pagination mark. If it is not filled in the first request, it means to traverse from the beginning. - zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历。 - llm_description: | - Pagination token. If it is not filled in the first request, it means to start traversal from the beginning. - If there are more items in the pagination query result, a new page_token will be returned at the same time. - The page_token can be used to obtain the query result in the next traversal. - 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 - form: llm - - - name: page_size - type: number - required: false - default: 20 - label: - en_US: page_size - zh_Hans: 分页大小 - human_description: - en_US: paging size - zh_Hans: 分页大小,默认值为 20,最大值为 100。 - llm_description: The default value of paging size is 20 and the maximum value is 100. - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_tables.py b/api/core/tools/provider/builtin/feishu_base/tools/list_tables.py new file mode 100644 index 00000000000000..c7768a496debce --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/list_tables.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class ListTablesTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + page_token = tool_parameters.get("page_token") + page_size = tool_parameters.get("page_size", 20) + + res = client.list_tables(app_token, page_token, page_size) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/list_tables.yaml b/api/core/tools/provider/builtin/feishu_base/tools/list_tables.yaml new file mode 100644 index 00000000000000..7571519039bd24 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/list_tables.yaml @@ -0,0 +1,50 @@ +identity: + name: list_tables + author: Doug Lea + label: + en_US: List Tables + zh_Hans: 列出数据表 +description: + human: + en_US: Get All Data Tables under Multidimensional Table + zh_Hans: 获取多维表格下的所有数据表 + llm: A tool for getting all data tables under a multidimensional table. (获取多维表格下的所有数据表) +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: page_size + type: number + required: false + default: 20 + label: + en_US: page_size + zh_Hans: 分页大小 + human_description: + en_US: | + Page size, default value: 20, maximum value: 100. + zh_Hans: 分页大小,默认值:20,最大值:100。 + llm_description: 分页大小,默认值:20,最大值:100。 + form: form + + - name: page_token + type: string + required: false + label: + en_US: page_token + zh_Hans: 分页标记 + human_description: + en_US: | + Page token, leave empty for the first request to start from the beginning; a new page_token will be returned if there are more items in the paginated query results, which can be used for the next traversal. Example value: "tblsRc9GRRXKqhvW". + zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。示例值:"tblsRc9GRRXKqhvW"。 + llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。示例值:"tblsRc9GRRXKqhvW"。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py b/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py deleted file mode 100644 index bb4bd6c3a6c531..00000000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.py +++ /dev/null @@ -1,47 +0,0 @@ -import json -from typing import Any, Union - -import httpx - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool - - -class ReadBaseRecordTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/{record_id}" - - access_token = tool_parameters.get('Authorization', '') - if not access_token: - return self.create_text_message('Invalid parameter access_token') - - app_token = tool_parameters.get('app_token', '') - if not app_token: - return self.create_text_message('Invalid parameter app_token') - - table_id = tool_parameters.get('table_id', '') - if not table_id: - return self.create_text_message('Invalid parameter table_id') - - record_id = tool_parameters.get('record_id', '') - if not record_id: - return self.create_text_message('Invalid parameter record_id') - - headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", - } - - try: - res = httpx.get(url.format(app_token=app_token, table_id=table_id, record_id=record_id), headers=headers, - timeout=30) - res_json = res.json() - if res.is_success: - return self.create_text_message(text=json.dumps(res_json)) - else: - return self.create_text_message( - f"Failed to read base record, status code: {res.status_code}, response: {res.text}") - except Exception as e: - return self.create_text_message("Failed to read base record. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.yaml b/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.yaml deleted file mode 100644 index 400e9a1021f2db..00000000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/read_base_record.yaml +++ /dev/null @@ -1,60 +0,0 @@ -identity: - name: read_base_record - author: Doug Lea - label: - en_US: Read Base Record - zh_Hans: 根据 record_id 的值检索多维表格数据表的记录 -description: - human: - en_US: Read base record - zh_Hans: | - 根据 record_id 的值检索多维表格数据表的记录 - llm: Retrieve records from a multidimensional table based on the value of record_id -parameters: - - name: Authorization - type: string - required: true - label: - en_US: token - zh_Hans: 凭证 - human_description: - en_US: API access token parameter, tenant_access_token or user_access_token - zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token - llm_description: API access token parameter, tenant_access_token or user_access_token - form: llm - - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: 多维表格 - human_description: - en_US: bitable app token - zh_Hans: 多维表格的唯一标识符 app_token - llm_description: bitable app token - form: llm - - - name: table_id - type: string - required: true - label: - en_US: table_id - zh_Hans: 多维表格的数据表 - human_description: - en_US: bitable table id - zh_Hans: 多维表格数据表的唯一标识符 table_id - llm_description: bitable table id - form: llm - - - name: record_id - type: string - required: true - label: - en_US: record_id - zh_Hans: 单条记录的 id - human_description: - en_US: The id of a single record - zh_Hans: 单条记录的 id - llm_description: The id of a single record - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/read_records.py b/api/core/tools/provider/builtin/feishu_base/tools/read_records.py new file mode 100644 index 00000000000000..46f3df4ff040f3 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/read_records.py @@ -0,0 +1,21 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class ReadRecordsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + table_id = tool_parameters.get("table_id") + table_name = tool_parameters.get("table_name") + record_ids = tool_parameters.get("record_ids") + user_id_type = tool_parameters.get("user_id_type", "open_id") + + res = client.read_records(app_token, table_id, table_name, record_ids, user_id_type) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/read_records.yaml b/api/core/tools/provider/builtin/feishu_base/tools/read_records.yaml new file mode 100644 index 00000000000000..911e667cfc90ad --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/read_records.yaml @@ -0,0 +1,86 @@ +identity: + name: read_records + author: Doug Lea + label: + en_US: Read Records + zh_Hans: 批量获取记录 +description: + human: + en_US: Batch Retrieve Records from Multidimensional Table + zh_Hans: 批量获取多维表格数据表中的记录信息 + llm: A tool for batch retrieving records from a multidimensional table, supporting up to 100 records per call. (批量获取多维表格数据表中的记录信息,单次调用最多支持查询 100 条记录) + +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: table_id + type: string + required: false + label: + en_US: table_id + zh_Hans: table_id + human_description: + en_US: Unique identifier for the multidimensional table data, either table_id or table_name must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + form: llm + + - name: table_name + type: string + required: false + label: + en_US: table_name + zh_Hans: table_name + human_description: + en_US: Name of the multidimensional table data, either table_name or table_id must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + form: llm + + - name: record_ids + type: string + required: true + label: + en_US: record_ids + zh_Hans: 记录 ID 列表 + human_description: + en_US: List of record IDs, which can be obtained by calling the "Query Records API". + zh_Hans: 记录 ID 列表,可以通过调用"查询记录接口"获取。 + llm_description: 记录 ID 列表,可以通过调用"查询记录接口"获取。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form diff --git a/api/core/tools/provider/builtin/feishu_base/tools/search_records.py b/api/core/tools/provider/builtin/feishu_base/tools/search_records.py new file mode 100644 index 00000000000000..c959496735e747 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/search_records.py @@ -0,0 +1,39 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class SearchRecordsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + table_id = tool_parameters.get("table_id") + table_name = tool_parameters.get("table_name") + view_id = tool_parameters.get("view_id") + field_names = tool_parameters.get("field_names") + sort = tool_parameters.get("sort") + filters = tool_parameters.get("filter") + page_token = tool_parameters.get("page_token") + automatic_fields = tool_parameters.get("automatic_fields", False) + user_id_type = tool_parameters.get("user_id_type", "open_id") + page_size = tool_parameters.get("page_size", 20) + + res = client.search_record( + app_token, + table_id, + table_name, + view_id, + field_names, + sort, + filters, + page_token, + automatic_fields, + user_id_type, + page_size, + ) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/search_records.yaml b/api/core/tools/provider/builtin/feishu_base/tools/search_records.yaml new file mode 100644 index 00000000000000..decf76d53ed928 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/search_records.yaml @@ -0,0 +1,163 @@ +identity: + name: search_records + author: Doug Lea + label: + en_US: Search Records + zh_Hans: 查询记录 +description: + human: + en_US: Query records in a multidimensional table, up to 500 rows per query. + zh_Hans: 查询多维表格数据表中的记录,单次最多查询 500 行记录。 + llm: A tool for querying records in a multidimensional table, up to 500 rows per query. (查询多维表格数据表中的记录,单次最多查询 500 行记录) +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: table_id + type: string + required: false + label: + en_US: table_id + zh_Hans: table_id + human_description: + en_US: Unique identifier for the multidimensional table data, either table_id or table_name must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + form: llm + + - name: table_name + type: string + required: false + label: + en_US: table_name + zh_Hans: table_name + human_description: + en_US: Name of the multidimensional table data, either table_name or table_id must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + form: llm + + - name: view_id + type: string + required: false + label: + en_US: view_id + zh_Hans: 视图唯一标识 + human_description: + en_US: | + Unique identifier for a view in a multidimensional table. It can be found in the URL's query parameter with the key 'view'. For example: https://svi136aogf123.feishu.cn/base/KWC8bYsYXahYqGsTtqectNn9n3e?table=tblE8a2fmBIEflaE&view=vewlkAVpRx. + zh_Hans: 多维表格中视图的唯一标识,可在多维表格的 URL 地址栏中找到,query 参数中 key 为 view 的部分。例如:https://svi136aogf123.feishu.cn/base/KWC8bYsYXahYqGsTtqectNn9n3e?table=tblE8a2fmBIEflaE&view=vewlkAVpRx。 + llm_description: 多维表格中视图的唯一标识,可在多维表格的 URL 地址栏中找到,query 参数中 key 为 view 的部分。例如:https://svi136aogf123.feishu.cn/base/KWC8bYsYXahYqGsTtqectNn9n3e?table=tblE8a2fmBIEflaE&view=vewlkAVpRx。 + form: llm + + - name: field_names + type: string + required: false + label: + en_US: field_names + zh_Hans: 字段名称 + human_description: + en_US: | + Field names to specify which fields to include in the returned records. Example value: ["Field1", "Field2"]. + zh_Hans: 字段名称,用于指定本次查询返回记录中包含的字段。示例值:["字段1","字段2"]。 + llm_description: 字段名称,用于指定本次查询返回记录中包含的字段。示例值:["字段1","字段2"]。 + form: llm + + - name: sort + type: string + required: false + label: + en_US: sort + zh_Hans: 排序条件 + human_description: + en_US: | + Sorting conditions, for example: [{"field_name":"Multiline Text","desc":true}]. + zh_Hans: 排序条件,例如:[{"field_name":"多行文本","desc":true}]。 + llm_description: 排序条件,例如:[{"field_name":"多行文本","desc":true}]。 + form: llm + + - name: filter + type: string + required: false + label: + en_US: filter + zh_Hans: 筛选条件 + human_description: + en_US: Object containing filter information. For details on how to fill in the filter, refer to the record filter parameter guide (https://open.larkoffice.com/document/uAjLw4CM/ukTMukTMukTM/reference/bitable-v1/app-table-record/record-filter-guide). + zh_Hans: 包含条件筛选信息的对象。了解如何填写 filter,参考记录筛选参数填写指南(https://open.larkoffice.com/document/uAjLw4CM/ukTMukTMukTM/reference/bitable-v1/app-table-record/record-filter-guide)。 + llm_description: 包含条件筛选信息的对象。了解如何填写 filter,参考记录筛选参数填写指南(https://open.larkoffice.com/document/uAjLw4CM/ukTMukTMukTM/reference/bitable-v1/app-table-record/record-filter-guide)。 + form: llm + + - name: automatic_fields + type: boolean + required: false + label: + en_US: automatic_fields + zh_Hans: automatic_fields + human_description: + en_US: Whether to return automatically calculated fields. Default is false, meaning they are not returned. + zh_Hans: 是否返回自动计算的字段。默认为 false,表示不返回。 + llm_description: 是否返回自动计算的字段。默认为 false,表示不返回。 + form: form + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form + + - name: page_size + type: number + required: false + default: 20 + label: + en_US: page_size + zh_Hans: 分页大小 + human_description: + en_US: | + Page size, default value: 20, maximum value: 500. + zh_Hans: 分页大小,默认值:20,最大值:500。 + llm_description: 分页大小,默认值:20,最大值:500。 + form: form + + - name: page_token + type: string + required: false + label: + en_US: page_token + zh_Hans: 分页标记 + human_description: + en_US: | + Page token, leave empty for the first request to start from the beginning; a new page_token will be returned if there are more items in the paginated query results, which can be used for the next traversal. Example value: "tblsRc9GRRXKqhvW". + zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。示例值:"tblsRc9GRRXKqhvW"。 + llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。示例值:"tblsRc9GRRXKqhvW"。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py b/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py deleted file mode 100644 index 6551053ce22535..00000000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.py +++ /dev/null @@ -1,56 +0,0 @@ -import json -from typing import Any, Union - -import httpx - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.tool.builtin_tool import BuiltinTool - - -class UpdateBaseRecordTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - url = "https://open.feishu.cn/open-apis/bitable/v1/apps/{app_token}/tables/{table_id}/records/{record_id}" - - access_token = tool_parameters.get('Authorization', '') - if not access_token: - return self.create_text_message('Invalid parameter access_token') - - app_token = tool_parameters.get('app_token', '') - if not app_token: - return self.create_text_message('Invalid parameter app_token') - - table_id = tool_parameters.get('table_id', '') - if not table_id: - return self.create_text_message('Invalid parameter table_id') - - record_id = tool_parameters.get('record_id', '') - if not record_id: - return self.create_text_message('Invalid parameter record_id') - - fields = tool_parameters.get('fields', '') - if not fields: - return self.create_text_message('Invalid parameter fields') - - headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", - } - - params = {} - payload = { - "fields": json.loads(fields) - } - - try: - res = httpx.put(url.format(app_token=app_token, table_id=table_id, record_id=record_id), headers=headers, - params=params, json=payload, timeout=30) - res_json = res.json() - if res.is_success: - return self.create_text_message(text=json.dumps(res_json)) - else: - return self.create_text_message( - f"Failed to update base record, status code: {res.status_code}, response: {res.text}") - except Exception as e: - return self.create_text_message("Failed to update base record. {}".format(e)) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.yaml b/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.yaml deleted file mode 100644 index 788798c4b3b40e..00000000000000 --- a/api/core/tools/provider/builtin/feishu_base/tools/update_base_record.yaml +++ /dev/null @@ -1,78 +0,0 @@ -identity: - name: update_base_record - author: Doug Lea - label: - en_US: Update Base Record - zh_Hans: 更新多维表格数据表中的一条记录 -description: - human: - en_US: Update base record - zh_Hans: | - 更新多维表格数据表中的一条记录,详细请参考:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/app-table-record/update - llm: Update a record in a multidimensional table data table -parameters: - - name: Authorization - type: string - required: true - label: - en_US: token - zh_Hans: 凭证 - human_description: - en_US: API access token parameter, tenant_access_token or user_access_token - zh_Hans: API 的访问凭证参数,tenant_access_token 或 user_access_token - llm_description: API access token parameter, tenant_access_token or user_access_token - form: llm - - - name: app_token - type: string - required: true - label: - en_US: app_token - zh_Hans: 多维表格 - human_description: - en_US: bitable app token - zh_Hans: 多维表格的唯一标识符 app_token - llm_description: bitable app token - form: llm - - - name: table_id - type: string - required: true - label: - en_US: table_id - zh_Hans: 多维表格的数据表 - human_description: - en_US: bitable table id - zh_Hans: 多维表格数据表的唯一标识符 table_id - llm_description: bitable table id - form: llm - - - name: record_id - type: string - required: true - label: - en_US: record_id - zh_Hans: 单条记录的 id - human_description: - en_US: The id of a single record - zh_Hans: 单条记录的 id - llm_description: The id of a single record - form: llm - - - name: fields - type: string - required: true - label: - en_US: fields - zh_Hans: 数据表的列字段内容 - human_description: - en_US: The fields of a multidimensional table data table, that is, the columns of the data table. - zh_Hans: | - 要更新一行多维表格记录,字段结构拼接如下:{"多行文本":"多行文本内容","单选":"选项1","多选":["选项1","选项2"],"复选框":true,"人员":[{"id":"ou_2910013f1e6456f16a0ce75ede950a0a"}],"群组":[{"id":"oc_cd07f55f14d6f4a4f1b51504e7e97f48"}],"电话号码":"13026162666"} - 当前接口支持的字段类型为:多行文本、单选、条码、多选、日期、人员、附件、复选框、超链接、数字、单向关联、双向关联、电话号码、地理位置。 - 不同类型字段的数据结构请参考数据结构概述:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure - llm_description: | - 要更新一行多维表格记录,字段结构拼接如下:{"多行文本":"多行文本内容","单选":"选项1","多选":["选项1","选项2"],"复选框":true,"人员":[{"id":"ou_2910013f1e6456f16a0ce75ede950a0a"}],"群组":[{"id":"oc_cd07f55f14d6f4a4f1b51504e7e97f48"}],"电话号码":"13026162666"} - 当前接口支持的字段类型为:多行文本、单选、条码、多选、日期、人员、附件、复选框、超链接、数字、单向关联、双向关联、电话号码、地理位置。 - 不同类型字段的数据结构请参考数据结构概述:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure - form: llm diff --git a/api/core/tools/provider/builtin/feishu_base/tools/update_records.py b/api/core/tools/provider/builtin/feishu_base/tools/update_records.py new file mode 100644 index 00000000000000..a7b036387500b0 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/update_records.py @@ -0,0 +1,21 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class UpdateRecordsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + table_id = tool_parameters.get("table_id") + table_name = tool_parameters.get("table_name") + records = tool_parameters.get("records") + user_id_type = tool_parameters.get("user_id_type", "open_id") + + res = client.update_records(app_token, table_id, table_name, records, user_id_type) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_base/tools/update_records.yaml b/api/core/tools/provider/builtin/feishu_base/tools/update_records.yaml new file mode 100644 index 00000000000000..68117e71367892 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_base/tools/update_records.yaml @@ -0,0 +1,91 @@ +identity: + name: update_records + author: Doug Lea + label: + en_US: Update Records + zh_Hans: 更新多条记录 +description: + human: + en_US: Update Multiple Records in Multidimensional Table + zh_Hans: 更新多维表格数据表中的多条记录 + llm: A tool for updating multiple records in a multidimensional table. (更新多维表格数据表中的多条记录) +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: table_id + type: string + required: false + label: + en_US: table_id + zh_Hans: table_id + human_description: + en_US: Unique identifier for the multidimensional table data, either table_id or table_name must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + form: llm + + - name: table_name + type: string + required: false + label: + en_US: table_name + zh_Hans: table_name + human_description: + en_US: Name of the multidimensional table data, either table_name or table_id must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + form: llm + + - name: records + type: string + required: true + label: + en_US: records + zh_Hans: 记录列表 + human_description: + en_US: | + List of records to be updated in this request. Example value: [{"fields":{"multi-line-text":"text content","single_select":"option 1","date":1674206443000},"record_id":"recupK4f4RM5RX"}]. + For supported field types, refer to the integration guide (https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification). For data structures of different field types, refer to the data structure overview (https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure). + zh_Hans: | + 本次请求将要更新的记录列表,示例值:[{"fields":{"多行文本":"文本内容","单选":"选项 1","日期":1674206443000},"record_id":"recupK4f4RM5RX"}]。 + 当前接口支持的字段类型请参考接入指南(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification),不同类型字段的数据结构请参考数据结构概述(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure)。 + llm_description: | + 本次请求将要更新的记录列表,示例值:[{"fields":{"多行文本":"文本内容","单选":"选项 1","日期":1674206443000},"record_id":"recupK4f4RM5RX"}]。 + 当前接口支持的字段类型请参考接入指南(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification),不同类型字段的数据结构请参考数据结构概述(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure)。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form diff --git a/api/core/tools/provider/builtin/feishu_calendar/_assets/icon.png b/api/core/tools/provider/builtin/feishu_calendar/_assets/icon.png new file mode 100644 index 00000000000000..2a934747a98c66 Binary files /dev/null and b/api/core/tools/provider/builtin/feishu_calendar/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/feishu_calendar/feishu_calendar.py b/api/core/tools/provider/builtin/feishu_calendar/feishu_calendar.py new file mode 100644 index 00000000000000..a46a9fa9e80cab --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_calendar/feishu_calendar.py @@ -0,0 +1,7 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.utils.feishu_api_utils import auth + + +class FeishuCalendarProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + auth(credentials) diff --git a/api/core/tools/provider/builtin/feishu_calendar/feishu_calendar.yaml b/api/core/tools/provider/builtin/feishu_calendar/feishu_calendar.yaml new file mode 100644 index 00000000000000..db5bab5c1081d9 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_calendar/feishu_calendar.yaml @@ -0,0 +1,36 @@ +identity: + author: Doug Lea + name: feishu_calendar + label: + en_US: Feishu Calendar + zh_Hans: 飞书日历 + description: + en_US: | + Feishu calendar, requires the following permissions: calendar:calendar:read、calendar:calendar、contact:user.id:readonly. + zh_Hans: | + 飞书日历,需要开通以下权限: calendar:calendar:read、calendar:calendar、contact:user.id:readonly。 + icon: icon.png + tags: + - social + - productivity +credentials_for_provider: + app_id: + type: text-input + required: true + label: + en_US: APP ID + placeholder: + en_US: Please input your feishu app id + zh_Hans: 请输入你的飞书 app id + help: + en_US: Get your app_id and app_secret from Feishu + zh_Hans: 从飞书获取您的 app_id 和 app_secret + url: https://open.larkoffice.com/app + app_secret: + type: secret-input + required: true + label: + en_US: APP Secret + placeholder: + en_US: Please input your app secret + zh_Hans: 请输入你的飞书 app secret diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.py b/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.py new file mode 100644 index 00000000000000..8f83aea5abbe3d --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.py @@ -0,0 +1,20 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class AddEventAttendeesTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + event_id = tool_parameters.get("event_id") + attendee_phone_or_email = tool_parameters.get("attendee_phone_or_email") + need_notification = tool_parameters.get("need_notification", True) + + res = client.add_event_attendees(event_id, attendee_phone_or_email, need_notification) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.yaml b/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.yaml new file mode 100644 index 00000000000000..b7744499b07344 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/add_event_attendees.yaml @@ -0,0 +1,54 @@ +identity: + name: add_event_attendees + author: Doug Lea + label: + en_US: Add Event Attendees + zh_Hans: 添加日程参会人 +description: + human: + en_US: Add Event Attendees + zh_Hans: 添加日程参会人 + llm: A tool for adding attendees to events in Feishu. (在飞书中添加日程参会人) +parameters: + - name: event_id + type: string + required: true + label: + en_US: Event ID + zh_Hans: 日程 ID + human_description: + en_US: | + The ID of the event, which will be returned when the event is created. For example: fb2a6406-26d6-4c8d-a487-6f0246c94d2f_0. + zh_Hans: | + 创建日程时会返回日程 ID。例如: fb2a6406-26d6-4c8d-a487-6f0246c94d2f_0。 + llm_description: | + 日程 ID,创建日程时会返回日程 ID。例如: fb2a6406-26d6-4c8d-a487-6f0246c94d2f_0。 + form: llm + + - name: need_notification + type: boolean + required: false + default: true + label: + en_US: Need Notification + zh_Hans: 是否需要通知 + human_description: + en_US: | + Whether to send a Bot notification to attendees. true: send, false: do not send. + zh_Hans: | + 是否给参与人发送 Bot 通知,true: 发送,false: 不发送。 + llm_description: | + 是否给参与人发送 Bot 通知,true: 发送,false: 不发送。 + form: form + + - name: attendee_phone_or_email + type: string + required: true + label: + en_US: Attendee Phone or Email + zh_Hans: 参会人电话或邮箱 + human_description: + en_US: The list of attendee emails or phone numbers, separated by commas. + zh_Hans: 日程参会人邮箱或者手机号列表,使用逗号分隔。 + llm_description: 日程参会人邮箱或者手机号列表,使用逗号分隔。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/create_event.py b/api/core/tools/provider/builtin/feishu_calendar/tools/create_event.py new file mode 100644 index 00000000000000..8820bebdbed922 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/create_event.py @@ -0,0 +1,26 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class CreateEventTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + summary = tool_parameters.get("summary") + description = tool_parameters.get("description") + start_time = tool_parameters.get("start_time") + end_time = tool_parameters.get("end_time") + attendee_ability = tool_parameters.get("attendee_ability") + need_notification = tool_parameters.get("need_notification", True) + auto_record = tool_parameters.get("auto_record", False) + + res = client.create_event( + summary, description, start_time, end_time, attendee_ability, need_notification, auto_record + ) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/create_event.yaml b/api/core/tools/provider/builtin/feishu_calendar/tools/create_event.yaml new file mode 100644 index 00000000000000..f0784221ce7965 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/create_event.yaml @@ -0,0 +1,119 @@ +identity: + name: create_event + author: Doug Lea + label: + en_US: Create Event + zh_Hans: 创建日程 +description: + human: + en_US: Create Event + zh_Hans: 创建日程 + llm: A tool for creating events in Feishu.(创建飞书日程) +parameters: + - name: summary + type: string + required: false + label: + en_US: Summary + zh_Hans: 日程标题 + human_description: + en_US: The title of the event. If not filled, the event title will display (No Subject). + zh_Hans: 日程标题,若不填则日程标题显示 (无主题)。 + llm_description: 日程标题,若不填则日程标题显示 (无主题)。 + form: llm + + - name: description + type: string + required: false + label: + en_US: Description + zh_Hans: 日程描述 + human_description: + en_US: The description of the event. + zh_Hans: 日程描述。 + llm_description: 日程描述。 + form: llm + + - name: need_notification + type: boolean + required: false + default: true + label: + en_US: Need Notification + zh_Hans: 是否发送通知 + human_description: + en_US: | + Whether to send a bot message when the event is created, true: send, false: do not send. + zh_Hans: 创建日程时是否发送 bot 消息,true:发送,false:不发送。 + llm_description: 创建日程时是否发送 bot 消息,true:发送,false:不发送。 + form: form + + - name: start_time + type: string + required: true + label: + en_US: Start Time + zh_Hans: 开始时间 + human_description: + en_US: | + The start time of the event, format: 2006-01-02 15:04:05. + zh_Hans: 日程开始时间,格式:2006-01-02 15:04:05。 + llm_description: 日程开始时间,格式:2006-01-02 15:04:05。 + form: llm + + - name: end_time + type: string + required: true + label: + en_US: End Time + zh_Hans: 结束时间 + human_description: + en_US: | + The end time of the event, format: 2006-01-02 15:04:05. + zh_Hans: 日程结束时间,格式:2006-01-02 15:04:05。 + llm_description: 日程结束时间,格式:2006-01-02 15:04:05。 + form: llm + + - name: attendee_ability + type: select + required: false + options: + - value: none + label: + en_US: none + zh_Hans: 无 + - value: can_see_others + label: + en_US: can_see_others + zh_Hans: 可以查看参与人列表 + - value: can_invite_others + label: + en_US: can_invite_others + zh_Hans: 可以邀请其它参与人 + - value: can_modify_event + label: + en_US: can_modify_event + zh_Hans: 可以编辑日程 + default: "none" + label: + en_US: attendee_ability + zh_Hans: 参会人权限 + human_description: + en_US: Attendee ability, optional values are none, can_see_others, can_invite_others, can_modify_event, with a default value of none. + zh_Hans: 参会人权限,可选值有无、可以查看参与人列表、可以邀请其它参与人、可以编辑日程,默认值为无。 + llm_description: 参会人权限,可选值有无、可以查看参与人列表、可以邀请其它参与人、可以编辑日程,默认值为无。 + form: form + + - name: auto_record + type: boolean + required: false + default: false + label: + en_US: Auto Record + zh_Hans: 自动录制 + human_description: + en_US: | + Whether to enable automatic recording, true: enabled, automatically record when the meeting starts; false: not enabled. + zh_Hans: 是否开启自动录制,true:开启,会议开始后自动录制;false:不开启。 + llm_description: 是否开启自动录制,true:开启,会议开始后自动录制;false:不开启。 + form: form diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.py b/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.py new file mode 100644 index 00000000000000..144889692f9055 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class DeleteEventTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + event_id = tool_parameters.get("event_id") + need_notification = tool_parameters.get("need_notification", True) + + res = client.delete_event(event_id, need_notification) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.yaml b/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.yaml new file mode 100644 index 00000000000000..54fdb04acc3371 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/delete_event.yaml @@ -0,0 +1,38 @@ +identity: + name: delete_event + author: Doug Lea + label: + en_US: Delete Event + zh_Hans: 删除日程 +description: + human: + en_US: Delete Event + zh_Hans: 删除日程 + llm: A tool for deleting events in Feishu.(在飞书中删除日程) +parameters: + - name: event_id + type: string + required: true + label: + en_US: Event ID + zh_Hans: 日程 ID + human_description: + en_US: | + The ID of the event, for example: e8b9791c-39ae-4908-8ad8-66b13159b9fb_0. + zh_Hans: 日程 ID,例如:e8b9791c-39ae-4908-8ad8-66b13159b9fb_0。 + llm_description: 日程 ID,例如:e8b9791c-39ae-4908-8ad8-66b13159b9fb_0。 + form: llm + + - name: need_notification + type: boolean + required: false + default: true + label: + en_US: Need Notification + zh_Hans: 是否需要通知 + human_description: + en_US: | + Indicates whether to send bot notifications to event participants upon deletion. true: send, false: do not send. + zh_Hans: 删除日程是否给日程参与人发送 bot 通知,true:发送,false:不发送。 + llm_description: 删除日程是否给日程参与人发送 bot 通知,true:发送,false:不发送。 + form: form diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.py b/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.py new file mode 100644 index 00000000000000..a2cd5a8b17d0af --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.py @@ -0,0 +1,18 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class GetPrimaryCalendarTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + user_id_type = tool_parameters.get("user_id_type", "open_id") + + res = client.get_primary_calendar(user_id_type) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.yaml b/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.yaml new file mode 100644 index 00000000000000..3440c85d4a9733 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/get_primary_calendar.yaml @@ -0,0 +1,37 @@ +identity: + name: get_primary_calendar + author: Doug Lea + label: + en_US: Get Primary Calendar + zh_Hans: 查询主日历信息 +description: + human: + en_US: Get Primary Calendar + zh_Hans: 查询主日历信息 + llm: A tool for querying primary calendar information in Feishu.(在飞书中查询主日历信息) +parameters: + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.py b/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.py new file mode 100644 index 00000000000000..8815b4c9c871cd --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.py @@ -0,0 +1,21 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class ListEventsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + start_time = tool_parameters.get("start_time") + end_time = tool_parameters.get("end_time") + page_token = tool_parameters.get("page_token") + page_size = tool_parameters.get("page_size") + + res = client.list_events(start_time, end_time, page_token, page_size) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.yaml b/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.yaml new file mode 100644 index 00000000000000..5f0155a2465866 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/list_events.yaml @@ -0,0 +1,62 @@ +identity: + name: list_events + author: Doug Lea + label: + en_US: List Events + zh_Hans: 获取日程列表 +description: + human: + en_US: List Events + zh_Hans: 获取日程列表 + llm: A tool for listing events in Feishu.(在飞书中获取日程列表) +parameters: + - name: start_time + type: string + required: false + label: + en_US: Start Time + zh_Hans: 开始时间 + human_description: + en_US: | + The start time, defaults to 0:00 of the current day if not provided, format: 2006-01-02 15:04:05. + zh_Hans: 开始时间,不传值时默认当天 0 点时间,格式为:2006-01-02 15:04:05。 + llm_description: 开始时间,不传值时默认当天 0 点时间,格式为:2006-01-02 15:04:05。 + form: llm + + - name: end_time + type: string + required: false + label: + en_US: End Time + zh_Hans: 结束时间 + human_description: + en_US: | + The end time, defaults to 23:59 of the current day if not provided, format: 2006-01-02 15:04:05. + zh_Hans: 结束时间,不传值时默认当天 23:59 分时间,格式为:2006-01-02 15:04:05。 + llm_description: 结束时间,不传值时默认当天 23:59 分时间,格式为:2006-01-02 15:04:05。 + form: llm + + - name: page_size + type: number + required: false + default: 50 + label: + en_US: Page Size + zh_Hans: 分页大小 + human_description: + en_US: The page size, i.e., the number of data entries returned in a single request. The default value is 50, and the value range is [50,1000]. + zh_Hans: 分页大小,即单次请求所返回的数据条目数。默认值为 50,取值范围为 [50,1000]。 + llm_description: 分页大小,即单次请求所返回的数据条目数。默认值为 50,取值范围为 [50,1000]。 + form: form + + - name: page_token + type: string + required: false + label: + en_US: Page Token + zh_Hans: 分页标记 + human_description: + en_US: The pagination token. Leave it blank for the first request, indicating to start traversing from the beginning; when the pagination query result has more items, a new page_token will be returned simultaneously, which can be used to obtain the query result in the next traversal. + zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/search_events.py b/api/core/tools/provider/builtin/feishu_calendar/tools/search_events.py new file mode 100644 index 00000000000000..dc365205a4cffa --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/search_events.py @@ -0,0 +1,23 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class SearchEventsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + query = tool_parameters.get("query") + start_time = tool_parameters.get("start_time") + end_time = tool_parameters.get("end_time") + page_token = tool_parameters.get("page_token") + user_id_type = tool_parameters.get("user_id_type", "open_id") + page_size = tool_parameters.get("page_size", 20) + + res = client.search_events(query, start_time, end_time, page_token, user_id_type, page_size) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/search_events.yaml b/api/core/tools/provider/builtin/feishu_calendar/tools/search_events.yaml new file mode 100644 index 00000000000000..bd60a07b5b5341 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/search_events.yaml @@ -0,0 +1,100 @@ +identity: + name: search_events + author: Doug Lea + label: + en_US: Search Events + zh_Hans: 搜索日程 +description: + human: + en_US: Search Events + zh_Hans: 搜索日程 + llm: A tool for searching events in Feishu.(在飞书中搜索日程) +parameters: + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form + + - name: query + type: string + required: true + label: + en_US: Query + zh_Hans: 搜索关键字 + human_description: + en_US: The search keyword used for fuzzy searching event names, with a maximum input of 200 characters. + zh_Hans: 用于模糊查询日程名称的搜索关键字,最大输入 200 字符。 + llm_description: 用于模糊查询日程名称的搜索关键字,最大输入 200 字符。 + form: llm + + - name: start_time + type: string + required: false + label: + en_US: Start Time + zh_Hans: 开始时间 + human_description: + en_US: | + The start time, defaults to 0:00 of the current day if not provided, format: 2006-01-02 15:04:05. + zh_Hans: 开始时间,不传值时默认当天 0 点时间,格式为:2006-01-02 15:04:05。 + llm_description: 开始时间,不传值时默认当天 0 点时间,格式为:2006-01-02 15:04:05。 + form: llm + + - name: end_time + type: string + required: false + label: + en_US: End Time + zh_Hans: 结束时间 + human_description: + en_US: | + The end time, defaults to 23:59 of the current day if not provided, format: 2006-01-02 15:04:05. + zh_Hans: 结束时间,不传值时默认当天 23:59 分时间,格式为:2006-01-02 15:04:05。 + llm_description: 结束时间,不传值时默认当天 23:59 分时间,格式为:2006-01-02 15:04:05。 + form: llm + + - name: page_size + type: number + required: false + default: 20 + label: + en_US: Page Size + zh_Hans: 分页大小 + human_description: + en_US: The page size, i.e., the number of data entries returned in a single request. The default value is 20, and the value range is [10,100]. + zh_Hans: 分页大小,即单次请求所返回的数据条目数。默认值为 20,取值范围为 [10,100]。 + llm_description: 分页大小,即单次请求所返回的数据条目数。默认值为 20,取值范围为 [10,100]。 + form: form + + - name: page_token + type: string + required: false + label: + en_US: Page Token + zh_Hans: 分页标记 + human_description: + en_US: The pagination token. Leave it blank for the first request, indicating to start traversing from the beginning; when the pagination query result has more items, a new page_token will be returned simultaneously, which can be used to obtain the query result in the next traversal. + zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.py b/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.py new file mode 100644 index 00000000000000..85bcb1d3f63847 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.py @@ -0,0 +1,24 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class UpdateEventTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + event_id = tool_parameters.get("event_id") + summary = tool_parameters.get("summary") + description = tool_parameters.get("description") + need_notification = tool_parameters.get("need_notification", True) + start_time = tool_parameters.get("start_time") + end_time = tool_parameters.get("end_time") + auto_record = tool_parameters.get("auto_record", False) + + res = client.update_event(event_id, summary, description, need_notification, start_time, end_time, auto_record) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.yaml b/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.yaml new file mode 100644 index 00000000000000..4d60dbf8c8e1b0 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_calendar/tools/update_event.yaml @@ -0,0 +1,100 @@ +identity: + name: update_event + author: Doug Lea + label: + en_US: Update Event + zh_Hans: 更新日程 +description: + human: + en_US: Update Event + zh_Hans: 更新日程 + llm: A tool for updating events in Feishu.(更新飞书中的日程) +parameters: + - name: event_id + type: string + required: true + label: + en_US: Event ID + zh_Hans: 日程 ID + human_description: + en_US: | + The ID of the event, for example: e8b9791c-39ae-4908-8ad8-66b13159b9fb_0. + zh_Hans: 日程 ID,例如:e8b9791c-39ae-4908-8ad8-66b13159b9fb_0。 + llm_description: 日程 ID,例如:e8b9791c-39ae-4908-8ad8-66b13159b9fb_0。 + form: llm + + - name: summary + type: string + required: false + label: + en_US: Summary + zh_Hans: 日程标题 + human_description: + en_US: The title of the event. + zh_Hans: 日程标题。 + llm_description: 日程标题。 + form: llm + + - name: description + type: string + required: false + label: + en_US: Description + zh_Hans: 日程描述 + human_description: + en_US: The description of the event. + zh_Hans: 日程描述。 + llm_description: 日程描述。 + form: llm + + - name: need_notification + type: boolean + required: false + label: + en_US: Need Notification + zh_Hans: 是否发送通知 + human_description: + en_US: | + Whether to send a bot message when the event is updated, true: send, false: do not send. + zh_Hans: 更新日程时是否发送 bot 消息,true:发送,false:不发送。 + llm_description: 更新日程时是否发送 bot 消息,true:发送,false:不发送。 + form: form + + - name: start_time + type: string + required: false + label: + en_US: Start Time + zh_Hans: 开始时间 + human_description: + en_US: | + The start time of the event, format: 2006-01-02 15:04:05. + zh_Hans: 日程开始时间,格式:2006-01-02 15:04:05。 + llm_description: 日程开始时间,格式:2006-01-02 15:04:05。 + form: llm + + - name: end_time + type: string + required: false + label: + en_US: End Time + zh_Hans: 结束时间 + human_description: + en_US: | + The end time of the event, format: 2006-01-02 15:04:05. + zh_Hans: 日程结束时间,格式:2006-01-02 15:04:05。 + llm_description: 日程结束时间,格式:2006-01-02 15:04:05。 + form: llm + + - name: auto_record + type: boolean + required: false + label: + en_US: Auto Record + zh_Hans: 自动录制 + human_description: + en_US: | + Whether to enable automatic recording, true: enabled, automatically record when the meeting starts; false: not enabled. + zh_Hans: 是否开启自动录制,true:开启,会议开始后自动录制;false:不开启。 + llm_description: 是否开启自动录制,true:开启,会议开始后自动录制;false:不开启。 + form: form diff --git a/api/core/tools/provider/builtin/feishu_document/_assets/icon.svg b/api/core/tools/provider/builtin/feishu_document/_assets/icon.svg new file mode 100644 index 00000000000000..5a0a6416b3db32 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/_assets/icon.svg @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/api/core/tools/provider/builtin/feishu_document/feishu_document.py b/api/core/tools/provider/builtin/feishu_document/feishu_document.py new file mode 100644 index 00000000000000..217ae52082b82c --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/feishu_document.py @@ -0,0 +1,7 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.utils.feishu_api_utils import auth + + +class FeishuDocumentProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + auth(credentials) diff --git a/api/core/tools/provider/builtin/feishu_document/feishu_document.yaml b/api/core/tools/provider/builtin/feishu_document/feishu_document.yaml new file mode 100644 index 00000000000000..8f9afa6149445c --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/feishu_document.yaml @@ -0,0 +1,36 @@ +identity: + author: Doug Lea + name: feishu_document + label: + en_US: Lark Cloud Document + zh_Hans: 飞书云文档 + description: + en_US: | + Lark cloud document, requires the following permissions: docx:document、drive:drive、docs:document.content:read. + zh_Hans: | + 飞书云文档,需要开通以下权限: docx:document、drive:drive、docs:document.content:read。 + icon: icon.svg + tags: + - social + - productivity +credentials_for_provider: + app_id: + type: text-input + required: true + label: + en_US: APP ID + placeholder: + en_US: Please input your feishu app id + zh_Hans: 请输入你的飞书 app id + help: + en_US: Get your app_id and app_secret from Feishu + zh_Hans: 从飞书获取您的 app_id 和 app_secret + url: https://open.larkoffice.com/app + app_secret: + type: secret-input + required: true + label: + en_US: APP Secret + placeholder: + en_US: Please input your app secret + zh_Hans: 请输入你的飞书 app secret diff --git a/api/core/tools/provider/builtin/feishu_document/tools/create_document.py b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py new file mode 100644 index 00000000000000..090a0828e89bbf --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/create_document.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class CreateDocumentTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + title = tool_parameters.get("title") + content = tool_parameters.get("content") + folder_token = tool_parameters.get("folder_token") + + res = client.create_document(title, content, folder_token) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/create_document.yaml b/api/core/tools/provider/builtin/feishu_document/tools/create_document.yaml new file mode 100644 index 00000000000000..85382e9d8e8d1f --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/create_document.yaml @@ -0,0 +1,48 @@ +identity: + name: create_document + author: Doug Lea + label: + en_US: Create Lark document + zh_Hans: 创建飞书文档 +description: + human: + en_US: Create Lark document + zh_Hans: 创建飞书文档,支持创建空文档和带内容的文档,支持 markdown 语法创建。应用需要开启机器人能力(https://open.feishu.cn/document/faq/trouble-shooting/how-to-enable-bot-ability)。 + llm: A tool for creating Feishu documents. +parameters: + - name: title + type: string + required: false + label: + en_US: Document title + zh_Hans: 文档标题 + human_description: + en_US: Document title, only supports plain text content. + zh_Hans: 文档标题,只支持纯文本内容。 + llm_description: 文档标题,只支持纯文本内容,可以为空。 + form: llm + + - name: content + type: string + required: false + label: + en_US: Document content + zh_Hans: 文档内容 + human_description: + en_US: Document content, supports markdown syntax, can be empty. + zh_Hans: 文档内容,支持 markdown 语法,可以为空。 + llm_description: 文档内容,支持 markdown 语法,可以为空。 + form: llm + + - name: folder_token + type: string + required: false + label: + en_US: folder_token + zh_Hans: 文档所在文件夹的 Token + human_description: + en_US: | + The token of the folder where the document is located. If it is not passed or is empty, it means the root directory. For Example: https://svi136aogf123.feishu.cn/drive/folder/JgR9fiG9AlPt8EdsSNpcGjIInbf + zh_Hans: 文档所在文件夹的 Token,不传或传空表示根目录。例如:https://svi136aogf123.feishu.cn/drive/folder/JgR9fiG9AlPt8EdsSNpcGjIInbf。 + llm_description: 文档所在文件夹的 Token,不传或传空表示根目录。例如:https://svi136aogf123.feishu.cn/drive/folder/JgR9fiG9AlPt8EdsSNpcGjIInbf。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_document/tools/get_document_content.py b/api/core/tools/provider/builtin/feishu_document/tools/get_document_content.py new file mode 100644 index 00000000000000..e67a017facc8d4 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/get_document_content.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class GetDocumentRawContentTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + document_id = tool_parameters.get("document_id") + mode = tool_parameters.get("mode", "markdown") + lang = tool_parameters.get("lang", "0") + + res = client.get_document_content(document_id, mode, lang) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/get_document_content.yaml b/api/core/tools/provider/builtin/feishu_document/tools/get_document_content.yaml new file mode 100644 index 00000000000000..15e827cde91ee6 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/get_document_content.yaml @@ -0,0 +1,70 @@ +identity: + name: get_document_content + author: Doug Lea + label: + en_US: Get Document Content + zh_Hans: 获取飞书云文档的内容 +description: + human: + en_US: Get document content + zh_Hans: 获取飞书云文档的内容 + llm: A tool for retrieving content from Feishu cloud documents. +parameters: + - name: document_id + type: string + required: true + label: + en_US: document_id + zh_Hans: 飞书文档的唯一标识 + human_description: + en_US: Unique identifier for a Feishu document. You can also input the document's URL. + zh_Hans: 飞书文档的唯一标识,支持输入文档的 URL。 + llm_description: 飞书文档的唯一标识,支持输入文档的 URL。 + form: llm + + - name: mode + type: select + required: false + options: + - value: text + label: + en_US: text + zh_Hans: text + - value: markdown + label: + en_US: markdown + zh_Hans: markdown + default: "markdown" + label: + en_US: mode + zh_Hans: 文档返回格式 + human_description: + en_US: Format of the document return, optional values are text, markdown, can be empty, default is markdown. + zh_Hans: 文档返回格式,可选值有 text、markdown,可以为空,默认值为 markdown。 + llm_description: 文档返回格式,可选值有 text、markdown,可以为空,默认值为 markdown。 + form: form + + - name: lang + type: select + required: false + options: + - value: "0" + label: + en_US: User's default name + zh_Hans: 用户的默认名称 + - value: "1" + label: + en_US: User's English name + zh_Hans: 用户的英文名称 + default: "0" + label: + en_US: lang + zh_Hans: 指定@用户的语言 + human_description: + en_US: | + Specifies the language for MentionUser, optional values are [0, 1]. 0: User's default name, 1: User's English name, default is 0. + zh_Hans: | + 指定返回的 MentionUser,即@用户的语言,可选值有 [0,1]。0: 该用户的默认名称,1: 该用户的英文名称,默认值为 0。 + llm_description: | + 指定返回的 MentionUser,即@用户的语言,可选值有 [0,1]。0: 该用户的默认名称,1: 该用户的英文名称,默认值为 0。 + form: form diff --git a/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py b/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py new file mode 100644 index 00000000000000..dd57c6870d0ba9 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.py @@ -0,0 +1,20 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class ListDocumentBlockTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + document_id = tool_parameters.get("document_id") + page_token = tool_parameters.get("page_token", "") + user_id_type = tool_parameters.get("user_id_type", "open_id") + page_size = tool_parameters.get("page_size", 500) + + res = client.list_document_blocks(document_id, page_token, user_id_type, page_size) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.yaml b/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.yaml new file mode 100644 index 00000000000000..d4fab96c1f9601 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/list_document_blocks.yaml @@ -0,0 +1,74 @@ +identity: + name: list_document_blocks + author: Doug Lea + label: + en_US: List Document Blocks + zh_Hans: 获取飞书文档所有块 +description: + human: + en_US: List document blocks + zh_Hans: 获取飞书文档所有块的富文本内容并分页返回 + llm: A tool to get all blocks of Feishu documents +parameters: + - name: document_id + type: string + required: true + label: + en_US: document_id + zh_Hans: 飞书文档的唯一标识 + human_description: + en_US: Unique identifier for a Feishu document. You can also input the document's URL. + zh_Hans: 飞书文档的唯一标识,支持输入文档的 URL。 + llm_description: 飞书文档的唯一标识,支持输入文档的 URL。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form + + - name: page_size + type: number + required: false + default: 500 + label: + en_US: page_size + zh_Hans: 分页大小 + human_description: + en_US: Paging size, the default and maximum value is 500. + zh_Hans: 分页大小, 默认值和最大值为 500。 + llm_description: 分页大小, 表示一次请求最多返回多少条数据,默认值和最大值为 500。 + form: form + + - name: page_token + type: string + required: false + label: + en_US: page_token + zh_Hans: 分页标记 + human_description: + en_US: Pagination token used to navigate through query results, allowing retrieval of additional items in subsequent requests. + zh_Hans: 分页标记,用于分页查询结果,以便下次遍历时获取更多项。 + llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_document/tools/write_document.py b/api/core/tools/provider/builtin/feishu_document/tools/write_document.py new file mode 100644 index 00000000000000..59f08f53dc68de --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/write_document.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class CreateDocumentTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + document_id = tool_parameters.get("document_id") + content = tool_parameters.get("content") + position = tool_parameters.get("position", "end") + + res = client.write_document(document_id, content, position) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_document/tools/write_document.yaml b/api/core/tools/provider/builtin/feishu_document/tools/write_document.yaml new file mode 100644 index 00000000000000..de70f4e7726a28 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_document/tools/write_document.yaml @@ -0,0 +1,57 @@ +identity: + name: write_document + author: Doug Lea + label: + en_US: Write Document + zh_Hans: 在飞书文档中新增内容 +description: + human: + en_US: Adding new content to Lark documents + zh_Hans: 在飞书文档中新增内容 + llm: A tool for adding new content to Lark documents. +parameters: + - name: document_id + type: string + required: true + label: + en_US: document_id + zh_Hans: 飞书文档的唯一标识 + human_description: + en_US: Unique identifier for a Feishu document. You can also input the document's URL. + zh_Hans: 飞书文档的唯一标识,支持输入文档的 URL。 + llm_description: 飞书文档的唯一标识,支持输入文档的 URL。 + form: llm + + - name: content + type: string + required: true + label: + en_US: Plain text or Markdown content + zh_Hans: 纯文本或 Markdown 内容 + human_description: + en_US: Plain text or Markdown content. Note that embedded tables in the document should not have merged cells. + zh_Hans: 纯文本或 Markdown 内容。注意文档的内嵌套表格不允许有单元格合并。 + llm_description: 纯文本或 Markdown 内容,注意文档的内嵌套表格不允许有单元格合并。 + form: llm + + - name: position + type: select + required: false + options: + - value: start + label: + en_US: document start + zh_Hans: 文档开始 + - value: end + label: + en_US: document end + zh_Hans: 文档结束 + default: "end" + label: + en_US: position + zh_Hans: 内容添加位置 + human_description: + en_US: Content insertion position, optional values are start, end. 'start' means adding content at the beginning of the document; 'end' means adding content at the end of the document. The default value is end. + zh_Hans: 内容添加位置,可选值有 start、end。start 表示在文档开头添加内容;end 表示在文档结尾添加内容,默认值为 end。 + llm_description: 内容添加位置,可选值有 start、end。start 表示在文档开头添加内容;end 表示在文档结尾添加内容,默认值为 end。 + form: form diff --git a/api/core/tools/provider/builtin/feishu_message/_assets/icon.svg b/api/core/tools/provider/builtin/feishu_message/_assets/icon.svg new file mode 100644 index 00000000000000..222a1571f9bbbb --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/_assets/icon.svg @@ -0,0 +1,19 @@ + + + + diff --git a/api/core/tools/provider/builtin/feishu_message/feishu_message.py b/api/core/tools/provider/builtin/feishu_message/feishu_message.py new file mode 100644 index 00000000000000..a3b54737691c9c --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/feishu_message.py @@ -0,0 +1,7 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.utils.feishu_api_utils import auth + + +class FeishuMessageProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + auth(credentials) diff --git a/api/core/tools/provider/builtin/feishu_message/feishu_message.yaml b/api/core/tools/provider/builtin/feishu_message/feishu_message.yaml new file mode 100644 index 00000000000000..56683ec1680f40 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/feishu_message.yaml @@ -0,0 +1,36 @@ +identity: + author: Doug Lea + name: feishu_message + label: + en_US: Lark Message + zh_Hans: 飞书消息 + description: + en_US: | + Lark message, requires the following permissions: im:message、im:message.group_msg. + zh_Hans: | + 飞书消息,需要开通以下权限: im:message、im:message.group_msg。 + icon: icon.svg + tags: + - social + - productivity +credentials_for_provider: + app_id: + type: text-input + required: true + label: + en_US: APP ID + placeholder: + en_US: Please input your feishu app id + zh_Hans: 请输入你的飞书 app id + help: + en_US: Get your app_id and app_secret from Feishu + zh_Hans: 从飞书获取您的 app_id 和 app_secret + url: https://open.larkoffice.com/app + app_secret: + type: secret-input + required: true + label: + en_US: APP Secret + placeholder: + en_US: Please input your app secret + zh_Hans: 请输入你的飞书 app secret diff --git a/api/core/tools/provider/builtin/feishu_message/tools/get_chat_messages.py b/api/core/tools/provider/builtin/feishu_message/tools/get_chat_messages.py new file mode 100644 index 00000000000000..7eb29230b2ceb0 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/tools/get_chat_messages.py @@ -0,0 +1,23 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class GetChatMessagesTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + container_id = tool_parameters.get("container_id") + start_time = tool_parameters.get("start_time") + end_time = tool_parameters.get("end_time") + page_token = tool_parameters.get("page_token") + sort_type = tool_parameters.get("sort_type", "ByCreateTimeAsc") + page_size = tool_parameters.get("page_size", 20) + + res = client.get_chat_messages(container_id, start_time, end_time, page_token, sort_type, page_size) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_message/tools/get_chat_messages.yaml b/api/core/tools/provider/builtin/feishu_message/tools/get_chat_messages.yaml new file mode 100644 index 00000000000000..984c9120e8cd96 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/tools/get_chat_messages.yaml @@ -0,0 +1,96 @@ +identity: + name: get_chat_messages + author: Doug Lea + label: + en_US: Get Chat Messages + zh_Hans: 获取指定单聊、群聊的消息历史 +description: + human: + en_US: Get Chat Messages + zh_Hans: 获取指定单聊、群聊的消息历史 + llm: A tool for getting chat messages from specific one-on-one chats or group chats.(获取指定单聊、群聊的消息历史) +parameters: + - name: container_id + type: string + required: true + label: + en_US: Container Id + zh_Hans: 群聊或单聊的 ID + human_description: + en_US: The ID of the group chat or single chat. Refer to the group ID description for how to obtain it. https://open.feishu.cn/document/server-docs/group/chat/chat-id-description + zh_Hans: 群聊或单聊的 ID,获取方式参见群 ID 说明。https://open.feishu.cn/document/server-docs/group/chat/chat-id-description + llm_description: 群聊或单聊的 ID,获取方式参见群 ID 说明。https://open.feishu.cn/document/server-docs/group/chat/chat-id-description + form: llm + + - name: start_time + type: string + required: false + label: + en_US: Start Time + zh_Hans: 起始时间 + human_description: + en_US: The start time for querying historical messages, formatted as "2006-01-02 15:04:05". + zh_Hans: 待查询历史信息的起始时间,格式为 "2006-01-02 15:04:05"。 + llm_description: 待查询历史信息的起始时间,格式为 "2006-01-02 15:04:05"。 + form: llm + + - name: end_time + type: string + required: false + label: + en_US: End Time + zh_Hans: 结束时间 + human_description: + en_US: The end time for querying historical messages, formatted as "2006-01-02 15:04:05". + zh_Hans: 待查询历史信息的结束时间,格式为 "2006-01-02 15:04:05"。 + llm_description: 待查询历史信息的结束时间,格式为 "2006-01-02 15:04:05"。 + form: llm + + - name: sort_type + type: select + required: false + options: + - value: ByCreateTimeAsc + label: + en_US: ByCreateTimeAsc + zh_Hans: ByCreateTimeAsc + - value: ByCreateTimeDesc + label: + en_US: ByCreateTimeDesc + zh_Hans: ByCreateTimeDesc + default: "ByCreateTimeAsc" + label: + en_US: Sort Type + zh_Hans: 排序方式 + human_description: + en_US: | + The message sorting method. Optional values are ByCreateTimeAsc: sorted in ascending order by message creation time; ByCreateTimeDesc: sorted in descending order by message creation time. The default value is ByCreateTimeAsc. Note: When using page_token for pagination requests, the sorting method (sort_type) is consistent with the first request and cannot be changed midway. + zh_Hans: | + 消息排序方式,可选值有 ByCreateTimeAsc:按消息创建时间升序排列;ByCreateTimeDesc:按消息创建时间降序排列。默认值为:ByCreateTimeAsc。注意:使用 page_token 分页请求时,排序方式(sort_type)均与第一次请求一致,不支持中途改换排序方式。 + llm_description: 消息排序方式,可选值有 ByCreateTimeAsc:按消息创建时间升序排列;ByCreateTimeDesc:按消息创建时间降序排列。默认值为:ByCreateTimeAsc。注意:使用 page_token 分页请求时,排序方式(sort_type)均与第一次请求一致,不支持中途改换排序方式。 + form: form + + - name: page_size + type: number + required: false + default: 20 + label: + en_US: Page Size + zh_Hans: 分页大小 + human_description: + en_US: The page size, i.e., the number of data entries returned in a single request. The default value is 20, and the value range is [1,50]. + zh_Hans: 分页大小,即单次请求所返回的数据条目数。默认值为 20,取值范围为 [1,50]。 + llm_description: 分页大小,即单次请求所返回的数据条目数。默认值为 20,取值范围为 [1,50]。 + form: form + + - name: page_token + type: string + required: false + label: + en_US: Page Token + zh_Hans: 分页标记 + human_description: + en_US: The pagination token. Leave it blank for the first request, indicating to start traversing from the beginning; when the pagination query result has more items, a new page_token will be returned simultaneously, which can be used to obtain the query result in the next traversal. + zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_message/tools/get_thread_messages.py b/api/core/tools/provider/builtin/feishu_message/tools/get_thread_messages.py new file mode 100644 index 00000000000000..3b14f46e0048a8 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/tools/get_thread_messages.py @@ -0,0 +1,21 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class GetChatMessagesTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + container_id = tool_parameters.get("container_id") + page_token = tool_parameters.get("page_token") + sort_type = tool_parameters.get("sort_type", "ByCreateTimeAsc") + page_size = tool_parameters.get("page_size", 20) + + res = client.get_thread_messages(container_id, page_token, sort_type, page_size) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_message/tools/get_thread_messages.yaml b/api/core/tools/provider/builtin/feishu_message/tools/get_thread_messages.yaml new file mode 100644 index 00000000000000..85a138292f6203 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/tools/get_thread_messages.yaml @@ -0,0 +1,72 @@ +identity: + name: get_thread_messages + author: Doug Lea + label: + en_US: Get Thread Messages + zh_Hans: 获取指定话题的消息历史 +description: + human: + en_US: Get Thread Messages + zh_Hans: 获取指定话题的消息历史 + llm: A tool for getting chat messages from specific threads.(获取指定话题的消息历史) +parameters: + - name: container_id + type: string + required: true + label: + en_US: Thread Id + zh_Hans: 话题 ID + human_description: + en_US: The ID of the thread. Refer to the thread overview on how to obtain the thread_id. https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/im-v1/message/thread-introduction + zh_Hans: 话题 ID,获取方式参见话题概述的如何获取 thread_id 章节。https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/im-v1/message/thread-introduction + llm_description: 话题 ID,获取方式参见话题概述的如何获取 thread_id 章节。https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/im-v1/message/thread-introduction + form: llm + + - name: sort_type + type: select + required: false + options: + - value: ByCreateTimeAsc + label: + en_US: ByCreateTimeAsc + zh_Hans: ByCreateTimeAsc + - value: ByCreateTimeDesc + label: + en_US: ByCreateTimeDesc + zh_Hans: ByCreateTimeDesc + default: "ByCreateTimeAsc" + label: + en_US: Sort Type + zh_Hans: 排序方式 + human_description: + en_US: | + The message sorting method. Optional values are ByCreateTimeAsc: sorted in ascending order by message creation time; ByCreateTimeDesc: sorted in descending order by message creation time. The default value is ByCreateTimeAsc. Note: When using page_token for pagination requests, the sorting method (sort_type) is consistent with the first request and cannot be changed midway. + zh_Hans: | + 消息排序方式,可选值有 ByCreateTimeAsc:按消息创建时间升序排列;ByCreateTimeDesc:按消息创建时间降序排列。默认值为:ByCreateTimeAsc。注意:使用 page_token 分页请求时,排序方式(sort_type)均与第一次请求一致,不支持中途改换排序方式。 + llm_description: 消息排序方式,可选值有 ByCreateTimeAsc:按消息创建时间升序排列;ByCreateTimeDesc:按消息创建时间降序排列。默认值为:ByCreateTimeAsc。注意:使用 page_token 分页请求时,排序方式(sort_type)均与第一次请求一致,不支持中途改换排序方式。 + form: form + + - name: page_size + type: number + required: false + default: 20 + label: + en_US: Page Size + zh_Hans: 分页大小 + human_description: + en_US: The page size, i.e., the number of data entries returned in a single request. The default value is 20, and the value range is [1,50]. + zh_Hans: 分页大小,即单次请求所返回的数据条目数。默认值为 20,取值范围为 [1,50]。 + llm_description: 分页大小,即单次请求所返回的数据条目数。默认值为 20,取值范围为 [1,50]。 + form: form + + - name: page_token + type: string + required: false + label: + en_US: Page Token + zh_Hans: 分页标记 + human_description: + en_US: The pagination token. Leave it blank for the first request, indicating to start traversing from the beginning; when the pagination query result has more items, a new page_token will be returned simultaneously, which can be used to obtain the query result in the next traversal. + zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py b/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py new file mode 100644 index 00000000000000..1dd315d0e293a0 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.py @@ -0,0 +1,20 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class SendBotMessageTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + receive_id_type = tool_parameters.get("receive_id_type") + receive_id = tool_parameters.get("receive_id") + msg_type = tool_parameters.get("msg_type") + content = tool_parameters.get("content") + + res = client.send_bot_message(receive_id_type, receive_id, msg_type, content) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.yaml b/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.yaml new file mode 100644 index 00000000000000..4f7f65a8a74fc0 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/tools/send_bot_message.yaml @@ -0,0 +1,125 @@ +identity: + name: send_bot_message + author: Doug Lea + label: + en_US: Send Bot Message + zh_Hans: 发送飞书应用消息 +description: + human: + en_US: Send bot message + zh_Hans: 发送飞书应用消息 + llm: A tool for sending Feishu application messages. +parameters: + - name: receive_id + type: string + required: true + label: + en_US: receive_id + zh_Hans: 消息接收者的 ID + human_description: + en_US: The ID of the message receiver, the ID type is consistent with the value of the query parameter receive_id_type. + zh_Hans: 消息接收者的 ID,ID 类型与查询参数 receive_id_type 的取值一致。 + llm_description: 消息接收者的 ID,ID 类型与查询参数 receive_id_type 的取值一致。 + form: llm + + - name: receive_id_type + type: select + required: true + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + - value: email + label: + en_US: email + zh_Hans: email + - value: chat_id + label: + en_US: chat_id + zh_Hans: chat_id + label: + en_US: receive_id_type + zh_Hans: 消息接收者的 ID 类型 + human_description: + en_US: The ID type of the message receiver, optional values are open_id, union_id, user_id, email, chat_id, with a default value of open_id. + zh_Hans: 消息接收者的 ID 类型,可选值有 open_id、union_id、user_id、email、chat_id,默认值为 open_id。 + llm_description: 消息接收者的 ID 类型,可选值有 open_id、union_id、user_id、email、chat_id,默认值为 open_id。 + form: form + + - name: msg_type + type: select + required: true + options: + - value: text + label: + en_US: text + zh_Hans: 文本 + - value: interactive + label: + en_US: interactive + zh_Hans: 卡片 + - value: post + label: + en_US: post + zh_Hans: 富文本 + - value: image + label: + en_US: image + zh_Hans: 图片 + - value: file + label: + en_US: file + zh_Hans: 文件 + - value: audio + label: + en_US: audio + zh_Hans: 语音 + - value: media + label: + en_US: media + zh_Hans: 视频 + - value: sticker + label: + en_US: sticker + zh_Hans: 表情包 + - value: share_chat + label: + en_US: share_chat + zh_Hans: 分享群名片 + - value: share_user + label: + en_US: share_user + zh_Hans: 分享个人名片 + - value: system + label: + en_US: system + zh_Hans: 系统消息 + label: + en_US: msg_type + zh_Hans: 消息类型 + human_description: + en_US: Message type. Optional values are text, post, image, file, audio, media, sticker, interactive, share_chat, share_user, system. For detailed introduction of different message types, refer to the message content(https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json). + zh_Hans: 消息类型。可选值有:text、post、image、file、audio、media、sticker、interactive、share_chat、share_user、system。不同消息类型的详细介绍,参见发送消息内容(https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json)。 + llm_description: 消息类型。可选值有:text、post、image、file、audio、media、sticker、interactive、share_chat、share_user、system。不同消息类型的详细介绍,参见发送消息内容(https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json)。 + form: form + + - name: content + type: string + required: true + label: + en_US: content + zh_Hans: 消息内容 + human_description: + en_US: Message content, a JSON structure serialized string. The value of this parameter corresponds to msg_type. For example, if msg_type is text, this parameter needs to pass in text type content. To understand the format and usage limitations of different message types, refer to the message content(https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json). + zh_Hans: 消息内容,JSON 结构序列化后的字符串。该参数的取值与 msg_type 对应,例如 msg_type 取值为 text,则该参数需要传入文本类型的内容。了解不同类型的消息内容格式、使用限制,可参见发送消息内容(https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json)。 + llm_description: 消息内容,JSON 结构序列化后的字符串。该参数的取值与 msg_type 对应,例如 msg_type 取值为 text,则该参数需要传入文本类型的内容。了解不同类型的消息内容格式、使用限制,可参见发送消息内容(https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json)。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py b/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py new file mode 100644 index 00000000000000..44e70e0a15b64d --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class SendWebhookMessageTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + webhook = tool_parameters.get("webhook") + msg_type = tool_parameters.get("msg_type") + content = tool_parameters.get("content") + + res = client.send_webhook_message(webhook, msg_type, content) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.yaml b/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.yaml new file mode 100644 index 00000000000000..eeeae8b29cd935 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_message/tools/send_webhook_message.yaml @@ -0,0 +1,68 @@ +identity: + name: send_webhook_message + author: Doug Lea + label: + en_US: Send Webhook Message + zh_Hans: 使用自定义机器人发送飞书消息 +description: + human: + en_US: Send webhook message + zh_Hans: 使用自定义机器人发送飞书消息 + llm: A tool for sending Lark messages using a custom robot. +parameters: + - name: webhook + type: string + required: true + label: + en_US: webhook + zh_Hans: webhook + human_description: + en_US: | + The address of the webhook, the format of the webhook address corresponding to the bot is as follows: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxxxxxxxxx. For details, please refer to: Feishu Custom Bot Usage Guide(https://open.larkoffice.com/document/client-docs/bot-v3/add-custom-bot) + zh_Hans: | + webhook 的地址,机器人对应的 webhook 地址格式如下: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxxxxxxxxx,详情可参考: 飞书自定义机器人使用指南(https://open.larkoffice.com/document/client-docs/bot-v3/add-custom-bot) + llm_description: | + webhook 的地址,机器人对应的 webhook 地址格式如下: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxxxxxxxxx,详情可参考: 飞书自定义机器人使用指南(https://open.larkoffice.com/document/client-docs/bot-v3/add-custom-bot) + form: llm + + - name: msg_type + type: select + required: true + options: + - value: text + label: + en_US: text + zh_Hans: 文本 + - value: interactive + label: + en_US: interactive + zh_Hans: 卡片 + - value: image + label: + en_US: image + zh_Hans: 图片 + - value: share_chat + label: + en_US: share_chat + zh_Hans: 分享群名片 + label: + en_US: msg_type + zh_Hans: 消息类型 + human_description: + en_US: Message type. Optional values are text, image, interactive, share_chat. For detailed introduction of different message types, refer to the message content(https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json). + zh_Hans: 消息类型。可选值有:text、image、interactive、share_chat。不同消息类型的详细介绍,参见发送消息内容(https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json)。 + llm_description: 消息类型。可选值有:text、image、interactive、share_chat。不同消息类型的详细介绍,参见发送消息内容(https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json)。 + form: form + + + - name: content + type: string + required: true + label: + en_US: content + zh_Hans: 消息内容 + human_description: + en_US: Message content, a JSON structure serialized string. The value of this parameter corresponds to msg_type. For example, if msg_type is text, this parameter needs to pass in text type content. To understand the format and usage limitations of different message types, refer to the message content(https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json). + zh_Hans: 消息内容,JSON 结构序列化后的字符串。该参数的取值与 msg_type 对应,例如 msg_type 取值为 text,则该参数需要传入文本类型的内容。了解不同类型的消息内容格式、使用限制,可参见发送消息内容(https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json)。 + llm_description: 消息内容,JSON 结构序列化后的字符串。该参数的取值与 msg_type 对应,例如 msg_type 取值为 text,则该参数需要传入文本类型的内容。了解不同类型的消息内容格式、使用限制,可参见发送消息内容(https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json)。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_spreadsheet/_assets/icon.png b/api/core/tools/provider/builtin/feishu_spreadsheet/_assets/icon.png new file mode 100644 index 00000000000000..258b361261d4e3 Binary files /dev/null and b/api/core/tools/provider/builtin/feishu_spreadsheet/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/feishu_spreadsheet/feishu_spreadsheet.py b/api/core/tools/provider/builtin/feishu_spreadsheet/feishu_spreadsheet.py new file mode 100644 index 00000000000000..a3b54737691c9c --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_spreadsheet/feishu_spreadsheet.py @@ -0,0 +1,7 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.utils.feishu_api_utils import auth + + +class FeishuMessageProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + auth(credentials) diff --git a/api/core/tools/provider/builtin/feishu_spreadsheet/feishu_spreadsheet.yaml b/api/core/tools/provider/builtin/feishu_spreadsheet/feishu_spreadsheet.yaml new file mode 100644 index 00000000000000..29e448d730f745 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_spreadsheet/feishu_spreadsheet.yaml @@ -0,0 +1,36 @@ +identity: + author: Doug Lea + name: feishu_spreadsheet + label: + en_US: Feishu Spreadsheet + zh_Hans: 飞书电子表格 + description: + en_US: | + Feishu Spreadsheet, requires the following permissions: sheets:spreadsheet. + zh_Hans: | + 飞书电子表格,需要开通以下权限: sheets:spreadsheet。 + icon: icon.png + tags: + - social + - productivity +credentials_for_provider: + app_id: + type: text-input + required: true + label: + en_US: APP ID + placeholder: + en_US: Please input your feishu app id + zh_Hans: 请输入你的飞书 app id + help: + en_US: Get your app_id and app_secret from Feishu + zh_Hans: 从飞书获取您的 app_id 和 app_secret + url: https://open.larkoffice.com/app + app_secret: + type: secret-input + required: true + label: + en_US: APP Secret + placeholder: + en_US: Please input your app secret + zh_Hans: 请输入你的飞书 app secret diff --git a/api/core/tools/provider/builtin/feishu_spreadsheet/tools/add_cols.py b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/add_cols.py new file mode 100644 index 00000000000000..44d062f9bdded2 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/add_cols.py @@ -0,0 +1,22 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class AddColsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + spreadsheet_token = tool_parameters.get("spreadsheet_token") + sheet_id = tool_parameters.get("sheet_id") + sheet_name = tool_parameters.get("sheet_name") + length = tool_parameters.get("length") + values = tool_parameters.get("values") + + res = client.add_cols(spreadsheet_token, sheet_id, sheet_name, length, values) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_spreadsheet/tools/add_cols.yaml b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/add_cols.yaml new file mode 100644 index 00000000000000..b73335f405c20c --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/add_cols.yaml @@ -0,0 +1,72 @@ +identity: + name: add_cols + author: Doug Lea + label: + en_US: Add Cols + zh_Hans: 新增多列至工作表最后 +description: + human: + en_US: Add Cols + zh_Hans: 新增多列至工作表最后 + llm: A tool for adding multiple columns to the end of a spreadsheet. (新增多列至工作表最后) +parameters: + - name: spreadsheet_token + type: string + required: true + label: + en_US: spreadsheet_token + zh_Hans: 电子表格 token + human_description: + en_US: Spreadsheet token, supports input of spreadsheet URL. + zh_Hans: 电子表格 token,支持输入电子表格 url。 + llm_description: 电子表格 token,支持输入电子表格 url。 + form: llm + + - name: sheet_id + type: string + required: false + label: + en_US: sheet_id + zh_Hans: 工作表 ID + human_description: + en_US: Sheet ID, either sheet_id or sheet_name must be filled. + zh_Hans: 工作表 ID,与 sheet_name 二者其一必填。 + llm_description: 工作表 ID,与 sheet_name 二者其一必填。 + form: llm + + - name: sheet_name + type: string + required: false + label: + en_US: sheet_name + zh_Hans: 工作表名称 + human_description: + en_US: Sheet name, either sheet_id or sheet_name must be filled. + zh_Hans: 工作表名称,与 sheet_id 二者其一必填。 + llm_description: 工作表名称,与 sheet_id 二者其一必填。 + form: llm + + - name: length + type: number + required: true + label: + en_US: length + zh_Hans: 要增加的列数 + human_description: + en_US: Number of columns to add, range (0-5000]. + zh_Hans: 要增加的列数,范围(0-5000]。 + llm_description: 要增加的列数,范围(0-5000]。 + form: form + + - name: values + type: string + required: false + label: + en_US: values + zh_Hans: 新增列的单元格内容 + human_description: + en_US: | + Content of the new columns, array of objects in string format, each array represents a row of table data, format like: [ [ "ID","Name","Age" ],[ 1,"Zhang San",10 ],[ 2,"Li Si",11 ] ]. + zh_Hans: 新增列的单元格内容,数组对象字符串,每个数组一行表格数据,格式:[["编号","姓名","年龄"],[1,"张三",10],[2,"李四",11]]。 + llm_description: 新增列的单元格内容,数组对象字符串,每个数组一行表格数据,格式:[["编号","姓名","年龄"],[1,"张三",10],[2,"李四",11]]。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_spreadsheet/tools/add_rows.py b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/add_rows.py new file mode 100644 index 00000000000000..3a85b7b46ccb93 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/add_rows.py @@ -0,0 +1,22 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class AddRowsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + spreadsheet_token = tool_parameters.get("spreadsheet_token") + sheet_id = tool_parameters.get("sheet_id") + sheet_name = tool_parameters.get("sheet_name") + length = tool_parameters.get("length") + values = tool_parameters.get("values") + + res = client.add_rows(spreadsheet_token, sheet_id, sheet_name, length, values) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_spreadsheet/tools/add_rows.yaml b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/add_rows.yaml new file mode 100644 index 00000000000000..6bce305b9825ec --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/add_rows.yaml @@ -0,0 +1,72 @@ +identity: + name: add_rows + author: Doug Lea + label: + en_US: Add Rows + zh_Hans: 新增多行至工作表最后 +description: + human: + en_US: Add Rows + zh_Hans: 新增多行至工作表最后 + llm: A tool for adding multiple rows to the end of a spreadsheet. (新增多行至工作表最后) +parameters: + - name: spreadsheet_token + type: string + required: true + label: + en_US: spreadsheet_token + zh_Hans: 电子表格 token + human_description: + en_US: Spreadsheet token, supports input of spreadsheet URL. + zh_Hans: 电子表格 token,支持输入电子表格 url。 + llm_description: 电子表格 token,支持输入电子表格 url。 + form: llm + + - name: sheet_id + type: string + required: false + label: + en_US: sheet_id + zh_Hans: 工作表 ID + human_description: + en_US: Sheet ID, either sheet_id or sheet_name must be filled. + zh_Hans: 工作表 ID,与 sheet_name 二者其一必填。 + llm_description: 工作表 ID,与 sheet_name 二者其一必填。 + form: llm + + - name: sheet_name + type: string + required: false + label: + en_US: sheet_name + zh_Hans: 工作表名称 + human_description: + en_US: Sheet name, either sheet_id or sheet_name must be filled. + zh_Hans: 工作表名称,与 sheet_id 二者其一必填。 + llm_description: 工作表名称,与 sheet_id 二者其一必填。 + form: llm + + - name: length + type: number + required: true + label: + en_US: length + zh_Hans: 要增加行数 + human_description: + en_US: Number of rows to add, range (0-5000]. + zh_Hans: 要增加行数,范围(0-5000]。 + llm_description: 要增加行数,范围(0-5000]。 + form: form + + - name: values + type: string + required: false + label: + en_US: values + zh_Hans: 新增行的表格内容 + human_description: + en_US: | + Content of the new rows, array of objects in string format, each array represents a row of table data, format like: [ [ "ID","Name","Age" ],[ 1,"Zhang San",10 ],[ 2,"Li Si",11 ] ]. + zh_Hans: 新增行的表格内容,数组对象字符串,每个数组一行表格数据,格式,如:[["编号","姓名","年龄"],[1,"张三",10],[2,"李四",11]]。 + llm_description: 新增行的表格内容,数组对象字符串,每个数组一行表格数据,格式,如:[["编号","姓名","年龄"],[1,"张三",10],[2,"李四",11]]。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_spreadsheet/tools/create_spreadsheet.py b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/create_spreadsheet.py new file mode 100644 index 00000000000000..647364fab0a966 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/create_spreadsheet.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class CreateSpreadsheetTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + title = tool_parameters.get("title") + folder_token = tool_parameters.get("folder_token") + + res = client.create_spreadsheet(title, folder_token) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_spreadsheet/tools/create_spreadsheet.yaml b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/create_spreadsheet.yaml new file mode 100644 index 00000000000000..931310e63172d4 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/create_spreadsheet.yaml @@ -0,0 +1,35 @@ +identity: + name: create_spreadsheet + author: Doug Lea + label: + en_US: Create Spreadsheet + zh_Hans: 创建电子表格 +description: + human: + en_US: Create Spreadsheet + zh_Hans: 创建电子表格 + llm: A tool for creating spreadsheets. (创建电子表格) +parameters: + - name: title + type: string + required: false + label: + en_US: Spreadsheet Title + zh_Hans: 电子表格标题 + human_description: + en_US: The title of the spreadsheet + zh_Hans: 电子表格的标题 + llm_description: 电子表格的标题 + form: llm + + - name: folder_token + type: string + required: false + label: + en_US: Folder Token + zh_Hans: 文件夹 token + human_description: + en_US: The token of the folder, supports folder URL input, e.g., https://bytedance.larkoffice.com/drive/folder/CxHEf4DCSlNkL2dUTCJcPRgentg + zh_Hans: 文件夹 token,支持文件夹 URL 输入,如:https://bytedance.larkoffice.com/drive/folder/CxHEf4DCSlNkL2dUTCJcPRgentg + llm_description: 文件夹 token,支持文件夹 URL 输入,如:https://bytedance.larkoffice.com/drive/folder/CxHEf4DCSlNkL2dUTCJcPRgentg + form: llm diff --git a/api/core/tools/provider/builtin/feishu_spreadsheet/tools/get_spreadsheet.py b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/get_spreadsheet.py new file mode 100644 index 00000000000000..dda8c59daffabf --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/get_spreadsheet.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class GetSpreadsheetTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + spreadsheet_token = tool_parameters.get("spreadsheet_token") + user_id_type = tool_parameters.get("user_id_type", "open_id") + + res = client.get_spreadsheet(spreadsheet_token, user_id_type) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_spreadsheet/tools/get_spreadsheet.yaml b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/get_spreadsheet.yaml new file mode 100644 index 00000000000000..c519938617ba8c --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/get_spreadsheet.yaml @@ -0,0 +1,49 @@ +identity: + name: get_spreadsheet + author: Doug Lea + label: + en_US: Get Spreadsheet + zh_Hans: 获取电子表格信息 +description: + human: + en_US: Get Spreadsheet + zh_Hans: 获取电子表格信息 + llm: A tool for getting information from spreadsheets. (获取电子表格信息) +parameters: + - name: spreadsheet_token + type: string + required: true + label: + en_US: Spreadsheet Token + zh_Hans: 电子表格 token + human_description: + en_US: Spreadsheet token, supports input of spreadsheet URL. + zh_Hans: 电子表格 token,支持输入电子表格 URL。 + llm_description: 电子表格 token,支持输入电子表格 URL。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form diff --git a/api/core/tools/provider/builtin/feishu_spreadsheet/tools/list_spreadsheet_sheets.py b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/list_spreadsheet_sheets.py new file mode 100644 index 00000000000000..98497791c0fa1e --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/list_spreadsheet_sheets.py @@ -0,0 +1,18 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class ListSpreadsheetSheetsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + spreadsheet_token = tool_parameters.get("spreadsheet_token") + + res = client.list_spreadsheet_sheets(spreadsheet_token) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_spreadsheet/tools/list_spreadsheet_sheets.yaml b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/list_spreadsheet_sheets.yaml new file mode 100644 index 00000000000000..c6a7ef45d46589 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/list_spreadsheet_sheets.yaml @@ -0,0 +1,23 @@ +identity: + name: list_spreadsheet_sheets + author: Doug Lea + label: + en_US: List Spreadsheet Sheets + zh_Hans: 列出电子表格所有工作表 +description: + human: + en_US: List Spreadsheet Sheets + zh_Hans: 列出电子表格所有工作表 + llm: A tool for listing all sheets in a spreadsheet. (列出电子表格所有工作表) +parameters: + - name: spreadsheet_token + type: string + required: true + label: + en_US: Spreadsheet Token + zh_Hans: 电子表格 token + human_description: + en_US: Spreadsheet token, supports input of spreadsheet URL. + zh_Hans: 电子表格 token,支持输入电子表格 URL。 + llm_description: 电子表格 token,支持输入电子表格 URL。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_spreadsheet/tools/read_cols.py b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/read_cols.py new file mode 100644 index 00000000000000..ebe3f619d091d1 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/read_cols.py @@ -0,0 +1,23 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class ReadColsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + spreadsheet_token = tool_parameters.get("spreadsheet_token") + sheet_id = tool_parameters.get("sheet_id") + sheet_name = tool_parameters.get("sheet_name") + start_col = tool_parameters.get("start_col") + num_cols = tool_parameters.get("num_cols") + user_id_type = tool_parameters.get("user_id_type", "open_id") + + res = client.read_cols(spreadsheet_token, sheet_id, sheet_name, start_col, num_cols, user_id_type) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_spreadsheet/tools/read_cols.yaml b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/read_cols.yaml new file mode 100644 index 00000000000000..34da74592d5898 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/read_cols.yaml @@ -0,0 +1,97 @@ +identity: + name: read_cols + author: Doug Lea + label: + en_US: Read Cols + zh_Hans: 读取工作表列数据 +description: + human: + en_US: Read Cols + zh_Hans: 读取工作表列数据 + llm: A tool for reading column data from a spreadsheet. (读取工作表列数据) +parameters: + - name: spreadsheet_token + type: string + required: true + label: + en_US: spreadsheet_token + zh_Hans: 电子表格 token + human_description: + en_US: Spreadsheet token, supports input of spreadsheet URL. + zh_Hans: 电子表格 token,支持输入电子表格 url。 + llm_description: 电子表格 token,支持输入电子表格 url。 + form: llm + + - name: sheet_id + type: string + required: false + label: + en_US: sheet_id + zh_Hans: 工作表 ID + human_description: + en_US: Sheet ID, either sheet_id or sheet_name must be filled. + zh_Hans: 工作表 ID,与 sheet_name 二者其一必填。 + llm_description: 工作表 ID,与 sheet_name 二者其一必填。 + form: llm + + - name: sheet_name + type: string + required: false + label: + en_US: sheet_name + zh_Hans: 工作表名称 + human_description: + en_US: Sheet name, either sheet_id or sheet_name must be filled. + zh_Hans: 工作表名称,与 sheet_id 二者其一必填。 + llm_description: 工作表名称,与 sheet_id 二者其一必填。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form + + - name: start_col + type: number + required: false + label: + en_US: start_col + zh_Hans: 起始列号 + human_description: + en_US: Starting column number, starting from 1. + zh_Hans: 起始列号,从 1 开始。 + llm_description: 起始列号,从 1 开始。 + form: form + + - name: num_cols + type: number + required: true + label: + en_US: num_cols + zh_Hans: 读取列数 + human_description: + en_US: Number of columns to read. + zh_Hans: 读取列数 + llm_description: 读取列数 + form: form diff --git a/api/core/tools/provider/builtin/feishu_spreadsheet/tools/read_rows.py b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/read_rows.py new file mode 100644 index 00000000000000..86b91b104b7029 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/read_rows.py @@ -0,0 +1,23 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class ReadRowsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + spreadsheet_token = tool_parameters.get("spreadsheet_token") + sheet_id = tool_parameters.get("sheet_id") + sheet_name = tool_parameters.get("sheet_name") + start_row = tool_parameters.get("start_row") + num_rows = tool_parameters.get("num_rows") + user_id_type = tool_parameters.get("user_id_type", "open_id") + + res = client.read_rows(spreadsheet_token, sheet_id, sheet_name, start_row, num_rows, user_id_type) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_spreadsheet/tools/read_rows.yaml b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/read_rows.yaml new file mode 100644 index 00000000000000..5dfa8d58354125 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/read_rows.yaml @@ -0,0 +1,97 @@ +identity: + name: read_rows + author: Doug Lea + label: + en_US: Read Rows + zh_Hans: 读取工作表行数据 +description: + human: + en_US: Read Rows + zh_Hans: 读取工作表行数据 + llm: A tool for reading row data from a spreadsheet. (读取工作表行数据) +parameters: + - name: spreadsheet_token + type: string + required: true + label: + en_US: spreadsheet_token + zh_Hans: 电子表格 token + human_description: + en_US: Spreadsheet token, supports input of spreadsheet URL. + zh_Hans: 电子表格 token,支持输入电子表格 url。 + llm_description: 电子表格 token,支持输入电子表格 url。 + form: llm + + - name: sheet_id + type: string + required: false + label: + en_US: sheet_id + zh_Hans: 工作表 ID + human_description: + en_US: Sheet ID, either sheet_id or sheet_name must be filled. + zh_Hans: 工作表 ID,与 sheet_name 二者其一必填。 + llm_description: 工作表 ID,与 sheet_name 二者其一必填。 + form: llm + + - name: sheet_name + type: string + required: false + label: + en_US: sheet_name + zh_Hans: 工作表名称 + human_description: + en_US: Sheet name, either sheet_id or sheet_name must be filled. + zh_Hans: 工作表名称,与 sheet_id 二者其一必填。 + llm_description: 工作表名称,与 sheet_id 二者其一必填。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form + + - name: start_row + type: number + required: false + label: + en_US: start_row + zh_Hans: 起始行号 + human_description: + en_US: Starting row number, starting from 1. + zh_Hans: 起始行号,从 1 开始。 + llm_description: 起始行号,从 1 开始。 + form: form + + - name: num_rows + type: number + required: true + label: + en_US: num_rows + zh_Hans: 读取行数 + human_description: + en_US: Number of rows to read. + zh_Hans: 读取行数 + llm_description: 读取行数 + form: form diff --git a/api/core/tools/provider/builtin/feishu_spreadsheet/tools/read_table.py b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/read_table.py new file mode 100644 index 00000000000000..ddd607d87838f4 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/read_table.py @@ -0,0 +1,23 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class ReadTableTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + spreadsheet_token = tool_parameters.get("spreadsheet_token") + sheet_id = tool_parameters.get("sheet_id") + sheet_name = tool_parameters.get("sheet_name") + num_range = tool_parameters.get("num_range") + query = tool_parameters.get("query") + user_id_type = tool_parameters.get("user_id_type", "open_id") + + res = client.read_table(spreadsheet_token, sheet_id, sheet_name, num_range, query, user_id_type) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_spreadsheet/tools/read_table.yaml b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/read_table.yaml new file mode 100644 index 00000000000000..10534436d66e7a --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_spreadsheet/tools/read_table.yaml @@ -0,0 +1,122 @@ +identity: + name: read_table + author: Doug Lea + label: + en_US: Read Table + zh_Hans: 自定义读取电子表格行列数据 +description: + human: + en_US: Read Table + zh_Hans: 自定义读取电子表格行列数据 + llm: A tool for custom reading of row and column data from a spreadsheet. (自定义读取电子表格行列数据) +parameters: + - name: spreadsheet_token + type: string + required: true + label: + en_US: spreadsheet_token + zh_Hans: 电子表格 token + human_description: + en_US: Spreadsheet token, supports input of spreadsheet URL. + zh_Hans: 电子表格 token,支持输入电子表格 url。 + llm_description: 电子表格 token,支持输入电子表格 url。 + form: llm + + - name: sheet_id + type: string + required: false + label: + en_US: sheet_id + zh_Hans: 工作表 ID + human_description: + en_US: Sheet ID, either sheet_id or sheet_name must be filled. + zh_Hans: 工作表 ID,与 sheet_name 二者其一必填。 + llm_description: 工作表 ID,与 sheet_name 二者其一必填。 + form: llm + + - name: sheet_name + type: string + required: false + label: + en_US: sheet_name + zh_Hans: 工作表名称 + human_description: + en_US: Sheet name, either sheet_id or sheet_name must be filled. + zh_Hans: 工作表名称,与 sheet_id 二者其一必填。 + llm_description: 工作表名称,与 sheet_id 二者其一必填。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form + + - name: start_row + type: number + required: false + label: + en_US: start_row + zh_Hans: 起始行号 + human_description: + en_US: Starting row number, starting from 1. + zh_Hans: 起始行号,从 1 开始。 + llm_description: 起始行号,从 1 开始。 + form: form + + - name: num_rows + type: number + required: false + label: + en_US: num_rows + zh_Hans: 读取行数 + human_description: + en_US: Number of rows to read. + zh_Hans: 读取行数 + llm_description: 读取行数 + form: form + + - name: range + type: string + required: false + label: + en_US: range + zh_Hans: 取数范围 + human_description: + en_US: | + Data range, format like: A1:B2, can be empty when query=all. + zh_Hans: 取数范围,格式如:A1:B2,query=all 时可为空。 + llm_description: 取数范围,格式如:A1:B2,query=all 时可为空。 + form: llm + + - name: query + type: string + required: false + label: + en_US: query + zh_Hans: 查询 + human_description: + en_US: Pass "all" to query all data in the table, but no more than 100 columns. + zh_Hans: 传 all,表示查询表格所有数据,但最多查询 100 列数据。 + llm_description: 传 all,表示查询表格所有数据,但最多查询 100 列数据。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_task/_assets/icon.png b/api/core/tools/provider/builtin/feishu_task/_assets/icon.png new file mode 100644 index 00000000000000..3485be0d0fbd85 Binary files /dev/null and b/api/core/tools/provider/builtin/feishu_task/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/feishu_task/feishu_task.py b/api/core/tools/provider/builtin/feishu_task/feishu_task.py new file mode 100644 index 00000000000000..6df05968d8f176 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_task/feishu_task.py @@ -0,0 +1,7 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.utils.feishu_api_utils import auth + + +class FeishuTaskProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + auth(credentials) diff --git a/api/core/tools/provider/builtin/feishu_task/feishu_task.yaml b/api/core/tools/provider/builtin/feishu_task/feishu_task.yaml new file mode 100644 index 00000000000000..88736f79a02e87 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_task/feishu_task.yaml @@ -0,0 +1,36 @@ +identity: + author: Doug Lea + name: feishu_task + label: + en_US: Feishu Task + zh_Hans: 飞书任务 + description: + en_US: | + Feishu Task, requires the following permissions: task:task:write、contact:user.id:readonly. + zh_Hans: | + 飞书任务,需要开通以下权限: task:task:write、contact:user.id:readonly。 + icon: icon.png + tags: + - social + - productivity +credentials_for_provider: + app_id: + type: text-input + required: true + label: + en_US: APP ID + placeholder: + en_US: Please input your feishu app id + zh_Hans: 请输入你的飞书 app id + help: + en_US: Get your app_id and app_secret from Feishu + zh_Hans: 从飞书获取您的 app_id 和 app_secret + url: https://open.larkoffice.com/app + app_secret: + type: secret-input + required: true + label: + en_US: APP Secret + placeholder: + en_US: Please input your app secret + zh_Hans: 请输入你的飞书 app secret diff --git a/api/core/tools/provider/builtin/feishu_task/tools/add_members.py b/api/core/tools/provider/builtin/feishu_task/tools/add_members.py new file mode 100644 index 00000000000000..e58ed22e0f4797 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_task/tools/add_members.py @@ -0,0 +1,20 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class AddMembersTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + task_guid = tool_parameters.get("task_guid") + member_phone_or_email = tool_parameters.get("member_phone_or_email") + member_role = tool_parameters.get("member_role", "follower") + + res = client.add_members(task_guid, member_phone_or_email, member_role) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_task/tools/add_members.yaml b/api/core/tools/provider/builtin/feishu_task/tools/add_members.yaml new file mode 100644 index 00000000000000..063c0f7f04956c --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_task/tools/add_members.yaml @@ -0,0 +1,58 @@ +identity: + name: add_members + author: Doug Lea + label: + en_US: Add Members + zh_Hans: 添加任务成员 +description: + human: + en_US: Add Members + zh_Hans: 添加任务成员 + llm: A tool for adding members to a Feishu task.(添加任务成员) +parameters: + - name: task_guid + type: string + required: true + label: + en_US: Task GUID + zh_Hans: 任务 GUID + human_description: + en_US: | + The GUID of the task to be added, supports passing either the Task ID or the Task link URL. Example of Task ID: 8b5425ec-9f2a-43bd-a3ab-01912f50282b; Example of Task link URL: https://applink.feishu-pre.net/client/todo/detail?guid=8c6bf822-e4da-449a-b82a-dc44020f9be9&suite_entity_num=t21587362 + zh_Hans: 要添加的任务的 GUID,支持传任务 ID 和任务链接 URL。任务 ID 示例:8b5425ec-9f2a-43bd-a3ab-01912f50282b;任务链接 URL 示例:https://applink.feishu-pre.net/client/todo/detail?guid=8c6bf822-e4da-449a-b82a-dc44020f9be9&suite_entity_num=t21587362 + llm_description: 要添加的任务的 GUID,支持传任务 ID 和任务链接 URL。任务 ID 示例:8b5425ec-9f2a-43bd-a3ab-01912f50282b;任务链接 URL 示例:https://applink.feishu-pre.net/client/todo/detail?guid=8c6bf822-e4da-449a-b82a-dc44020f9be9&suite_entity_num=t21587362 + form: llm + + - name: member_phone_or_email + type: string + required: true + label: + en_US: Task Member Phone Or Email + zh_Hans: 任务成员的电话或邮箱 + human_description: + en_US: A list of member emails or phone numbers, separated by commas. + zh_Hans: 任务成员邮箱或者手机号列表,使用逗号分隔。 + llm_description: 任务成员邮箱或者手机号列表,使用逗号分隔。 + form: llm + + - name: member_role + type: select + required: true + options: + - value: assignee + label: + en_US: assignee + zh_Hans: 负责人 + - value: follower + label: + en_US: follower + zh_Hans: 关注人 + default: "follower" + label: + en_US: member_role + zh_Hans: 成员的角色 + human_description: + en_US: Member role, optional values are "assignee" (responsible person) and "follower" (observer), with a default value of "assignee". + zh_Hans: 成员的角色,可选值有 "assignee"(负责人)和 "follower"(关注人),默认值为 "assignee"。 + llm_description: 成员的角色,可选值有 "assignee"(负责人)和 "follower"(关注人),默认值为 "assignee"。 + form: form diff --git a/api/core/tools/provider/builtin/feishu_task/tools/create_task.py b/api/core/tools/provider/builtin/feishu_task/tools/create_task.py new file mode 100644 index 00000000000000..96cdcd71f6d2ec --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_task/tools/create_task.py @@ -0,0 +1,22 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class CreateTaskTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + summary = tool_parameters.get("summary") + start_time = tool_parameters.get("start_time") + end_time = tool_parameters.get("end_time") + completed_time = tool_parameters.get("completed_time") + description = tool_parameters.get("description") + + res = client.create_task(summary, start_time, end_time, completed_time, description) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_task/tools/create_task.yaml b/api/core/tools/provider/builtin/feishu_task/tools/create_task.yaml new file mode 100644 index 00000000000000..7eb4af168bf740 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_task/tools/create_task.yaml @@ -0,0 +1,74 @@ +identity: + name: create_task + author: Doug Lea + label: + en_US: Create Task + zh_Hans: 创建飞书任务 +description: + human: + en_US: Create Feishu Task + zh_Hans: 创建飞书任务 + llm: A tool for creating tasks in Feishu.(创建飞书任务) +parameters: + - name: summary + type: string + required: true + label: + en_US: Task Title + zh_Hans: 任务标题 + human_description: + en_US: The title of the task. + zh_Hans: 任务标题 + llm_description: 任务标题 + form: llm + + - name: description + type: string + required: false + label: + en_US: Task Description + zh_Hans: 任务备注 + human_description: + en_US: The description or notes for the task. + zh_Hans: 任务备注 + llm_description: 任务备注 + form: llm + + - name: start_time + type: string + required: false + label: + en_US: Start Time + zh_Hans: 任务开始时间 + human_description: + en_US: | + The start time of the task, in the format: 2006-01-02 15:04:05 + zh_Hans: 任务开始时间,格式为:2006-01-02 15:04:05 + llm_description: 任务开始时间,格式为:2006-01-02 15:04:05 + form: llm + + - name: end_time + type: string + required: false + label: + en_US: End Time + zh_Hans: 任务结束时间 + human_description: + en_US: | + The end time of the task, in the format: 2006-01-02 15:04:05 + zh_Hans: 任务结束时间,格式为:2006-01-02 15:04:05 + llm_description: 任务结束时间,格式为:2006-01-02 15:04:05 + form: llm + + - name: completed_time + type: string + required: false + label: + en_US: Completed Time + zh_Hans: 任务完成时间 + human_description: + en_US: | + The completion time of the task, in the format: 2006-01-02 15:04:05. Leave empty to create an incomplete task; fill in a specific time to create a completed task. + zh_Hans: 任务完成时间,格式为:2006-01-02 15:04:05,不填写表示创建一个未完成任务;填写一个具体的时间表示创建一个已完成任务。 + llm_description: 任务完成时间,格式为:2006-01-02 15:04:05,不填写表示创建一个未完成任务;填写一个具体的时间表示创建一个已完成任务。 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_task/tools/delete_task.py b/api/core/tools/provider/builtin/feishu_task/tools/delete_task.py new file mode 100644 index 00000000000000..dee036fee5203a --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_task/tools/delete_task.py @@ -0,0 +1,18 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class UpdateTaskTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + task_guid = tool_parameters.get("task_guid") + + res = client.delete_task(task_guid) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_task/tools/delete_task.yaml b/api/core/tools/provider/builtin/feishu_task/tools/delete_task.yaml new file mode 100644 index 00000000000000..d3f97413676624 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_task/tools/delete_task.yaml @@ -0,0 +1,24 @@ +identity: + name: delete_task + author: Doug Lea + label: + en_US: Delete Task + zh_Hans: 删除飞书任务 +description: + human: + en_US: Delete Task + zh_Hans: 删除飞书任务 + llm: A tool for deleting tasks in Feishu.(删除飞书任务) +parameters: + - name: task_guid + type: string + required: true + label: + en_US: Task GUID + zh_Hans: 任务 GUID + human_description: + en_US: | + The GUID of the task to be deleted, supports passing either the Task ID or the Task link URL. Example of Task ID: 8b5425ec-9f2a-43bd-a3ab-01912f50282b; Example of Task link URL: https://applink.feishu-pre.net/client/todo/detail?guid=8c6bf822-e4da-449a-b82a-dc44020f9be9&suite_entity_num=t21587362 + zh_Hans: 要删除的任务的 GUID,支持传任务 ID 和任务链接 URL。任务 ID 示例:8b5425ec-9f2a-43bd-a3ab-01912f50282b;任务链接 URL 示例:https://applink.feishu-pre.net/client/todo/detail?guid=8c6bf822-e4da-449a-b82a-dc44020f9be9&suite_entity_num=t21587362 + llm_description: 要删除的任务的 GUID,支持传任务 ID 和任务链接 URL。任务 ID 示例:8b5425ec-9f2a-43bd-a3ab-01912f50282b;任务链接 URL 示例:https://applink.feishu-pre.net/client/todo/detail?guid=8c6bf822-e4da-449a-b82a-dc44020f9be9&suite_entity_num=t21587362 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_task/tools/update_task.py b/api/core/tools/provider/builtin/feishu_task/tools/update_task.py new file mode 100644 index 00000000000000..4a48cd283abf1d --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_task/tools/update_task.py @@ -0,0 +1,23 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class UpdateTaskTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + task_guid = tool_parameters.get("task_guid") + summary = tool_parameters.get("summary") + start_time = tool_parameters.get("start_time") + end_time = tool_parameters.get("end_time") + completed_time = tool_parameters.get("completed_time") + description = tool_parameters.get("description") + + res = client.update_task(task_guid, summary, start_time, end_time, completed_time, description) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_task/tools/update_task.yaml b/api/core/tools/provider/builtin/feishu_task/tools/update_task.yaml new file mode 100644 index 00000000000000..83c9bcb1c443ac --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_task/tools/update_task.yaml @@ -0,0 +1,89 @@ +identity: + name: update_task + author: Doug Lea + label: + en_US: Update Task + zh_Hans: 更新飞书任务 +description: + human: + en_US: Update Feishu Task + zh_Hans: 更新飞书任务 + llm: A tool for updating tasks in Feishu.(更新飞书任务) +parameters: + - name: task_guid + type: string + required: true + label: + en_US: Task GUID + zh_Hans: 任务 GUID + human_description: + en_US: | + The task ID, supports inputting either the Task ID or the Task link URL. Example of Task ID: 42cad8a0-f8c8-4344-9be2-d1d7e8e91b64; Example of Task link URL: https://applink.feishu-pre.net/client/todo/detail?guid=42cad8a0-f8c8-4344-9be2-d1d7e8e91b64&suite_entity_num=t21700217 + zh_Hans: | + 任务ID,支持传入任务 ID 和任务链接 URL。任务 ID 示例: 42cad8a0-f8c8-4344-9be2-d1d7e8e91b64;任务链接 URL 示例: https://applink.feishu-pre.net/client/todo/detail?guid=42cad8a0-f8c8-4344-9be2-d1d7e8e91b64&suite_entity_num=t21700217 + llm_description: | + 任务ID,支持传入任务 ID 和任务链接 URL。任务 ID 示例: 42cad8a0-f8c8-4344-9be2-d1d7e8e91b64;任务链接 URL 示例: https://applink.feishu-pre.net/client/todo/detail?guid=42cad8a0-f8c8-4344-9be2-d1d7e8e91b64&suite_entity_num=t21700217 + form: llm + + - name: summary + type: string + required: true + label: + en_US: Task Title + zh_Hans: 任务标题 + human_description: + en_US: The title of the task. + zh_Hans: 任务标题 + llm_description: 任务标题 + form: llm + + - name: description + type: string + required: false + label: + en_US: Task Description + zh_Hans: 任务备注 + human_description: + en_US: The description or notes for the task. + zh_Hans: 任务备注 + llm_description: 任务备注 + form: llm + + - name: start_time + type: string + required: false + label: + en_US: Start Time + zh_Hans: 任务开始时间 + human_description: + en_US: | + The start time of the task, in the format: 2006-01-02 15:04:05 + zh_Hans: 任务开始时间,格式为:2006-01-02 15:04:05 + llm_description: 任务开始时间,格式为:2006-01-02 15:04:05 + form: llm + + - name: end_time + type: string + required: false + label: + en_US: End Time + zh_Hans: 任务结束时间 + human_description: + en_US: | + The end time of the task, in the format: 2006-01-02 15:04:05 + zh_Hans: 任务结束时间,格式为:2006-01-02 15:04:05 + llm_description: 任务结束时间,格式为:2006-01-02 15:04:05 + form: llm + + - name: completed_time + type: string + required: false + label: + en_US: Completed Time + zh_Hans: 任务完成时间 + human_description: + en_US: | + The completion time of the task, in the format: 2006-01-02 15:04:05 + zh_Hans: 任务完成时间,格式为:2006-01-02 15:04:05 + llm_description: 任务完成时间,格式为:2006-01-02 15:04:05 + form: llm diff --git a/api/core/tools/provider/builtin/feishu_wiki/_assets/icon.png b/api/core/tools/provider/builtin/feishu_wiki/_assets/icon.png new file mode 100644 index 00000000000000..878672c9ae5a51 Binary files /dev/null and b/api/core/tools/provider/builtin/feishu_wiki/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/feishu_wiki/feishu_wiki.py b/api/core/tools/provider/builtin/feishu_wiki/feishu_wiki.py new file mode 100644 index 00000000000000..6c5fccb1a31d0d --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_wiki/feishu_wiki.py @@ -0,0 +1,7 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.utils.feishu_api_utils import auth + + +class FeishuWikiProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + auth(credentials) diff --git a/api/core/tools/provider/builtin/feishu_wiki/feishu_wiki.yaml b/api/core/tools/provider/builtin/feishu_wiki/feishu_wiki.yaml new file mode 100644 index 00000000000000..1fb5f71cbc5169 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_wiki/feishu_wiki.yaml @@ -0,0 +1,36 @@ +identity: + author: Doug Lea + name: feishu_wiki + label: + en_US: Feishu Wiki + zh_Hans: 飞书知识库 + description: + en_US: | + Feishu Wiki, requires the following permissions: wiki:wiki:readonly. + zh_Hans: | + 飞书知识库,需要开通以下权限: wiki:wiki:readonly。 + icon: icon.png + tags: + - social + - productivity +credentials_for_provider: + app_id: + type: text-input + required: true + label: + en_US: APP ID + placeholder: + en_US: Please input your feishu app id + zh_Hans: 请输入你的飞书 app id + help: + en_US: Get your app_id and app_secret from Feishu + zh_Hans: 从飞书获取您的 app_id 和 app_secret + url: https://open.larkoffice.com/app + app_secret: + type: secret-input + required: true + label: + en_US: APP Secret + placeholder: + en_US: Please input your app secret + zh_Hans: 请输入你的飞书 app secret diff --git a/api/core/tools/provider/builtin/feishu_wiki/tools/get_wiki_nodes.py b/api/core/tools/provider/builtin/feishu_wiki/tools/get_wiki_nodes.py new file mode 100644 index 00000000000000..374b4c9a7d1492 --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_wiki/tools/get_wiki_nodes.py @@ -0,0 +1,21 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.feishu_api_utils import FeishuRequest + + +class GetWikiNodesTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = FeishuRequest(app_id, app_secret) + + space_id = tool_parameters.get("space_id") + parent_node_token = tool_parameters.get("parent_node_token") + page_token = tool_parameters.get("page_token") + page_size = tool_parameters.get("page_size") + + res = client.get_wiki_nodes(space_id, parent_node_token, page_token, page_size) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/feishu_wiki/tools/get_wiki_nodes.yaml b/api/core/tools/provider/builtin/feishu_wiki/tools/get_wiki_nodes.yaml new file mode 100644 index 00000000000000..74d51e7bcbc32a --- /dev/null +++ b/api/core/tools/provider/builtin/feishu_wiki/tools/get_wiki_nodes.yaml @@ -0,0 +1,63 @@ +identity: + name: get_wiki_nodes + author: Doug Lea + label: + en_US: Get Wiki Nodes + zh_Hans: 获取知识空间子节点列表 +description: + human: + en_US: | + Get the list of child nodes in Wiki, make sure the app/bot is a member of the wiki space. See How to add an app as a wiki base administrator (member). https://open.feishu.cn/document/server-docs/docs/wiki-v2/wiki-qa + zh_Hans: | + 获取知识库全部子节点列表,请确保应用/机器人为知识空间成员。参阅如何将应用添加为知识库管理员(成员)。https://open.feishu.cn/document/server-docs/docs/wiki-v2/wiki-qa + llm: A tool for getting all sub-nodes of a knowledge base.(获取知识空间子节点列表) +parameters: + - name: space_id + type: string + required: true + label: + en_US: Space Id + zh_Hans: 知识空间 ID + human_description: + en_US: | + The ID of the knowledge space. Supports space link URL, for example: https://svi136aogf123.feishu.cn/wiki/settings/7166950623940706332 + zh_Hans: 知识空间 ID,支持空间链接 URL,例如:https://svi136aogf123.feishu.cn/wiki/settings/7166950623940706332 + llm_description: 知识空间 ID,支持空间链接 URL,例如:https://svi136aogf123.feishu.cn/wiki/settings/7166950623940706332 + form: llm + + - name: page_size + type: number + required: false + default: 10 + label: + en_US: Page Size + zh_Hans: 分页大小 + human_description: + en_US: The size of each page, with a maximum value of 50. + zh_Hans: 分页大小,最大值 50。 + llm_description: 分页大小,最大值 50。 + form: form + + - name: page_token + type: string + required: false + label: + en_US: Page Token + zh_Hans: 分页标记 + human_description: + en_US: The pagination token. Leave empty for the first request to start from the beginning; if the paginated query result has more items, a new page_token will be returned, which can be used to get the next set of results. + zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + form: llm + + - name: parent_node_token + type: string + required: false + label: + en_US: Parent Node Token + zh_Hans: 父节点 token + human_description: + en_US: The token of the parent node. + zh_Hans: 父节点 token + llm_description: 父节点 token + form: llm diff --git a/api/core/tools/provider/builtin/firecrawl/firecrawl.py b/api/core/tools/provider/builtin/firecrawl/firecrawl.py index 24dc35759d8e6d..01455d7206f185 100644 --- a/api/core/tools/provider/builtin/firecrawl/firecrawl.py +++ b/api/core/tools/provider/builtin/firecrawl/firecrawl.py @@ -7,15 +7,8 @@ class FirecrawlProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: # Example validation using the ScrapeTool, only scraping title for minimize content - ScrapeTool().fork_tool_runtime( - runtime={"credentials": credentials} - ).invoke( - user_id='', - tool_parameters={ - "url": "https://google.com", - "onlyIncludeTags": 'title' - } + ScrapeTool().fork_tool_runtime(runtime={"credentials": credentials}).invoke( + user_id="", tool_parameters={"url": "https://google.com", "onlyIncludeTags": "title"} ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py b/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py index 3b3f78731b3de4..d9fb6f04bcfa75 100644 --- a/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py +++ b/api/core/tools/provider/builtin/firecrawl/firecrawl_appx.py @@ -13,85 +13,83 @@ class FirecrawlApp: def __init__(self, api_key: str | None = None, base_url: str | None = None): self.api_key = api_key - self.base_url = base_url or 'https://api.firecrawl.dev' + self.base_url = base_url or "https://api.firecrawl.dev" if not self.api_key: raise ValueError("API key is required") def _prepare_headers(self, idempotency_key: str | None = None): - headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} if idempotency_key: - headers['Idempotency-Key'] = idempotency_key + headers["Idempotency-Key"] = idempotency_key return headers def _request( - self, - method: str, - url: str, - data: Mapping[str, Any] | None = None, - headers: Mapping[str, str] | None = None, - retries: int = 3, - backoff_factor: float = 0.3, + self, + method: str, + url: str, + data: Mapping[str, Any] | None = None, + headers: Mapping[str, str] | None = None, + retries: int = 3, + backoff_factor: float = 0.3, ) -> Mapping[str, Any] | None: if not headers: headers = self._prepare_headers() for i in range(retries): try: response = requests.request(method, url, json=data, headers=headers) - response.raise_for_status() return response.json() - except requests.exceptions.RequestException as e: + except requests.exceptions.RequestException: if i < retries - 1: - time.sleep(backoff_factor * (2 ** i)) + time.sleep(backoff_factor * (2**i)) else: raise return None def scrape_url(self, url: str, **kwargs): - endpoint = f'{self.base_url}/v0/scrape' - data = {'url': url, **kwargs} + endpoint = f"{self.base_url}/v1/scrape" + data = {"url": url, **kwargs} logger.debug(f"Sent request to {endpoint=} body={data}") - response = self._request('POST', endpoint, data) + response = self._request("POST", endpoint, data) if response is None: raise HTTPError("Failed to scrape URL after multiple retries") return response - def search(self, query: str, **kwargs): - endpoint = f'{self.base_url}/v0/search' - data = {'query': query, **kwargs} + def map(self, url: str, **kwargs): + endpoint = f"{self.base_url}/v1/map" + data = {"url": url, **kwargs} logger.debug(f"Sent request to {endpoint=} body={data}") - response = self._request('POST', endpoint, data) + response = self._request("POST", endpoint, data) if response is None: - raise HTTPError("Failed to perform search after multiple retries") + raise HTTPError("Failed to perform map after multiple retries") return response def crawl_url( - self, url: str, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs + self, url: str, wait: bool = True, poll_interval: int = 5, idempotency_key: str | None = None, **kwargs ): - endpoint = f'{self.base_url}/v0/crawl' + endpoint = f"{self.base_url}/v1/crawl" headers = self._prepare_headers(idempotency_key) - data = {'url': url, **kwargs} + data = {"url": url, **kwargs} logger.debug(f"Sent request to {endpoint=} body={data}") - response = self._request('POST', endpoint, data, headers) + response = self._request("POST", endpoint, data, headers) if response is None: raise HTTPError("Failed to initiate crawl after multiple retries") - job_id: str = response['jobId'] + elif response.get("success") == False: + raise HTTPError(f'Failed to crawl: {response.get("error")}') + job_id: str = response["id"] if wait: return self._monitor_job_status(job_id=job_id, poll_interval=poll_interval) return response def check_crawl_status(self, job_id: str): - endpoint = f'{self.base_url}/v0/crawl/status/{job_id}' - response = self._request('GET', endpoint) + endpoint = f"{self.base_url}/v1/crawl/{job_id}" + response = self._request("GET", endpoint) if response is None: raise HTTPError(f"Failed to check status for job {job_id} after multiple retries") return response def cancel_crawl_job(self, job_id: str): - endpoint = f'{self.base_url}/v0/crawl/cancel/{job_id}' - response = self._request('DELETE', endpoint) + endpoint = f"{self.base_url}/v1/crawl/{job_id}" + response = self._request("DELETE", endpoint) if response is None: raise HTTPError(f"Failed to cancel job {job_id} after multiple retries") return response @@ -99,9 +97,9 @@ def cancel_crawl_job(self, job_id: str): def _monitor_job_status(self, job_id: str, poll_interval: int): while True: status = self.check_crawl_status(job_id) - if status['status'] == 'completed': + if status["status"] == "completed": return status - elif status['status'] == 'failed': + elif status["status"] == "failed": raise HTTPError(f'Job {job_id} failed: {status["error"]}') time.sleep(poll_interval) @@ -109,7 +107,7 @@ def _monitor_job_status(self, job_id: str, poll_interval: int): def get_array_params(tool_parameters: dict[str, Any], key): param = tool_parameters.get(key) if param: - return param.split(',') + return param.split(",") def get_json_params(tool_parameters: dict[str, Any], key): @@ -119,6 +117,6 @@ def get_json_params(tool_parameters: dict[str, Any], key): # support both single quotes and double quotes param = param.replace("'", '"') param = json.loads(param) - except: + except Exception: raise ValueError(f"Invalid {key} format.") return param diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl.py b/api/core/tools/provider/builtin/firecrawl/tools/crawl.py index 08c40a4064c511..15ab510c6c889c 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/crawl.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl.py @@ -8,41 +8,38 @@ class CrawlTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: """ - the crawlerOptions and pageOptions comes from doc here: + the api doc: https://docs.firecrawl.dev/api-reference/endpoint/crawl """ - app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], - base_url=self.runtime.credentials['base_url']) - crawlerOptions = {} - pageOptions = {} - - wait_for_results = tool_parameters.get('wait_for_results', True) - - crawlerOptions['excludes'] = get_array_params(tool_parameters, 'excludes') - crawlerOptions['includes'] = get_array_params(tool_parameters, 'includes') - crawlerOptions['returnOnlyUrls'] = tool_parameters.get('returnOnlyUrls', False) - crawlerOptions['maxDepth'] = tool_parameters.get('maxDepth') - crawlerOptions['mode'] = tool_parameters.get('mode') - crawlerOptions['ignoreSitemap'] = tool_parameters.get('ignoreSitemap', False) - crawlerOptions['limit'] = tool_parameters.get('limit', 5) - crawlerOptions['allowBackwardCrawling'] = tool_parameters.get('allowBackwardCrawling', False) - crawlerOptions['allowExternalContentLinks'] = tool_parameters.get('allowExternalContentLinks', False) - - pageOptions['headers'] = get_json_params(tool_parameters, 'headers') - pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False) - pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False) - pageOptions['onlyIncludeTags'] = get_array_params(tool_parameters, 'onlyIncludeTags') - pageOptions['removeTags'] = get_array_params(tool_parameters, 'removeTags') - pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False) - pageOptions['replaceAllPathsWithAbsolutePaths'] = tool_parameters.get('replaceAllPathsWithAbsolutePaths', False) - pageOptions['screenshot'] = tool_parameters.get('screenshot', False) - pageOptions['waitFor'] = tool_parameters.get('waitFor', 0) - - crawl_result = app.crawl_url( - url=tool_parameters['url'], - wait=wait_for_results, - crawlerOptions=crawlerOptions, - pageOptions=pageOptions + app = FirecrawlApp( + api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] ) + scrapeOptions = {} + payload = {} + + wait_for_results = tool_parameters.get("wait_for_results", True) + + payload["excludePaths"] = get_array_params(tool_parameters, "excludePaths") + payload["includePaths"] = get_array_params(tool_parameters, "includePaths") + payload["maxDepth"] = tool_parameters.get("maxDepth") + payload["ignoreSitemap"] = tool_parameters.get("ignoreSitemap", False) + payload["limit"] = tool_parameters.get("limit", 5) + payload["allowBackwardLinks"] = tool_parameters.get("allowBackwardLinks", False) + payload["allowExternalLinks"] = tool_parameters.get("allowExternalLinks", False) + payload["webhook"] = tool_parameters.get("webhook") + + scrapeOptions["formats"] = get_array_params(tool_parameters, "formats") + scrapeOptions["headers"] = get_json_params(tool_parameters, "headers") + scrapeOptions["includeTags"] = get_array_params(tool_parameters, "includeTags") + scrapeOptions["excludeTags"] = get_array_params(tool_parameters, "excludeTags") + scrapeOptions["onlyMainContent"] = tool_parameters.get("onlyMainContent", False) + scrapeOptions["waitFor"] = tool_parameters.get("waitFor", 0) + scrapeOptions = {k: v for k, v in scrapeOptions.items() if v not in (None, "")} + payload["scrapeOptions"] = scrapeOptions or None + + payload = {k: v for k, v in payload.items() if v not in (None, "")} + + crawl_result = app.crawl_url(url=tool_parameters["url"], wait=wait_for_results, **payload) + return self.create_json_message(crawl_result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml b/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml index 0c5399f973c970..0d7dbcac20ea16 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl.yaml @@ -31,44 +31,33 @@ parameters: en_US: If you choose not to wait, it will directly return a job ID. You can use this job ID to check the crawling results or cancel the crawling task, which is usually very useful for a large-scale crawling task. zh_Hans: 如果选择不等待,则会直接返回一个job_id,可以通过job_id查询爬取结果或取消爬取任务,这通常对于一个大型爬取任务来说非常有用。 form: form -############## Crawl Options ####################### - - name: includes +############## Payload ####################### + - name: excludePaths type: string - required: false label: - en_US: URL patterns to include - zh_Hans: 要包含的URL模式 + en_US: URL patterns to exclude + zh_Hans: 要排除的URL模式 placeholder: en_US: Use commas to separate multiple tags zh_Hans: 多个标签时使用半角逗号分隔 human_description: en_US: | - Only pages matching these patterns will be crawled. Example: blog/*, about/* - zh_Hans: 只有与这些模式匹配的页面才会被爬取。示例:blog/*, about/* + Pages matching these patterns will be skipped. Example: blog/*, about/* + zh_Hans: 匹配这些模式的页面将被跳过。示例:blog/*, about/* form: form - - name: excludes + - name: includePaths type: string + required: false label: - en_US: URL patterns to exclude - zh_Hans: 要排除的URL模式 + en_US: URL patterns to include + zh_Hans: 要包含的URL模式 placeholder: en_US: Use commas to separate multiple tags zh_Hans: 多个标签时使用半角逗号分隔 human_description: en_US: | - Pages matching these patterns will be skipped. Example: blog/*, about/* - zh_Hans: 匹配这些模式的页面将被跳过。示例:blog/*, about/* - form: form - - name: returnOnlyUrls - type: boolean - default: false - label: - en_US: return Only Urls - zh_Hans: 仅返回URL - human_description: - en_US: | - If true, returns only the URLs as a list on the crawl status. Attention: the return response will be a list of URLs inside the data, not a list of documents. - zh_Hans: 只返回爬取到的网页链接,而不是网页内容本身。 + Only pages matching these patterns will be crawled. Example: blog/*, about/* + zh_Hans: 只有与这些模式匹配的页面才会被爬取。示例:blog/*, about/* form: form - name: maxDepth type: number @@ -80,27 +69,10 @@ parameters: zh_Hans: 相对于输入的URL,爬取的最大深度。maxDepth为0时,仅抓取输入的URL。maxDepth为1时,抓取输入的URL以及所有一级深层页面。maxDepth为2时,抓取输入的URL以及所有两级深层页面。更高值遵循相同模式。 form: form min: 0 - - name: mode - type: select - required: false - form: form - options: - - value: default - label: - en_US: default - - value: fast - label: - en_US: fast - default: default - label: - en_US: Crawl Mode - zh_Hans: 爬取模式 - human_description: - en_US: The crawling mode to use. Fast mode crawls 4x faster websites without sitemap, but may not be as accurate and shouldn't be used in heavy js-rendered websites. - zh_Hans: 使用fast模式将不会使用其站点地图,比普通模式快4倍,但是可能不够准确,也不适用于大量js渲染的网站。 + default: 2 - name: ignoreSitemap type: boolean - default: false + default: true label: en_US: ignore Sitemap zh_Hans: 忽略站点地图 @@ -120,7 +92,7 @@ parameters: form: form min: 1 default: 5 - - name: allowBackwardCrawling + - name: allowBackwardLinks type: boolean default: false label: @@ -130,7 +102,7 @@ parameters: en_US: Enables the crawler to navigate from a specific URL to previously linked pages. For instance, from 'example.com/product/123' back to 'example.com/product' zh_Hans: 使爬虫能够从特定URL导航到之前链接的页面。例如,从'example.com/product/123'返回到'example.com/product' form: form - - name: allowExternalContentLinks + - name: allowExternalLinks type: boolean default: false label: @@ -140,7 +112,30 @@ parameters: en_US: Allows the crawler to follow links to external websites. zh_Hans: form: form -############## Page Options ####################### + - name: webhook + type: string + label: + en_US: webhook + human_description: + en_US: | + The URL to send the webhook to. This will trigger for crawl started (crawl.started) ,every page crawled (crawl.page) and when the crawl is completed (crawl.completed or crawl.failed). The response will be the same as the /scrape endpoint. + zh_Hans: 发送Webhook的URL。这将在开始爬取(crawl.started)、每爬取一个页面(crawl.page)以及爬取完成(crawl.completed或crawl.failed)时触发。响应将与/scrape端点相同。 + form: form +############## Scrape Options ####################### + - name: formats + type: string + label: + en_US: Formats + zh_Hans: 结果的格式 + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 + human_description: + en_US: | + Formats to include in the output. Available options: markdown, html, rawHtml, links, screenshot + zh_Hans: | + 输出中应包含的格式。可以填入: markdown, html, rawHtml, links, screenshot + form: form - name: headers type: string label: @@ -155,30 +150,10 @@ parameters: en_US: Please enter an object that can be serialized in JSON zh_Hans: 请输入可以json序列化的对象 form: form - - name: includeHtml - type: boolean - default: false - label: - en_US: include Html - zh_Hans: 包含HTML - human_description: - en_US: Include the HTML version of the content on page. Will output a html key in the response. - zh_Hans: 返回中包含一个HTML版本的内容,将以html键返回。 - form: form - - name: includeRawHtml - type: boolean - default: false - label: - en_US: include Raw Html - zh_Hans: 包含原始HTML - human_description: - en_US: Include the raw HTML content of the page. Will output a rawHtml key in the response. - zh_Hans: 返回中包含一个原始HTML版本的内容,将以rawHtml键返回。 - form: form - - name: onlyIncludeTags + - name: includeTags type: string label: - en_US: only Include Tags + en_US: Include Tags zh_Hans: 仅抓取这些标签 placeholder: en_US: Use commas to separate multiple tags @@ -189,20 +164,10 @@ parameters: zh_Hans: | 仅在最终输出中包含HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer form: form - - name: onlyMainContent - type: boolean - default: false - label: - en_US: only Main Content - zh_Hans: 仅抓取主要内容 - human_description: - en_US: Only return the main content of the page excluding headers, navs, footers, etc. - zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。 - form: form - - name: removeTags + - name: excludeTags type: string label: - en_US: remove Tags + en_US: Exclude Tags zh_Hans: 要移除这些标签 human_description: en_US: | @@ -213,25 +178,15 @@ parameters: en_US: Use commas to separate multiple tags zh_Hans: 多个标签时使用半角逗号分隔 form: form - - name: replaceAllPathsWithAbsolutePaths - type: boolean - default: false - label: - en_US: All AbsolutePaths - zh_Hans: 使用绝对路径 - human_description: - en_US: Replace all relative paths with absolute paths for images and links. - zh_Hans: 将所有图片和链接的相对路径替换为绝对路径。 - form: form - - name: screenshot + - name: onlyMainContent type: boolean default: false label: - en_US: screenshot - zh_Hans: 截图 + en_US: only Main Content + zh_Hans: 仅抓取主要内容 human_description: - en_US: Include a screenshot of the top of the page that you are scraping. - zh_Hans: 提供正在抓取的页面的顶部的截图。 + en_US: Only return the main content of the page excluding headers, navs, footers, etc. + zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。 form: form - name: waitFor type: number diff --git a/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.py b/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.py index fa6c1f87ee2c42..0d2486c7ca4426 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/crawl_job.py @@ -7,14 +7,15 @@ class CrawlJobTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], - base_url=self.runtime.credentials['base_url']) - operation = tool_parameters.get('operation', 'get') - if operation == 'get': - result = app.check_crawl_status(job_id=tool_parameters['job_id']) - elif operation == 'cancel': - result = app.cancel_crawl_job(job_id=tool_parameters['job_id']) + app = FirecrawlApp( + api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] + ) + operation = tool_parameters.get("operation", "get") + if operation == "get": + result = app.check_crawl_status(job_id=tool_parameters["job_id"]) + elif operation == "cancel": + result = app.cancel_crawl_job(job_id=tool_parameters["job_id"]) else: - raise ValueError(f'Invalid operation: {operation}') + raise ValueError(f"Invalid operation: {operation}") return self.create_json_message(result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/map.py b/api/core/tools/provider/builtin/firecrawl/tools/map.py new file mode 100644 index 00000000000000..bdfb5faeb8e2c9 --- /dev/null +++ b/api/core/tools/provider/builtin/firecrawl/tools/map.py @@ -0,0 +1,25 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.firecrawl.firecrawl_appx import FirecrawlApp +from core.tools.tool.builtin_tool import BuiltinTool + + +class MapTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + """ + the api doc: + https://docs.firecrawl.dev/api-reference/endpoint/map + """ + app = FirecrawlApp( + api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] + ) + payload = {} + payload["search"] = tool_parameters.get("search") + payload["ignoreSitemap"] = tool_parameters.get("ignoreSitemap", True) + payload["includeSubdomains"] = tool_parameters.get("includeSubdomains", False) + payload["limit"] = tool_parameters.get("limit", 5000) + + map_result = app.map(url=tool_parameters["url"], **payload) + + return self.create_json_message(map_result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/map.yaml b/api/core/tools/provider/builtin/firecrawl/tools/map.yaml new file mode 100644 index 00000000000000..9913756983370a --- /dev/null +++ b/api/core/tools/provider/builtin/firecrawl/tools/map.yaml @@ -0,0 +1,59 @@ +identity: + name: map + author: hjlarry + label: + en_US: Map + zh_Hans: 地图式快爬 +description: + human: + en_US: Input a website and get all the urls on the website - extremly fast + zh_Hans: 输入一个网站,快速获取网站上的所有网址。 + llm: Input a website and get all the urls on the website - extremly fast +parameters: + - name: url + type: string + required: true + label: + en_US: Start URL + zh_Hans: 起始URL + human_description: + en_US: The base URL to start crawling from. + zh_Hans: 要爬取网站的起始URL。 + llm_description: The URL of the website that needs to be crawled. This is a required parameter. + form: llm + - name: search + type: string + label: + en_US: search + zh_Hans: 搜索查询 + human_description: + en_US: Search query to use for mapping. During the Alpha phase, the 'smart' part of the search functionality is limited to 100 search results. However, if map finds more results, there is no limit applied. + zh_Hans: 用于映射的搜索查询。在Alpha阶段,搜索功能的“智能”部分限制为最多100个搜索结果。然而,如果地图找到了更多结果,则不施加任何限制。 + llm_description: Search query to use for mapping. During the Alpha phase, the 'smart' part of the search functionality is limited to 100 search results. However, if map finds more results, there is no limit applied. + form: llm +############## Page Options ####################### + - name: ignoreSitemap + type: boolean + default: true + label: + en_US: ignore Sitemap + zh_Hans: 忽略站点地图 + human_description: + en_US: Ignore the website sitemap when crawling. + zh_Hans: 爬取时忽略网站站点地图。 + form: form + - name: includeSubdomains + type: boolean + default: false + label: + en_US: include Subdomains + zh_Hans: 包含子域名 + form: form + - name: limit + type: number + min: 0 + default: 5000 + label: + en_US: Maximum results + zh_Hans: 最大结果数量 + form: form diff --git a/api/core/tools/provider/builtin/firecrawl/tools/scrape.py b/api/core/tools/provider/builtin/firecrawl/tools/scrape.py index 91412da548a0b6..f00a9b31ce8c2c 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/scrape.py +++ b/api/core/tools/provider/builtin/firecrawl/tools/scrape.py @@ -6,34 +6,34 @@ class ScrapeTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: """ - the pageOptions and extractorOptions comes from doc here: + the api doc: https://docs.firecrawl.dev/api-reference/endpoint/scrape """ - app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], - base_url=self.runtime.credentials['base_url']) - - pageOptions = {} - extractorOptions = {} - - pageOptions['headers'] = get_json_params(tool_parameters, 'headers') - pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False) - pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False) - pageOptions['onlyIncludeTags'] = get_array_params(tool_parameters, 'onlyIncludeTags') - pageOptions['removeTags'] = get_array_params(tool_parameters, 'removeTags') - pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False) - pageOptions['replaceAllPathsWithAbsolutePaths'] = tool_parameters.get('replaceAllPathsWithAbsolutePaths', False) - pageOptions['screenshot'] = tool_parameters.get('screenshot', False) - pageOptions['waitFor'] = tool_parameters.get('waitFor', 0) - - extractorOptions['mode'] = tool_parameters.get('mode', '') - extractorOptions['extractionPrompt'] = tool_parameters.get('extractionPrompt', '') - extractorOptions['extractionSchema'] = get_json_params(tool_parameters, 'extractionSchema') - - crawl_result = app.scrape_url(url=tool_parameters['url'], - pageOptions=pageOptions, - extractorOptions=extractorOptions) - - return self.create_json_message(crawl_result) + app = FirecrawlApp( + api_key=self.runtime.credentials["firecrawl_api_key"], base_url=self.runtime.credentials["base_url"] + ) + + payload = {} + extract = {} + + payload["formats"] = get_array_params(tool_parameters, "formats") + payload["onlyMainContent"] = tool_parameters.get("onlyMainContent", True) + payload["includeTags"] = get_array_params(tool_parameters, "includeTags") + payload["excludeTags"] = get_array_params(tool_parameters, "excludeTags") + payload["headers"] = get_json_params(tool_parameters, "headers") + payload["waitFor"] = tool_parameters.get("waitFor", 0) + payload["timeout"] = tool_parameters.get("timeout", 30000) + + extract["schema"] = get_json_params(tool_parameters, "schema") + extract["systemPrompt"] = tool_parameters.get("systemPrompt") + extract["prompt"] = tool_parameters.get("prompt") + extract = {k: v for k, v in extract.items() if v not in (None, "")} + payload["extract"] = extract or None + + payload = {k: v for k, v in payload.items() if v not in (None, "")} + + crawl_result = app.scrape_url(url=tool_parameters["url"], **payload) + markdown_result = crawl_result.get("data", {}).get("markdown", "") + return [self.create_text_message(markdown_result), self.create_json_message(crawl_result)] diff --git a/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml b/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml index 598429de5e027e..8f1f1348a459ca 100644 --- a/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml +++ b/api/core/tools/provider/builtin/firecrawl/tools/scrape.yaml @@ -6,8 +6,8 @@ identity: zh_Hans: 单页面抓取 description: human: - en_US: Extract data from a single URL. - zh_Hans: 从单个URL抓取数据。 + en_US: Turn any url into clean data. + zh_Hans: 将任何网址转换为干净的数据。 llm: This tool is designed to scrape URL and output the content in Markdown format. parameters: - name: url @@ -21,45 +21,35 @@ parameters: zh_Hans: 要抓取并提取数据的网站URL。 llm_description: The URL of the website that needs to be crawled. This is a required parameter. form: llm -############## Page Options ####################### - - name: headers +############## Payload ####################### + - name: formats type: string label: - en_US: headers - zh_Hans: 请求头 + en_US: Formats + zh_Hans: 结果的格式 + placeholder: + en_US: Use commas to separate multiple tags + zh_Hans: 多个标签时使用半角逗号分隔 human_description: en_US: | - Headers to send with the request. Can be used to send cookies, user-agent, etc. Example: {"cookies": "testcookies"} + Formats to include in the output. Available options: markdown, html, rawHtml, links, screenshot, extract, screenshot@fullPage zh_Hans: | - 随请求发送的头部。可以用来发送cookies、用户代理等。示例:{"cookies": "testcookies"} - placeholder: - en_US: Please enter an object that can be serialized in JSON - zh_Hans: 请输入可以json序列化的对象 + 输出中应包含的格式。可以填入: markdown, html, rawHtml, links, screenshot, extract, screenshot@fullPage form: form - - name: includeHtml - type: boolean - default: false - label: - en_US: include Html - zh_Hans: 包含HTML - human_description: - en_US: Include the HTML version of the content on page. Will output a html key in the response. - zh_Hans: 返回中包含一个HTML版本的内容,将以html键返回。 - form: form - - name: includeRawHtml + - name: onlyMainContent type: boolean default: false label: - en_US: include Raw Html - zh_Hans: 包含原始HTML + en_US: only Main Content + zh_Hans: 仅抓取主要内容 human_description: - en_US: Include the raw HTML content of the page. Will output a rawHtml key in the response. - zh_Hans: 返回中包含一个原始HTML版本的内容,将以rawHtml键返回。 + en_US: Only return the main content of the page excluding headers, navs, footers, etc. + zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。 form: form - - name: onlyIncludeTags + - name: includeTags type: string label: - en_US: only Include Tags + en_US: Include Tags zh_Hans: 仅抓取这些标签 placeholder: en_US: Use commas to separate multiple tags @@ -70,20 +60,10 @@ parameters: zh_Hans: | 仅在最终输出中包含HTML页面的这些标签,可以通过标签名、类或ID来设定,使用逗号分隔值。示例:script, .ad, #footer form: form - - name: onlyMainContent - type: boolean - default: false - label: - en_US: only Main Content - zh_Hans: 仅抓取主要内容 - human_description: - en_US: Only return the main content of the page excluding headers, navs, footers, etc. - zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。 - form: form - - name: removeTags + - name: excludeTags type: string label: - en_US: remove Tags + en_US: Exclude Tags zh_Hans: 要移除这些标签 human_description: en_US: | @@ -94,29 +74,24 @@ parameters: en_US: Use commas to separate multiple tags zh_Hans: 多个标签时使用半角逗号分隔 form: form - - name: replaceAllPathsWithAbsolutePaths - type: boolean - default: false - label: - en_US: All AbsolutePaths - zh_Hans: 使用绝对路径 - human_description: - en_US: Replace all relative paths with absolute paths for images and links. - zh_Hans: 将所有图片和链接的相对路径替换为绝对路径。 - form: form - - name: screenshot - type: boolean - default: false + - name: headers + type: string label: - en_US: screenshot - zh_Hans: 截图 + en_US: headers + zh_Hans: 请求头 human_description: - en_US: Include a screenshot of the top of the page that you are scraping. - zh_Hans: 提供正在抓取的页面的顶部的截图。 + en_US: | + Headers to send with the request. Can be used to send cookies, user-agent, etc. Example: {"cookies": "testcookies"} + zh_Hans: | + 随请求发送的头部。可以用来发送cookies、用户代理等。示例:{"cookies": "testcookies"} + placeholder: + en_US: Please enter an object that can be serialized in JSON + zh_Hans: 请输入可以json序列化的对象 form: form - name: waitFor type: number min: 0 + default: 0 label: en_US: wait For zh_Hans: 等待时间 @@ -124,57 +99,54 @@ parameters: en_US: Wait x amount of milliseconds for the page to load to fetch content. zh_Hans: 等待x毫秒以使页面加载并获取内容。 form: form -############## Extractor Options ####################### - - name: mode - type: select - options: - - value: markdown - label: - en_US: markdown - - value: llm-extraction - label: - en_US: llm-extraction - - value: llm-extraction-from-raw-html - label: - en_US: llm-extraction-from-raw-html - - value: llm-extraction-from-markdown - label: - en_US: llm-extraction-from-markdown - label: - en_US: Extractor Mode - zh_Hans: 提取模式 - human_description: - en_US: | - The extraction mode to use. 'markdown': Returns the scraped markdown content, does not perform LLM extraction. 'llm-extraction': Extracts information from the cleaned and parsed content using LLM. - zh_Hans: 使用的提取模式。“markdown”:返回抓取的markdown内容,不执行LLM提取。“llm-extractioin”:使用LLM按Extractor Schema从内容中提取信息。 - form: form - - name: extractionPrompt - type: string + - name: timeout + type: number + min: 0 + default: 30000 label: - en_US: Extractor Prompt - zh_Hans: 提取时的提示词 + en_US: Timeout human_description: - en_US: A prompt describing what information to extract from the page, applicable for LLM extraction modes. - zh_Hans: 当使用LLM提取模式时,用于给LLM描述提取规则。 + en_US: Timeout in milliseconds for the request. + zh_Hans: 请求的超时时间(以毫秒为单位)。 form: form - - name: extractionSchema +############## Extractor Options ####################### + - name: schema type: string label: en_US: Extractor Schema zh_Hans: 提取时的结构 placeholder: en_US: Please enter an object that can be serialized in JSON + zh_Hans: 请输入可以json序列化的对象 human_description: en_US: | - The schema for the data to be extracted, required only for LLM extraction modes. Example: { + The schema for the data to be extracted. Example: { "type": "object", "properties": {"company_mission": {"type": "string"}}, "required": ["company_mission"] } zh_Hans: | - 当使用LLM提取模式时,使用该结构去提取,示例:{ + 使用该结构去提取,示例:{ "type": "object", "properties": {"company_mission": {"type": "string"}}, "required": ["company_mission"] } form: form + - name: systemPrompt + type: string + label: + en_US: Extractor System Prompt + zh_Hans: 提取时的系统提示词 + human_description: + en_US: The system prompt to use for the extraction. + zh_Hans: 用于提取的系统提示。 + form: form + - name: prompt + type: string + label: + en_US: Extractor Prompt + zh_Hans: 提取时的提示词 + human_description: + en_US: The prompt to use for the extraction without a schema. + zh_Hans: 用于无schema时提取的提示词 + form: form diff --git a/api/core/tools/provider/builtin/firecrawl/tools/search.py b/api/core/tools/provider/builtin/firecrawl/tools/search.py deleted file mode 100644 index e2b2ac6b4dddb6..00000000000000 --- a/api/core/tools/provider/builtin/firecrawl/tools/search.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import Any - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.provider.builtin.firecrawl.firecrawl_appx import FirecrawlApp -from core.tools.tool.builtin_tool import BuiltinTool - - -class SearchTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: - """ - the pageOptions and searchOptions comes from doc here: - https://docs.firecrawl.dev/api-reference/endpoint/search - """ - app = FirecrawlApp(api_key=self.runtime.credentials['firecrawl_api_key'], - base_url=self.runtime.credentials['base_url']) - pageOptions = {} - pageOptions['onlyMainContent'] = tool_parameters.get('onlyMainContent', False) - pageOptions['fetchPageContent'] = tool_parameters.get('fetchPageContent', True) - pageOptions['includeHtml'] = tool_parameters.get('includeHtml', False) - pageOptions['includeRawHtml'] = tool_parameters.get('includeRawHtml', False) - searchOptions = {'limit': tool_parameters.get('limit')} - search_result = app.search( - query=tool_parameters['keyword'], - pageOptions=pageOptions, - searchOptions=searchOptions - ) - - return self.create_json_message(search_result) diff --git a/api/core/tools/provider/builtin/firecrawl/tools/search.yaml b/api/core/tools/provider/builtin/firecrawl/tools/search.yaml deleted file mode 100644 index 29df0cfaaaf412..00000000000000 --- a/api/core/tools/provider/builtin/firecrawl/tools/search.yaml +++ /dev/null @@ -1,75 +0,0 @@ -identity: - name: search - author: ahasasjeb - label: - en_US: Search - zh_Hans: 搜索 -description: - human: - en_US: Search, and output in Markdown format - zh_Hans: 搜索,并且以Markdown格式输出 - llm: This tool can perform online searches and convert the results to Markdown format. -parameters: - - name: keyword - type: string - required: true - label: - en_US: keyword - zh_Hans: 关键词 - human_description: - en_US: Input keywords to use Firecrawl API for search. - zh_Hans: 输入关键词即可使用Firecrawl API进行搜索。 - llm_description: Efficiently extract keywords from user text. - form: llm -############## Page Options ####################### - - name: onlyMainContent - type: boolean - default: false - label: - en_US: only Main Content - zh_Hans: 仅抓取主要内容 - human_description: - en_US: Only return the main content of the page excluding headers, navs, footers, etc. - zh_Hans: 只返回页面的主要内容,不包括头部、导航栏、尾部等。 - form: form - - name: fetchPageContent - type: boolean - default: true - label: - en_US: fetch Page Content - zh_Hans: 抓取页面内容 - human_description: - en_US: Fetch the content of each page. If false, defaults to a basic fast serp API. - zh_Hans: 获取每个页面的内容。如果为否,则使用基本的快速搜索结果页面API。 - form: form - - name: includeHtml - type: boolean - default: false - label: - en_US: include Html - zh_Hans: 包含HTML - human_description: - en_US: Include the HTML version of the content on page. Will output a html key in the response. - zh_Hans: 返回中包含一个HTML版本的内容,将以html键返回。 - form: form - - name: includeRawHtml - type: boolean - default: false - label: - en_US: include Raw Html - zh_Hans: 包含原始HTML - human_description: - en_US: Include the raw HTML content of the page. Will output a rawHtml key in the response. - zh_Hans: 返回中包含一个原始HTML版本的内容,将以rawHtml键返回。 - form: form -############## Search Options ####################### - - name: limit - type: number - min: 0 - label: - en_US: Maximum results - zh_Hans: 最大结果数量 - human_description: - en_US: Maximum number of results. Max is 20 during beta. - zh_Hans: 最大结果数量。在测试阶段,最大为20。 - form: form diff --git a/api/core/tools/provider/builtin/gaode/gaode.py b/api/core/tools/provider/builtin/gaode/gaode.py index b55d93e07b0d2b..49a8e537fb9070 100644 --- a/api/core/tools/provider/builtin/gaode/gaode.py +++ b/api/core/tools/provider/builtin/gaode/gaode.py @@ -9,17 +9,19 @@ class GaodeProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: - if 'api_key' not in credentials or not credentials.get('api_key'): + if "api_key" not in credentials or not credentials.get("api_key"): raise ToolProviderCredentialValidationError("Gaode API key is required.") try: - response = requests.get(url="https://restapi.amap.com/v3/geocode/geo?address={address}&key={apikey}" - "".format(address=urllib.parse.quote('广东省广州市天河区广州塔'), - apikey=credentials.get('api_key'))) - if response.status_code == 200 and (response.json()).get('info') == 'OK': + response = requests.get( + url="https://restapi.amap.com/v3/geocode/geo?address={address}&key={apikey}".format( + address=urllib.parse.quote("广东省广州市天河区广州塔"), apikey=credentials.get("api_key") + ) + ) + if response.status_code == 200 and (response.json()).get("info") == "OK": pass else: - raise ToolProviderCredentialValidationError((response.json()).get('info')) + raise ToolProviderCredentialValidationError((response.json()).get("info")) except Exception as e: raise ToolProviderCredentialValidationError("Gaode API Key is invalid. {}".format(e)) except Exception as e: diff --git a/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py b/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py index efd11cedce4238..ea06e2ce611cbc 100644 --- a/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py +++ b/api/core/tools/provider/builtin/gaode/tools/gaode_weather.py @@ -8,50 +8,57 @@ class GaodeRepositoriesTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - city = tool_parameters.get('city', '') + city = tool_parameters.get("city", "") if not city: - return self.create_text_message('Please tell me your city') + return self.create_text_message("Please tell me your city") - if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): return self.create_text_message("Gaode API key is required.") try: s = requests.session() - api_domain = 'https://restapi.amap.com/v3' - city_response = s.request(method='GET', headers={"Content-Type": "application/json; charset=utf-8"}, - url="{url}/config/district?keywords={keywords}" - "&subdistrict=0&extensions=base&key={apikey}" - "".format(url=api_domain, keywords=city, - apikey=self.runtime.credentials.get('api_key'))) + api_domain = "https://restapi.amap.com/v3" + city_response = s.request( + method="GET", + headers={"Content-Type": "application/json; charset=utf-8"}, + url="{url}/config/district?keywords={keywords}&subdistrict=0&extensions=base&key={apikey}".format( + url=api_domain, keywords=city, apikey=self.runtime.credentials.get("api_key") + ), + ) City_data = city_response.json() - if city_response.status_code == 200 and City_data.get('info') == 'OK': - if len(City_data.get('districts')) > 0: - CityCode = City_data['districts'][0]['adcode'] - weatherInfo_response = s.request(method='GET', - url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json" - "".format(url=api_domain, citycode=CityCode, - apikey=self.runtime.credentials.get('api_key'))) + if city_response.status_code == 200 and City_data.get("info") == "OK": + if len(City_data.get("districts")) > 0: + CityCode = City_data["districts"][0]["adcode"] + weatherInfo_response = s.request( + method="GET", + url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json" + "".format(url=api_domain, citycode=CityCode, apikey=self.runtime.credentials.get("api_key")), + ) weatherInfo_data = weatherInfo_response.json() - if weatherInfo_response.status_code == 200 and weatherInfo_data.get('info') == 'OK': + if weatherInfo_response.status_code == 200 and weatherInfo_data.get("info") == "OK": contents = [] - if len(weatherInfo_data.get('forecasts')) > 0: - for item in weatherInfo_data['forecasts'][0]['casts']: + if len(weatherInfo_data.get("forecasts")) > 0: + for item in weatherInfo_data["forecasts"][0]["casts"]: content = {} - content['date'] = item.get('date') - content['week'] = item.get('week') - content['dayweather'] = item.get('dayweather') - content['daytemp_float'] = item.get('daytemp_float') - content['daywind'] = item.get('daywind') - content['nightweather'] = item.get('nightweather') - content['nighttemp_float'] = item.get('nighttemp_float') + content["date"] = item.get("date") + content["week"] = item.get("week") + content["dayweather"] = item.get("dayweather") + content["daytemp_float"] = item.get("daytemp_float") + content["daywind"] = item.get("daywind") + content["nightweather"] = item.get("nightweather") + content["nighttemp_float"] = item.get("nighttemp_float") contents.append(content) s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False)) + ) s.close() - return self.create_text_message(f'No weather information for {city} was found.') + return self.create_text_message(f"No weather information for {city} was found.") except Exception as e: return self.create_text_message("Gaode API Key and Api Version is invalid. {}".format(e)) diff --git a/api/core/tools/provider/builtin/getimgai/getimgai.py b/api/core/tools/provider/builtin/getimgai/getimgai.py index c81d5fa333cd5d..bbd07d120fd0ea 100644 --- a/api/core/tools/provider/builtin/getimgai/getimgai.py +++ b/api/core/tools/provider/builtin/getimgai/getimgai.py @@ -7,16 +7,13 @@ class GetImgAIProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: # Example validation using the text2image tool - Text2ImageTool().fork_tool_runtime( - runtime={"credentials": credentials} - ).invoke( - user_id='', + Text2ImageTool().fork_tool_runtime(runtime={"credentials": credentials}).invoke( + user_id="", tool_parameters={ "prompt": "A fire egg", "response_format": "url", "style": "photorealism", - } + }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/getimgai/getimgai_appx.py b/api/core/tools/provider/builtin/getimgai/getimgai_appx.py index e28c57649cac4c..0e95a5f654505f 100644 --- a/api/core/tools/provider/builtin/getimgai/getimgai_appx.py +++ b/api/core/tools/provider/builtin/getimgai/getimgai_appx.py @@ -8,18 +8,16 @@ logger = logging.getLogger(__name__) + class GetImgAIApp: def __init__(self, api_key: str | None = None, base_url: str | None = None): self.api_key = api_key - self.base_url = base_url or 'https://api.getimg.ai/v1' + self.base_url = base_url or "https://api.getimg.ai/v1" if not self.api_key: raise ValueError("API key is required") def _prepare_headers(self): - headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} return headers def _request( @@ -38,22 +36,20 @@ def _request( return response.json() except requests.exceptions.RequestException as e: if i < retries - 1 and isinstance(e, HTTPError) and e.response.status_code >= 500: - time.sleep(backoff_factor * (2 ** i)) + time.sleep(backoff_factor * (2**i)) else: raise return None - def text2image( - self, mode: str, **kwargs - ): - data = kwargs['params'] - if not data.get('prompt'): + def text2image(self, mode: str, **kwargs): + data = kwargs["params"] + if not data.get("prompt"): raise ValueError("Prompt is required") - endpoint = f'{self.base_url}/{mode}/text-to-image' + endpoint = f"{self.base_url}/{mode}/text-to-image" headers = self._prepare_headers() logger.debug(f"Send request to {endpoint=} body={data}") - response = self._request('POST', endpoint, data, headers) + response = self._request("POST", endpoint, data, headers) if response is None: raise HTTPError("Failed to initiate getimg.ai after multiple retries") return response diff --git a/api/core/tools/provider/builtin/getimgai/tools/text2image.py b/api/core/tools/provider/builtin/getimgai/tools/text2image.py index dad7314479a89d..c556749552c8ef 100644 --- a/api/core/tools/provider/builtin/getimgai/tools/text2image.py +++ b/api/core/tools/provider/builtin/getimgai/tools/text2image.py @@ -7,28 +7,28 @@ class Text2ImageTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - app = GetImgAIApp(api_key=self.runtime.credentials['getimg_api_key'], base_url=self.runtime.credentials['base_url']) + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + app = GetImgAIApp( + api_key=self.runtime.credentials["getimg_api_key"], base_url=self.runtime.credentials["base_url"] + ) options = { - 'style': tool_parameters.get('style'), - 'prompt': tool_parameters.get('prompt'), - 'aspect_ratio': tool_parameters.get('aspect_ratio'), - 'output_format': tool_parameters.get('output_format', 'jpeg'), - 'response_format': tool_parameters.get('response_format', 'url'), - 'width': tool_parameters.get('width'), - 'height': tool_parameters.get('height'), - 'steps': tool_parameters.get('steps'), - 'negative_prompt': tool_parameters.get('negative_prompt'), - 'prompt_2': tool_parameters.get('prompt_2'), + "style": tool_parameters.get("style"), + "prompt": tool_parameters.get("prompt"), + "aspect_ratio": tool_parameters.get("aspect_ratio"), + "output_format": tool_parameters.get("output_format", "jpeg"), + "response_format": tool_parameters.get("response_format", "url"), + "width": tool_parameters.get("width"), + "height": tool_parameters.get("height"), + "steps": tool_parameters.get("steps"), + "negative_prompt": tool_parameters.get("negative_prompt"), + "prompt_2": tool_parameters.get("prompt_2"), } options = {k: v for k, v in options.items() if v} - text2image_result = app.text2image( - mode=tool_parameters.get('mode', 'essential-v2'), - params=options, - wait=True - ) + text2image_result = app.text2image(mode=tool_parameters.get("mode", "essential-v2"), params=options, wait=True) if not isinstance(text2image_result, str): text2image_result = json.dumps(text2image_result, ensure_ascii=False, indent=4) diff --git a/api/core/tools/provider/builtin/gitee_ai/_assets/icon.svg b/api/core/tools/provider/builtin/gitee_ai/_assets/icon.svg new file mode 100644 index 00000000000000..6dd75d1a6b5b44 --- /dev/null +++ b/api/core/tools/provider/builtin/gitee_ai/_assets/icon.svg @@ -0,0 +1,3 @@ + + + diff --git a/api/core/tools/provider/builtin/gitee_ai/gitee_ai.py b/api/core/tools/provider/builtin/gitee_ai/gitee_ai.py new file mode 100644 index 00000000000000..151cafec14b2b7 --- /dev/null +++ b/api/core/tools/provider/builtin/gitee_ai/gitee_ai.py @@ -0,0 +1,17 @@ +import requests + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class GiteeAIProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + url = "https://ai.gitee.com/api/base/account/me" + headers = { + "accept": "application/json", + "authorization": f"Bearer {credentials.get('api_key')}", + } + + response = requests.get(url, headers=headers) + if response.status_code != 200: + raise ToolProviderCredentialValidationError("GiteeAI API key is invalid") diff --git a/api/core/tools/provider/builtin/gitee_ai/gitee_ai.yaml b/api/core/tools/provider/builtin/gitee_ai/gitee_ai.yaml new file mode 100644 index 00000000000000..2e18f8a7fca56a --- /dev/null +++ b/api/core/tools/provider/builtin/gitee_ai/gitee_ai.yaml @@ -0,0 +1,22 @@ +identity: + author: Gitee AI + name: gitee_ai + label: + en_US: Gitee AI + zh_Hans: Gitee AI + description: + en_US: 快速体验大模型,领先探索 AI 开源世界 + zh_Hans: 快速体验大模型,领先探索 AI 开源世界 + icon: icon.svg + tags: + - image +credentials_for_provider: + api_key: + type: secret-input + required: true + label: + en_US: API Key + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + url: https://ai.gitee.com/dashboard/settings/tokens diff --git a/api/core/tools/provider/builtin/gitee_ai/tools/text-to-image.py b/api/core/tools/provider/builtin/gitee_ai/tools/text-to-image.py new file mode 100644 index 00000000000000..14291d17294472 --- /dev/null +++ b/api/core/tools/provider/builtin/gitee_ai/tools/text-to-image.py @@ -0,0 +1,33 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class GiteeAITool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + headers = { + "content-type": "application/json", + "authorization": f"Bearer {self.runtime.credentials['api_key']}", + } + + payload = { + "inputs": tool_parameters.get("inputs"), + "width": tool_parameters.get("width", "720"), + "height": tool_parameters.get("height", "720"), + } + model = tool_parameters.get("model", "Kolors") + url = f"https://ai.gitee.com/api/serverless/{model}/text-to-image" + + response = requests.post(url, json=payload, headers=headers) + if response.status_code != 200: + return self.create_text_message(f"Got Error Response:{response.text}") + + # The returned image is base64 and needs to be mark as an image + result = [self.create_blob_message(blob=response.content, meta={"mime_type": "image/jpeg"})] + + return result diff --git a/api/core/tools/provider/builtin/gitee_ai/tools/text-to-image.yaml b/api/core/tools/provider/builtin/gitee_ai/tools/text-to-image.yaml new file mode 100644 index 00000000000000..5e03f9abe9dfe4 --- /dev/null +++ b/api/core/tools/provider/builtin/gitee_ai/tools/text-to-image.yaml @@ -0,0 +1,72 @@ +identity: + name: text to image + author: gitee_ai + label: + en_US: text to image + icon: icon.svg +description: + human: + en_US: generate images using a variety of popular models + llm: This tool is used to generate image from text. +parameters: + - name: model + type: select + required: true + options: + - value: flux-1-schnell + label: + en_US: flux-1-schnell + - value: Kolors + label: + en_US: Kolors + - value: stable-diffusion-3-medium + label: + en_US: stable-diffusion-3-medium + - value: stable-diffusion-xl-base-1.0 + label: + en_US: stable-diffusion-xl-base-1.0 + - value: stable-diffusion-v1-4 + label: + en_US: stable-diffusion-v1-4 + default: Kolors + label: + en_US: Choose Image Model + zh_Hans: 选择生成图片的模型 + form: form + - name: inputs + type: string + required: true + label: + en_US: Input Text + zh_Hans: 输入文本 + human_description: + en_US: The text input used to generate the image. + zh_Hans: 用于生成图片的输入文本。 + llm_description: This text input will be used to generate image. + form: llm + - name: width + type: number + required: true + default: 720 + min: 1 + max: 1024 + label: + en_US: Image Width + zh_Hans: 图片宽度 + human_description: + en_US: The width of the generated image. + zh_Hans: 生成图片的宽度。 + form: form + - name: height + type: number + required: true + default: 720 + min: 1 + max: 1024 + label: + en_US: Image Height + zh_Hans: 图片高度 + human_description: + en_US: The height of the generated image. + zh_Hans: 生成图片的高度。 + form: form diff --git a/api/core/tools/provider/builtin/github/github.py b/api/core/tools/provider/builtin/github/github.py index 9275504208cbc9..87a34ac3e806ea 100644 --- a/api/core/tools/provider/builtin/github/github.py +++ b/api/core/tools/provider/builtin/github/github.py @@ -4,28 +4,28 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController -class GihubProvider(BuiltinToolProviderController): +class GithubProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: - if 'access_tokens' not in credentials or not credentials.get('access_tokens'): + if "access_tokens" not in credentials or not credentials.get("access_tokens"): raise ToolProviderCredentialValidationError("Github API Access Tokens is required.") - if 'api_version' not in credentials or not credentials.get('api_version'): - api_version = '2022-11-28' + if "api_version" not in credentials or not credentials.get("api_version"): + api_version = "2022-11-28" else: - api_version = credentials.get('api_version') + api_version = credentials.get("api_version") try: headers = { "Content-Type": "application/vnd.github+json", "Authorization": f"Bearer {credentials.get('access_tokens')}", - "X-GitHub-Api-Version": api_version + "X-GitHub-Api-Version": api_version, } response = requests.get( - url="https://api.github.com/search/users?q={account}".format(account='charli117'), - headers=headers) + url="https://api.github.com/search/users?q={account}".format(account="charli117"), headers=headers + ) if response.status_code != 200: - raise ToolProviderCredentialValidationError((response.json()).get('message')) + raise ToolProviderCredentialValidationError((response.json()).get("message")) except Exception as e: raise ToolProviderCredentialValidationError("Github API Key and Api Version is invalid. {}".format(e)) except Exception as e: diff --git a/api/core/tools/provider/builtin/github/tools/github_repositories.py b/api/core/tools/provider/builtin/github/tools/github_repositories.py index a2f1e07fd49d7a..32f9922e651785 100644 --- a/api/core/tools/provider/builtin/github/tools/github_repositories.py +++ b/api/core/tools/provider/builtin/github/tools/github_repositories.py @@ -9,54 +9,62 @@ from core.tools.tool.builtin_tool import BuiltinTool -class GihubRepositoriesTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: +class GithubRepositoriesTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - top_n = tool_parameters.get('top_n', 5) - query = tool_parameters.get('query', '') + top_n = tool_parameters.get("top_n", 5) + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input symbol') + return self.create_text_message("Please input symbol") - if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'): + if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"): return self.create_text_message("Github API Access Tokens is required.") - if 'api_version' not in self.runtime.credentials or not self.runtime.credentials.get('api_version'): - api_version = '2022-11-28' + if "api_version" not in self.runtime.credentials or not self.runtime.credentials.get("api_version"): + api_version = "2022-11-28" else: - api_version = self.runtime.credentials.get('api_version') + api_version = self.runtime.credentials.get("api_version") try: headers = { "Content-Type": "application/vnd.github+json", "Authorization": f"Bearer {self.runtime.credentials.get('access_tokens')}", - "X-GitHub-Api-Version": api_version + "X-GitHub-Api-Version": api_version, } s = requests.session() - api_domain = 'https://api.github.com' - response = s.request(method='GET', headers=headers, - url=f"{api_domain}/search/repositories?" - f"q={quote(query)}&sort=stars&per_page={top_n}&order=desc") + api_domain = "https://api.github.com" + response = s.request( + method="GET", + headers=headers, + url=f"{api_domain}/search/repositories?q={quote(query)}&sort=stars&per_page={top_n}&order=desc", + ) response_data = response.json() - if response.status_code == 200 and isinstance(response_data.get('items'), list): + if response.status_code == 200 and isinstance(response_data.get("items"), list): contents = [] - if len(response_data.get('items')) > 0: - for item in response_data.get('items'): + if len(response_data.get("items")) > 0: + for item in response_data.get("items"): content = {} - updated_at_object = datetime.strptime(item['updated_at'], "%Y-%m-%dT%H:%M:%SZ") - content['owner'] = item['owner']['login'] - content['name'] = item['name'] - content['description'] = item['description'][:100] + '...' if len(item['description']) > 100 else item['description'] - content['url'] = item['html_url'] - content['star'] = item['watchers'] - content['forks'] = item['forks'] - content['updated'] = updated_at_object.strftime("%Y-%m-%d") + updated_at_object = datetime.strptime(item["updated_at"], "%Y-%m-%dT%H:%M:%SZ") + content["owner"] = item["owner"]["login"] + content["name"] = item["name"] + content["description"] = ( + item["description"][:100] + "..." if len(item["description"]) > 100 else item["description"] + ) + content["url"] = item["html_url"] + content["star"] = item["watchers"] + content["forks"] = item["forks"] + content["updated"] = updated_at_object.strftime("%Y-%m-%d") contents.append(content) s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(contents, ensure_ascii=False)) + ) else: - return self.create_text_message(f'No items related to {query} were found.') + return self.create_text_message(f"No items related to {query} were found.") else: - return self.create_text_message((response.json()).get('message')) + return self.create_text_message((response.json()).get("message")) except Exception as e: return self.create_text_message("Github API Key and Api Version is invalid. {}".format(e)) diff --git a/api/core/tools/provider/builtin/gitlab/gitlab.py b/api/core/tools/provider/builtin/gitlab/gitlab.py index 0c13ec662a4f98..9bd4a0bd52ea64 100644 --- a/api/core/tools/provider/builtin/gitlab/gitlab.py +++ b/api/core/tools/provider/builtin/gitlab/gitlab.py @@ -9,13 +9,13 @@ class GitlabProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - if 'access_tokens' not in credentials or not credentials.get('access_tokens'): + if "access_tokens" not in credentials or not credentials.get("access_tokens"): raise ToolProviderCredentialValidationError("Gitlab Access Tokens is required.") - - if 'site_url' not in credentials or not credentials.get('site_url'): - site_url = 'https://gitlab.com' + + if "site_url" not in credentials or not credentials.get("site_url"): + site_url = "https://gitlab.com" else: - site_url = credentials.get('site_url') + site_url = credentials.get("site_url") try: headers = { @@ -23,12 +23,10 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "Authorization": f"Bearer {credentials.get('access_tokens')}", } - response = requests.get( - url= f"{site_url}/api/v4/user", - headers=headers) + response = requests.get(url=f"{site_url}/api/v4/user", headers=headers) if response.status_code != 200: - raise ToolProviderCredentialValidationError((response.json()).get('message')) + raise ToolProviderCredentialValidationError((response.json()).get("message")) except Exception as e: raise ToolProviderCredentialValidationError("Gitlab Access Tokens is invalid. {}".format(e)) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py index 880d722bda8e2f..716da7c8c110f3 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.py @@ -1,4 +1,5 @@ import json +import urllib.parse from datetime import datetime, timedelta from typing import Any, Union @@ -9,103 +10,125 @@ class GitlabCommitsTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - project = tool_parameters.get('project', '') - employee = tool_parameters.get('employee', '') - start_time = tool_parameters.get('start_time', '') - end_time = tool_parameters.get('end_time', '') - change_type = tool_parameters.get('change_type', 'all') - - if not project: - return self.create_text_message('Project is required') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + branch = tool_parameters.get("branch", "") + repository = tool_parameters.get("repository", "") + employee = tool_parameters.get("employee", "") + start_time = tool_parameters.get("start_time", "") + end_time = tool_parameters.get("end_time", "") + change_type = tool_parameters.get("change_type", "all") + + if not repository: + return self.create_text_message("Either repository is required") if not start_time: start_time = (datetime.utcnow() - timedelta(days=1)).isoformat() if not end_time: end_time = datetime.utcnow().isoformat() - access_token = self.runtime.credentials.get('access_tokens') - site_url = self.runtime.credentials.get('site_url') + access_token = self.runtime.credentials.get("access_tokens") + site_url = self.runtime.credentials.get("site_url") - if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'): + if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"): return self.create_text_message("Gitlab API Access Tokens is required.") - if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'): - site_url = 'https://gitlab.com' - + if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"): + site_url = "https://gitlab.com" + # Get commit content - result = self.fetch(user_id, site_url, access_token, project, employee, start_time, end_time, change_type) + result = self.fetch_commits( + site_url, access_token, repository, branch, employee, start_time, end_time, change_type, is_repository=True + ) return [self.create_json_message(item) for item in result] - - def fetch(self,user_id: str, site_url: str, access_token: str, project: str, employee: str = None, start_time: str = '', end_time: str = '', change_type: str = '') -> list[dict[str, Any]]: + + def fetch_commits( + self, + site_url: str, + access_token: str, + repository: str, + branch: str, + employee: str, + start_time: str, + end_time: str, + change_type: str, + is_repository: bool, + ) -> list[dict[str, Any]]: domain = site_url headers = {"PRIVATE-TOKEN": access_token} results = [] try: - # Get all of projects - url = f"{domain}/api/v4/projects" - response = requests.get(url, headers=headers) - response.raise_for_status() - projects = response.json() - - filtered_projects = [p for p in projects if project == "*" or p['name'] == project] - - for project in filtered_projects: - project_id = project['id'] - project_name = project['name'] - print(f"Project: {project_name}") - - # Get all of proejct commits - commits_url = f"{domain}/api/v4/projects/{project_id}/repository/commits" - params = { - 'since': start_time, - 'until': end_time - } - if employee: - params['author'] = employee - - commits_response = requests.get(commits_url, headers=headers, params=params) - commits_response.raise_for_status() - commits = commits_response.json() - - for commit in commits: - commit_sha = commit['id'] - author_name = commit['author_name'] - - diff_url = f"{domain}/api/v4/projects/{project_id}/repository/commits/{commit_sha}/diff" - diff_response = requests.get(diff_url, headers=headers) - diff_response.raise_for_status() - diffs = diff_response.json() - - for diff in diffs: - # Caculate code lines of changed - added_lines = diff['diff'].count('\n+') - removed_lines = diff['diff'].count('\n-') - total_changes = added_lines + removed_lines - - if change_type == "new": - if added_lines > 1: - final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if line.startswith('+') and not line.startswith('+++')]) - results.append({ + # URL encode the repository path + encoded_repository = urllib.parse.quote(repository, safe="") + commits_url = f"{domain}/api/v4/projects/{encoded_repository}/repository/commits" + + # Fetch commits for the repository + params = {"since": start_time, "until": end_time} + if branch: + params["ref_name"] = branch + if employee: + params["author"] = employee + + commits_response = requests.get(commits_url, headers=headers, params=params) + commits_response.raise_for_status() + commits = commits_response.json() + + for commit in commits: + commit_sha = commit["id"] + author_name = commit["author_name"] + + diff_url = f"{domain}/api/v4/projects/{encoded_repository}/repository/commits/{commit_sha}/diff" + + diff_response = requests.get(diff_url, headers=headers) + diff_response.raise_for_status() + diffs = diff_response.json() + + for diff in diffs: + # Calculate code lines of changes + added_lines = diff["diff"].count("\n+") + removed_lines = diff["diff"].count("\n-") + total_changes = added_lines + removed_lines + + if change_type == "new": + if added_lines > 1: + final_code = "".join( + [ + line[1:] + for line in diff["diff"].split("\n") + if line.startswith("+") and not line.startswith("+++") + ] + ) + results.append( + { + "diff_url": diff_url, "commit_sha": commit_sha, "author_name": author_name, - "diff": final_code - }) - else: - if total_changes > 1: - final_code = ''.join([line[1:] for line in diff['diff'].split('\n') if (line.startswith('+') or line.startswith('-')) and not line.startswith('+++') and not line.startswith('---')]) - final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code - results.append({ + "diff": final_code, + } + ) + else: + if total_changes > 1: + final_code = "".join( + [ + line[1:] + for line in diff["diff"].split("\n") + if (line.startswith("+") or line.startswith("-")) + and not line.startswith("+++") + and not line.startswith("---") + ] + ) + final_code_escaped = json.dumps(final_code)[1:-1] # Escape the final code + results.append( + { + "diff_url": diff_url, "commit_sha": commit_sha, "author_name": author_name, - "diff": final_code_escaped - }) + "diff": final_code_escaped, + } + ) except requests.RequestException as e: print(f"Error fetching data from GitLab: {e}") - - return results \ No newline at end of file + + return results diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml index dd4e31d6633d37..2ff5fb570ecc42 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_commits.yaml @@ -6,7 +6,7 @@ identity: zh_Hans: GitLab 提交内容查询 description: human: - en_US: A tool for query GitLab commits, Input should be a exists username or projec. + en_US: A tool for query GitLab commits, Input should be a exists username or project. zh_Hans: 一个用于查询 GitLab 代码提交内容的工具,输入的内容应该是一个已存在的用户名或者项目名。 llm: A tool for query GitLab commits, Input should be a exists username or project. parameters: @@ -21,16 +21,27 @@ parameters: zh_Hans: 员工用户名 llm_description: User name for GitLab form: llm - - name: project + - name: repository type: string required: true label: - en_US: project - zh_Hans: 项目名 + en_US: repository + zh_Hans: 仓库路径 human_description: - en_US: project - zh_Hans: 项目名 - llm_description: project for GitLab + en_US: repository + zh_Hans: 仓库路径,以namespace/project_name的形式。 + llm_description: Repository path for GitLab, like namespace/project_name. + form: llm + - name: branch + type: string + required: false + label: + en_US: branch + zh_Hans: 分支名 + human_description: + en_US: branch + zh_Hans: 分支名 + llm_description: branch for GitLab form: llm - name: start_time type: string diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py index 7fa1d0d1124bab..1e77f3c6dfc678 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py @@ -1,3 +1,4 @@ +import urllib.parse from typing import Any, Union import requests @@ -7,89 +8,96 @@ class GitlabFilesTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - project = tool_parameters.get('project', '') - branch = tool_parameters.get('branch', '') - path = tool_parameters.get('path', '') - - - if not project: - return self.create_text_message('Project is required') - if not branch: - return self.create_text_message('Branch is required') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + project = tool_parameters.get("project", "") + repository = tool_parameters.get("repository", "") + branch = tool_parameters.get("branch", "") + path = tool_parameters.get("path", "") + if not project and not repository: + return self.create_text_message("Either project or repository is required") + if not branch: + return self.create_text_message("Branch is required") if not path: - return self.create_text_message('Path is required') + return self.create_text_message("Path is required") - access_token = self.runtime.credentials.get('access_tokens') - site_url = self.runtime.credentials.get('site_url') + access_token = self.runtime.credentials.get("access_tokens") + site_url = self.runtime.credentials.get("site_url") - if 'access_tokens' not in self.runtime.credentials or not self.runtime.credentials.get('access_tokens'): + if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"): return self.create_text_message("Gitlab API Access Tokens is required.") - if 'site_url' not in self.runtime.credentials or not self.runtime.credentials.get('site_url'): - site_url = 'https://gitlab.com' - - # Get project ID from project name - project_id = self.get_project_id(site_url, access_token, project) - if not project_id: - return self.create_text_message(f"Project '{project}' not found.") + if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"): + site_url = "https://gitlab.com" - # Get commit content - result = self.fetch(user_id, project_id, site_url, access_token, branch, path) + # Get file content + if repository: + result = self.fetch_files(site_url, access_token, repository, branch, path, is_repository=True) + else: + result = self.fetch_files(site_url, access_token, project, branch, path, is_repository=False) return [self.create_json_message(item) for item in result] - - def extract_project_name_and_path(self, path: str) -> tuple[str, str]: - parts = path.split('/', 1) - if len(parts) < 2: - return None, None - return parts[0], parts[1] - def get_project_id(self, site_url: str, access_token: str, project_name: str) -> Union[str, None]: - headers = {"PRIVATE-TOKEN": access_token} - try: - url = f"{site_url}/api/v4/projects?search={project_name}" - response = requests.get(url, headers=headers) - response.raise_for_status() - projects = response.json() - for project in projects: - if project['name'] == project_name: - return project['id'] - except requests.RequestException as e: - print(f"Error fetching project ID from GitLab: {e}") - return None - - def fetch(self,user_id: str, project_id: str, site_url: str, access_token: str, branch: str, path: str = None) -> list[dict[str, Any]]: + def fetch_files( + self, site_url: str, access_token: str, identifier: str, branch: str, path: str, is_repository: bool + ) -> list[dict[str, Any]]: domain = site_url headers = {"PRIVATE-TOKEN": access_token} results = [] try: - # List files and directories in the given path - url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}" - response = requests.get(url, headers=headers) + if is_repository: + # URL encode the repository path + encoded_identifier = urllib.parse.quote(identifier, safe="") + tree_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/tree?path={path}&ref={branch}" + else: + # Get project ID from project name + project_id = self.get_project_id(site_url, access_token, identifier) + if not project_id: + return self.create_text_message(f"Project '{identifier}' not found.") + tree_url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}" + + response = requests.get(tree_url, headers=headers) response.raise_for_status() items = response.json() for item in items: - item_path = item['path'] - if item['type'] == 'tree': # It's a directory - results.extend(self.fetch(project_id, site_url, access_token, branch, item_path)) + item_path = item["path"] + if item["type"] == "tree": # It's a directory + results.extend( + self.fetch_files(site_url, access_token, identifier, branch, item_path, is_repository) + ) else: # It's a file - file_url = f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}" + if is_repository: + file_url = ( + f"{domain}/api/v4/projects/{encoded_identifier}/repository/files" + f"/{item_path}/raw?ref={branch}" + ) + else: + file_url = ( + f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}" + ) + file_response = requests.get(file_url, headers=headers) file_response.raise_for_status() file_content = file_response.text - results.append({ - "path": item_path, - "branch": branch, - "content": file_content - }) + results.append({"path": item_path, "branch": branch, "content": file_content}) except requests.RequestException as e: print(f"Error fetching data from GitLab: {e}") - - return results \ No newline at end of file + + return results + + def get_project_id(self, site_url: str, access_token: str, project_name: str) -> Union[str, None]: + headers = {"PRIVATE-TOKEN": access_token} + try: + url = f"{site_url}/api/v4/projects?search={project_name}" + response = requests.get(url, headers=headers) + response.raise_for_status() + projects = response.json() + for project in projects: + if project["name"] == project_name: + return project["id"] + except requests.RequestException as e: + print(f"Error fetching project ID from GitLab: {e}") + return None diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml index d99b6254c1b99c..4c733673f15254 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.yaml @@ -10,9 +10,20 @@ description: zh_Hans: 一个用于查询 GitLab 文件的工具,输入的内容应该是分支和一个已存在文件或者文件夹路径。 llm: A tool for query GitLab files, Input should be a exists file or directory path. parameters: + - name: repository + type: string + required: false + label: + en_US: repository + zh_Hans: 仓库路径 + human_description: + en_US: repository + zh_Hans: 仓库路径,以namespace/project_name的形式。 + llm_description: Repository path for GitLab, like namespace/project_name. + form: llm - name: project type: string - required: true + required: false label: en_US: project zh_Hans: 项目 diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_mergerequests.py b/api/core/tools/provider/builtin/gitlab/tools/gitlab_mergerequests.py new file mode 100644 index 00000000000000..ef99fa82e9d9d6 --- /dev/null +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_mergerequests.py @@ -0,0 +1,78 @@ +import urllib.parse +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class GitlabMergeRequestsTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + repository = tool_parameters.get("repository", "") + branch = tool_parameters.get("branch", "") + start_time = tool_parameters.get("start_time", "") + end_time = tool_parameters.get("end_time", "") + state = tool_parameters.get("state", "opened") # Default to "opened" + + if not repository: + return self.create_text_message("Repository is required") + + access_token = self.runtime.credentials.get("access_tokens") + site_url = self.runtime.credentials.get("site_url") + + if not access_token: + return self.create_text_message("Gitlab API Access Tokens is required.") + if not site_url: + site_url = "https://gitlab.com" + + # Get merge requests + result = self.get_merge_requests(site_url, access_token, repository, branch, start_time, end_time, state) + + return [self.create_json_message(item) for item in result] + + def get_merge_requests( + self, site_url: str, access_token: str, repository: str, branch: str, start_time: str, end_time: str, state: str + ) -> list[dict[str, Any]]: + domain = site_url + headers = {"PRIVATE-TOKEN": access_token} + results = [] + + try: + # URL encode the repository path + encoded_repository = urllib.parse.quote(repository, safe="") + merge_requests_url = f"{domain}/api/v4/projects/{encoded_repository}/merge_requests" + params = {"state": state} + + # Add time filters if provided + if start_time: + params["created_after"] = start_time + if end_time: + params["created_before"] = end_time + + response = requests.get(merge_requests_url, headers=headers, params=params) + response.raise_for_status() + merge_requests = response.json() + + for mr in merge_requests: + # Filter by target branch + if branch and mr["target_branch"] != branch: + continue + + results.append( + { + "id": mr["id"], + "title": mr["title"], + "author": mr["author"]["name"], + "web_url": mr["web_url"], + "target_branch": mr["target_branch"], + "created_at": mr["created_at"], + "state": mr["state"], + } + ) + except requests.RequestException as e: + print(f"Error fetching merge requests from GitLab: {e}") + + return results diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_mergerequests.yaml b/api/core/tools/provider/builtin/gitlab/tools/gitlab_mergerequests.yaml new file mode 100644 index 00000000000000..4c886b69c03a32 --- /dev/null +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_mergerequests.yaml @@ -0,0 +1,77 @@ +identity: + name: gitlab_mergerequests + author: Leo.Wang + label: + en_US: GitLab Merge Requests + zh_Hans: GitLab 合并请求查询 +description: + human: + en_US: A tool for query GitLab merge requests, Input should be a exists reposity or branch. + zh_Hans: 一个用于查询 GitLab 代码合并请求的工具,输入的内容应该是一个已存在的仓库名或者分支。 + llm: A tool for query GitLab merge requests, Input should be a exists reposity or branch. +parameters: + - name: repository + type: string + required: false + label: + en_US: repository + zh_Hans: 仓库路径 + human_description: + en_US: repository + zh_Hans: 仓库路径,以namespace/project_name的形式。 + llm_description: Repository path for GitLab, like namespace/project_name. + form: llm + - name: branch + type: string + required: false + label: + en_US: branch + zh_Hans: 分支名 + human_description: + en_US: branch + zh_Hans: 分支名 + llm_description: branch for GitLab + form: llm + - name: start_time + type: string + required: false + label: + en_US: start_time + zh_Hans: 开始时间 + human_description: + en_US: start_time + zh_Hans: 开始时间 + llm_description: Start time for GitLab + form: llm + - name: end_time + type: string + required: false + label: + en_US: end_time + zh_Hans: 结束时间 + human_description: + en_US: end_time + zh_Hans: 结束时间 + llm_description: End time for GitLab + form: llm + - name: state + type: select + required: false + options: + - value: opened + label: + en_US: opened + zh_Hans: 打开 + - value: closed + label: + en_US: closed + zh_Hans: 关闭 + default: opened + label: + en_US: state + zh_Hans: 变更状态 + human_description: + en_US: state + zh_Hans: 变更状态 + llm_description: Merge request state type for GitLab + form: llm diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_projects.py b/api/core/tools/provider/builtin/gitlab/tools/gitlab_projects.py new file mode 100644 index 00000000000000..ea0c028b4f3d07 --- /dev/null +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_projects.py @@ -0,0 +1,81 @@ +import urllib.parse +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class GitlabProjectsTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + project_name = tool_parameters.get("project_name", "") + page = tool_parameters.get("page", 1) + page_size = tool_parameters.get("page_size", 20) + + access_token = self.runtime.credentials.get("access_tokens") + site_url = self.runtime.credentials.get("site_url") + + if not access_token: + return self.create_text_message("Gitlab API Access Tokens is required.") + if not site_url: + site_url = "https://gitlab.com" + + # Get project content + result = self.fetch_projects(site_url, access_token, project_name, page, page_size) + + return [self.create_json_message(item) for item in result] + + def fetch_projects( + self, + site_url: str, + access_token: str, + project_name: str, + page: str, + page_size: str, + ) -> list[dict[str, Any]]: + domain = site_url + headers = {"PRIVATE-TOKEN": access_token} + results = [] + + try: + if project_name: + # URL encode the project name for the search query + encoded_project_name = urllib.parse.quote(project_name, safe="") + projects_url = ( + f"{domain}/api/v4/projects?search={encoded_project_name}&page={page}&per_page={page_size}" + ) + else: + projects_url = f"{domain}/api/v4/projects?page={page}&per_page={page_size}" + + response = requests.get(projects_url, headers=headers) + response.raise_for_status() + projects = response.json() + + for project in projects: + # Filter projects by exact name match if necessary + if project_name and project["name"].lower() == project_name.lower(): + results.append( + { + "id": project["id"], + "name": project["name"], + "description": project.get("description", ""), + "web_url": project["web_url"], + } + ) + elif not project_name: + # If no specific project name is provided, add all projects + results.append( + { + "id": project["id"], + "name": project["name"], + "description": project.get("description", ""), + "web_url": project["web_url"], + } + ) + except requests.RequestException as e: + print(f"Error fetching data from GitLab: {e}") + + return results diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_projects.yaml b/api/core/tools/provider/builtin/gitlab/tools/gitlab_projects.yaml new file mode 100644 index 00000000000000..5fe098e1f7a647 --- /dev/null +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_projects.yaml @@ -0,0 +1,45 @@ +identity: + name: gitlab_projects + author: Leo.Wang + label: + en_US: GitLab Projects + zh_Hans: GitLab 项目列表查询 +description: + human: + en_US: A tool for query GitLab projects, Input should be a project name. + zh_Hans: 一个用于查询 GitLab 项目列表的工具,输入的内容应该是一个项目名称。 + llm: A tool for query GitLab projects, Input should be a project name. +parameters: + - name: project_name + type: string + required: false + label: + en_US: project_name + zh_Hans: 项目名称 + human_description: + en_US: project_name + zh_Hans: 项目名称 + llm_description: Project name for GitLab + form: llm + - name: page + type: string + required: false + label: + en_US: page + zh_Hans: 页码 + human_description: + en_US: page + zh_Hans: 页码 + llm_description: Page index for GitLab + form: llm + - name: page_size + type: string + required: false + label: + en_US: page_size + zh_Hans: 每页数量 + human_description: + en_US: page_size + zh_Hans: 每页数量 + llm_description: Page size for GitLab + form: llm diff --git a/api/core/tools/provider/builtin/google/google.py b/api/core/tools/provider/builtin/google/google.py index 8f4b9a4a4e9784..6b5395f9d3e5b8 100644 --- a/api/core/tools/provider/builtin/google/google.py +++ b/api/core/tools/provider/builtin/google/google.py @@ -13,12 +13,8 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "query": "test", - "result_type": "link" - }, + user_id="", + tool_parameters={"query": "test", "result_type": "link"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/google/tools/google_search.py b/api/core/tools/provider/builtin/google/tools/google_search.py index 09d0326fb4a0b7..a9f65925d86f94 100644 --- a/api/core/tools/provider/builtin/google/tools/google_search.py +++ b/api/core/tools/provider/builtin/google/tools/google_search.py @@ -9,7 +9,6 @@ class GoogleSearchTool(BuiltinTool): - def _parse_response(self, response: dict) -> dict: result = {} if "knowledge_graph" in response: @@ -17,25 +16,23 @@ def _parse_response(self, response: dict) -> dict: result["description"] = response["knowledge_graph"].get("description", "") if "organic_results" in response: result["organic_results"] = [ - { - "title": item.get("title", ""), - "link": item.get("link", ""), - "snippet": item.get("snippet", "") - } + {"title": item.get("title", ""), "link": item.get("link", ""), "snippet": item.get("snippet", "")} for item in response["organic_results"] ] return result - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: params = { - "api_key": self.runtime.credentials['serpapi_api_key'], - "q": tool_parameters['query'], + "api_key": self.runtime.credentials["serpapi_api_key"], + "q": tool_parameters["query"], "engine": "google", "google_domain": "google.com", "gl": "us", - "hl": "en" + "hl": "en", } response = requests.get(url=SERP_API_URL, params=params) response.raise_for_status() diff --git a/api/core/tools/provider/builtin/google_translate/google_translate.py b/api/core/tools/provider/builtin/google_translate/google_translate.py index f6e1d65834798b..ea53aa4eeb906f 100644 --- a/api/core/tools/provider/builtin/google_translate/google_translate.py +++ b/api/core/tools/provider/builtin/google_translate/google_translate.py @@ -8,10 +8,6 @@ class JsonExtractProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - GoogleTranslate().invoke(user_id='', - tool_parameters={ - "content": "这是一段测试文本", - "dest": "en" - }) + GoogleTranslate().invoke(user_id="", tool_parameters={"content": "这是一段测试文本", "dest": "en"}) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/google_translate/tools/translate.py b/api/core/tools/provider/builtin/google_translate/tools/translate.py index 4314182b06dbbc..ea3f2077d5d485 100644 --- a/api/core/tools/provider/builtin/google_translate/tools/translate.py +++ b/api/core/tools/provider/builtin/google_translate/tools/translate.py @@ -7,46 +7,41 @@ class GoogleTranslate(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - dest = tool_parameters.get('dest', '') + dest = tool_parameters.get("dest", "") if not dest: - return self.create_text_message('Invalid parameter destination language') + return self.create_text_message("Invalid parameter destination language") try: result = self._translate(content, dest) return self.create_text_message(str(result)) except Exception: - return self.create_text_message('Translation service error, please check the network') + return self.create_text_message("Translation service error, please check the network") def _translate(self, content: str, dest: str) -> str: try: url = "https://translate.googleapis.com/translate_a/single" - params = { - "client": "gtx", - "sl": "auto", - "tl": dest, - "dt": "t", - "q": content - } + params = {"client": "gtx", "sl": "auto", "tl": dest, "dt": "t", "q": content} headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)" + " Chrome/91.0.4472.124 Safari/537.36" } - response_json = requests.get( - url, params=params, headers=headers).json() + response_json = requests.get(url, params=params, headers=headers).json() result = response_json[0] - translated_text = ''.join([item[0] for item in result if item[0]]) + translated_text = "".join([item[0] for item in result if item[0]]) return str(translated_text) except Exception as e: return str(e) diff --git a/api/core/tools/provider/builtin/hap/hap.py b/api/core/tools/provider/builtin/hap/hap.py index e0a48e05a5ef8c..cbdf9504659568 100644 --- a/api/core/tools/provider/builtin/hap/hap.py +++ b/api/core/tools/provider/builtin/hap/hap.py @@ -5,4 +5,4 @@ class HapProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: - pass \ No newline at end of file + pass diff --git a/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py b/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py index 0e101dc67daa13..597adc91db9768 100644 --- a/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py +++ b/api/core/tools/provider/builtin/hap/tools/add_worksheet_record.py @@ -8,41 +8,40 @@ class AddWorksheetRecordTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - record_data = tool_parameters.get('record_data', '') + return self.create_text_message("Invalid parameter Worksheet ID") + record_data = tool_parameters.get("record_data", "") if not record_data: - return self.create_text_message('Invalid parameter Record Row Data') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Record Row Data") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/v2/open/worksheet/addRow" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id} try: - payload['controls'] = json.loads(record_data) + payload["controls"] = json.loads(record_data) res = httpx.post(url, headers=headers, json=payload, timeout=60) res.raise_for_status() res_json = res.json() - if res_json.get('error_code') != 1: + if res_json.get("error_code") != 1: return self.create_text_message(f"Failed to add the new record. {res_json['error_msg']}") return self.create_text_message(f"New record added successfully. The record ID is {res_json['data']}.") except httpx.RequestError as e: diff --git a/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py b/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py index ba25952c9f4ea6..5d42af4c490598 100644 --- a/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py +++ b/api/core/tools/provider/builtin/hap/tools/delete_worksheet_record.py @@ -7,43 +7,42 @@ class DeleteWorksheetRecordTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - row_id = tool_parameters.get('row_id', '') + return self.create_text_message("Invalid parameter Worksheet ID") + row_id = tool_parameters.get("row_id", "") if not row_id: - return self.create_text_message('Invalid parameter Record Row ID') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Record Row ID") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/v2/open/worksheet/deleteRow" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id, "rowId": row_id} try: res = httpx.post(url, headers=headers, json=payload, timeout=30) res.raise_for_status() res_json = res.json() - if res_json.get('error_code') != 1: + if res_json.get("error_code") != 1: return self.create_text_message(f"Failed to delete the record. {res_json['error_msg']}") return self.create_text_message("Successfully deleted the record.") except httpx.RequestError as e: return self.create_text_message(f"Failed to delete the record, request error: {e}") except Exception as e: - return self.create_text_message(f"Failed to delete the record, unexpected error: {e}") \ No newline at end of file + return self.create_text_message(f"Failed to delete the record, unexpected error: {e}") diff --git a/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py b/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py index 2c46d9dd4e7392..6887b8b4e99df6 100644 --- a/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py +++ b/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py @@ -8,43 +8,42 @@ class GetWorksheetFieldsTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Worksheet ID") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/v2/open/worksheet/getWorksheetInfo" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id} try: res = httpx.post(url, headers=headers, json=payload, timeout=60) res.raise_for_status() res_json = res.json() - if res_json.get('error_code') != 1: + if res_json.get("error_code") != 1: return self.create_text_message(f"Failed to get the worksheet information. {res_json['error_msg']}") - - fields_json, fields_table = self.get_controls(res_json['data']['controls']) - result_type = tool_parameters.get('result_type', 'table') + + fields_json, fields_table = self.get_controls(res_json["data"]["controls"]) + result_type = tool_parameters.get("result_type", "table") return self.create_text_message( - text=json.dumps(fields_json, ensure_ascii=False) if result_type == 'json' else fields_table + text=json.dumps(fields_json, ensure_ascii=False) if result_type == "json" else fields_table ) except httpx.RequestError as e: return self.create_text_message(f"Failed to get the worksheet information, request error: {e}") @@ -88,61 +87,66 @@ def get_field_type_by_id(self, field_type_id: int) -> str: 50: "Text", 51: "Query Record", } - return field_type_map.get(field_type_id, '') + return field_type_map.get(field_type_id, "") def get_controls(self, controls: list) -> dict: fields = [] - fields_list = ['|fieldId|fieldName|fieldType|fieldTypeId|description|options|','|'+'---|'*6] + fields_list = ["|fieldId|fieldName|fieldType|fieldTypeId|description|options|", "|" + "---|" * 6] for control in controls: - if control['type'] in self._get_ignore_types(): + if control["type"] in self._get_ignore_types(): continue - field_type_id = control['type'] - field_type = self.get_field_type_by_id(control['type']) + field_type_id = control["type"] + field_type = self.get_field_type_by_id(control["type"]) if field_type_id == 30: - source_type = control['sourceControl']['type'] + source_type = control["sourceControl"]["type"] if source_type in self._get_ignore_types(): continue else: field_type_id = source_type field_type = self.get_field_type_by_id(source_type) field = { - 'id': control['controlId'], - 'name': control['controlName'], - 'type': field_type, - 'typeId': field_type_id, - 'description': control['remark'].replace('\n', ' ').replace('\t', ' '), - 'options': self._extract_options(control), + "id": control["controlId"], + "name": control["controlName"], + "type": field_type, + "typeId": field_type_id, + "description": control["remark"].replace("\n", " ").replace("\t", " "), + "options": self._extract_options(control), } fields.append(field) - fields_list.append(f"|{field['id']}|{field['name']}|{field['type']}|{field['typeId']}|{field['description']}|{field['options'] if field['options'] else ''}|") + fields_list.append( + f"|{field['id']}|{field['name']}|{field['type']}|{field['typeId']}|{field['description']}" + f"|{field['options'] or ''}|" + ) - fields.append({ - 'id': 'ctime', - 'name': 'Created Time', - 'type': self.get_field_type_by_id(16), - 'typeId': 16, - 'description': '', - 'options': [] - }) + fields.append( + { + "id": "ctime", + "name": "Created Time", + "type": self.get_field_type_by_id(16), + "typeId": 16, + "description": "", + "options": [], + } + ) fields_list.append("|ctime|Created Time|Date|16|||") - return fields, '\n'.join(fields_list) + return fields, "\n".join(fields_list) def _extract_options(self, control: dict) -> list: options = [] - if control['type'] in [9, 10, 11]: - options.extend([{"key": opt['key'], "value": opt['value']} for opt in control.get('options', [])]) - elif control['type'] in [28, 36]: - itemnames = control['advancedSetting'].get('itemnames') - if itemnames and itemnames.startswith('[{'): + if control["type"] in {9, 10, 11}: + options.extend([{"key": opt["key"], "value": opt["value"]} for opt in control.get("options", [])]) + elif control["type"] in {28, 36}: + itemnames = control["advancedSetting"].get("itemnames") + if itemnames and itemnames.startswith("[{"): try: options = json.loads(itemnames) except json.JSONDecodeError: pass - elif control['type'] == 30: - source_type = control['sourceControl']['type'] + elif control["type"] == 30: + source_type = control["sourceControl"]["type"] if source_type not in self._get_ignore_types(): - options.extend([{"key": opt['key'], "value": opt['value']} for opt in control.get('options', [])]) + options.extend([{"key": opt["key"], "value": opt["value"]} for opt in control.get("options", [])]) return options - + def _get_ignore_types(self): - return {14, 21, 22, 34, 42, 43, 45, 47, 49, 10010} \ No newline at end of file + return {14, 21, 22, 34, 42, 43, 45, 47, 49, 10010} diff --git a/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py b/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py index 6bf1caa65ec337..26d7116869b6d9 100644 --- a/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py +++ b/api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py @@ -8,64 +8,66 @@ class GetWorksheetPivotDataTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - x_column_fields = tool_parameters.get('x_column_fields', '') - if not x_column_fields or not x_column_fields.startswith('['): - return self.create_text_message('Invalid parameter Column Fields') - y_row_fields = tool_parameters.get('y_row_fields', '') - if y_row_fields and not y_row_fields.strip().startswith('['): - return self.create_text_message('Invalid parameter Row Fields') + return self.create_text_message("Invalid parameter Worksheet ID") + x_column_fields = tool_parameters.get("x_column_fields", "") + if not x_column_fields or not x_column_fields.startswith("["): + return self.create_text_message("Invalid parameter Column Fields") + y_row_fields = tool_parameters.get("y_row_fields", "") + if y_row_fields and not y_row_fields.strip().startswith("["): + return self.create_text_message("Invalid parameter Row Fields") elif not y_row_fields: - y_row_fields = '[]' - value_fields = tool_parameters.get('value_fields', '') - if not value_fields or not value_fields.strip().startswith('['): - return self.create_text_message('Invalid parameter Value Fields') - - host = tool_parameters.get('host', '') + y_row_fields = "[]" + value_fields = tool_parameters.get("value_fields", "") + if not value_fields or not value_fields.strip().startswith("["): + return self.create_text_message("Invalid parameter Value Fields") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/report/getPivotData" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id, "options": {"showTotal": True}} try: x_column_fields = json.loads(x_column_fields) - payload['columns'] = x_column_fields + payload["columns"] = x_column_fields y_row_fields = json.loads(y_row_fields) - if y_row_fields: payload['rows'] = y_row_fields + if y_row_fields: + payload["rows"] = y_row_fields value_fields = json.loads(value_fields) - payload['values'] = value_fields - sort_fields = tool_parameters.get('sort_fields', '') - if not sort_fields: sort_fields = '[]' + payload["values"] = value_fields + sort_fields = tool_parameters.get("sort_fields", "") + if not sort_fields: + sort_fields = "[]" sort_fields = json.loads(sort_fields) - if sort_fields: payload['options']['sort'] = sort_fields + if sort_fields: + payload["options"]["sort"] = sort_fields res = httpx.post(url, headers=headers, json=payload, timeout=60) res.raise_for_status() res_json = res.json() - if res_json.get('status') != 1: + if res_json.get("status") != 1: return self.create_text_message(f"Failed to get the worksheet pivot data. {res_json['msg']}") - - pivot_json = self.generate_pivot_json(res_json['data']) - pivot_table = self.generate_pivot_table(res_json['data']) - result_type = tool_parameters.get('result_type', '') - text = pivot_table if result_type == 'table' else json.dumps(pivot_json, ensure_ascii=False) + + pivot_json = self.generate_pivot_json(res_json["data"]) + pivot_table = self.generate_pivot_table(res_json["data"]) + result_type = tool_parameters.get("result_type", "") + text = pivot_table if result_type == "table" else json.dumps(pivot_json, ensure_ascii=False) return self.create_text_message(text) except httpx.RequestError as e: return self.create_text_message(f"Failed to get the worksheet pivot data, request error: {e}") @@ -75,27 +77,31 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any] return self.create_text_message(f"Failed to get the worksheet pivot data, unexpected error: {e}") def generate_pivot_table(self, data: dict[str, Any]) -> str: - columns = data['metadata']['columns'] - rows = data['metadata']['rows'] - values = data['metadata']['values'] + columns = data["metadata"]["columns"] + rows = data["metadata"]["rows"] + values = data["metadata"]["values"] - rows_data = data['data'] + rows_data = data["data"] - header = ([row['displayName'] for row in rows] if rows else []) + [column['displayName'] for column in columns] + [value['displayName'] for value in values] - line = (['---'] * len(rows) if rows else []) + ['---'] * len(columns) + ['--:'] * len(values) + header = ( + ([row["displayName"] for row in rows] if rows else []) + + [column["displayName"] for column in columns] + + [value["displayName"] for value in values] + ) + line = (["---"] * len(rows) if rows else []) + ["---"] * len(columns) + ["--:"] * len(values) table = [header, line] for row in rows_data: - row_data = [self.replace_pipe(row['rows'][r['controlId']]) for r in rows] if rows else [] - row_data.extend([self.replace_pipe(row['columns'][column['controlId']]) for column in columns]) - row_data.extend([self.replace_pipe(str(row['values'][value['controlId']])) for value in values]) + row_data = [self.replace_pipe(row["rows"][r["controlId"]]) for r in rows] if rows else [] + row_data.extend([self.replace_pipe(row["columns"][column["controlId"]]) for column in columns]) + row_data.extend([self.replace_pipe(str(row["values"][value["controlId"]])) for value in values]) table.append(row_data) - return '\n'.join([('|'+'|'.join(row) +'|') for row in table]) - + return "\n".join([("|" + "|".join(row) + "|") for row in table]) + def replace_pipe(self, text: str) -> str: - return text.replace('|', '▏').replace('\n', ' ') - + return text.replace("|", "▏").replace("\n", " ") + def generate_pivot_json(self, data: dict[str, Any]) -> dict: fields = { "x-axis": [ @@ -103,13 +109,14 @@ def generate_pivot_json(self, data: dict[str, Any]) -> dict: for column in data["metadata"]["columns"] ], "y-axis": [ - {"fieldId": row["controlId"], "fieldName": row["displayName"]} - for row in data["metadata"]["rows"] - ] if data["metadata"]["rows"] else [], + {"fieldId": row["controlId"], "fieldName": row["displayName"]} for row in data["metadata"]["rows"] + ] + if data["metadata"]["rows"] + else [], "values": [ {"fieldId": value["controlId"], "fieldName": value["displayName"]} for value in data["metadata"]["values"] - ] + ], } # fields = ([ # {"fieldId": row["controlId"], "fieldName": row["displayName"]} @@ -123,8 +130,8 @@ def generate_pivot_json(self, data: dict[str, Any]) -> dict: # ] rows = [] for row in data["data"]: - row_data = row["rows"] if row["rows"] else {} + row_data = row["rows"] or {} row_data.update(row["columns"]) row_data.update(row["values"]) rows.append(row_data) - return {"fields": fields, "rows": rows, "summary": data["metadata"]["totalRow"]} \ No newline at end of file + return {"fields": fields, "rows": rows, "summary": data["metadata"]["totalRow"]} diff --git a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py index dddc041cc19fb3..d6ac3688b7794a 100644 --- a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py +++ b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py @@ -9,191 +9,213 @@ class ListWorksheetRecordsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') + return self.create_text_message("Invalid parameter App Key") - sign = tool_parameters.get('sign', '') + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') + return self.create_text_message("Invalid parameter Sign") - worksheet_id = tool_parameters.get('worksheet_id', '') + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') + return self.create_text_message("Invalid parameter Worksheet ID") - host = tool_parameters.get('host', '') + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not (host.startswith("http://") or host.startswith("https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" - + host = f"{host.removesuffix('/')}/api" + url_fields = f"{host}/v2/open/worksheet/getWorksheetInfo" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id} - field_ids = tool_parameters.get('field_ids', '') + field_ids = tool_parameters.get("field_ids", "") try: res = httpx.post(url_fields, headers=headers, json=payload, timeout=30) res_json = res.json() if res.is_success: - if res_json['error_code'] != 1: - return self.create_text_message("Failed to get the worksheet information. {}".format(res_json['error_msg'])) + if res_json["error_code"] != 1: + return self.create_text_message( + "Failed to get the worksheet information. {}".format(res_json["error_msg"]) + ) else: - worksheet_name = res_json['data']['name'] - fields, schema, table_header = self.get_schema(res_json['data']['controls'], field_ids) + worksheet_name = res_json["data"]["name"] + fields, schema, table_header = self.get_schema(res_json["data"]["controls"], field_ids) else: return self.create_text_message( - f"Failed to get the worksheet information, status code: {res.status_code}, response: {res.text}") + f"Failed to get the worksheet information, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: - return self.create_text_message("Failed to get the worksheet information, something went wrong: {}".format(e)) + return self.create_text_message( + "Failed to get the worksheet information, something went wrong: {}".format(e) + ) if field_ids: - payload['controls'] = [v.strip() for v in field_ids.split(',')] if field_ids else [] - filters = tool_parameters.get('filters', '') + payload["controls"] = [v.strip() for v in field_ids.split(",")] if field_ids else [] + filters = tool_parameters.get("filters", "") if filters: - payload['filters'] = json.loads(filters) - sort_id = tool_parameters.get('sort_id', '') - sort_is_asc = tool_parameters.get('sort_is_asc', False) + payload["filters"] = json.loads(filters) + sort_id = tool_parameters.get("sort_id", "") + sort_is_asc = tool_parameters.get("sort_is_asc", False) if sort_id: - payload['sortId'] = sort_id - payload['isAsc'] = sort_is_asc - limit = tool_parameters.get('limit', 50) - payload['pageSize'] = limit - page_index = tool_parameters.get('page_index', 1) - payload['pageIndex'] = page_index - payload['useControlId'] = True - payload['listType'] = 1 + payload["sortId"] = sort_id + payload["isAsc"] = sort_is_asc + limit = tool_parameters.get("limit", 50) + payload["pageSize"] = limit + page_index = tool_parameters.get("page_index", 1) + payload["pageIndex"] = page_index + payload["useControlId"] = True + payload["listType"] = 1 url = f"{host}/v2/open/worksheet/getFilterRows" try: res = httpx.post(url, headers=headers, json=payload, timeout=90) res_json = res.json() if res.is_success: - if res_json['error_code'] != 1: - return self.create_text_message("Failed to get the records. {}".format(res_json['error_msg'])) + if res_json["error_code"] != 1: + return self.create_text_message("Failed to get the records. {}".format(res_json["error_msg"])) else: result = { "fields": fields, "rows": [], "total": res_json.get("data", {}).get("total"), - "payload": {key: payload[key] for key in ['worksheetId', 'controls', 'filters', 'sortId', 'isAsc', 'pageSize', 'pageIndex'] if key in payload} + "payload": { + key: payload[key] + for key in [ + "worksheetId", + "controls", + "filters", + "sortId", + "isAsc", + "pageSize", + "pageIndex", + ] + if key in payload + }, } rows = res_json.get("data", {}).get("rows", []) - result_type = tool_parameters.get('result_type', '') - if not result_type: result_type = 'table' - if result_type == 'json': + result_type = tool_parameters.get("result_type", "") + if not result_type: + result_type = "table" + if result_type == "json": for row in rows: - result['rows'].append(self.get_row_field_value(row, schema)) + result["rows"].append(self.get_row_field_value(row, schema)) return self.create_text_message(json.dumps(result, ensure_ascii=False)) else: result_text = f"Found {result['total']} rows in worksheet \"{worksheet_name}\"." - if result['total'] > 0: - result_text += f" The following are {result['total'] if result['total'] < limit else limit} pieces of data presented in a table format:\n\n{table_header}" + if result["total"] > 0: + result_text += ( + f" The following are {min(limit, result['total'])}" + f" pieces of data presented in a table format:\n\n{table_header}" + ) for row in rows: result_values = [] for f in fields: - result_values.append(self.handle_value_type(row[f['fieldId']], schema[f['fieldId']])) - result_text += '\n|'+'|'.join(result_values)+'|' + result_values.append( + self.handle_value_type(row[f["fieldId"]], schema[f["fieldId"]]) + ) + result_text += "\n|" + "|".join(result_values) + "|" return self.create_text_message(result_text) else: return self.create_text_message( - f"Failed to get the records, status code: {res.status_code}, response: {res.text}") + f"Failed to get the records, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to get the records, something went wrong: {}".format(e)) - def get_row_field_value(self, row: dict, schema: dict): row_value = {"rowid": row["rowid"]} for field in schema: row_value[field] = self.handle_value_type(row[field], schema[field]) return row_value - - def get_schema(self, controls: list, fieldids: str): - allow_fields = {v.strip() for v in fieldids.split(',')} if fieldids else set() + def get_schema(self, controls: list, fieldids: str): + allow_fields = {v.strip() for v in fieldids.split(",")} if fieldids else set() fields = [] schema = {} field_names = [] for control in controls: control_type_id = self.get_real_type_id(control) - if (control_type_id in self._get_ignore_types()) or (allow_fields and not control['controlId'] in allow_fields): + if (control_type_id in self._get_ignore_types()) or ( + allow_fields and control["controlId"] not in allow_fields + ): continue else: - fields.append({'fieldId': control['controlId'], 'fieldName': control['controlName']}) - schema[control['controlId']] = {'typeId': control_type_id, 'options': self.set_option(control)} - field_names.append(control['controlName']) - if (not allow_fields or ('ctime' in allow_fields)): - fields.append({'fieldId': 'ctime', 'fieldName': 'Created Time'}) - schema['ctime'] = {'typeId': 16, 'options': {}} + fields.append({"fieldId": control["controlId"], "fieldName": control["controlName"]}) + schema[control["controlId"]] = {"typeId": control_type_id, "options": self.set_option(control)} + field_names.append(control["controlName"]) + if not allow_fields or ("ctime" in allow_fields): + fields.append({"fieldId": "ctime", "fieldName": "Created Time"}) + schema["ctime"] = {"typeId": 16, "options": {}} field_names.append("Created Time") - fields.append({'fieldId':'rowid', 'fieldName': 'Record Row ID'}) - schema['rowid'] = {'typeId': 2, 'options': {}} + fields.append({"fieldId": "rowid", "fieldName": "Record Row ID"}) + schema["rowid"] = {"typeId": 2, "options": {}} field_names.append("Record Row ID") - return fields, schema, '|'+'|'.join(field_names)+'|\n|'+'---|'*len(field_names) - + return fields, schema, "|" + "|".join(field_names) + "|\n|" + "---|" * len(field_names) + def get_real_type_id(self, control: dict) -> int: - return control['sourceControlType'] if control['type'] == 30 else control['type'] - + return control["sourceControlType"] if control["type"] == 30 else control["type"] + def set_option(self, control: dict) -> dict: options = {} - if control.get('options'): - options = {option['key']: option['value'] for option in control['options']} - elif control.get('advancedSetting', {}).get('itemnames'): + if control.get("options"): + options = {option["key"]: option["value"] for option in control["options"]} + elif control.get("advancedSetting", {}).get("itemnames"): try: - itemnames = json.loads(control['advancedSetting']['itemnames']) - options = {item['key']: item['value'] for item in itemnames} + itemnames = json.loads(control["advancedSetting"]["itemnames"]) + options = {item["key"]: item["value"] for item in itemnames} except json.JSONDecodeError: pass return options def _get_ignore_types(self): return {14, 21, 22, 34, 42, 43, 45, 47, 49, 10010} - + def handle_value_type(self, value, field): type_id = field.get("typeId") if type_id == 10: value = value if isinstance(value, str) else "、".join(value) - elif type_id in [28, 36]: + elif type_id in {28, 36}: value = field.get("options", {}).get(value, value) - elif type_id in [26, 27, 48, 14]: + elif type_id in {26, 27, 48, 14}: value = self.process_value(value) - elif type_id in [35, 29]: + elif type_id in {35, 29}: value = self.parse_cascade_or_associated(field, value) elif type_id == 40: value = self.parse_location(value) - return self.rich_text_to_plain_text(value) if value else '' + return self.rich_text_to_plain_text(value) if value else "" def process_value(self, value): if isinstance(value, str): - if value.startswith("[{\"accountId\""): + if value.startswith('[{"accountId"'): value = json.loads(value) - value = ', '.join([item['fullname'] for item in value]) - elif value.startswith("[{\"departmentId\""): + value = ", ".join([item["fullname"] for item in value]) + elif value.startswith('[{"departmentId"'): value = json.loads(value) - value = '、'.join([item['departmentName'] for item in value]) - elif value.startswith("[{\"organizeId\""): + value = "、".join([item["departmentName"] for item in value]) + elif value.startswith('[{"organizeId"'): value = json.loads(value) - value = '、'.join([item['organizeName'] for item in value]) - elif value.startswith("[{\"file_id\""): - value = '' - elif value == '[]': - value = '' - elif hasattr(value, 'accountId'): - value = value['fullname'] + value = "、".join([item["organizeName"] for item in value]) + elif value.startswith('[{"file_id"') or value == "[]": + value = "" + elif hasattr(value, "accountId"): + value = value["fullname"] return value def parse_cascade_or_associated(self, field, value): - if (field['typeId'] == 35 and value.startswith('[')) or (field['typeId'] == 29 and value.startswith('[{')): + if (field["typeId"] == 35 and value.startswith("[")) or (field["typeId"] == 29 and value.startswith("[{")): value = json.loads(value) - value = value[0]['name'] if len(value) > 0 else '' + value = value[0]["name"] if len(value) > 0 else "" else: - value = '' + value = "" return value def parse_location(self, value): @@ -205,5 +227,5 @@ def parse_location(self, value): return value def rich_text_to_plain_text(self, rich_text): - text = re.sub(r'<[^>]+>', '', rich_text) if '<' in rich_text else rich_text - return text.replace("|", "▏").replace("\n", " ") \ No newline at end of file + text = re.sub(r"<[^>]+>", "", rich_text) if "<" in rich_text else rich_text + return text.replace("|", "▏").replace("\n", " ") diff --git a/api/core/tools/provider/builtin/hap/tools/list_worksheets.py b/api/core/tools/provider/builtin/hap/tools/list_worksheets.py index 960cbd10acadb1..4e852c0028497c 100644 --- a/api/core/tools/provider/builtin/hap/tools/list_worksheets.py +++ b/api/core/tools/provider/builtin/hap/tools/list_worksheets.py @@ -8,75 +8,76 @@ class ListWorksheetsTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Sign") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not (host.startswith("http://") or host.startswith("https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/v1/open/app/get" - result_type = tool_parameters.get('result_type', '') + result_type = tool_parameters.get("result_type", "") if not result_type: - result_type = 'table' + result_type = "table" - headers = { 'Content-Type': 'application/json' } - params = { "appKey": appkey, "sign": sign, } + headers = {"Content-Type": "application/json"} + params = { + "appKey": appkey, + "sign": sign, + } try: res = httpx.get(url, headers=headers, params=params, timeout=30) res_json = res.json() if res.is_success: - if res_json['error_code'] != 1: - return self.create_text_message("Failed to access the application. {}".format(res_json['error_msg'])) + if res_json["error_code"] != 1: + return self.create_text_message( + "Failed to access the application. {}".format(res_json["error_msg"]) + ) else: - if result_type == 'json': + if result_type == "json": worksheets = [] - for section in res_json['data']['sections']: + for section in res_json["data"]["sections"]: worksheets.extend(self._extract_worksheets(section, result_type)) return self.create_text_message(text=json.dumps(worksheets, ensure_ascii=False)) else: - worksheets = '|worksheetId|worksheetName|description|\n|---|---|---|' - for section in res_json['data']['sections']: + worksheets = "|worksheetId|worksheetName|description|\n|---|---|---|" + for section in res_json["data"]["sections"]: worksheets += self._extract_worksheets(section, result_type) return self.create_text_message(worksheets) else: return self.create_text_message( - f"Failed to list worksheets, status code: {res.status_code}, response: {res.text}") + f"Failed to list worksheets, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to list worksheets, something went wrong: {}".format(e)) def _extract_worksheets(self, section, type): items = [] - tables = '' - for item in section.get('items', []): - if item.get('type') == 0 and (not 'notes' in item or item.get('notes') != 'NO'): - if type == 'json': - filtered_item = { - 'id': item['id'], - 'name': item['name'], - 'notes': item.get('notes', '') - } + tables = "" + for item in section.get("items", []): + if item.get("type") == 0 and ("notes" not in item or item.get("notes") != "NO"): + if type == "json": + filtered_item = {"id": item["id"], "name": item["name"], "notes": item.get("notes", "")} items.append(filtered_item) else: tables += f"\n|{item['id']}|{item['name']}|{item.get('notes', '')}|" - for child_section in section.get('childSections', []): - if type == 'json': - items.extend(self._extract_worksheets(child_section, 'json')) + for child_section in section.get("childSections", []): + if type == "json": + items.extend(self._extract_worksheets(child_section, "json")) else: - tables += self._extract_worksheets(child_section, 'table') - - return items if type == 'json' else tables \ No newline at end of file + tables += self._extract_worksheets(child_section, "table") + + return items if type == "json" else tables diff --git a/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py b/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py index 6ca1b98d90b3f1..971f3d37f6dfbf 100644 --- a/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py +++ b/api/core/tools/provider/builtin/hap/tools/update_worksheet_record.py @@ -8,44 +8,43 @@ class UpdateWorksheetRecordTool(BuiltinTool): - - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - - appkey = tool_parameters.get('appkey', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + appkey = tool_parameters.get("appkey", "") if not appkey: - return self.create_text_message('Invalid parameter App Key') - sign = tool_parameters.get('sign', '') + return self.create_text_message("Invalid parameter App Key") + sign = tool_parameters.get("sign", "") if not sign: - return self.create_text_message('Invalid parameter Sign') - worksheet_id = tool_parameters.get('worksheet_id', '') + return self.create_text_message("Invalid parameter Sign") + worksheet_id = tool_parameters.get("worksheet_id", "") if not worksheet_id: - return self.create_text_message('Invalid parameter Worksheet ID') - row_id = tool_parameters.get('row_id', '') + return self.create_text_message("Invalid parameter Worksheet ID") + row_id = tool_parameters.get("row_id", "") if not row_id: - return self.create_text_message('Invalid parameter Record Row ID') - record_data = tool_parameters.get('record_data', '') + return self.create_text_message("Invalid parameter Record Row ID") + record_data = tool_parameters.get("record_data", "") if not record_data: - return self.create_text_message('Invalid parameter Record Row Data') - - host = tool_parameters.get('host', '') + return self.create_text_message("Invalid parameter Record Row Data") + + host = tool_parameters.get("host", "") if not host: - host = 'https://api.mingdao.com' + host = "https://api.mingdao.com" elif not host.startswith(("http://", "https://")): - return self.create_text_message('Invalid parameter Host Address') + return self.create_text_message("Invalid parameter Host Address") else: - host = f"{host[:-1] if host.endswith('/') else host}/api" + host = f"{host.removesuffix('/')}/api" url = f"{host}/v2/open/worksheet/editRow" - headers = {'Content-Type': 'application/json'} + headers = {"Content-Type": "application/json"} payload = {"appKey": appkey, "sign": sign, "worksheetId": worksheet_id, "rowId": row_id} try: - payload['controls'] = json.loads(record_data) + payload["controls"] = json.loads(record_data) res = httpx.post(url, headers=headers, json=payload, timeout=60) res.raise_for_status() res_json = res.json() - if res_json.get('error_code') != 1: + if res_json.get("error_code") != 1: return self.create_text_message(f"Failed to update the record. {res_json['error_msg']}") return self.create_text_message("Record updated successfully.") except httpx.RequestError as e: diff --git a/api/core/tools/provider/builtin/jina/jina.py b/api/core/tools/provider/builtin/jina/jina.py index 12e5058cdc92f0..154e15db016dd1 100644 --- a/api/core/tools/provider/builtin/jina/jina.py +++ b/api/core/tools/provider/builtin/jina/jina.py @@ -10,27 +10,29 @@ class GoogleProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - if credentials['api_key'] is None: - credentials['api_key'] = '' + if credentials["api_key"] is None: + credentials["api_key"] = "" else: - result = JinaReaderTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).invoke( - user_id='', - tool_parameters={ - "url": "https://example.com", - }, - )[0] + result = ( + JinaReaderTool() + .fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ) + .invoke( + user_id="", + tool_parameters={ + "url": "https://example.com", + }, + )[0] + ) message = json.loads(result.message) - if message['code'] != 200: - raise ToolProviderCredentialValidationError(message['message']) + if message["code"] != 200: + raise ToolProviderCredentialValidationError(message["message"]) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - + def _get_tool_labels(self) -> list[ToolLabelEnum]: - return [ - ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY - ] \ No newline at end of file + return [ToolLabelEnum.SEARCH, ToolLabelEnum.PRODUCTIVITY] diff --git a/api/core/tools/provider/builtin/jina/jina.yaml b/api/core/tools/provider/builtin/jina/jina.yaml index 06f23382d92a3a..af3ca23ffaff46 100644 --- a/api/core/tools/provider/builtin/jina/jina.yaml +++ b/api/core/tools/provider/builtin/jina/jina.yaml @@ -2,13 +2,13 @@ identity: author: Dify name: jina label: - en_US: Jina - zh_Hans: Jina - pt_BR: Jina + en_US: Jina AI + zh_Hans: Jina AI + pt_BR: Jina AI description: - en_US: Convert any URL to an LLM-friendly input or perform searches on the web for grounding information. Experience improved output for your agent and RAG systems at no cost. - zh_Hans: 将任何URL转换为LLM易读的输入或在网页上搜索引擎上搜索引擎。 - pt_BR: Converte qualquer URL em uma entrada LLm-fácil de ler ou realize pesquisas na web para obter informação de grounding. Tenha uma experiência melhor para seu agente e sistemas RAG sem custo. + en_US: Your Search Foundation, Supercharged! + zh_Hans: 您的搜索底座,从此不同! + pt_BR: Your Search Foundation, Supercharged! icon: icon.svg tags: - search @@ -22,11 +22,11 @@ credentials_for_provider: zh_Hans: API 密钥(可留空) pt_BR: Chave API (deixe vazio se você não tiver uma) placeholder: - en_US: Please enter your Jina API key - zh_Hans: 请输入你的 Jina API 密钥 - pt_BR: Por favor, insira sua chave de API do Jina + en_US: Please enter your Jina AI API key + zh_Hans: 请输入你的 Jina AI API 密钥 + pt_BR: Por favor, insira sua chave de API do Jina AI help: - en_US: Get your Jina API key from Jina (optional, but you can get a higher rate) - zh_Hans: 从 Jina 获取您的 Jina API 密钥(非必须,能得到更高的速率) - pt_BR: Obtenha sua chave de API do Jina na Jina (opcional, mas você pode obter uma taxa mais alta) + en_US: Get your Jina AI API key from Jina AI (optional, but you can get a higher rate) + zh_Hans: 从 Jina AI 获取您的 Jina AI API 密钥(非必须,能得到更高的速率) + pt_BR: Obtenha sua chave de API do Jina AI na Jina AI (opcional, mas você pode obter uma taxa mais alta) url: https://jina.ai diff --git a/api/core/tools/provider/builtin/jina/tools/jina_reader.py b/api/core/tools/provider/builtin/jina/tools/jina_reader.py index cee46cee2390e1..0dd55c65291783 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_reader.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_reader.py @@ -9,26 +9,25 @@ class JinaReaderTool(BuiltinTool): - _jina_reader_endpoint = 'https://r.jina.ai/' + _jina_reader_endpoint = "https://r.jina.ai/" - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - url = tool_parameters['url'] + url = tool_parameters["url"] - headers = { - 'Accept': 'application/json' - } + headers = {"Accept": "application/json"} - if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'): - headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key') + if "api_key" in self.runtime.credentials and self.runtime.credentials.get("api_key"): + headers["Authorization"] = "Bearer " + self.runtime.credentials.get("api_key") - request_params = tool_parameters.get('request_params') - if request_params is not None and request_params != '': + request_params = tool_parameters.get("request_params") + if request_params is not None and request_params != "": try: request_params = json.loads(request_params) if not isinstance(request_params, dict): @@ -36,40 +35,40 @@ def _invoke(self, except (json.JSONDecodeError, ValueError) as e: raise ValueError(f"Invalid request_params: {e}") - target_selector = tool_parameters.get('target_selector') - if target_selector is not None and target_selector != '': - headers['X-Target-Selector'] = target_selector + target_selector = tool_parameters.get("target_selector") + if target_selector is not None and target_selector != "": + headers["X-Target-Selector"] = target_selector - wait_for_selector = tool_parameters.get('wait_for_selector') - if wait_for_selector is not None and wait_for_selector != '': - headers['X-Wait-For-Selector'] = wait_for_selector + wait_for_selector = tool_parameters.get("wait_for_selector") + if wait_for_selector is not None and wait_for_selector != "": + headers["X-Wait-For-Selector"] = wait_for_selector - if tool_parameters.get('image_caption', False): - headers['X-With-Generated-Alt'] = 'true' + if tool_parameters.get("image_caption", False): + headers["X-With-Generated-Alt"] = "true" - if tool_parameters.get('gather_all_links_at_the_end', False): - headers['X-With-Links-Summary'] = 'true' + if tool_parameters.get("gather_all_links_at_the_end", False): + headers["X-With-Links-Summary"] = "true" - if tool_parameters.get('gather_all_images_at_the_end', False): - headers['X-With-Images-Summary'] = 'true' + if tool_parameters.get("gather_all_images_at_the_end", False): + headers["X-With-Images-Summary"] = "true" - proxy_server = tool_parameters.get('proxy_server') - if proxy_server is not None and proxy_server != '': - headers['X-Proxy-Url'] = proxy_server + proxy_server = tool_parameters.get("proxy_server") + if proxy_server is not None and proxy_server != "": + headers["X-Proxy-Url"] = proxy_server - if tool_parameters.get('no_cache', False): - headers['X-No-Cache'] = 'true' + if tool_parameters.get("no_cache", False): + headers["X-No-Cache"] = "true" - max_retries = tool_parameters.get('max_retries', 3) + max_retries = tool_parameters.get("max_retries", 3) response = ssrf_proxy.get( str(URL(self._jina_reader_endpoint + url)), headers=headers, params=request_params, timeout=(10, 60), - max_retries=max_retries + max_retries=max_retries, ) - if tool_parameters.get('summary', False): + if tool_parameters.get("summary", False): return self.create_text_message(self.summary(user_id, response.text)) return self.create_text_message(response.text) diff --git a/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml b/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml index 58ad6d8694222d..589bc3433d9478 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml +++ b/api/core/tools/provider/builtin/jina/tools/jina_reader.yaml @@ -2,14 +2,14 @@ identity: name: jina_reader author: Dify label: - en_US: JinaReader - zh_Hans: JinaReader - pt_BR: JinaReader + en_US: Fetch Single Page + zh_Hans: 获取单页面 + pt_BR: Fetch Single Page description: human: - en_US: Convert any URL to an LLM-friendly input. Experience improved output for your agent and RAG systems at no cost. - zh_Hans: 将任何 URL 转换为 LLM 友好的输入。无需付费即可体验为您的 Agent 和 RAG 系统提供的改进输出。 - pt_BR: Converta qualquer URL em uma entrada amigável ao LLM. Experimente uma saída aprimorada para seus sistemas de agente e RAG sem custo. + en_US: Fetch the target URL (can be a PDF) and convert it into a LLM-friendly markdown. + zh_Hans: 获取目标网址(可以是 PDF),并将其转换为适合大模型处理的 Markdown 格式。 + pt_BR: Busque a URL de destino (que pode ser um PDF) e converta em um Markdown LLM-friendly. llm: A tool for scraping webpages. Input should be a URL. parameters: - name: url @@ -17,13 +17,13 @@ parameters: required: true label: en_US: URL - zh_Hans: 网页链接 + zh_Hans: 网址 pt_BR: URL human_description: - en_US: used for linking to webpages - zh_Hans: 用于链接到网页 - pt_BR: used for linking to webpages - llm_description: url for scraping + en_US: Web link + zh_Hans: 网页链接 + pt_BR: URL da web + llm_description: url para scraping form: llm - name: request_params type: string @@ -31,14 +31,14 @@ parameters: label: en_US: Request params zh_Hans: 请求参数 - pt_BR: Request params + pt_BR: Parâmetros de solicitação human_description: en_US: | request parameters, format: {"key1": "value1", "key2": "value2"} zh_Hans: | 请求参数,格式:{"key1": "value1", "key2": "value2"} pt_BR: | - request parameters, format: {"key1": "value1", "key2": "value2"} + parâmetros de solicitação, formato: {"key1": "value1", "key2": "value2"} llm_description: request parameters form: llm - name: target_selector @@ -51,7 +51,7 @@ parameters: human_description: en_US: css selector for scraping specific elements zh_Hans: css 选择器用于抓取特定元素 - pt_BR: css selector for scraping specific elements + pt_BR: css selector para scraping de elementos específicos llm_description: css selector of the target element to scrape form: form - name: wait_for_selector @@ -64,7 +64,7 @@ parameters: human_description: en_US: css selector for waiting for specific elements zh_Hans: css 选择器用于等待特定元素 - pt_BR: css selector for waiting for specific elements + pt_BR: css selector para aguardar elementos específicos llm_description: css selector of the target element to wait for form: form - name: image_caption @@ -77,8 +77,8 @@ parameters: pt_BR: Legenda da imagem human_description: en_US: "Captions all images at the specified URL, adding 'Image [idx]: [caption]' as an alt tag for those without one. This allows downstream LLMs to interact with the images in activities such as reasoning and summarizing." - zh_Hans: "为指定 URL 上的所有图像添加标题,为没有标题的图像添加“Image [idx]: [caption]”作为 alt 标签。这允许下游 LLM 在推理和总结等活动中与图像进行交互。" - pt_BR: "Captions all images at the specified URL, adding 'Image [idx]: [caption]' as an alt tag for those without one. This allows downstream LLMs to interact with the images in activities such as reasoning and summarizing." + zh_Hans: "为指定 URL 上的所有图像添加标题,为没有标题的图像添加“Image [idx]: [caption]”作为 alt 标签,以支持下游模型的图像交互。" + pt_BR: "Adiciona legendas a todas as imagens na URL especificada, adicionando 'Imagem [idx]: [legenda]' como uma tag alt para aquelas que não têm uma. Isso permite que os modelos LLM inferiores interajam com as imagens em atividades como raciocínio e resumo." llm_description: Captions all images at the specified URL form: form - name: gather_all_links_at_the_end @@ -91,8 +91,8 @@ parameters: pt_BR: Coletar todos os links ao final human_description: en_US: A "Buttons & Links" section will be created at the end. This helps the downstream LLMs or web agents navigating the page or take further actions. - zh_Hans: 最后会创建一个“按钮和链接”部分。这可以帮助下游 LLM 或 Web 代理浏览页面或采取进一步的行动。 - pt_BR: A "Buttons & Links" section will be created at the end. This helps the downstream LLMs or web agents navigating the page or take further actions. + zh_Hans: 末尾将添加“按钮和链接”部分,方便下游模型或网络代理做页面导航或执行进一步操作。 + pt_BR: Um "Botões & Links" section will be created at the end. This helps the downstream LLMs or web agents navigating the page or take further actions. llm_description: Gather all links at the end form: form - name: gather_all_images_at_the_end @@ -105,8 +105,8 @@ parameters: pt_BR: Coletar todas as imagens ao final human_description: en_US: An "Images" section will be created at the end. This gives the downstream LLMs an overview of all visuals on the page, which may improve reasoning. - zh_Hans: 最后会创建一个“图像”部分。这可以让下游的 LLM 概览页面上的所有视觉效果,从而提高推理能力。 - pt_BR: An "Images" section will be created at the end. This gives the downstream LLMs an overview of all visuals on the page, which may improve reasoning. + zh_Hans: 末尾会新增“图片”部分,方便下游模型全面了解页面的视觉内容,提升推理效果。 + pt_BR: Um "Imagens" section will be created at the end. This gives the downstream LLMs an overview of all visuals on the page, which may improve reasoning. llm_description: Gather all images at the end form: form - name: proxy_server diff --git a/api/core/tools/provider/builtin/jina/tools/jina_search.py b/api/core/tools/provider/builtin/jina/tools/jina_search.py index d4a81cd0965142..30af6de7831e59 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_search.py +++ b/api/core/tools/provider/builtin/jina/tools/jina_search.py @@ -8,44 +8,39 @@ class JinaSearchTool(BuiltinTool): - _jina_search_endpoint = 'https://s.jina.ai/' + _jina_search_endpoint = "https://s.jina.ai/" def _invoke( self, user_id: str, tool_parameters: dict[str, Any], ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - query = tool_parameters['query'] + query = tool_parameters["query"] - headers = { - 'Accept': 'application/json' - } + headers = {"Accept": "application/json"} - if 'api_key' in self.runtime.credentials and self.runtime.credentials.get('api_key'): - headers['Authorization'] = "Bearer " + self.runtime.credentials.get('api_key') + if "api_key" in self.runtime.credentials and self.runtime.credentials.get("api_key"): + headers["Authorization"] = "Bearer " + self.runtime.credentials.get("api_key") - if tool_parameters.get('image_caption', False): - headers['X-With-Generated-Alt'] = 'true' + if tool_parameters.get("image_caption", False): + headers["X-With-Generated-Alt"] = "true" - if tool_parameters.get('gather_all_links_at_the_end', False): - headers['X-With-Links-Summary'] = 'true' + if tool_parameters.get("gather_all_links_at_the_end", False): + headers["X-With-Links-Summary"] = "true" - if tool_parameters.get('gather_all_images_at_the_end', False): - headers['X-With-Images-Summary'] = 'true' + if tool_parameters.get("gather_all_images_at_the_end", False): + headers["X-With-Images-Summary"] = "true" - proxy_server = tool_parameters.get('proxy_server') - if proxy_server is not None and proxy_server != '': - headers['X-Proxy-Url'] = proxy_server + proxy_server = tool_parameters.get("proxy_server") + if proxy_server is not None and proxy_server != "": + headers["X-Proxy-Url"] = proxy_server - if tool_parameters.get('no_cache', False): - headers['X-No-Cache'] = 'true' + if tool_parameters.get("no_cache", False): + headers["X-No-Cache"] = "true" - max_retries = tool_parameters.get('max_retries', 3) + max_retries = tool_parameters.get("max_retries", 3) response = ssrf_proxy.get( - str(URL(self._jina_search_endpoint + query)), - headers=headers, - timeout=(10, 60), - max_retries=max_retries + str(URL(self._jina_search_endpoint + query)), headers=headers, timeout=(10, 60), max_retries=max_retries ) return self.create_text_message(response.text) diff --git a/api/core/tools/provider/builtin/jina/tools/jina_search.yaml b/api/core/tools/provider/builtin/jina/tools/jina_search.yaml index 2bc70e1be1934d..e58c639e5690d0 100644 --- a/api/core/tools/provider/builtin/jina/tools/jina_search.yaml +++ b/api/core/tools/provider/builtin/jina/tools/jina_search.yaml @@ -2,13 +2,14 @@ identity: name: jina_search author: Dify label: - en_US: JinaSearch - zh_Hans: JinaSearch - pt_BR: JinaSearch + en_US: Search the web + zh_Hans: 联网搜索 + pt_BR: Search the web description: human: - en_US: Search on the web and get the top 5 results. Useful for grounding using information from the web. - zh_Hans: 在网络上搜索返回前 5 个结果。 + en_US: Search on the public web of a given query and return the top results as LLM-friendly markdown. + zh_Hans: 针对给定的查询在互联网上进行搜索,并以适合大模型处理的 Markdown 格式返回最相关的结果。 + pt_BR: Procurar na web pública de uma consulta fornecida e retornar os melhores resultados como markdown para LLMs. llm: A tool for searching results on the web for grounding. Input should be a simple question. parameters: - name: query @@ -16,11 +17,13 @@ parameters: required: true label: en_US: Question (Query) - zh_Hans: 信息查询 + zh_Hans: 查询 + pt_BR: Pergunta (Consulta) human_description: en_US: used to find information on the web zh_Hans: 在网络上搜索信息 - llm_description: simple question to ask on the web + pt_BR: Usado para encontrar informações na web + llm_description: Pergunta simples para fazer na web form: llm - name: image_caption type: boolean @@ -32,7 +35,7 @@ parameters: pt_BR: Legenda da imagem human_description: en_US: "Captions all images at the specified URL, adding 'Image [idx]: [caption]' as an alt tag for those without one. This allows downstream LLMs to interact with the images in activities such as reasoning and summarizing." - zh_Hans: "为指定 URL 上的所有图像添加标题,为没有标题的图像添加“Image [idx]: [caption]”作为 alt 标签。这允许下游 LLM 在推理和总结等活动中与图像进行交互。" + zh_Hans: "为指定 URL 上的所有图像添加标题,为没有标题的图像添加“Image [idx]: [caption]”作为 alt 标签,以支持下游模型的图像交互。" pt_BR: "Captions all images at the specified URL, adding 'Image [idx]: [caption]' as an alt tag for those without one. This allows downstream LLMs to interact with the images in activities such as reasoning and summarizing." llm_description: Captions all images at the specified URL form: form @@ -46,8 +49,8 @@ parameters: pt_BR: Coletar todos os links ao final human_description: en_US: A "Buttons & Links" section will be created at the end. This helps the downstream LLMs or web agents navigating the page or take further actions. - zh_Hans: 最后会创建一个“按钮和链接”部分。这可以帮助下游 LLM 或 Web 代理浏览页面或采取进一步的行动。 - pt_BR: A "Buttons & Links" section will be created at the end. This helps the downstream LLMs or web agents navigating the page or take further actions. + zh_Hans: 末尾将添加“按钮和链接”部分,汇总页面上的所有链接。方便下游模型或网络代理做页面导航或执行进一步操作。 + pt_BR: Um "Botão & Links" seção será criada no final. Isso ajuda os LLMs ou agentes da web navegando pela página ou executar ações adicionais. llm_description: Gather all links at the end form: form - name: gather_all_images_at_the_end @@ -60,8 +63,8 @@ parameters: pt_BR: Coletar todas as imagens ao final human_description: en_US: An "Images" section will be created at the end. This gives the downstream LLMs an overview of all visuals on the page, which may improve reasoning. - zh_Hans: 最后会创建一个“图像”部分。这可以让下游的 LLM 概览页面上的所有视觉效果,从而提高推理能力。 - pt_BR: An "Images" section will be created at the end. This gives the downstream LLMs an overview of all visuals on the page, which may improve reasoning. + zh_Hans: 末尾会新增“图片”部分,汇总页面上的所有图片。方便下游模型概览页面的视觉内容,提升推理效果。 + pt_BR: Um "Imagens" seção será criada no final. Isso fornece uma visão geral de todas as imagens na página para os LLMs, que pode melhorar a razão. llm_description: Gather all images at the end form: form - name: proxy_server @@ -74,7 +77,7 @@ parameters: human_description: en_US: Use proxy to access URLs zh_Hans: 利用代理访问 URL - pt_BR: Use proxy to access URLs + pt_BR: Usar proxy para acessar URLs llm_description: Use proxy to access URLs form: form - name: no_cache @@ -83,7 +86,7 @@ parameters: default: false label: en_US: Bypass the Cache - zh_Hans: 绕过缓存 + zh_Hans: 是否绕过缓存 pt_BR: Ignorar o cache human_description: en_US: Bypass the Cache diff --git a/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py b/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py new file mode 100644 index 00000000000000..06dabcc9c2a74e --- /dev/null +++ b/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.py @@ -0,0 +1,39 @@ +from typing import Any + +from core.helper import ssrf_proxy +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class JinaTokenizerTool(BuiltinTool): + _jina_tokenizer_endpoint = "https://tokenize.jina.ai/" + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> ToolInvokeMessage: + content = tool_parameters["content"] + body = {"content": content} + + headers = {"Content-Type": "application/json"} + + if "api_key" in self.runtime.credentials and self.runtime.credentials.get("api_key"): + headers["Authorization"] = "Bearer " + self.runtime.credentials.get("api_key") + + if tool_parameters.get("return_chunks", False): + body["return_chunks"] = True + + if tool_parameters.get("return_tokens", False): + body["return_tokens"] = True + + if tokenizer := tool_parameters.get("tokenizer"): + body["tokenizer"] = tokenizer + + response = ssrf_proxy.post( + self._jina_tokenizer_endpoint, + headers=headers, + json=body, + ) + + return self.create_json_message(response.json()) diff --git a/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.yaml b/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.yaml new file mode 100644 index 00000000000000..74885cdf9a7048 --- /dev/null +++ b/api/core/tools/provider/builtin/jina/tools/jina_tokenizer.yaml @@ -0,0 +1,78 @@ +identity: + name: jina_tokenizer + author: hjlarry + label: + en_US: Segment + zh_Hans: 切分器 + pt_BR: Segment +description: + human: + en_US: Split long text into chunks and do tokenization. + zh_Hans: 将长文本拆分成小段落,并做分词处理。 + pt_BR: Dividir o texto longo em pedaços e fazer tokenização. + llm: Free API to tokenize text and segment long text into chunks. +parameters: + - name: content + type: string + required: true + label: + en_US: Content + zh_Hans: 内容 + pt_BR: Conteúdo + llm_description: the content which need to tokenize or segment + form: llm + - name: return_tokens + type: boolean + required: false + label: + en_US: Return the tokens + zh_Hans: 是否返回tokens + pt_BR: Retornar os tokens + human_description: + en_US: Return the tokens and their corresponding ids in the response. + zh_Hans: 返回tokens及其对应的ids。 + pt_BR: Retornar os tokens e seus respectivos ids na resposta. + form: form + - name: return_chunks + type: boolean + label: + en_US: Return the chunks + zh_Hans: 是否分块 + pt_BR: Retornar os chunks + human_description: + en_US: Chunking the input into semantically meaningful segments while handling a wide variety of text types and edge cases based on common structural cues. + zh_Hans: 将输入文本分块为语义有意义的片段,同时基于常见的结构线索处理各种文本类型和特殊情况。 + pt_BR: Dividir o texto de entrada em segmentos semanticamente significativos, enquanto lida com uma ampla variedade de tipos de texto e casos de borda com base em pistas estruturais comuns. + form: form + - name: tokenizer + type: select + options: + - value: cl100k_base + label: + en_US: cl100k_base + - value: o200k_base + label: + en_US: o200k_base + - value: p50k_base + label: + en_US: p50k_base + - value: r50k_base + label: + en_US: r50k_base + - value: p50k_edit + label: + en_US: p50k_edit + - value: gpt2 + label: + en_US: gpt2 + label: + en_US: Tokenizer + human_description: + en_US: | + · cl100k_base --- gpt-4, gpt-3.5-turbo, gpt-3.5 + · o200k_base --- gpt-4o, gpt-4o-mini + · p50k_base --- text-davinci-003, text-davinci-002 + · r50k_base --- text-davinci-001, text-curie-001 + · p50k_edit --- text-davinci-edit-001, code-davinci-edit-001 + · gpt2 --- gpt-2 + form: form diff --git a/api/core/tools/provider/builtin/json_process/json_process.py b/api/core/tools/provider/builtin/json_process/json_process.py index f6eed3c6282314..10746210b5c652 100644 --- a/api/core/tools/provider/builtin/json_process/json_process.py +++ b/api/core/tools/provider/builtin/json_process/json_process.py @@ -8,10 +8,9 @@ class JsonExtractProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - JSONParseTool().invoke(user_id='', - tool_parameters={ - 'content': '{"name": "John", "age": 30, "city": "New York"}', - 'json_filter': '$.name' - }) + JSONParseTool().invoke( + user_id="", + tool_parameters={"content": '{"name": "John", "age": 30, "city": "New York"}', "json_filter": "$.name"}, + ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/json_process/tools/delete.py b/api/core/tools/provider/builtin/json_process/tools/delete.py index 1b49cfe2f300f8..fcab3d71a93cf9 100644 --- a/api/core/tools/provider/builtin/json_process/tools/delete.py +++ b/api/core/tools/provider/builtin/json_process/tools/delete.py @@ -8,34 +8,35 @@ class JSONDeleteTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the JSON delete tool """ # Get content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # Get query - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Invalid parameter query') + return self.create_text_message("Invalid parameter query") - ensure_ascii = tool_parameters.get('ensure_ascii', True) + ensure_ascii = tool_parameters.get("ensure_ascii", True) try: result = self._delete(content, query, ensure_ascii) return self.create_text_message(str(result)) except Exception as e: - return self.create_text_message(f'Failed to delete JSON content: {str(e)}') + return self.create_text_message(f"Failed to delete JSON content: {str(e)}") def _delete(self, origin_json: str, query: str, ensure_ascii: bool) -> str: try: input_data = json.loads(origin_json) - expr = parse('$.' + query.lstrip('$.')) # Ensure query path starts with $ + expr = parse("$." + query.lstrip("$.")) # Ensure query path starts with $ matches = expr.find(input_data) diff --git a/api/core/tools/provider/builtin/json_process/tools/insert.py b/api/core/tools/provider/builtin/json_process/tools/insert.py index 48d1bdcab48885..793c74e5f9df51 100644 --- a/api/core/tools/provider/builtin/json_process/tools/insert.py +++ b/api/core/tools/provider/builtin/json_process/tools/insert.py @@ -8,46 +8,49 @@ class JSONParseTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # get query - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Invalid parameter query') + return self.create_text_message("Invalid parameter query") # get new value - new_value = tool_parameters.get('new_value', '') + new_value = tool_parameters.get("new_value", "") if not new_value: - return self.create_text_message('Invalid parameter new_value') + return self.create_text_message("Invalid parameter new_value") # get insert position - index = tool_parameters.get('index') + index = tool_parameters.get("index") # get create path - create_path = tool_parameters.get('create_path', False) + create_path = tool_parameters.get("create_path", False) # get value decode. # if true, it will be decoded to an dict - value_decode = tool_parameters.get('value_decode', False) + value_decode = tool_parameters.get("value_decode", False) - ensure_ascii = tool_parameters.get('ensure_ascii', True) + ensure_ascii = tool_parameters.get("ensure_ascii", True) try: result = self._insert(content, query, new_value, ensure_ascii, value_decode, index, create_path) return self.create_text_message(str(result)) except Exception: - return self.create_text_message('Failed to insert JSON content') + return self.create_text_message("Failed to insert JSON content") - def _insert(self, origin_json, query, new_value, ensure_ascii: bool, value_decode: bool, index=None, create_path=False): + def _insert( + self, origin_json, query, new_value, ensure_ascii: bool, value_decode: bool, index=None, create_path=False + ): try: input_data = json.loads(origin_json) expr = parse(query) @@ -61,13 +64,13 @@ def _insert(self, origin_json, query, new_value, ensure_ascii: bool, value_decod if not matches and create_path: # create new path - path_parts = query.strip('$').strip('.').split('.') + path_parts = query.strip("$").strip(".").split(".") current = input_data for i, part in enumerate(path_parts): - if '[' in part and ']' in part: + if "[" in part and "]" in part: # process array index - array_name, index = part.split('[') - index = int(index.rstrip(']')) + array_name, index = part.split("[") + index = int(index.rstrip("]")) if array_name not in current: current[array_name] = [] while len(current[array_name]) <= index: diff --git a/api/core/tools/provider/builtin/json_process/tools/parse.py b/api/core/tools/provider/builtin/json_process/tools/parse.py index ecd39113ae5498..37cae401533190 100644 --- a/api/core/tools/provider/builtin/json_process/tools/parse.py +++ b/api/core/tools/provider/builtin/json_process/tools/parse.py @@ -8,29 +8,30 @@ class JSONParseTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # get json filter - json_filter = tool_parameters.get('json_filter', '') + json_filter = tool_parameters.get("json_filter", "") if not json_filter: - return self.create_text_message('Invalid parameter json_filter') + return self.create_text_message("Invalid parameter json_filter") - ensure_ascii = tool_parameters.get('ensure_ascii', True) + ensure_ascii = tool_parameters.get("ensure_ascii", True) try: result = self._extract(content, json_filter, ensure_ascii) return self.create_text_message(str(result)) except Exception: - return self.create_text_message('Failed to extract JSON content') + return self.create_text_message("Failed to extract JSON content") # Extract data from JSON content def _extract(self, content: str, json_filter: str, ensure_ascii: bool) -> str: diff --git a/api/core/tools/provider/builtin/json_process/tools/replace.py b/api/core/tools/provider/builtin/json_process/tools/replace.py index b19198aa938942..383825c2d0b259 100644 --- a/api/core/tools/provider/builtin/json_process/tools/replace.py +++ b/api/core/tools/provider/builtin/json_process/tools/replace.py @@ -8,55 +8,60 @@ class JSONReplaceTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # get query - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Invalid parameter query') + return self.create_text_message("Invalid parameter query") # get replace value - replace_value = tool_parameters.get('replace_value', '') + replace_value = tool_parameters.get("replace_value", "") if not replace_value: - return self.create_text_message('Invalid parameter replace_value') + return self.create_text_message("Invalid parameter replace_value") # get replace model - replace_model = tool_parameters.get('replace_model', '') + replace_model = tool_parameters.get("replace_model", "") if not replace_model: - return self.create_text_message('Invalid parameter replace_model') + return self.create_text_message("Invalid parameter replace_model") # get value decode. # if true, it will be decoded to an dict - value_decode = tool_parameters.get('value_decode', False) + value_decode = tool_parameters.get("value_decode", False) - ensure_ascii = tool_parameters.get('ensure_ascii', True) + ensure_ascii = tool_parameters.get("ensure_ascii", True) try: - if replace_model == 'pattern': + if replace_model == "pattern": # get replace pattern - replace_pattern = tool_parameters.get('replace_pattern', '') + replace_pattern = tool_parameters.get("replace_pattern", "") if not replace_pattern: - return self.create_text_message('Invalid parameter replace_pattern') - result = self._replace_pattern(content, query, replace_pattern, replace_value, ensure_ascii, value_decode) - elif replace_model == 'key': + return self.create_text_message("Invalid parameter replace_pattern") + result = self._replace_pattern( + content, query, replace_pattern, replace_value, ensure_ascii, value_decode + ) + elif replace_model == "key": result = self._replace_key(content, query, replace_value, ensure_ascii) - elif replace_model == 'value': + elif replace_model == "value": result = self._replace_value(content, query, replace_value, ensure_ascii, value_decode) return self.create_text_message(str(result)) except Exception: - return self.create_text_message('Failed to replace JSON content') + return self.create_text_message("Failed to replace JSON content") # Replace pattern - def _replace_pattern(self, content: str, query: str, replace_pattern: str, replace_value: str, ensure_ascii: bool, value_decode: bool) -> str: + def _replace_pattern( + self, content: str, query: str, replace_pattern: str, replace_value: str, ensure_ascii: bool, value_decode: bool + ) -> str: try: input_data = json.loads(content) expr = parse(query) @@ -102,7 +107,9 @@ def _replace_key(self, content: str, query: str, replace_value: str, ensure_asci return str(e) # Replace value - def _replace_value(self, content: str, query: str, replace_value: str, ensure_ascii: bool, value_decode: bool) -> str: + def _replace_value( + self, content: str, query: str, replace_value: str, ensure_ascii: bool, value_decode: bool + ) -> str: try: input_data = json.loads(content) expr = parse(query) diff --git a/api/core/tools/provider/builtin/judge0ce/judge0ce.py b/api/core/tools/provider/builtin/judge0ce/judge0ce.py index bac6576797f067..50db74dd9ebced 100644 --- a/api/core/tools/provider/builtin/judge0ce/judge0ce.py +++ b/api/core/tools/provider/builtin/judge0ce/judge0ce.py @@ -13,7 +13,7 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "source_code": "print('hello world')", "language_id": 71, @@ -21,4 +21,3 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/judge0ce/tools/executeCode.py b/api/core/tools/provider/builtin/judge0ce/tools/executeCode.py index 6031687c03f48b..b8d654ff639575 100644 --- a/api/core/tools/provider/builtin/judge0ce/tools/executeCode.py +++ b/api/core/tools/provider/builtin/judge0ce/tools/executeCode.py @@ -9,11 +9,13 @@ class ExecuteCodeTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ invoke tools """ - api_key = self.runtime.credentials['X-RapidAPI-Key'] + api_key = self.runtime.credentials["X-RapidAPI-Key"] url = "https://judge0-ce.p.rapidapi.com/submissions" @@ -22,15 +24,15 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolIn headers = { "Content-Type": "application/json", "X-RapidAPI-Key": api_key, - "X-RapidAPI-Host": "judge0-ce.p.rapidapi.com" + "X-RapidAPI-Host": "judge0-ce.p.rapidapi.com", } payload = { - "language_id": tool_parameters['language_id'], - "source_code": tool_parameters['source_code'], - "stdin": tool_parameters.get('stdin', ''), - "expected_output": tool_parameters.get('expected_output', ''), - "additional_files": tool_parameters.get('additional_files', ''), + "language_id": tool_parameters["language_id"], + "source_code": tool_parameters["source_code"], + "stdin": tool_parameters.get("stdin", ""), + "expected_output": tool_parameters.get("expected_output", ""), + "additional_files": tool_parameters.get("additional_files", ""), } response = post(url, data=json.dumps(payload), headers=headers, params=querystring) @@ -38,22 +40,22 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolIn if response.status_code != 201: raise Exception(response.text) - token = response.json()['token'] + token = response.json()["token"] url = f"https://judge0-ce.p.rapidapi.com/submissions/{token}" - headers = { - "X-RapidAPI-Key": api_key - } - + headers = {"X-RapidAPI-Key": api_key} + response = requests.get(url, headers=headers) if response.status_code == 200: result = response.json() - return self.create_text_message(text=f"stdout: {result.get('stdout', '')}\n" - f"stderr: {result.get('stderr', '')}\n" - f"compile_output: {result.get('compile_output', '')}\n" - f"message: {result.get('message', '')}\n" - f"status: {result['status']['description']}\n" - f"time: {result.get('time', '')} seconds\n" - f"memory: {result.get('memory', '')} bytes") + return self.create_text_message( + text=f"stdout: {result.get('stdout', '')}\n" + f"stderr: {result.get('stderr', '')}\n" + f"compile_output: {result.get('compile_output', '')}\n" + f"message: {result.get('message', '')}\n" + f"status: {result['status']['description']}\n" + f"time: {result.get('time', '')} seconds\n" + f"memory: {result.get('memory', '')} bytes" + ) else: - return self.create_text_message(text=f"Error retrieving submission details: {response.text}") \ No newline at end of file + return self.create_text_message(text=f"Error retrieving submission details: {response.text}") diff --git a/api/core/tools/provider/builtin/lark_base/_assets/icon.png b/api/core/tools/provider/builtin/lark_base/_assets/icon.png new file mode 100644 index 00000000000000..036e586772ef50 Binary files /dev/null and b/api/core/tools/provider/builtin/lark_base/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/lark_base/lark_base.py b/api/core/tools/provider/builtin/lark_base/lark_base.py new file mode 100644 index 00000000000000..de9b3683119844 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_base/lark_base.py @@ -0,0 +1,7 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.utils.lark_api_utils import lark_auth + + +class LarkBaseProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + lark_auth(credentials) diff --git a/api/core/tools/provider/builtin/lark_base/lark_base.yaml b/api/core/tools/provider/builtin/lark_base/lark_base.yaml new file mode 100644 index 00000000000000..200b2e22cfa558 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_base/lark_base.yaml @@ -0,0 +1,36 @@ +identity: + author: Doug Lea + name: lark_base + label: + en_US: Lark Base + zh_Hans: Lark 多维表格 + description: + en_US: | + Lark base, requires the following permissions: bitable:app. + zh_Hans: | + Lark 多维表格,需要开通以下权限: bitable:app。 + icon: icon.png + tags: + - social + - productivity +credentials_for_provider: + app_id: + type: text-input + required: true + label: + en_US: APP ID + placeholder: + en_US: Please input your Lark app id + zh_Hans: 请输入你的 Lark app id + help: + en_US: Get your app_id and app_secret from Lark + zh_Hans: 从 Lark 获取您的 app_id 和 app_secret + url: https://open.larksuite.com/app + app_secret: + type: secret-input + required: true + label: + en_US: APP Secret + placeholder: + en_US: Please input your app secret + zh_Hans: 请输入你的 Lark app secret diff --git a/api/core/tools/provider/builtin/lark_base/tools/add_records.py b/api/core/tools/provider/builtin/lark_base/tools/add_records.py new file mode 100644 index 00000000000000..c46898062a8cc2 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_base/tools/add_records.py @@ -0,0 +1,21 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class AddRecordsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + table_id = tool_parameters.get("table_id") + table_name = tool_parameters.get("table_name") + records = tool_parameters.get("records") + user_id_type = tool_parameters.get("user_id_type", "open_id") + + res = client.add_records(app_token, table_id, table_name, records, user_id_type) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_base/tools/add_records.yaml b/api/core/tools/provider/builtin/lark_base/tools/add_records.yaml new file mode 100644 index 00000000000000..f2a93490dc0c31 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_base/tools/add_records.yaml @@ -0,0 +1,91 @@ +identity: + name: add_records + author: Doug Lea + label: + en_US: Add Records + zh_Hans: 新增多条记录 +description: + human: + en_US: Add Multiple Records to Multidimensional Table + zh_Hans: 在多维表格数据表中新增多条记录 + llm: A tool for adding multiple records to a multidimensional table. (在多维表格数据表中新增多条记录) +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: table_id + type: string + required: false + label: + en_US: table_id + zh_Hans: table_id + human_description: + en_US: Unique identifier for the multidimensional table data, either table_id or table_name must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + form: llm + + - name: table_name + type: string + required: false + label: + en_US: table_name + zh_Hans: table_name + human_description: + en_US: Name of the multidimensional table data, either table_name or table_id must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + form: llm + + - name: records + type: string + required: true + label: + en_US: records + zh_Hans: 记录列表 + human_description: + en_US: | + List of records to be added in this request. Example value: [{"multi-line-text":"text content","single_select":"option 1","date":1674206443000}] + For supported field types, refer to the integration guide (https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification). For data structures of different field types, refer to the data structure overview (https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure). + zh_Hans: | + 本次请求将要新增的记录列表,示例值:[{"多行文本":"文本内容","单选":"选项 1","日期":1674206443000}]。 + 当前接口支持的字段类型请参考接入指南(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification),不同类型字段的数据结构请参考数据结构概述(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure)。 + llm_description: | + 本次请求将要新增的记录列表,示例值:[{"多行文本":"文本内容","单选":"选项 1","日期":1674206443000}]。 + 当前接口支持的字段类型请参考接入指南(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification),不同类型字段的数据结构请参考数据结构概述(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure)。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form diff --git a/api/core/tools/provider/builtin/lark_base/tools/create_base.py b/api/core/tools/provider/builtin/lark_base/tools/create_base.py new file mode 100644 index 00000000000000..a857c6ced6f94b --- /dev/null +++ b/api/core/tools/provider/builtin/lark_base/tools/create_base.py @@ -0,0 +1,18 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class CreateBaseTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + name = tool_parameters.get("name") + folder_token = tool_parameters.get("folder_token") + + res = client.create_base(name, folder_token) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_base/tools/create_base.yaml b/api/core/tools/provider/builtin/lark_base/tools/create_base.yaml new file mode 100644 index 00000000000000..e622edf3362ba4 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_base/tools/create_base.yaml @@ -0,0 +1,42 @@ +identity: + name: create_base + author: Doug Lea + label: + en_US: Create Base + zh_Hans: 创建多维表格 +description: + human: + en_US: Create Multidimensional Table in Specified Directory + zh_Hans: 在指定目录下创建多维表格 + llm: A tool for creating a multidimensional table in a specified directory. (在指定目录下创建多维表格) +parameters: + - name: name + type: string + required: false + label: + en_US: name + zh_Hans: 多维表格 App 名字 + human_description: + en_US: | + Name of the multidimensional table App. Example value: "A new multidimensional table". + zh_Hans: 多维表格 App 名字,示例值:"一篇新的多维表格"。 + llm_description: 多维表格 App 名字,示例值:"一篇新的多维表格"。 + form: llm + + - name: folder_token + type: string + required: false + label: + en_US: folder_token + zh_Hans: 多维表格 App 归属文件夹 + human_description: + en_US: | + Folder where the multidimensional table App belongs. Default is empty, meaning the table will be created in the root directory of the cloud space. Example values: Lf8uf6BoAlWkUfdGtpMjUV0PpZd or https://lark-japan.jp.larksuite.com/drive/folder/Lf8uf6BoAlWkUfdGtpMjUV0PpZd. + The folder_token must be an existing folder and supports inputting folder token or folder URL. + zh_Hans: | + 多维表格 App 归属文件夹。默认为空,表示多维表格将被创建在云空间根目录。示例值: Lf8uf6BoAlWkUfdGtpMjUV0PpZd 或者 https://lark-japan.jp.larksuite.com/drive/folder/Lf8uf6BoAlWkUfdGtpMjUV0PpZd。 + folder_token 必须是已存在的文件夹,支持输入文件夹 token 或者文件夹 URL。 + llm_description: | + 多维表格 App 归属文件夹。默认为空,表示多维表格将被创建在云空间根目录。示例值: Lf8uf6BoAlWkUfdGtpMjUV0PpZd 或者 https://lark-japan.jp.larksuite.com/drive/folder/Lf8uf6BoAlWkUfdGtpMjUV0PpZd。 + folder_token 必须是已存在的文件夹,支持输入文件夹 token 或者文件夹 URL。 + form: llm diff --git a/api/core/tools/provider/builtin/lark_base/tools/create_table.py b/api/core/tools/provider/builtin/lark_base/tools/create_table.py new file mode 100644 index 00000000000000..aff7e715b73a73 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_base/tools/create_table.py @@ -0,0 +1,20 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class CreateTableTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + table_name = tool_parameters.get("table_name") + default_view_name = tool_parameters.get("default_view_name") + fields = tool_parameters.get("fields") + + res = client.create_table(app_token, table_name, default_view_name, fields) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_base/tools/create_table.yaml b/api/core/tools/provider/builtin/lark_base/tools/create_table.yaml new file mode 100644 index 00000000000000..8b1007b9a53166 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_base/tools/create_table.yaml @@ -0,0 +1,61 @@ +identity: + name: create_table + author: Doug Lea + label: + en_US: Create Table + zh_Hans: 新增数据表 +description: + human: + en_US: Add a Data Table to Multidimensional Table + zh_Hans: 在多维表格中新增一个数据表 + llm: A tool for adding a data table to a multidimensional table. (在多维表格中新增一个数据表) +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: table_name + type: string + required: true + label: + en_US: Table Name + zh_Hans: 数据表名称 + human_description: + en_US: | + The name of the data table, length range: 1 character to 100 characters. + zh_Hans: 数据表名称,长度范围:1 字符 ~ 100 字符。 + llm_description: 数据表名称,长度范围:1 字符 ~ 100 字符。 + form: llm + + - name: default_view_name + type: string + required: false + label: + en_US: Default View Name + zh_Hans: 默认表格视图的名称 + human_description: + en_US: The name of the default table view, defaults to "Table" if not filled. + zh_Hans: 默认表格视图的名称,不填则默认为"表格"。 + llm_description: 默认表格视图的名称,不填则默认为"表格"。 + form: llm + + - name: fields + type: string + required: true + label: + en_US: Initial Fields + zh_Hans: 初始字段 + human_description: + en_US: | + Initial fields of the data table, format: [ { "field_name": "Multi-line Text","type": 1 },{ "field_name": "Number","type": 2 },{ "field_name": "Single Select","type": 3 },{ "field_name": "Multiple Select","type": 4 },{ "field_name": "Date","type": 5 } ]. For field details, refer to: https://open.larkoffice.com/document/server-docs/docs/bitable-v1/app-table-field/guide + zh_Hans: 数据表的初始字段,格式为:[{"field_name":"多行文本","type":1},{"field_name":"数字","type":2},{"field_name":"单选","type":3},{"field_name":"多选","type":4},{"field_name":"日期","type":5}]。字段详情参考:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/app-table-field/guide + llm_description: 数据表的初始字段,格式为:[{"field_name":"多行文本","type":1},{"field_name":"数字","type":2},{"field_name":"单选","type":3},{"field_name":"多选","type":4},{"field_name":"日期","type":5}]。字段详情参考:https://open.larkoffice.com/document/server-docs/docs/bitable-v1/app-table-field/guide + form: llm diff --git a/api/core/tools/provider/builtin/lark_base/tools/delete_records.py b/api/core/tools/provider/builtin/lark_base/tools/delete_records.py new file mode 100644 index 00000000000000..1b0a7470505e4d --- /dev/null +++ b/api/core/tools/provider/builtin/lark_base/tools/delete_records.py @@ -0,0 +1,20 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class DeleteRecordsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + table_id = tool_parameters.get("table_id") + table_name = tool_parameters.get("table_name") + record_ids = tool_parameters.get("record_ids") + + res = client.delete_records(app_token, table_id, table_name, record_ids) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_base/tools/delete_records.yaml b/api/core/tools/provider/builtin/lark_base/tools/delete_records.yaml new file mode 100644 index 00000000000000..c30ebd630ce9d8 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_base/tools/delete_records.yaml @@ -0,0 +1,86 @@ +identity: + name: delete_records + author: Doug Lea + label: + en_US: Delete Records + zh_Hans: 删除多条记录 +description: + human: + en_US: Delete Multiple Records from Multidimensional Table + zh_Hans: 删除多维表格数据表中的多条记录 + llm: A tool for deleting multiple records from a multidimensional table. (删除多维表格数据表中的多条记录) +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: table_id + type: string + required: false + label: + en_US: table_id + zh_Hans: table_id + human_description: + en_US: Unique identifier for the multidimensional table data, either table_id or table_name must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + form: llm + + - name: table_name + type: string + required: false + label: + en_US: table_name + zh_Hans: table_name + human_description: + en_US: Name of the multidimensional table data, either table_name or table_id must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + form: llm + + - name: record_ids + type: string + required: true + label: + en_US: Record IDs + zh_Hans: 记录 ID 列表 + human_description: + en_US: | + List of IDs for the records to be deleted, example value: ["recwNXzPQv"]. + zh_Hans: 删除的多条记录 ID 列表,示例值:["recwNXzPQv"]。 + llm_description: 删除的多条记录 ID 列表,示例值:["recwNXzPQv"]。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form diff --git a/api/core/tools/provider/builtin/lark_base/tools/delete_tables.py b/api/core/tools/provider/builtin/lark_base/tools/delete_tables.py new file mode 100644 index 00000000000000..e0ecae2f175050 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_base/tools/delete_tables.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class DeleteTablesTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + table_ids = tool_parameters.get("table_ids") + table_names = tool_parameters.get("table_names") + + res = client.delete_tables(app_token, table_ids, table_names) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_base/tools/delete_tables.yaml b/api/core/tools/provider/builtin/lark_base/tools/delete_tables.yaml new file mode 100644 index 00000000000000..498126eae53302 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_base/tools/delete_tables.yaml @@ -0,0 +1,49 @@ +identity: + name: delete_tables + author: Doug Lea + label: + en_US: Delete Tables + zh_Hans: 删除数据表 +description: + human: + en_US: Batch Delete Data Tables from Multidimensional Table + zh_Hans: 批量删除多维表格中的数据表 + llm: A tool for batch deleting data tables from a multidimensional table. (批量删除多维表格中的数据表) +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: table_ids + type: string + required: false + label: + en_US: Table IDs + zh_Hans: 数据表 ID + human_description: + en_US: | + IDs of the tables to be deleted. Each operation supports deleting up to 50 tables. Example: ["tbl1TkhyTWDkSoZ3"]. Ensure that either table_ids or table_names is not empty. + zh_Hans: 待删除的数据表的 ID,每次操作最多支持删除 50 个数据表。示例值:["tbl1TkhyTWDkSoZ3"]。请确保 table_ids 和 table_names 至少有一个不为空。 + llm_description: 待删除的数据表的 ID,每次操作最多支持删除 50 个数据表。示例值:["tbl1TkhyTWDkSoZ3"]。请确保 table_ids 和 table_names 至少有一个不为空。 + form: llm + + - name: table_names + type: string + required: false + label: + en_US: Table Names + zh_Hans: 数据表名称 + human_description: + en_US: | + Names of the tables to be deleted. Each operation supports deleting up to 50 tables. Example: ["Table1", "Table2"]. Ensure that either table_names or table_ids is not empty. + zh_Hans: 待删除的数据表的名称,每次操作最多支持删除 50 个数据表。示例值:["数据表1", "数据表2"]。请确保 table_names 和 table_ids 至少有一个不为空。 + llm_description: 待删除的数据表的名称,每次操作最多支持删除 50 个数据表。示例值:["数据表1", "数据表2"]。请确保 table_names 和 table_ids 至少有一个不为空。 + form: llm diff --git a/api/core/tools/provider/builtin/lark_base/tools/get_base_info.py b/api/core/tools/provider/builtin/lark_base/tools/get_base_info.py new file mode 100644 index 00000000000000..2c23248b88765a --- /dev/null +++ b/api/core/tools/provider/builtin/lark_base/tools/get_base_info.py @@ -0,0 +1,17 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class GetBaseInfoTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + + res = client.get_base_info(app_token) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_base/tools/get_base_info.yaml b/api/core/tools/provider/builtin/lark_base/tools/get_base_info.yaml new file mode 100644 index 00000000000000..eb0e7a26c06a55 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_base/tools/get_base_info.yaml @@ -0,0 +1,23 @@ +identity: + name: get_base_info + author: Doug Lea + label: + en_US: Get Base Info + zh_Hans: 获取多维表格元数据 +description: + human: + en_US: Get Metadata Information of Specified Multidimensional Table + zh_Hans: 获取指定多维表格的元数据信息 + llm: A tool for getting metadata information of a specified multidimensional table. (获取指定多维表格的元数据信息) +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm diff --git a/api/core/tools/provider/builtin/lark_base/tools/list_tables.py b/api/core/tools/provider/builtin/lark_base/tools/list_tables.py new file mode 100644 index 00000000000000..55b706854b2735 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_base/tools/list_tables.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class ListTablesTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + page_token = tool_parameters.get("page_token") + page_size = tool_parameters.get("page_size", 20) + + res = client.list_tables(app_token, page_token, page_size) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_base/tools/list_tables.yaml b/api/core/tools/provider/builtin/lark_base/tools/list_tables.yaml new file mode 100644 index 00000000000000..7571519039bd24 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_base/tools/list_tables.yaml @@ -0,0 +1,50 @@ +identity: + name: list_tables + author: Doug Lea + label: + en_US: List Tables + zh_Hans: 列出数据表 +description: + human: + en_US: Get All Data Tables under Multidimensional Table + zh_Hans: 获取多维表格下的所有数据表 + llm: A tool for getting all data tables under a multidimensional table. (获取多维表格下的所有数据表) +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: page_size + type: number + required: false + default: 20 + label: + en_US: page_size + zh_Hans: 分页大小 + human_description: + en_US: | + Page size, default value: 20, maximum value: 100. + zh_Hans: 分页大小,默认值:20,最大值:100。 + llm_description: 分页大小,默认值:20,最大值:100。 + form: form + + - name: page_token + type: string + required: false + label: + en_US: page_token + zh_Hans: 分页标记 + human_description: + en_US: | + Page token, leave empty for the first request to start from the beginning; a new page_token will be returned if there are more items in the paginated query results, which can be used for the next traversal. Example value: "tblsRc9GRRXKqhvW". + zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。示例值:"tblsRc9GRRXKqhvW"。 + llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。示例值:"tblsRc9GRRXKqhvW"。 + form: llm diff --git a/api/core/tools/provider/builtin/lark_base/tools/read_records.py b/api/core/tools/provider/builtin/lark_base/tools/read_records.py new file mode 100644 index 00000000000000..5cf25aad848dfa --- /dev/null +++ b/api/core/tools/provider/builtin/lark_base/tools/read_records.py @@ -0,0 +1,21 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class ReadRecordsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + table_id = tool_parameters.get("table_id") + table_name = tool_parameters.get("table_name") + record_ids = tool_parameters.get("record_ids") + user_id_type = tool_parameters.get("user_id_type", "open_id") + + res = client.read_records(app_token, table_id, table_name, record_ids, user_id_type) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_base/tools/read_records.yaml b/api/core/tools/provider/builtin/lark_base/tools/read_records.yaml new file mode 100644 index 00000000000000..911e667cfc90ad --- /dev/null +++ b/api/core/tools/provider/builtin/lark_base/tools/read_records.yaml @@ -0,0 +1,86 @@ +identity: + name: read_records + author: Doug Lea + label: + en_US: Read Records + zh_Hans: 批量获取记录 +description: + human: + en_US: Batch Retrieve Records from Multidimensional Table + zh_Hans: 批量获取多维表格数据表中的记录信息 + llm: A tool for batch retrieving records from a multidimensional table, supporting up to 100 records per call. (批量获取多维表格数据表中的记录信息,单次调用最多支持查询 100 条记录) + +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: table_id + type: string + required: false + label: + en_US: table_id + zh_Hans: table_id + human_description: + en_US: Unique identifier for the multidimensional table data, either table_id or table_name must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + form: llm + + - name: table_name + type: string + required: false + label: + en_US: table_name + zh_Hans: table_name + human_description: + en_US: Name of the multidimensional table data, either table_name or table_id must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + form: llm + + - name: record_ids + type: string + required: true + label: + en_US: record_ids + zh_Hans: 记录 ID 列表 + human_description: + en_US: List of record IDs, which can be obtained by calling the "Query Records API". + zh_Hans: 记录 ID 列表,可以通过调用"查询记录接口"获取。 + llm_description: 记录 ID 列表,可以通过调用"查询记录接口"获取。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form diff --git a/api/core/tools/provider/builtin/lark_base/tools/search_records.py b/api/core/tools/provider/builtin/lark_base/tools/search_records.py new file mode 100644 index 00000000000000..9b0abcf067951a --- /dev/null +++ b/api/core/tools/provider/builtin/lark_base/tools/search_records.py @@ -0,0 +1,39 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class SearchRecordsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + table_id = tool_parameters.get("table_id") + table_name = tool_parameters.get("table_name") + view_id = tool_parameters.get("view_id") + field_names = tool_parameters.get("field_names") + sort = tool_parameters.get("sort") + filters = tool_parameters.get("filter") + page_token = tool_parameters.get("page_token") + automatic_fields = tool_parameters.get("automatic_fields", False) + user_id_type = tool_parameters.get("user_id_type", "open_id") + page_size = tool_parameters.get("page_size", 20) + + res = client.search_record( + app_token, + table_id, + table_name, + view_id, + field_names, + sort, + filters, + page_token, + automatic_fields, + user_id_type, + page_size, + ) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_base/tools/search_records.yaml b/api/core/tools/provider/builtin/lark_base/tools/search_records.yaml new file mode 100644 index 00000000000000..edd86ab9d69686 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_base/tools/search_records.yaml @@ -0,0 +1,163 @@ +identity: + name: search_records + author: Doug Lea + label: + en_US: Search Records + zh_Hans: 查询记录 +description: + human: + en_US: Query records in a multidimensional table, up to 500 rows per query. + zh_Hans: 查询多维表格数据表中的记录,单次最多查询 500 行记录。 + llm: A tool for querying records in a multidimensional table, up to 500 rows per query. (查询多维表格数据表中的记录,单次最多查询 500 行记录) +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: table_id + type: string + required: false + label: + en_US: table_id + zh_Hans: table_id + human_description: + en_US: Unique identifier for the multidimensional table data, either table_id or table_name must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + form: llm + + - name: table_name + type: string + required: false + label: + en_US: table_name + zh_Hans: table_name + human_description: + en_US: Name of the multidimensional table data, either table_name or table_id must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + form: llm + + - name: view_id + type: string + required: false + label: + en_US: view_id + zh_Hans: 视图唯一标识 + human_description: + en_US: | + Unique identifier for a view in a multidimensional table. It can be found in the URL's query parameter with the key 'view'. For example: https://lark-japan.jp.larksuite.com/base/XXX0bfYEraW5OWsbhcFjEqj6pxh?table=tbl5I6jqwz8wBRMv&view=vewW5zXVEU. + zh_Hans: 多维表格中视图的唯一标识,可在多维表格的 URL 地址栏中找到,query 参数中 key 为 view 的部分。例如:https://lark-japan.jp.larksuite.com/base/XXX0bfYEraW5OWsbhcFjEqj6pxh?table=tbl5I6jqwz8wBRMv&view=vewW5zXVEU。 + llm_description: 多维表格中视图的唯一标识,可在多维表格的 URL 地址栏中找到,query 参数中 key 为 view 的部分。例如:https://lark-japan.jp.larksuite.com/base/XXX0bfYEraW5OWsbhcFjEqj6pxh?table=tbl5I6jqwz8wBRMv&view=vewW5zXVEU。 + form: llm + + - name: field_names + type: string + required: false + label: + en_US: field_names + zh_Hans: 字段名称 + human_description: + en_US: | + Field names to specify which fields to include in the returned records. Example value: ["Field1", "Field2"]. + zh_Hans: 字段名称,用于指定本次查询返回记录中包含的字段。示例值:["字段1","字段2"]。 + llm_description: 字段名称,用于指定本次查询返回记录中包含的字段。示例值:["字段1","字段2"]。 + form: llm + + - name: sort + type: string + required: false + label: + en_US: sort + zh_Hans: 排序条件 + human_description: + en_US: | + Sorting conditions, for example: [{"field_name":"Multiline Text","desc":true}]. + zh_Hans: 排序条件,例如:[{"field_name":"多行文本","desc":true}]。 + llm_description: 排序条件,例如:[{"field_name":"多行文本","desc":true}]。 + form: llm + + - name: filter + type: string + required: false + label: + en_US: filter + zh_Hans: 筛选条件 + human_description: + en_US: Object containing filter information. For details on how to fill in the filter, refer to the record filter parameter guide (https://open.larkoffice.com/document/uAjLw4CM/ukTMukTMukTM/reference/bitable-v1/app-table-record/record-filter-guide). + zh_Hans: 包含条件筛选信息的对象。了解如何填写 filter,参考记录筛选参数填写指南(https://open.larkoffice.com/document/uAjLw4CM/ukTMukTMukTM/reference/bitable-v1/app-table-record/record-filter-guide)。 + llm_description: 包含条件筛选信息的对象。了解如何填写 filter,参考记录筛选参数填写指南(https://open.larkoffice.com/document/uAjLw4CM/ukTMukTMukTM/reference/bitable-v1/app-table-record/record-filter-guide)。 + form: llm + + - name: automatic_fields + type: boolean + required: false + label: + en_US: automatic_fields + zh_Hans: automatic_fields + human_description: + en_US: Whether to return automatically calculated fields. Default is false, meaning they are not returned. + zh_Hans: 是否返回自动计算的字段。默认为 false,表示不返回。 + llm_description: 是否返回自动计算的字段。默认为 false,表示不返回。 + form: form + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form + + - name: page_size + type: number + required: false + default: 20 + label: + en_US: page_size + zh_Hans: 分页大小 + human_description: + en_US: | + Page size, default value: 20, maximum value: 500. + zh_Hans: 分页大小,默认值:20,最大值:500。 + llm_description: 分页大小,默认值:20,最大值:500。 + form: form + + - name: page_token + type: string + required: false + label: + en_US: page_token + zh_Hans: 分页标记 + human_description: + en_US: | + Page token, leave empty for the first request to start from the beginning; a new page_token will be returned if there are more items in the paginated query results, which can be used for the next traversal. Example value: "tblsRc9GRRXKqhvW". + zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。示例值:"tblsRc9GRRXKqhvW"。 + llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。示例值:"tblsRc9GRRXKqhvW"。 + form: llm diff --git a/api/core/tools/provider/builtin/lark_base/tools/update_records.py b/api/core/tools/provider/builtin/lark_base/tools/update_records.py new file mode 100644 index 00000000000000..7c263df2bb031c --- /dev/null +++ b/api/core/tools/provider/builtin/lark_base/tools/update_records.py @@ -0,0 +1,21 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class UpdateRecordsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + app_token = tool_parameters.get("app_token") + table_id = tool_parameters.get("table_id") + table_name = tool_parameters.get("table_name") + records = tool_parameters.get("records") + user_id_type = tool_parameters.get("user_id_type", "open_id") + + res = client.update_records(app_token, table_id, table_name, records, user_id_type) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_base/tools/update_records.yaml b/api/core/tools/provider/builtin/lark_base/tools/update_records.yaml new file mode 100644 index 00000000000000..68117e71367892 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_base/tools/update_records.yaml @@ -0,0 +1,91 @@ +identity: + name: update_records + author: Doug Lea + label: + en_US: Update Records + zh_Hans: 更新多条记录 +description: + human: + en_US: Update Multiple Records in Multidimensional Table + zh_Hans: 更新多维表格数据表中的多条记录 + llm: A tool for updating multiple records in a multidimensional table. (更新多维表格数据表中的多条记录) +parameters: + - name: app_token + type: string + required: true + label: + en_US: app_token + zh_Hans: app_token + human_description: + en_US: Unique identifier for the multidimensional table, supports inputting document URL. + zh_Hans: 多维表格的唯一标识符,支持输入文档 URL。 + llm_description: 多维表格的唯一标识符,支持输入文档 URL。 + form: llm + + - name: table_id + type: string + required: false + label: + en_US: table_id + zh_Hans: table_id + human_description: + en_US: Unique identifier for the multidimensional table data, either table_id or table_name must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的唯一标识符,table_id 和 table_name 至少需要提供一个,不能同时为空。 + form: llm + + - name: table_name + type: string + required: false + label: + en_US: table_name + zh_Hans: table_name + human_description: + en_US: Name of the multidimensional table data, either table_name or table_id must be provided, cannot be empty simultaneously. + zh_Hans: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + llm_description: 多维表格数据表的名称,table_name 和 table_id 至少需要提供一个,不能同时为空。 + form: llm + + - name: records + type: string + required: true + label: + en_US: records + zh_Hans: 记录列表 + human_description: + en_US: | + List of records to be updated in this request. Example value: [{"fields":{"multi-line-text":"text content","single_select":"option 1","date":1674206443000},"record_id":"recupK4f4RM5RX"}]. + For supported field types, refer to the integration guide (https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification). For data structures of different field types, refer to the data structure overview (https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure). + zh_Hans: | + 本次请求将要更新的记录列表,示例值:[{"fields":{"多行文本":"文本内容","单选":"选项 1","日期":1674206443000},"record_id":"recupK4f4RM5RX"}]。 + 当前接口支持的字段类型请参考接入指南(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification),不同类型字段的数据结构请参考数据结构概述(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure)。 + llm_description: | + 本次请求将要更新的记录列表,示例值:[{"fields":{"多行文本":"文本内容","单选":"选项 1","日期":1674206443000},"record_id":"recupK4f4RM5RX"}]。 + 当前接口支持的字段类型请参考接入指南(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/notification),不同类型字段的数据结构请参考数据结构概述(https://open.larkoffice.com/document/server-docs/docs/bitable-v1/bitable-structure)。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form diff --git a/api/core/tools/provider/builtin/lark_calendar/_assets/icon.png b/api/core/tools/provider/builtin/lark_calendar/_assets/icon.png new file mode 100644 index 00000000000000..2a934747a98c66 Binary files /dev/null and b/api/core/tools/provider/builtin/lark_calendar/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/lark_calendar/lark_calendar.py b/api/core/tools/provider/builtin/lark_calendar/lark_calendar.py new file mode 100644 index 00000000000000..871de69cc15b39 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_calendar/lark_calendar.py @@ -0,0 +1,7 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.utils.lark_api_utils import lark_auth + + +class LarkCalendarProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + lark_auth(credentials) diff --git a/api/core/tools/provider/builtin/lark_calendar/lark_calendar.yaml b/api/core/tools/provider/builtin/lark_calendar/lark_calendar.yaml new file mode 100644 index 00000000000000..72c41e36c0ebd3 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_calendar/lark_calendar.yaml @@ -0,0 +1,36 @@ +identity: + author: Doug Lea + name: lark_calendar + label: + en_US: Lark Calendar + zh_Hans: Lark 日历 + description: + en_US: | + Lark calendar, requires the following permissions: calendar:calendar:read、calendar:calendar、contact:user.id:readonly. + zh_Hans: | + Lark 日历,需要开通以下权限: calendar:calendar:read、calendar:calendar、contact:user.id:readonly。 + icon: icon.png + tags: + - social + - productivity +credentials_for_provider: + app_id: + type: text-input + required: true + label: + en_US: APP ID + placeholder: + en_US: Please input your Lark app id + zh_Hans: 请输入你的 Lark app id + help: + en_US: Get your app_id and app_secret from Lark + zh_Hans: 从 Lark 获取您的 app_id 和 app_secret + url: https://open.larksuite.com/app + app_secret: + type: secret-input + required: true + label: + en_US: APP Secret + placeholder: + en_US: Please input your app secret + zh_Hans: 请输入你的 Lark app secret diff --git a/api/core/tools/provider/builtin/lark_calendar/tools/add_event_attendees.py b/api/core/tools/provider/builtin/lark_calendar/tools/add_event_attendees.py new file mode 100644 index 00000000000000..f5929893ddfe24 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_calendar/tools/add_event_attendees.py @@ -0,0 +1,20 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class AddEventAttendeesTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + event_id = tool_parameters.get("event_id") + attendee_phone_or_email = tool_parameters.get("attendee_phone_or_email") + need_notification = tool_parameters.get("need_notification", True) + + res = client.add_event_attendees(event_id, attendee_phone_or_email, need_notification) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_calendar/tools/add_event_attendees.yaml b/api/core/tools/provider/builtin/lark_calendar/tools/add_event_attendees.yaml new file mode 100644 index 00000000000000..9d7a1319072d6f --- /dev/null +++ b/api/core/tools/provider/builtin/lark_calendar/tools/add_event_attendees.yaml @@ -0,0 +1,54 @@ +identity: + name: add_event_attendees + author: Doug Lea + label: + en_US: Add Event Attendees + zh_Hans: 添加日程参会人 +description: + human: + en_US: Add Event Attendees + zh_Hans: 添加日程参会人 + llm: A tool for adding attendees to events in Lark. (在 Lark 中添加日程参会人) +parameters: + - name: event_id + type: string + required: true + label: + en_US: Event ID + zh_Hans: 日程 ID + human_description: + en_US: | + The ID of the event, which will be returned when the event is created. For example: fb2a6406-26d6-4c8d-a487-6f0246c94d2f_0. + zh_Hans: | + 创建日程时会返回日程 ID。例如: fb2a6406-26d6-4c8d-a487-6f0246c94d2f_0。 + llm_description: | + 日程 ID,创建日程时会返回日程 ID。例如: fb2a6406-26d6-4c8d-a487-6f0246c94d2f_0。 + form: llm + + - name: need_notification + type: boolean + required: false + default: true + label: + en_US: Need Notification + zh_Hans: 是否需要通知 + human_description: + en_US: | + Whether to send a Bot notification to attendees. true: send, false: do not send. + zh_Hans: | + 是否给参与人发送 Bot 通知,true: 发送,false: 不发送。 + llm_description: | + 是否给参与人发送 Bot 通知,true: 发送,false: 不发送。 + form: form + + - name: attendee_phone_or_email + type: string + required: true + label: + en_US: Attendee Phone or Email + zh_Hans: 参会人电话或邮箱 + human_description: + en_US: The list of attendee emails or phone numbers, separated by commas. + zh_Hans: 日程参会人邮箱或者手机号列表,使用逗号分隔。 + llm_description: 日程参会人邮箱或者手机号列表,使用逗号分隔。 + form: llm diff --git a/api/core/tools/provider/builtin/lark_calendar/tools/create_event.py b/api/core/tools/provider/builtin/lark_calendar/tools/create_event.py new file mode 100644 index 00000000000000..8a0726008c3f8b --- /dev/null +++ b/api/core/tools/provider/builtin/lark_calendar/tools/create_event.py @@ -0,0 +1,26 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class CreateEventTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + summary = tool_parameters.get("summary") + description = tool_parameters.get("description") + start_time = tool_parameters.get("start_time") + end_time = tool_parameters.get("end_time") + attendee_ability = tool_parameters.get("attendee_ability") + need_notification = tool_parameters.get("need_notification", True) + auto_record = tool_parameters.get("auto_record", False) + + res = client.create_event( + summary, description, start_time, end_time, attendee_ability, need_notification, auto_record + ) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_calendar/tools/create_event.yaml b/api/core/tools/provider/builtin/lark_calendar/tools/create_event.yaml new file mode 100644 index 00000000000000..b738736e630fa5 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_calendar/tools/create_event.yaml @@ -0,0 +1,119 @@ +identity: + name: create_event + author: Doug Lea + label: + en_US: Create Event + zh_Hans: 创建日程 +description: + human: + en_US: Create Event + zh_Hans: 创建日程 + llm: A tool for creating events in Lark.(创建 Lark 日程) +parameters: + - name: summary + type: string + required: false + label: + en_US: Summary + zh_Hans: 日程标题 + human_description: + en_US: The title of the event. If not filled, the event title will display (No Subject). + zh_Hans: 日程标题,若不填则日程标题显示 (无主题)。 + llm_description: 日程标题,若不填则日程标题显示 (无主题)。 + form: llm + + - name: description + type: string + required: false + label: + en_US: Description + zh_Hans: 日程描述 + human_description: + en_US: The description of the event. + zh_Hans: 日程描述。 + llm_description: 日程描述。 + form: llm + + - name: need_notification + type: boolean + required: false + default: true + label: + en_US: Need Notification + zh_Hans: 是否发送通知 + human_description: + en_US: | + Whether to send a bot message when the event is created, true: send, false: do not send. + zh_Hans: 创建日程时是否发送 bot 消息,true:发送,false:不发送。 + llm_description: 创建日程时是否发送 bot 消息,true:发送,false:不发送。 + form: form + + - name: start_time + type: string + required: true + label: + en_US: Start Time + zh_Hans: 开始时间 + human_description: + en_US: | + The start time of the event, format: 2006-01-02 15:04:05. + zh_Hans: 日程开始时间,格式:2006-01-02 15:04:05。 + llm_description: 日程开始时间,格式:2006-01-02 15:04:05。 + form: llm + + - name: end_time + type: string + required: true + label: + en_US: End Time + zh_Hans: 结束时间 + human_description: + en_US: | + The end time of the event, format: 2006-01-02 15:04:05. + zh_Hans: 日程结束时间,格式:2006-01-02 15:04:05。 + llm_description: 日程结束时间,格式:2006-01-02 15:04:05。 + form: llm + + - name: attendee_ability + type: select + required: false + options: + - value: none + label: + en_US: none + zh_Hans: 无 + - value: can_see_others + label: + en_US: can_see_others + zh_Hans: 可以查看参与人列表 + - value: can_invite_others + label: + en_US: can_invite_others + zh_Hans: 可以邀请其它参与人 + - value: can_modify_event + label: + en_US: can_modify_event + zh_Hans: 可以编辑日程 + default: "none" + label: + en_US: attendee_ability + zh_Hans: 参会人权限 + human_description: + en_US: Attendee ability, optional values are none, can_see_others, can_invite_others, can_modify_event, with a default value of none. + zh_Hans: 参会人权限,可选值有无、可以查看参与人列表、可以邀请其它参与人、可以编辑日程,默认值为无。 + llm_description: 参会人权限,可选值有无、可以查看参与人列表、可以邀请其它参与人、可以编辑日程,默认值为无。 + form: form + + - name: auto_record + type: boolean + required: false + default: false + label: + en_US: Auto Record + zh_Hans: 自动录制 + human_description: + en_US: | + Whether to enable automatic recording, true: enabled, automatically record when the meeting starts; false: not enabled. + zh_Hans: 是否开启自动录制,true:开启,会议开始后自动录制;false:不开启。 + llm_description: 是否开启自动录制,true:开启,会议开始后自动录制;false:不开启。 + form: form diff --git a/api/core/tools/provider/builtin/lark_calendar/tools/delete_event.py b/api/core/tools/provider/builtin/lark_calendar/tools/delete_event.py new file mode 100644 index 00000000000000..0e4ceac5e5d070 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_calendar/tools/delete_event.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class DeleteEventTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + event_id = tool_parameters.get("event_id") + need_notification = tool_parameters.get("need_notification", True) + + res = client.delete_event(event_id, need_notification) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_calendar/tools/delete_event.yaml b/api/core/tools/provider/builtin/lark_calendar/tools/delete_event.yaml new file mode 100644 index 00000000000000..cdd6d7e1bb024a --- /dev/null +++ b/api/core/tools/provider/builtin/lark_calendar/tools/delete_event.yaml @@ -0,0 +1,38 @@ +identity: + name: delete_event + author: Doug Lea + label: + en_US: Delete Event + zh_Hans: 删除日程 +description: + human: + en_US: Delete Event + zh_Hans: 删除日程 + llm: A tool for deleting events in Lark.(在 Lark 中删除日程) +parameters: + - name: event_id + type: string + required: true + label: + en_US: Event ID + zh_Hans: 日程 ID + human_description: + en_US: | + The ID of the event, for example: e8b9791c-39ae-4908-8ad8-66b13159b9fb_0. + zh_Hans: 日程 ID,例如:e8b9791c-39ae-4908-8ad8-66b13159b9fb_0。 + llm_description: 日程 ID,例如:e8b9791c-39ae-4908-8ad8-66b13159b9fb_0。 + form: llm + + - name: need_notification + type: boolean + required: false + default: true + label: + en_US: Need Notification + zh_Hans: 是否需要通知 + human_description: + en_US: | + Indicates whether to send bot notifications to event participants upon deletion. true: send, false: do not send. + zh_Hans: 删除日程是否给日程参与人发送 bot 通知,true:发送,false:不发送。 + llm_description: 删除日程是否给日程参与人发送 bot 通知,true:发送,false:不发送。 + form: form diff --git a/api/core/tools/provider/builtin/lark_calendar/tools/get_primary_calendar.py b/api/core/tools/provider/builtin/lark_calendar/tools/get_primary_calendar.py new file mode 100644 index 00000000000000..d315bf35f05d98 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_calendar/tools/get_primary_calendar.py @@ -0,0 +1,18 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class GetPrimaryCalendarTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + user_id_type = tool_parameters.get("user_id_type", "open_id") + + res = client.get_primary_calendar(user_id_type) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_calendar/tools/get_primary_calendar.yaml b/api/core/tools/provider/builtin/lark_calendar/tools/get_primary_calendar.yaml new file mode 100644 index 00000000000000..fe615947700995 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_calendar/tools/get_primary_calendar.yaml @@ -0,0 +1,37 @@ +identity: + name: get_primary_calendar + author: Doug Lea + label: + en_US: Get Primary Calendar + zh_Hans: 查询主日历信息 +description: + human: + en_US: Get Primary Calendar + zh_Hans: 查询主日历信息 + llm: A tool for querying primary calendar information in Lark.(在 Lark 中查询主日历信息) +parameters: + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form diff --git a/api/core/tools/provider/builtin/lark_calendar/tools/list_events.py b/api/core/tools/provider/builtin/lark_calendar/tools/list_events.py new file mode 100644 index 00000000000000..d74cc049d34230 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_calendar/tools/list_events.py @@ -0,0 +1,21 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class ListEventsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + start_time = tool_parameters.get("start_time") + end_time = tool_parameters.get("end_time") + page_token = tool_parameters.get("page_token") + page_size = tool_parameters.get("page_size") + + res = client.list_events(start_time, end_time, page_token, page_size) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_calendar/tools/list_events.yaml b/api/core/tools/provider/builtin/lark_calendar/tools/list_events.yaml new file mode 100644 index 00000000000000..cef332f5272e55 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_calendar/tools/list_events.yaml @@ -0,0 +1,62 @@ +identity: + name: list_events + author: Doug Lea + label: + en_US: List Events + zh_Hans: 获取日程列表 +description: + human: + en_US: List Events + zh_Hans: 获取日程列表 + llm: A tool for listing events in Lark.(在 Lark 中获取日程列表) +parameters: + - name: start_time + type: string + required: false + label: + en_US: Start Time + zh_Hans: 开始时间 + human_description: + en_US: | + The start time, defaults to 0:00 of the current day if not provided, format: 2006-01-02 15:04:05. + zh_Hans: 开始时间,不传值时默认当天 0 点时间,格式为:2006-01-02 15:04:05。 + llm_description: 开始时间,不传值时默认当天 0 点时间,格式为:2006-01-02 15:04:05。 + form: llm + + - name: end_time + type: string + required: false + label: + en_US: End Time + zh_Hans: 结束时间 + human_description: + en_US: | + The end time, defaults to 23:59 of the current day if not provided, format: 2006-01-02 15:04:05. + zh_Hans: 结束时间,不传值时默认当天 23:59 分时间,格式为:2006-01-02 15:04:05。 + llm_description: 结束时间,不传值时默认当天 23:59 分时间,格式为:2006-01-02 15:04:05。 + form: llm + + - name: page_size + type: number + required: false + default: 50 + label: + en_US: Page Size + zh_Hans: 分页大小 + human_description: + en_US: The page size, i.e., the number of data entries returned in a single request. The default value is 50, and the value range is [50,1000]. + zh_Hans: 分页大小,即单次请求所返回的数据条目数。默认值为 50,取值范围为 [50,1000]。 + llm_description: 分页大小,即单次请求所返回的数据条目数。默认值为 50,取值范围为 [50,1000]。 + form: form + + - name: page_token + type: string + required: false + label: + en_US: Page Token + zh_Hans: 分页标记 + human_description: + en_US: The pagination token. Leave it blank for the first request, indicating to start traversing from the beginning; when the pagination query result has more items, a new page_token will be returned simultaneously, which can be used to obtain the query result in the next traversal. + zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + form: llm diff --git a/api/core/tools/provider/builtin/lark_calendar/tools/search_events.py b/api/core/tools/provider/builtin/lark_calendar/tools/search_events.py new file mode 100644 index 00000000000000..a20038e47dd430 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_calendar/tools/search_events.py @@ -0,0 +1,23 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class SearchEventsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + query = tool_parameters.get("query") + start_time = tool_parameters.get("start_time") + end_time = tool_parameters.get("end_time") + page_token = tool_parameters.get("page_token") + user_id_type = tool_parameters.get("user_id_type", "open_id") + page_size = tool_parameters.get("page_size", 20) + + res = client.search_events(query, start_time, end_time, page_token, user_id_type, page_size) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_calendar/tools/search_events.yaml b/api/core/tools/provider/builtin/lark_calendar/tools/search_events.yaml new file mode 100644 index 00000000000000..4d4f8819c11e4d --- /dev/null +++ b/api/core/tools/provider/builtin/lark_calendar/tools/search_events.yaml @@ -0,0 +1,100 @@ +identity: + name: search_events + author: Doug Lea + label: + en_US: Search Events + zh_Hans: 搜索日程 +description: + human: + en_US: Search Events + zh_Hans: 搜索日程 + llm: A tool for searching events in Lark.(在 Lark 中搜索日程) +parameters: + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form + + - name: query + type: string + required: true + label: + en_US: Query + zh_Hans: 搜索关键字 + human_description: + en_US: The search keyword used for fuzzy searching event names, with a maximum input of 200 characters. + zh_Hans: 用于模糊查询日程名称的搜索关键字,最大输入 200 字符。 + llm_description: 用于模糊查询日程名称的搜索关键字,最大输入 200 字符。 + form: llm + + - name: start_time + type: string + required: false + label: + en_US: Start Time + zh_Hans: 开始时间 + human_description: + en_US: | + The start time, defaults to 0:00 of the current day if not provided, format: 2006-01-02 15:04:05. + zh_Hans: 开始时间,不传值时默认当天 0 点时间,格式为:2006-01-02 15:04:05。 + llm_description: 开始时间,不传值时默认当天 0 点时间,格式为:2006-01-02 15:04:05。 + form: llm + + - name: end_time + type: string + required: false + label: + en_US: End Time + zh_Hans: 结束时间 + human_description: + en_US: | + The end time, defaults to 23:59 of the current day if not provided, format: 2006-01-02 15:04:05. + zh_Hans: 结束时间,不传值时默认当天 23:59 分时间,格式为:2006-01-02 15:04:05。 + llm_description: 结束时间,不传值时默认当天 23:59 分时间,格式为:2006-01-02 15:04:05。 + form: llm + + - name: page_size + type: number + required: false + default: 20 + label: + en_US: Page Size + zh_Hans: 分页大小 + human_description: + en_US: The page size, i.e., the number of data entries returned in a single request. The default value is 20, and the value range is [10,100]. + zh_Hans: 分页大小,即单次请求所返回的数据条目数。默认值为 20,取值范围为 [10,100]。 + llm_description: 分页大小,即单次请求所返回的数据条目数。默认值为 20,取值范围为 [10,100]。 + form: form + + - name: page_token + type: string + required: false + label: + en_US: Page Token + zh_Hans: 分页标记 + human_description: + en_US: The pagination token. Leave it blank for the first request, indicating to start traversing from the beginning; when the pagination query result has more items, a new page_token will be returned simultaneously, which can be used to obtain the query result in the next traversal. + zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + form: llm diff --git a/api/core/tools/provider/builtin/lark_calendar/tools/update_event.py b/api/core/tools/provider/builtin/lark_calendar/tools/update_event.py new file mode 100644 index 00000000000000..a04029377f6799 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_calendar/tools/update_event.py @@ -0,0 +1,24 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class UpdateEventTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + event_id = tool_parameters.get("event_id") + summary = tool_parameters.get("summary") + description = tool_parameters.get("description") + need_notification = tool_parameters.get("need_notification", True) + start_time = tool_parameters.get("start_time") + end_time = tool_parameters.get("end_time") + auto_record = tool_parameters.get("auto_record", False) + + res = client.update_event(event_id, summary, description, need_notification, start_time, end_time, auto_record) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_calendar/tools/update_event.yaml b/api/core/tools/provider/builtin/lark_calendar/tools/update_event.yaml new file mode 100644 index 00000000000000..b9992e5b03f944 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_calendar/tools/update_event.yaml @@ -0,0 +1,100 @@ +identity: + name: update_event + author: Doug Lea + label: + en_US: Update Event + zh_Hans: 更新日程 +description: + human: + en_US: Update Event + zh_Hans: 更新日程 + llm: A tool for updating events in Lark.(更新 Lark 中的日程) +parameters: + - name: event_id + type: string + required: true + label: + en_US: Event ID + zh_Hans: 日程 ID + human_description: + en_US: | + The ID of the event, for example: e8b9791c-39ae-4908-8ad8-66b13159b9fb_0. + zh_Hans: 日程 ID,例如:e8b9791c-39ae-4908-8ad8-66b13159b9fb_0。 + llm_description: 日程 ID,例如:e8b9791c-39ae-4908-8ad8-66b13159b9fb_0。 + form: llm + + - name: summary + type: string + required: false + label: + en_US: Summary + zh_Hans: 日程标题 + human_description: + en_US: The title of the event. + zh_Hans: 日程标题。 + llm_description: 日程标题。 + form: llm + + - name: description + type: string + required: false + label: + en_US: Description + zh_Hans: 日程描述 + human_description: + en_US: The description of the event. + zh_Hans: 日程描述。 + llm_description: 日程描述。 + form: llm + + - name: need_notification + type: boolean + required: false + label: + en_US: Need Notification + zh_Hans: 是否发送通知 + human_description: + en_US: | + Whether to send a bot message when the event is updated, true: send, false: do not send. + zh_Hans: 更新日程时是否发送 bot 消息,true:发送,false:不发送。 + llm_description: 更新日程时是否发送 bot 消息,true:发送,false:不发送。 + form: form + + - name: start_time + type: string + required: false + label: + en_US: Start Time + zh_Hans: 开始时间 + human_description: + en_US: | + The start time of the event, format: 2006-01-02 15:04:05. + zh_Hans: 日程开始时间,格式:2006-01-02 15:04:05。 + llm_description: 日程开始时间,格式:2006-01-02 15:04:05。 + form: llm + + - name: end_time + type: string + required: false + label: + en_US: End Time + zh_Hans: 结束时间 + human_description: + en_US: | + The end time of the event, format: 2006-01-02 15:04:05. + zh_Hans: 日程结束时间,格式:2006-01-02 15:04:05。 + llm_description: 日程结束时间,格式:2006-01-02 15:04:05。 + form: llm + + - name: auto_record + type: boolean + required: false + label: + en_US: Auto Record + zh_Hans: 自动录制 + human_description: + en_US: | + Whether to enable automatic recording, true: enabled, automatically record when the meeting starts; false: not enabled. + zh_Hans: 是否开启自动录制,true:开启,会议开始后自动录制;false:不开启。 + llm_description: 是否开启自动录制,true:开启,会议开始后自动录制;false:不开启。 + form: form diff --git a/api/core/tools/provider/builtin/lark_document/_assets/icon.svg b/api/core/tools/provider/builtin/lark_document/_assets/icon.svg new file mode 100644 index 00000000000000..5a0a6416b3db32 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_document/_assets/icon.svg @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/api/core/tools/provider/builtin/lark_document/lark_document.py b/api/core/tools/provider/builtin/lark_document/lark_document.py new file mode 100644 index 00000000000000..b1283276028361 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_document/lark_document.py @@ -0,0 +1,7 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.utils.lark_api_utils import lark_auth + + +class LarkDocumentProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + lark_auth(credentials) diff --git a/api/core/tools/provider/builtin/lark_document/lark_document.yaml b/api/core/tools/provider/builtin/lark_document/lark_document.yaml new file mode 100644 index 00000000000000..0cb4ae1d62d3f8 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_document/lark_document.yaml @@ -0,0 +1,36 @@ +identity: + author: Doug Lea + name: lark_document + label: + en_US: Lark Cloud Document + zh_Hans: Lark 云文档 + description: + en_US: | + Lark cloud document, requires the following permissions: docx:document、drive:drive、docs:document.content:read. + zh_Hans: | + Lark 云文档,需要开通以下权限: docx:document、drive:drive、docs:document.content:read。 + icon: icon.svg + tags: + - social + - productivity +credentials_for_provider: + app_id: + type: text-input + required: true + label: + en_US: APP ID + placeholder: + en_US: Please input your Lark app id + zh_Hans: 请输入你的 Lark app id + help: + en_US: Get your app_id and app_secret from Lark + zh_Hans: 从 Lark 获取您的 app_id 和 app_secret + url: https://open.larksuite.com/app + app_secret: + type: secret-input + required: true + label: + en_US: APP Secret + placeholder: + en_US: Please input your app secret + zh_Hans: 请输入你的 Lark app secret diff --git a/api/core/tools/provider/builtin/lark_document/tools/create_document.py b/api/core/tools/provider/builtin/lark_document/tools/create_document.py new file mode 100644 index 00000000000000..2b1dae0db5578c --- /dev/null +++ b/api/core/tools/provider/builtin/lark_document/tools/create_document.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class CreateDocumentTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + title = tool_parameters.get("title") + content = tool_parameters.get("content") + folder_token = tool_parameters.get("folder_token") + + res = client.create_document(title, content, folder_token) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_document/tools/create_document.yaml b/api/core/tools/provider/builtin/lark_document/tools/create_document.yaml new file mode 100644 index 00000000000000..37a1e23041c6c9 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_document/tools/create_document.yaml @@ -0,0 +1,48 @@ +identity: + name: create_document + author: Doug Lea + label: + en_US: Create Lark document + zh_Hans: 创建 Lark 文档 +description: + human: + en_US: Create Lark document + zh_Hans: 创建 Lark 文档,支持创建空文档和带内容的文档,支持 markdown 语法创建。应用需要开启机器人能力(https://open.larksuite.com/document/faq/trouble-shooting/how-to-enable-bot-ability)。 + llm: A tool for creating Lark documents. +parameters: + - name: title + type: string + required: false + label: + en_US: Document title + zh_Hans: 文档标题 + human_description: + en_US: Document title, only supports plain text content. + zh_Hans: 文档标题,只支持纯文本内容。 + llm_description: 文档标题,只支持纯文本内容,可以为空。 + form: llm + + - name: content + type: string + required: false + label: + en_US: Document content + zh_Hans: 文档内容 + human_description: + en_US: Document content, supports markdown syntax, can be empty. + zh_Hans: 文档内容,支持 markdown 语法,可以为空。 + llm_description: 文档内容,支持 markdown 语法,可以为空。 + form: llm + + - name: folder_token + type: string + required: false + label: + en_US: folder_token + zh_Hans: 文档所在文件夹的 Token + human_description: + en_US: | + The token of the folder where the document is located. If it is not passed or is empty, it means the root directory. For Example: https://lark-japan.jp.larksuite.com/drive/folder/Lf8uf6BoAlWkUfdGtpMjUV0PpZd + zh_Hans: 文档所在文件夹的 Token,不传或传空表示根目录。例如:https://lark-japan.jp.larksuite.com/drive/folder/Lf8uf6BoAlWkUfdGtpMjUV0PpZd。 + llm_description: 文档所在文件夹的 Token,不传或传空表示根目录。例如:https://lark-japan.jp.larksuite.com/drive/folder/Lf8uf6BoAlWkUfdGtpMjUV0PpZd。 + form: llm diff --git a/api/core/tools/provider/builtin/lark_document/tools/get_document_content.py b/api/core/tools/provider/builtin/lark_document/tools/get_document_content.py new file mode 100644 index 00000000000000..d15211b57e7a76 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_document/tools/get_document_content.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class GetDocumentRawContentTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + document_id = tool_parameters.get("document_id") + mode = tool_parameters.get("mode", "markdown") + lang = tool_parameters.get("lang", "0") + + res = client.get_document_content(document_id, mode, lang) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_document/tools/get_document_content.yaml b/api/core/tools/provider/builtin/lark_document/tools/get_document_content.yaml new file mode 100644 index 00000000000000..fd6a033bfd6947 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_document/tools/get_document_content.yaml @@ -0,0 +1,70 @@ +identity: + name: get_document_content + author: Doug Lea + label: + en_US: Get Lark Cloud Document Content + zh_Hans: 获取 Lark 云文档的内容 +description: + human: + en_US: Get lark cloud document content + zh_Hans: 获取 Lark 云文档的内容 + llm: A tool for retrieving content from Lark cloud documents. +parameters: + - name: document_id + type: string + required: true + label: + en_US: document_id + zh_Hans: Lark 文档的唯一标识 + human_description: + en_US: Unique identifier for a Lark document. You can also input the document's URL. + zh_Hans: Lark 文档的唯一标识,支持输入文档的 URL。 + llm_description: Lark 文档的唯一标识,支持输入文档的 URL。 + form: llm + + - name: mode + type: select + required: false + options: + - value: text + label: + en_US: text + zh_Hans: text + - value: markdown + label: + en_US: markdown + zh_Hans: markdown + default: "markdown" + label: + en_US: mode + zh_Hans: 文档返回格式 + human_description: + en_US: Format of the document return, optional values are text, markdown, can be empty, default is markdown. + zh_Hans: 文档返回格式,可选值有 text、markdown,可以为空,默认值为 markdown。 + llm_description: 文档返回格式,可选值有 text、markdown,可以为空,默认值为 markdown。 + form: form + + - name: lang + type: select + required: false + options: + - value: "0" + label: + en_US: User's default name + zh_Hans: 用户的默认名称 + - value: "1" + label: + en_US: User's English name + zh_Hans: 用户的英文名称 + default: "0" + label: + en_US: lang + zh_Hans: 指定@用户的语言 + human_description: + en_US: | + Specifies the language for MentionUser, optional values are [0, 1]. 0: User's default name, 1: User's English name, default is 0. + zh_Hans: | + 指定返回的 MentionUser,即@用户的语言,可选值有 [0,1]。0: 该用户的默认名称,1: 该用户的英文名称,默认值为 0。 + llm_description: | + 指定返回的 MentionUser,即@用户的语言,可选值有 [0,1]。0: 该用户的默认名称,1: 该用户的英文名称,默认值为 0。 + form: form diff --git a/api/core/tools/provider/builtin/lark_document/tools/list_document_blocks.py b/api/core/tools/provider/builtin/lark_document/tools/list_document_blocks.py new file mode 100644 index 00000000000000..b96a87489e055e --- /dev/null +++ b/api/core/tools/provider/builtin/lark_document/tools/list_document_blocks.py @@ -0,0 +1,20 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class ListDocumentBlockTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + document_id = tool_parameters.get("document_id") + page_token = tool_parameters.get("page_token", "") + user_id_type = tool_parameters.get("user_id_type", "open_id") + page_size = tool_parameters.get("page_size", 500) + + res = client.list_document_blocks(document_id, page_token, user_id_type, page_size) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_document/tools/list_document_blocks.yaml b/api/core/tools/provider/builtin/lark_document/tools/list_document_blocks.yaml new file mode 100644 index 00000000000000..08b673e0ae3ddc --- /dev/null +++ b/api/core/tools/provider/builtin/lark_document/tools/list_document_blocks.yaml @@ -0,0 +1,74 @@ +identity: + name: list_document_blocks + author: Doug Lea + label: + en_US: List Lark Document Blocks + zh_Hans: 获取 Lark 文档所有块 +description: + human: + en_US: List lark document blocks + zh_Hans: 获取 Lark 文档所有块的富文本内容并分页返回 + llm: A tool to get all blocks of Lark documents +parameters: + - name: document_id + type: string + required: true + label: + en_US: document_id + zh_Hans: Lark 文档的唯一标识 + human_description: + en_US: Unique identifier for a Lark document. You can also input the document's URL. + zh_Hans: Lark 文档的唯一标识,支持输入文档的 URL。 + llm_description: Lark 文档的唯一标识,支持输入文档的 URL。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form + + - name: page_size + type: number + required: false + default: 500 + label: + en_US: page_size + zh_Hans: 分页大小 + human_description: + en_US: Paging size, the default and maximum value is 500. + zh_Hans: 分页大小, 默认值和最大值为 500。 + llm_description: 分页大小, 表示一次请求最多返回多少条数据,默认值和最大值为 500。 + form: form + + - name: page_token + type: string + required: false + label: + en_US: page_token + zh_Hans: 分页标记 + human_description: + en_US: Pagination token used to navigate through query results, allowing retrieval of additional items in subsequent requests. + zh_Hans: 分页标记,用于分页查询结果,以便下次遍历时获取更多项。 + llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + form: llm diff --git a/api/core/tools/provider/builtin/lark_document/tools/write_document.py b/api/core/tools/provider/builtin/lark_document/tools/write_document.py new file mode 100644 index 00000000000000..888e0e39fce389 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_document/tools/write_document.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class CreateDocumentTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + document_id = tool_parameters.get("document_id") + content = tool_parameters.get("content") + position = tool_parameters.get("position", "end") + + res = client.write_document(document_id, content, position) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_document/tools/write_document.yaml b/api/core/tools/provider/builtin/lark_document/tools/write_document.yaml new file mode 100644 index 00000000000000..9cdf034ed08230 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_document/tools/write_document.yaml @@ -0,0 +1,57 @@ +identity: + name: write_document + author: Doug Lea + label: + en_US: Write Document + zh_Hans: 在 Lark 文档中新增内容 +description: + human: + en_US: Adding new content to Lark documents + zh_Hans: 在 Lark 文档中新增内容 + llm: A tool for adding new content to Lark documents. +parameters: + - name: document_id + type: string + required: true + label: + en_US: document_id + zh_Hans: Lark 文档的唯一标识 + human_description: + en_US: Unique identifier for a Lark document. You can also input the document's URL. + zh_Hans: Lark 文档的唯一标识,支持输入文档的 URL。 + llm_description: Lark 文档的唯一标识,支持输入文档的 URL。 + form: llm + + - name: content + type: string + required: true + label: + en_US: Plain text or Markdown content + zh_Hans: 纯文本或 Markdown 内容 + human_description: + en_US: Plain text or Markdown content. Note that embedded tables in the document should not have merged cells. + zh_Hans: 纯文本或 Markdown 内容。注意文档的内嵌套表格不允许有单元格合并。 + llm_description: 纯文本或 Markdown 内容,注意文档的内嵌套表格不允许有单元格合并。 + form: llm + + - name: position + type: select + required: false + options: + - value: start + label: + en_US: document start + zh_Hans: 文档开始 + - value: end + label: + en_US: document end + zh_Hans: 文档结束 + default: "end" + label: + en_US: position + zh_Hans: 内容添加位置 + human_description: + en_US: Content insertion position, optional values are start, end. 'start' means adding content at the beginning of the document; 'end' means adding content at the end of the document. The default value is end. + zh_Hans: 内容添加位置,可选值有 start、end。start 表示在文档开头添加内容;end 表示在文档结尾添加内容,默认值为 end。 + llm_description: 内容添加位置,可选值有 start、end。start 表示在文档开头添加内容;end 表示在文档结尾添加内容,默认值为 end。 + form: form diff --git a/api/core/tools/provider/builtin/lark_message_and_group/_assets/icon.png b/api/core/tools/provider/builtin/lark_message_and_group/_assets/icon.png new file mode 100644 index 00000000000000..0dfd58a9d512fd Binary files /dev/null and b/api/core/tools/provider/builtin/lark_message_and_group/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/lark_message_and_group/lark_message_and_group.py b/api/core/tools/provider/builtin/lark_message_and_group/lark_message_and_group.py new file mode 100644 index 00000000000000..de6997b0bf942f --- /dev/null +++ b/api/core/tools/provider/builtin/lark_message_and_group/lark_message_and_group.py @@ -0,0 +1,7 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.utils.lark_api_utils import lark_auth + + +class LarkMessageAndGroupProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + lark_auth(credentials) diff --git a/api/core/tools/provider/builtin/lark_message_and_group/lark_message_and_group.yaml b/api/core/tools/provider/builtin/lark_message_and_group/lark_message_and_group.yaml new file mode 100644 index 00000000000000..ad3fe0f3619098 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_message_and_group/lark_message_and_group.yaml @@ -0,0 +1,36 @@ +identity: + author: Doug Lea + name: lark_message_and_group + label: + en_US: Lark Message And Group + zh_Hans: Lark 消息和群组 + description: + en_US: | + Lark message and group, requires the following permissions: im:message、im:message.group_msg. + zh_Hans: | + Lark 消息和群组,需要开通以下权限: im:message、im:message.group_msg。 + icon: icon.png + tags: + - social + - productivity +credentials_for_provider: + app_id: + type: text-input + required: true + label: + en_US: APP ID + placeholder: + en_US: Please input your Lark app id + zh_Hans: 请输入你的 Lark app id + help: + en_US: Get your app_id and app_secret from Lark + zh_Hans: 从 Lark 获取您的 app_id 和 app_secret + url: https://open.larksuite.com/app + app_secret: + type: secret-input + required: true + label: + en_US: APP Secret + placeholder: + en_US: Please input your app secret + zh_Hans: 请输入你的 Lark app secret diff --git a/api/core/tools/provider/builtin/lark_message_and_group/tools/get_chat_messages.py b/api/core/tools/provider/builtin/lark_message_and_group/tools/get_chat_messages.py new file mode 100644 index 00000000000000..118bac7ab7d720 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_message_and_group/tools/get_chat_messages.py @@ -0,0 +1,23 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class GetChatMessagesTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + container_id = tool_parameters.get("container_id") + start_time = tool_parameters.get("start_time") + end_time = tool_parameters.get("end_time") + page_token = tool_parameters.get("page_token") + sort_type = tool_parameters.get("sort_type", "ByCreateTimeAsc") + page_size = tool_parameters.get("page_size", 20) + + res = client.get_chat_messages(container_id, start_time, end_time, page_token, sort_type, page_size) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_message_and_group/tools/get_chat_messages.yaml b/api/core/tools/provider/builtin/lark_message_and_group/tools/get_chat_messages.yaml new file mode 100644 index 00000000000000..965b45a5fbaec9 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_message_and_group/tools/get_chat_messages.yaml @@ -0,0 +1,96 @@ +identity: + name: get_chat_messages + author: Doug Lea + label: + en_US: Get Chat Messages + zh_Hans: 获取指定单聊、群聊的消息历史 +description: + human: + en_US: Get Chat Messages + zh_Hans: 获取指定单聊、群聊的消息历史 + llm: A tool for getting chat messages from specific one-on-one chats or group chats.(获取指定单聊、群聊的消息历史) +parameters: + - name: container_id + type: string + required: true + label: + en_US: Container Id + zh_Hans: 群聊或单聊的 ID + human_description: + en_US: The ID of the group chat or single chat. Refer to the group ID description for how to obtain it. https://open.larkoffice.com/document/server-docs/group/chat/chat-id-description + zh_Hans: 群聊或单聊的 ID,获取方式参见群 ID 说明。https://open.larkoffice.com/document/server-docs/group/chat/chat-id-description + llm_description: 群聊或单聊的 ID,获取方式参见群 ID 说明。https://open.larkoffice.com/document/server-docs/group/chat/chat-id-description + form: llm + + - name: start_time + type: string + required: false + label: + en_US: Start Time + zh_Hans: 起始时间 + human_description: + en_US: The start time for querying historical messages, formatted as "2006-01-02 15:04:05". + zh_Hans: 待查询历史信息的起始时间,格式为 "2006-01-02 15:04:05"。 + llm_description: 待查询历史信息的起始时间,格式为 "2006-01-02 15:04:05"。 + form: llm + + - name: end_time + type: string + required: false + label: + en_US: End Time + zh_Hans: 结束时间 + human_description: + en_US: The end time for querying historical messages, formatted as "2006-01-02 15:04:05". + zh_Hans: 待查询历史信息的结束时间,格式为 "2006-01-02 15:04:05"。 + llm_description: 待查询历史信息的结束时间,格式为 "2006-01-02 15:04:05"。 + form: llm + + - name: sort_type + type: select + required: false + options: + - value: ByCreateTimeAsc + label: + en_US: ByCreateTimeAsc + zh_Hans: ByCreateTimeAsc + - value: ByCreateTimeDesc + label: + en_US: ByCreateTimeDesc + zh_Hans: ByCreateTimeDesc + default: "ByCreateTimeAsc" + label: + en_US: Sort Type + zh_Hans: 排序方式 + human_description: + en_US: | + The message sorting method. Optional values are ByCreateTimeAsc: sorted in ascending order by message creation time; ByCreateTimeDesc: sorted in descending order by message creation time. The default value is ByCreateTimeAsc. Note: When using page_token for pagination requests, the sorting method (sort_type) is consistent with the first request and cannot be changed midway. + zh_Hans: | + 消息排序方式,可选值有 ByCreateTimeAsc:按消息创建时间升序排列;ByCreateTimeDesc:按消息创建时间降序排列。默认值为:ByCreateTimeAsc。注意:使用 page_token 分页请求时,排序方式(sort_type)均与第一次请求一致,不支持中途改换排序方式。 + llm_description: 消息排序方式,可选值有 ByCreateTimeAsc:按消息创建时间升序排列;ByCreateTimeDesc:按消息创建时间降序排列。默认值为:ByCreateTimeAsc。注意:使用 page_token 分页请求时,排序方式(sort_type)均与第一次请求一致,不支持中途改换排序方式。 + form: form + + - name: page_size + type: number + required: false + default: 20 + label: + en_US: Page Size + zh_Hans: 分页大小 + human_description: + en_US: The page size, i.e., the number of data entries returned in a single request. The default value is 20, and the value range is [1,50]. + zh_Hans: 分页大小,即单次请求所返回的数据条目数。默认值为 20,取值范围为 [1,50]。 + llm_description: 分页大小,即单次请求所返回的数据条目数。默认值为 20,取值范围为 [1,50]。 + form: form + + - name: page_token + type: string + required: false + label: + en_US: Page Token + zh_Hans: 分页标记 + human_description: + en_US: The pagination token. Leave it blank for the first request, indicating to start traversing from the beginning; when the pagination query result has more items, a new page_token will be returned simultaneously, which can be used to obtain the query result in the next traversal. + zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + form: llm diff --git a/api/core/tools/provider/builtin/lark_message_and_group/tools/get_thread_messages.py b/api/core/tools/provider/builtin/lark_message_and_group/tools/get_thread_messages.py new file mode 100644 index 00000000000000..3509d9bbcfe437 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_message_and_group/tools/get_thread_messages.py @@ -0,0 +1,21 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class GetChatMessagesTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + container_id = tool_parameters.get("container_id") + page_token = tool_parameters.get("page_token") + sort_type = tool_parameters.get("sort_type", "ByCreateTimeAsc") + page_size = tool_parameters.get("page_size", 20) + + res = client.get_thread_messages(container_id, page_token, sort_type, page_size) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_message_and_group/tools/get_thread_messages.yaml b/api/core/tools/provider/builtin/lark_message_and_group/tools/get_thread_messages.yaml new file mode 100644 index 00000000000000..5f7a4f0902523e --- /dev/null +++ b/api/core/tools/provider/builtin/lark_message_and_group/tools/get_thread_messages.yaml @@ -0,0 +1,72 @@ +identity: + name: get_thread_messages + author: Doug Lea + label: + en_US: Get Thread Messages + zh_Hans: 获取指定话题的消息历史 +description: + human: + en_US: Get Thread Messages + zh_Hans: 获取指定话题的消息历史 + llm: A tool for getting chat messages from specific threads.(获取指定话题的消息历史) +parameters: + - name: container_id + type: string + required: true + label: + en_US: Thread Id + zh_Hans: 话题 ID + human_description: + en_US: The ID of the thread. Refer to the thread overview on how to obtain the thread_id. https://open.larksuite.com/document/uAjLw4CM/ukTMukTMukTM/reference/im-v1/message/thread-introduction + zh_Hans: 话题 ID,获取方式参见话题概述的如何获取 thread_id 章节。https://open.larksuite.com/document/uAjLw4CM/ukTMukTMukTM/reference/im-v1/message/thread-introduction + llm_description: 话题 ID,获取方式参见话题概述的如何获取 thread_id 章节。https://open.larksuite.com/document/uAjLw4CM/ukTMukTMukTM/reference/im-v1/message/thread-introduction + form: llm + + - name: sort_type + type: select + required: false + options: + - value: ByCreateTimeAsc + label: + en_US: ByCreateTimeAsc + zh_Hans: ByCreateTimeAsc + - value: ByCreateTimeDesc + label: + en_US: ByCreateTimeDesc + zh_Hans: ByCreateTimeDesc + default: "ByCreateTimeAsc" + label: + en_US: Sort Type + zh_Hans: 排序方式 + human_description: + en_US: | + The message sorting method. Optional values are ByCreateTimeAsc: sorted in ascending order by message creation time; ByCreateTimeDesc: sorted in descending order by message creation time. The default value is ByCreateTimeAsc. Note: When using page_token for pagination requests, the sorting method (sort_type) is consistent with the first request and cannot be changed midway. + zh_Hans: | + 消息排序方式,可选值有 ByCreateTimeAsc:按消息创建时间升序排列;ByCreateTimeDesc:按消息创建时间降序排列。默认值为:ByCreateTimeAsc。注意:使用 page_token 分页请求时,排序方式(sort_type)均与第一次请求一致,不支持中途改换排序方式。 + llm_description: 消息排序方式,可选值有 ByCreateTimeAsc:按消息创建时间升序排列;ByCreateTimeDesc:按消息创建时间降序排列。默认值为:ByCreateTimeAsc。注意:使用 page_token 分页请求时,排序方式(sort_type)均与第一次请求一致,不支持中途改换排序方式。 + form: form + + - name: page_size + type: number + required: false + default: 20 + label: + en_US: Page Size + zh_Hans: 分页大小 + human_description: + en_US: The page size, i.e., the number of data entries returned in a single request. The default value is 20, and the value range is [1,50]. + zh_Hans: 分页大小,即单次请求所返回的数据条目数。默认值为 20,取值范围为 [1,50]。 + llm_description: 分页大小,即单次请求所返回的数据条目数。默认值为 20,取值范围为 [1,50]。 + form: form + + - name: page_token + type: string + required: false + label: + en_US: Page Token + zh_Hans: 分页标记 + human_description: + en_US: The pagination token. Leave it blank for the first request, indicating to start traversing from the beginning; when the pagination query result has more items, a new page_token will be returned simultaneously, which can be used to obtain the query result in the next traversal. + zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + form: llm diff --git a/api/core/tools/provider/builtin/lark_message_and_group/tools/send_bot_message.py b/api/core/tools/provider/builtin/lark_message_and_group/tools/send_bot_message.py new file mode 100644 index 00000000000000..b0a8df61e85f2e --- /dev/null +++ b/api/core/tools/provider/builtin/lark_message_and_group/tools/send_bot_message.py @@ -0,0 +1,20 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class SendBotMessageTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + receive_id_type = tool_parameters.get("receive_id_type") + receive_id = tool_parameters.get("receive_id") + msg_type = tool_parameters.get("msg_type") + content = tool_parameters.get("content") + + res = client.send_bot_message(receive_id_type, receive_id, msg_type, content) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_message_and_group/tools/send_bot_message.yaml b/api/core/tools/provider/builtin/lark_message_and_group/tools/send_bot_message.yaml new file mode 100644 index 00000000000000..b949c5e01694ce --- /dev/null +++ b/api/core/tools/provider/builtin/lark_message_and_group/tools/send_bot_message.yaml @@ -0,0 +1,125 @@ +identity: + name: send_bot_message + author: Doug Lea + label: + en_US: Send Bot Message + zh_Hans: 发送 Lark 应用消息 +description: + human: + en_US: Send bot message + zh_Hans: 发送 Lark 应用消息 + llm: A tool for sending Lark application messages. +parameters: + - name: receive_id + type: string + required: true + label: + en_US: receive_id + zh_Hans: 消息接收者的 ID + human_description: + en_US: The ID of the message receiver, the ID type is consistent with the value of the query parameter receive_id_type. + zh_Hans: 消息接收者的 ID,ID 类型与查询参数 receive_id_type 的取值一致。 + llm_description: 消息接收者的 ID,ID 类型与查询参数 receive_id_type 的取值一致。 + form: llm + + - name: receive_id_type + type: select + required: true + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + - value: email + label: + en_US: email + zh_Hans: email + - value: chat_id + label: + en_US: chat_id + zh_Hans: chat_id + label: + en_US: receive_id_type + zh_Hans: 消息接收者的 ID 类型 + human_description: + en_US: The ID type of the message receiver, optional values are open_id, union_id, user_id, email, chat_id, with a default value of open_id. + zh_Hans: 消息接收者的 ID 类型,可选值有 open_id、union_id、user_id、email、chat_id,默认值为 open_id。 + llm_description: 消息接收者的 ID 类型,可选值有 open_id、union_id、user_id、email、chat_id,默认值为 open_id。 + form: form + + - name: msg_type + type: select + required: true + options: + - value: text + label: + en_US: text + zh_Hans: 文本 + - value: interactive + label: + en_US: interactive + zh_Hans: 卡片 + - value: post + label: + en_US: post + zh_Hans: 富文本 + - value: image + label: + en_US: image + zh_Hans: 图片 + - value: file + label: + en_US: file + zh_Hans: 文件 + - value: audio + label: + en_US: audio + zh_Hans: 语音 + - value: media + label: + en_US: media + zh_Hans: 视频 + - value: sticker + label: + en_US: sticker + zh_Hans: 表情包 + - value: share_chat + label: + en_US: share_chat + zh_Hans: 分享群名片 + - value: share_user + label: + en_US: share_user + zh_Hans: 分享个人名片 + - value: system + label: + en_US: system + zh_Hans: 系统消息 + label: + en_US: msg_type + zh_Hans: 消息类型 + human_description: + en_US: Message type. Optional values are text, post, image, file, audio, media, sticker, interactive, share_chat, share_user, system. For detailed introduction of different message types, refer to the message content(https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json). + zh_Hans: 消息类型。可选值有:text、post、image、file、audio、media、sticker、interactive、share_chat、share_user、system。不同消息类型的详细介绍,参见发送消息内容(https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json)。 + llm_description: 消息类型。可选值有:text、post、image、file、audio、media、sticker、interactive、share_chat、share_user、system。不同消息类型的详细介绍,参见发送消息内容(https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json)。 + form: form + + - name: content + type: string + required: true + label: + en_US: content + zh_Hans: 消息内容 + human_description: + en_US: Message content, a JSON structure serialized string. The value of this parameter corresponds to msg_type. For example, if msg_type is text, this parameter needs to pass in text type content. To understand the format and usage limitations of different message types, refer to the message content(https://open.larksuite.com/document/server-docs/im-v1/message-content-description/create_json). + zh_Hans: 消息内容,JSON 结构序列化后的字符串。该参数的取值与 msg_type 对应,例如 msg_type 取值为 text,则该参数需要传入文本类型的内容。了解不同类型的消息内容格式、使用限制,可参见发送消息内容(https://open.larksuite.com/document/server-docs/im-v1/message-content-description/create_json)。 + llm_description: 消息内容,JSON 结构序列化后的字符串。该参数的取值与 msg_type 对应,例如 msg_type 取值为 text,则该参数需要传入文本类型的内容。了解不同类型的消息内容格式、使用限制,可参见发送消息内容(https://open.larksuite.com/document/server-docs/im-v1/message-content-description/create_json)。 + form: llm diff --git a/api/core/tools/provider/builtin/lark_message_and_group/tools/send_webhook_message.py b/api/core/tools/provider/builtin/lark_message_and_group/tools/send_webhook_message.py new file mode 100644 index 00000000000000..18a605079fc950 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_message_and_group/tools/send_webhook_message.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class SendWebhookMessageTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + webhook = tool_parameters.get("webhook") + msg_type = tool_parameters.get("msg_type") + content = tool_parameters.get("content") + + res = client.send_webhook_message(webhook, msg_type, content) + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_message_and_group/tools/send_webhook_message.yaml b/api/core/tools/provider/builtin/lark_message_and_group/tools/send_webhook_message.yaml new file mode 100644 index 00000000000000..ea13cae52ba997 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_message_and_group/tools/send_webhook_message.yaml @@ -0,0 +1,68 @@ +identity: + name: send_webhook_message + author: Doug Lea + label: + en_US: Send Webhook Message + zh_Hans: 使用自定义机器人发送 Lark 消息 +description: + human: + en_US: Send webhook message + zh_Hans: 使用自定义机器人发送 Lark 消息 + llm: A tool for sending Lark messages using a custom robot. +parameters: + - name: webhook + type: string + required: true + label: + en_US: webhook + zh_Hans: webhook + human_description: + en_US: | + The address of the webhook, the format of the webhook address corresponding to the bot is as follows: https://open.larksuite.com/open-apis/bot/v2/hook/xxxxxxxxxxxxxxxxx. For details, please refer to: Lark Custom Bot Usage Guide(https://open.larkoffice.com/document/client-docs/bot-v3/add-custom-bot) + zh_Hans: | + webhook 的地址,机器人对应的 webhook 地址格式如下: https://open.larksuite.com/open-apis/bot/v2/hook/xxxxxxxxxxxxxxxxx,详情可参考: Lark 自定义机器人使用指南(https://open.larksuite.com/document/client-docs/bot-v3/add-custom-bot) + llm_description: | + webhook 的地址,机器人对应的 webhook 地址格式如下: https://open.larksuite.com/open-apis/bot/v2/hook/xxxxxxxxxxxxxxxxx,详情可参考: Lark 自定义机器人使用指南(https://open.larksuite.com/document/client-docs/bot-v3/add-custom-bot) + form: llm + + - name: msg_type + type: select + required: true + options: + - value: text + label: + en_US: text + zh_Hans: 文本 + - value: interactive + label: + en_US: interactive + zh_Hans: 卡片 + - value: image + label: + en_US: image + zh_Hans: 图片 + - value: share_chat + label: + en_US: share_chat + zh_Hans: 分享群名片 + label: + en_US: msg_type + zh_Hans: 消息类型 + human_description: + en_US: Message type. Optional values are text, image, interactive, share_chat. For detailed introduction of different message types, refer to the message content(https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json). + zh_Hans: 消息类型。可选值有:text、image、interactive、share_chat。不同消息类型的详细介绍,参见发送消息内容(https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json)。 + llm_description: 消息类型。可选值有:text、image、interactive、share_chat。不同消息类型的详细介绍,参见发送消息内容(https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json)。 + form: form + + + - name: content + type: string + required: true + label: + en_US: content + zh_Hans: 消息内容 + human_description: + en_US: Message content, a JSON structure serialized string. The value of this parameter corresponds to msg_type. For example, if msg_type is text, this parameter needs to pass in text type content. To understand the format and usage limitations of different message types, refer to the message content(https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json). + zh_Hans: 消息内容,JSON 结构序列化后的字符串。该参数的取值与 msg_type 对应,例如 msg_type 取值为 text,则该参数需要传入文本类型的内容。了解不同类型的消息内容格式、使用限制,可参见发送消息内容(https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json)。 + llm_description: 消息内容,JSON 结构序列化后的字符串。该参数的取值与 msg_type 对应,例如 msg_type 取值为 text,则该参数需要传入文本类型的内容。了解不同类型的消息内容格式、使用限制,可参见发送消息内容(https://open.larkoffice.com/document/server-docs/im-v1/message-content-description/create_json)。 + form: llm diff --git a/api/core/tools/provider/builtin/lark_spreadsheet/_assets/icon.png b/api/core/tools/provider/builtin/lark_spreadsheet/_assets/icon.png new file mode 100644 index 00000000000000..258b361261d4e3 Binary files /dev/null and b/api/core/tools/provider/builtin/lark_spreadsheet/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/lark_spreadsheet/lark_spreadsheet.py b/api/core/tools/provider/builtin/lark_spreadsheet/lark_spreadsheet.py new file mode 100644 index 00000000000000..c791363f21fbe1 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_spreadsheet/lark_spreadsheet.py @@ -0,0 +1,7 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.utils.lark_api_utils import lark_auth + + +class LarkMessageProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + lark_auth(credentials) diff --git a/api/core/tools/provider/builtin/lark_spreadsheet/lark_spreadsheet.yaml b/api/core/tools/provider/builtin/lark_spreadsheet/lark_spreadsheet.yaml new file mode 100644 index 00000000000000..030b5c9063227a --- /dev/null +++ b/api/core/tools/provider/builtin/lark_spreadsheet/lark_spreadsheet.yaml @@ -0,0 +1,36 @@ +identity: + author: Doug Lea + name: lark_spreadsheet + label: + en_US: Lark Spreadsheet + zh_Hans: Lark 电子表格 + description: + en_US: | + Lark Spreadsheet, requires the following permissions: sheets:spreadsheet. + zh_Hans: | + Lark 电子表格,需要开通以下权限: sheets:spreadsheet。 + icon: icon.png + tags: + - social + - productivity +credentials_for_provider: + app_id: + type: text-input + required: true + label: + en_US: APP ID + placeholder: + en_US: Please input your Lark app id + zh_Hans: 请输入你的 Lark app id + help: + en_US: Get your app_id and app_secret from Lark + zh_Hans: 从 Lark 获取您的 app_id 和 app_secret + url: https://open.larksuite.com/app + app_secret: + type: secret-input + required: true + label: + en_US: APP Secret + placeholder: + en_US: Please input your app secret + zh_Hans: 请输入你的 Lark app secret diff --git a/api/core/tools/provider/builtin/lark_spreadsheet/tools/add_cols.py b/api/core/tools/provider/builtin/lark_spreadsheet/tools/add_cols.py new file mode 100644 index 00000000000000..deeb5a1ecf6f7d --- /dev/null +++ b/api/core/tools/provider/builtin/lark_spreadsheet/tools/add_cols.py @@ -0,0 +1,22 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class AddColsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + spreadsheet_token = tool_parameters.get("spreadsheet_token") + sheet_id = tool_parameters.get("sheet_id") + sheet_name = tool_parameters.get("sheet_name") + length = tool_parameters.get("length") + values = tool_parameters.get("values") + + res = client.add_cols(spreadsheet_token, sheet_id, sheet_name, length, values) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_spreadsheet/tools/add_cols.yaml b/api/core/tools/provider/builtin/lark_spreadsheet/tools/add_cols.yaml new file mode 100644 index 00000000000000..b73335f405c20c --- /dev/null +++ b/api/core/tools/provider/builtin/lark_spreadsheet/tools/add_cols.yaml @@ -0,0 +1,72 @@ +identity: + name: add_cols + author: Doug Lea + label: + en_US: Add Cols + zh_Hans: 新增多列至工作表最后 +description: + human: + en_US: Add Cols + zh_Hans: 新增多列至工作表最后 + llm: A tool for adding multiple columns to the end of a spreadsheet. (新增多列至工作表最后) +parameters: + - name: spreadsheet_token + type: string + required: true + label: + en_US: spreadsheet_token + zh_Hans: 电子表格 token + human_description: + en_US: Spreadsheet token, supports input of spreadsheet URL. + zh_Hans: 电子表格 token,支持输入电子表格 url。 + llm_description: 电子表格 token,支持输入电子表格 url。 + form: llm + + - name: sheet_id + type: string + required: false + label: + en_US: sheet_id + zh_Hans: 工作表 ID + human_description: + en_US: Sheet ID, either sheet_id or sheet_name must be filled. + zh_Hans: 工作表 ID,与 sheet_name 二者其一必填。 + llm_description: 工作表 ID,与 sheet_name 二者其一必填。 + form: llm + + - name: sheet_name + type: string + required: false + label: + en_US: sheet_name + zh_Hans: 工作表名称 + human_description: + en_US: Sheet name, either sheet_id or sheet_name must be filled. + zh_Hans: 工作表名称,与 sheet_id 二者其一必填。 + llm_description: 工作表名称,与 sheet_id 二者其一必填。 + form: llm + + - name: length + type: number + required: true + label: + en_US: length + zh_Hans: 要增加的列数 + human_description: + en_US: Number of columns to add, range (0-5000]. + zh_Hans: 要增加的列数,范围(0-5000]。 + llm_description: 要增加的列数,范围(0-5000]。 + form: form + + - name: values + type: string + required: false + label: + en_US: values + zh_Hans: 新增列的单元格内容 + human_description: + en_US: | + Content of the new columns, array of objects in string format, each array represents a row of table data, format like: [ [ "ID","Name","Age" ],[ 1,"Zhang San",10 ],[ 2,"Li Si",11 ] ]. + zh_Hans: 新增列的单元格内容,数组对象字符串,每个数组一行表格数据,格式:[["编号","姓名","年龄"],[1,"张三",10],[2,"李四",11]]。 + llm_description: 新增列的单元格内容,数组对象字符串,每个数组一行表格数据,格式:[["编号","姓名","年龄"],[1,"张三",10],[2,"李四",11]]。 + form: llm diff --git a/api/core/tools/provider/builtin/lark_spreadsheet/tools/add_rows.py b/api/core/tools/provider/builtin/lark_spreadsheet/tools/add_rows.py new file mode 100644 index 00000000000000..f434b1c60341f3 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_spreadsheet/tools/add_rows.py @@ -0,0 +1,22 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class AddRowsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + spreadsheet_token = tool_parameters.get("spreadsheet_token") + sheet_id = tool_parameters.get("sheet_id") + sheet_name = tool_parameters.get("sheet_name") + length = tool_parameters.get("length") + values = tool_parameters.get("values") + + res = client.add_rows(spreadsheet_token, sheet_id, sheet_name, length, values) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_spreadsheet/tools/add_rows.yaml b/api/core/tools/provider/builtin/lark_spreadsheet/tools/add_rows.yaml new file mode 100644 index 00000000000000..6bce305b9825ec --- /dev/null +++ b/api/core/tools/provider/builtin/lark_spreadsheet/tools/add_rows.yaml @@ -0,0 +1,72 @@ +identity: + name: add_rows + author: Doug Lea + label: + en_US: Add Rows + zh_Hans: 新增多行至工作表最后 +description: + human: + en_US: Add Rows + zh_Hans: 新增多行至工作表最后 + llm: A tool for adding multiple rows to the end of a spreadsheet. (新增多行至工作表最后) +parameters: + - name: spreadsheet_token + type: string + required: true + label: + en_US: spreadsheet_token + zh_Hans: 电子表格 token + human_description: + en_US: Spreadsheet token, supports input of spreadsheet URL. + zh_Hans: 电子表格 token,支持输入电子表格 url。 + llm_description: 电子表格 token,支持输入电子表格 url。 + form: llm + + - name: sheet_id + type: string + required: false + label: + en_US: sheet_id + zh_Hans: 工作表 ID + human_description: + en_US: Sheet ID, either sheet_id or sheet_name must be filled. + zh_Hans: 工作表 ID,与 sheet_name 二者其一必填。 + llm_description: 工作表 ID,与 sheet_name 二者其一必填。 + form: llm + + - name: sheet_name + type: string + required: false + label: + en_US: sheet_name + zh_Hans: 工作表名称 + human_description: + en_US: Sheet name, either sheet_id or sheet_name must be filled. + zh_Hans: 工作表名称,与 sheet_id 二者其一必填。 + llm_description: 工作表名称,与 sheet_id 二者其一必填。 + form: llm + + - name: length + type: number + required: true + label: + en_US: length + zh_Hans: 要增加行数 + human_description: + en_US: Number of rows to add, range (0-5000]. + zh_Hans: 要增加行数,范围(0-5000]。 + llm_description: 要增加行数,范围(0-5000]。 + form: form + + - name: values + type: string + required: false + label: + en_US: values + zh_Hans: 新增行的表格内容 + human_description: + en_US: | + Content of the new rows, array of objects in string format, each array represents a row of table data, format like: [ [ "ID","Name","Age" ],[ 1,"Zhang San",10 ],[ 2,"Li Si",11 ] ]. + zh_Hans: 新增行的表格内容,数组对象字符串,每个数组一行表格数据,格式,如:[["编号","姓名","年龄"],[1,"张三",10],[2,"李四",11]]。 + llm_description: 新增行的表格内容,数组对象字符串,每个数组一行表格数据,格式,如:[["编号","姓名","年龄"],[1,"张三",10],[2,"李四",11]]。 + form: llm diff --git a/api/core/tools/provider/builtin/lark_spreadsheet/tools/create_spreadsheet.py b/api/core/tools/provider/builtin/lark_spreadsheet/tools/create_spreadsheet.py new file mode 100644 index 00000000000000..74b20ac2c838f8 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_spreadsheet/tools/create_spreadsheet.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class CreateSpreadsheetTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + title = tool_parameters.get("title") + folder_token = tool_parameters.get("folder_token") + + res = client.create_spreadsheet(title, folder_token) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_spreadsheet/tools/create_spreadsheet.yaml b/api/core/tools/provider/builtin/lark_spreadsheet/tools/create_spreadsheet.yaml new file mode 100644 index 00000000000000..931310e63172d4 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_spreadsheet/tools/create_spreadsheet.yaml @@ -0,0 +1,35 @@ +identity: + name: create_spreadsheet + author: Doug Lea + label: + en_US: Create Spreadsheet + zh_Hans: 创建电子表格 +description: + human: + en_US: Create Spreadsheet + zh_Hans: 创建电子表格 + llm: A tool for creating spreadsheets. (创建电子表格) +parameters: + - name: title + type: string + required: false + label: + en_US: Spreadsheet Title + zh_Hans: 电子表格标题 + human_description: + en_US: The title of the spreadsheet + zh_Hans: 电子表格的标题 + llm_description: 电子表格的标题 + form: llm + + - name: folder_token + type: string + required: false + label: + en_US: Folder Token + zh_Hans: 文件夹 token + human_description: + en_US: The token of the folder, supports folder URL input, e.g., https://bytedance.larkoffice.com/drive/folder/CxHEf4DCSlNkL2dUTCJcPRgentg + zh_Hans: 文件夹 token,支持文件夹 URL 输入,如:https://bytedance.larkoffice.com/drive/folder/CxHEf4DCSlNkL2dUTCJcPRgentg + llm_description: 文件夹 token,支持文件夹 URL 输入,如:https://bytedance.larkoffice.com/drive/folder/CxHEf4DCSlNkL2dUTCJcPRgentg + form: llm diff --git a/api/core/tools/provider/builtin/lark_spreadsheet/tools/get_spreadsheet.py b/api/core/tools/provider/builtin/lark_spreadsheet/tools/get_spreadsheet.py new file mode 100644 index 00000000000000..0fe35b6dc645b8 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_spreadsheet/tools/get_spreadsheet.py @@ -0,0 +1,19 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class GetSpreadsheetTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + spreadsheet_token = tool_parameters.get("spreadsheet_token") + user_id_type = tool_parameters.get("user_id_type", "open_id") + + res = client.get_spreadsheet(spreadsheet_token, user_id_type) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_spreadsheet/tools/get_spreadsheet.yaml b/api/core/tools/provider/builtin/lark_spreadsheet/tools/get_spreadsheet.yaml new file mode 100644 index 00000000000000..c519938617ba8c --- /dev/null +++ b/api/core/tools/provider/builtin/lark_spreadsheet/tools/get_spreadsheet.yaml @@ -0,0 +1,49 @@ +identity: + name: get_spreadsheet + author: Doug Lea + label: + en_US: Get Spreadsheet + zh_Hans: 获取电子表格信息 +description: + human: + en_US: Get Spreadsheet + zh_Hans: 获取电子表格信息 + llm: A tool for getting information from spreadsheets. (获取电子表格信息) +parameters: + - name: spreadsheet_token + type: string + required: true + label: + en_US: Spreadsheet Token + zh_Hans: 电子表格 token + human_description: + en_US: Spreadsheet token, supports input of spreadsheet URL. + zh_Hans: 电子表格 token,支持输入电子表格 URL。 + llm_description: 电子表格 token,支持输入电子表格 URL。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form diff --git a/api/core/tools/provider/builtin/lark_spreadsheet/tools/list_spreadsheet_sheets.py b/api/core/tools/provider/builtin/lark_spreadsheet/tools/list_spreadsheet_sheets.py new file mode 100644 index 00000000000000..e711c23780e5e3 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_spreadsheet/tools/list_spreadsheet_sheets.py @@ -0,0 +1,18 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class ListSpreadsheetSheetsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + spreadsheet_token = tool_parameters.get("spreadsheet_token") + + res = client.list_spreadsheet_sheets(spreadsheet_token) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_spreadsheet/tools/list_spreadsheet_sheets.yaml b/api/core/tools/provider/builtin/lark_spreadsheet/tools/list_spreadsheet_sheets.yaml new file mode 100644 index 00000000000000..c6a7ef45d46589 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_spreadsheet/tools/list_spreadsheet_sheets.yaml @@ -0,0 +1,23 @@ +identity: + name: list_spreadsheet_sheets + author: Doug Lea + label: + en_US: List Spreadsheet Sheets + zh_Hans: 列出电子表格所有工作表 +description: + human: + en_US: List Spreadsheet Sheets + zh_Hans: 列出电子表格所有工作表 + llm: A tool for listing all sheets in a spreadsheet. (列出电子表格所有工作表) +parameters: + - name: spreadsheet_token + type: string + required: true + label: + en_US: Spreadsheet Token + zh_Hans: 电子表格 token + human_description: + en_US: Spreadsheet token, supports input of spreadsheet URL. + zh_Hans: 电子表格 token,支持输入电子表格 URL。 + llm_description: 电子表格 token,支持输入电子表格 URL。 + form: llm diff --git a/api/core/tools/provider/builtin/lark_spreadsheet/tools/read_cols.py b/api/core/tools/provider/builtin/lark_spreadsheet/tools/read_cols.py new file mode 100644 index 00000000000000..1df289c1d71b01 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_spreadsheet/tools/read_cols.py @@ -0,0 +1,23 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class ReadColsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + spreadsheet_token = tool_parameters.get("spreadsheet_token") + sheet_id = tool_parameters.get("sheet_id") + sheet_name = tool_parameters.get("sheet_name") + start_col = tool_parameters.get("start_col") + num_cols = tool_parameters.get("num_cols") + user_id_type = tool_parameters.get("user_id_type", "open_id") + + res = client.read_cols(spreadsheet_token, sheet_id, sheet_name, start_col, num_cols, user_id_type) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_spreadsheet/tools/read_cols.yaml b/api/core/tools/provider/builtin/lark_spreadsheet/tools/read_cols.yaml new file mode 100644 index 00000000000000..34da74592d5898 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_spreadsheet/tools/read_cols.yaml @@ -0,0 +1,97 @@ +identity: + name: read_cols + author: Doug Lea + label: + en_US: Read Cols + zh_Hans: 读取工作表列数据 +description: + human: + en_US: Read Cols + zh_Hans: 读取工作表列数据 + llm: A tool for reading column data from a spreadsheet. (读取工作表列数据) +parameters: + - name: spreadsheet_token + type: string + required: true + label: + en_US: spreadsheet_token + zh_Hans: 电子表格 token + human_description: + en_US: Spreadsheet token, supports input of spreadsheet URL. + zh_Hans: 电子表格 token,支持输入电子表格 url。 + llm_description: 电子表格 token,支持输入电子表格 url。 + form: llm + + - name: sheet_id + type: string + required: false + label: + en_US: sheet_id + zh_Hans: 工作表 ID + human_description: + en_US: Sheet ID, either sheet_id or sheet_name must be filled. + zh_Hans: 工作表 ID,与 sheet_name 二者其一必填。 + llm_description: 工作表 ID,与 sheet_name 二者其一必填。 + form: llm + + - name: sheet_name + type: string + required: false + label: + en_US: sheet_name + zh_Hans: 工作表名称 + human_description: + en_US: Sheet name, either sheet_id or sheet_name must be filled. + zh_Hans: 工作表名称,与 sheet_id 二者其一必填。 + llm_description: 工作表名称,与 sheet_id 二者其一必填。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form + + - name: start_col + type: number + required: false + label: + en_US: start_col + zh_Hans: 起始列号 + human_description: + en_US: Starting column number, starting from 1. + zh_Hans: 起始列号,从 1 开始。 + llm_description: 起始列号,从 1 开始。 + form: form + + - name: num_cols + type: number + required: true + label: + en_US: num_cols + zh_Hans: 读取列数 + human_description: + en_US: Number of columns to read. + zh_Hans: 读取列数 + llm_description: 读取列数 + form: form diff --git a/api/core/tools/provider/builtin/lark_spreadsheet/tools/read_rows.py b/api/core/tools/provider/builtin/lark_spreadsheet/tools/read_rows.py new file mode 100644 index 00000000000000..1cab38a4545269 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_spreadsheet/tools/read_rows.py @@ -0,0 +1,23 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class ReadRowsTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + spreadsheet_token = tool_parameters.get("spreadsheet_token") + sheet_id = tool_parameters.get("sheet_id") + sheet_name = tool_parameters.get("sheet_name") + start_row = tool_parameters.get("start_row") + num_rows = tool_parameters.get("num_rows") + user_id_type = tool_parameters.get("user_id_type", "open_id") + + res = client.read_rows(spreadsheet_token, sheet_id, sheet_name, start_row, num_rows, user_id_type) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_spreadsheet/tools/read_rows.yaml b/api/core/tools/provider/builtin/lark_spreadsheet/tools/read_rows.yaml new file mode 100644 index 00000000000000..5dfa8d58354125 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_spreadsheet/tools/read_rows.yaml @@ -0,0 +1,97 @@ +identity: + name: read_rows + author: Doug Lea + label: + en_US: Read Rows + zh_Hans: 读取工作表行数据 +description: + human: + en_US: Read Rows + zh_Hans: 读取工作表行数据 + llm: A tool for reading row data from a spreadsheet. (读取工作表行数据) +parameters: + - name: spreadsheet_token + type: string + required: true + label: + en_US: spreadsheet_token + zh_Hans: 电子表格 token + human_description: + en_US: Spreadsheet token, supports input of spreadsheet URL. + zh_Hans: 电子表格 token,支持输入电子表格 url。 + llm_description: 电子表格 token,支持输入电子表格 url。 + form: llm + + - name: sheet_id + type: string + required: false + label: + en_US: sheet_id + zh_Hans: 工作表 ID + human_description: + en_US: Sheet ID, either sheet_id or sheet_name must be filled. + zh_Hans: 工作表 ID,与 sheet_name 二者其一必填。 + llm_description: 工作表 ID,与 sheet_name 二者其一必填。 + form: llm + + - name: sheet_name + type: string + required: false + label: + en_US: sheet_name + zh_Hans: 工作表名称 + human_description: + en_US: Sheet name, either sheet_id or sheet_name must be filled. + zh_Hans: 工作表名称,与 sheet_id 二者其一必填。 + llm_description: 工作表名称,与 sheet_id 二者其一必填。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form + + - name: start_row + type: number + required: false + label: + en_US: start_row + zh_Hans: 起始行号 + human_description: + en_US: Starting row number, starting from 1. + zh_Hans: 起始行号,从 1 开始。 + llm_description: 起始行号,从 1 开始。 + form: form + + - name: num_rows + type: number + required: true + label: + en_US: num_rows + zh_Hans: 读取行数 + human_description: + en_US: Number of rows to read. + zh_Hans: 读取行数 + llm_description: 读取行数 + form: form diff --git a/api/core/tools/provider/builtin/lark_spreadsheet/tools/read_table.py b/api/core/tools/provider/builtin/lark_spreadsheet/tools/read_table.py new file mode 100644 index 00000000000000..0f05249004ee20 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_spreadsheet/tools/read_table.py @@ -0,0 +1,23 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class ReadTableTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + spreadsheet_token = tool_parameters.get("spreadsheet_token") + sheet_id = tool_parameters.get("sheet_id") + sheet_name = tool_parameters.get("sheet_name") + num_range = tool_parameters.get("num_range") + query = tool_parameters.get("query") + user_id_type = tool_parameters.get("user_id_type", "open_id") + + res = client.read_table(spreadsheet_token, sheet_id, sheet_name, num_range, query, user_id_type) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_spreadsheet/tools/read_table.yaml b/api/core/tools/provider/builtin/lark_spreadsheet/tools/read_table.yaml new file mode 100644 index 00000000000000..10534436d66e7a --- /dev/null +++ b/api/core/tools/provider/builtin/lark_spreadsheet/tools/read_table.yaml @@ -0,0 +1,122 @@ +identity: + name: read_table + author: Doug Lea + label: + en_US: Read Table + zh_Hans: 自定义读取电子表格行列数据 +description: + human: + en_US: Read Table + zh_Hans: 自定义读取电子表格行列数据 + llm: A tool for custom reading of row and column data from a spreadsheet. (自定义读取电子表格行列数据) +parameters: + - name: spreadsheet_token + type: string + required: true + label: + en_US: spreadsheet_token + zh_Hans: 电子表格 token + human_description: + en_US: Spreadsheet token, supports input of spreadsheet URL. + zh_Hans: 电子表格 token,支持输入电子表格 url。 + llm_description: 电子表格 token,支持输入电子表格 url。 + form: llm + + - name: sheet_id + type: string + required: false + label: + en_US: sheet_id + zh_Hans: 工作表 ID + human_description: + en_US: Sheet ID, either sheet_id or sheet_name must be filled. + zh_Hans: 工作表 ID,与 sheet_name 二者其一必填。 + llm_description: 工作表 ID,与 sheet_name 二者其一必填。 + form: llm + + - name: sheet_name + type: string + required: false + label: + en_US: sheet_name + zh_Hans: 工作表名称 + human_description: + en_US: Sheet name, either sheet_id or sheet_name must be filled. + zh_Hans: 工作表名称,与 sheet_id 二者其一必填。 + llm_description: 工作表名称,与 sheet_id 二者其一必填。 + form: llm + + - name: user_id_type + type: select + required: false + options: + - value: open_id + label: + en_US: open_id + zh_Hans: open_id + - value: union_id + label: + en_US: union_id + zh_Hans: union_id + - value: user_id + label: + en_US: user_id + zh_Hans: user_id + default: "open_id" + label: + en_US: user_id_type + zh_Hans: 用户 ID 类型 + human_description: + en_US: User ID type, optional values are open_id, union_id, user_id, with a default value of open_id. + zh_Hans: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + llm_description: 用户 ID 类型,可选值有 open_id、union_id、user_id,默认值为 open_id。 + form: form + + - name: start_row + type: number + required: false + label: + en_US: start_row + zh_Hans: 起始行号 + human_description: + en_US: Starting row number, starting from 1. + zh_Hans: 起始行号,从 1 开始。 + llm_description: 起始行号,从 1 开始。 + form: form + + - name: num_rows + type: number + required: false + label: + en_US: num_rows + zh_Hans: 读取行数 + human_description: + en_US: Number of rows to read. + zh_Hans: 读取行数 + llm_description: 读取行数 + form: form + + - name: range + type: string + required: false + label: + en_US: range + zh_Hans: 取数范围 + human_description: + en_US: | + Data range, format like: A1:B2, can be empty when query=all. + zh_Hans: 取数范围,格式如:A1:B2,query=all 时可为空。 + llm_description: 取数范围,格式如:A1:B2,query=all 时可为空。 + form: llm + + - name: query + type: string + required: false + label: + en_US: query + zh_Hans: 查询 + human_description: + en_US: Pass "all" to query all data in the table, but no more than 100 columns. + zh_Hans: 传 all,表示查询表格所有数据,但最多查询 100 列数据。 + llm_description: 传 all,表示查询表格所有数据,但最多查询 100 列数据。 + form: llm diff --git a/api/core/tools/provider/builtin/lark_task/_assets/icon.png b/api/core/tools/provider/builtin/lark_task/_assets/icon.png new file mode 100644 index 00000000000000..26ea6a2eefa5be Binary files /dev/null and b/api/core/tools/provider/builtin/lark_task/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/lark_task/lark_task.py b/api/core/tools/provider/builtin/lark_task/lark_task.py new file mode 100644 index 00000000000000..02cf009f017e61 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_task/lark_task.py @@ -0,0 +1,7 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.utils.lark_api_utils import lark_auth + + +class LarkTaskProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + lark_auth(credentials) diff --git a/api/core/tools/provider/builtin/lark_task/lark_task.yaml b/api/core/tools/provider/builtin/lark_task/lark_task.yaml new file mode 100644 index 00000000000000..ada068b0aab3ce --- /dev/null +++ b/api/core/tools/provider/builtin/lark_task/lark_task.yaml @@ -0,0 +1,36 @@ +identity: + author: Doug Lea + name: lark_task + label: + en_US: Lark Task + zh_Hans: Lark 任务 + description: + en_US: | + Lark Task, requires the following permissions: task:task:write、contact:user.id:readonly. + zh_Hans: | + Lark 任务,需要开通以下权限: task:task:write、contact:user.id:readonly。 + icon: icon.png + tags: + - social + - productivity +credentials_for_provider: + app_id: + type: text-input + required: true + label: + en_US: APP ID + placeholder: + en_US: Please input your Lark app id + zh_Hans: 请输入你的 Lark app id + help: + en_US: Get your app_id and app_secret from Lark + zh_Hans: 从 Lark 获取您的 app_id 和 app_secret + url: https://open.larksuite.com/app + app_secret: + type: secret-input + required: true + label: + en_US: APP Secret + placeholder: + en_US: Please input your app secret + zh_Hans: 请输入你的 Lark app secret diff --git a/api/core/tools/provider/builtin/lark_task/tools/add_members.py b/api/core/tools/provider/builtin/lark_task/tools/add_members.py new file mode 100644 index 00000000000000..9b8e4d68f394a8 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_task/tools/add_members.py @@ -0,0 +1,20 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class AddMembersTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + task_guid = tool_parameters.get("task_guid") + member_phone_or_email = tool_parameters.get("member_phone_or_email") + member_role = tool_parameters.get("member_role", "follower") + + res = client.add_members(task_guid, member_phone_or_email, member_role) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_task/tools/add_members.yaml b/api/core/tools/provider/builtin/lark_task/tools/add_members.yaml new file mode 100644 index 00000000000000..0b12172e0b85e7 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_task/tools/add_members.yaml @@ -0,0 +1,58 @@ +identity: + name: add_members + author: Doug Lea + label: + en_US: Add Lark Members + zh_Hans: 添加 Lark 任务成员 +description: + human: + en_US: Add Lark Members + zh_Hans: 添加 Lark 任务成员 + llm: A tool for adding members to a Lark task.(添加 Lark 任务成员) +parameters: + - name: task_guid + type: string + required: true + label: + en_US: Task GUID + zh_Hans: 任务 GUID + human_description: + en_US: | + The GUID of the task to be added, supports passing either the Task ID or the Task link URL. Example of Task ID: 8b5425ec-9f2a-43bd-a3ab-01912f50282b; Example of Task link URL: https://applink.larksuite.com/client/todo/detail?guid=1b066afa-96de-406c-90a3-dfd30159a571&suite_entity_num=t100805 + zh_Hans: 要添加的任务的 GUID,支持传任务 ID 和任务链接 URL。任务 ID 示例:8b5425ec-9f2a-43bd-a3ab-01912f50282b;任务链接 URL 示例:https://applink.larksuite.com/client/todo/detail?guid=1b066afa-96de-406c-90a3-dfd30159a571&suite_entity_num=t100805 + llm_description: 要添加的任务的 GUID,支持传任务 ID 和任务链接 URL。任务 ID 示例:8b5425ec-9f2a-43bd-a3ab-01912f50282b;任务链接 URL 示例:https://applink.larksuite.com/client/todo/detail?guid=1b066afa-96de-406c-90a3-dfd30159a571&suite_entity_num=t100805 + form: llm + + - name: member_phone_or_email + type: string + required: true + label: + en_US: Task Member Phone Or Email + zh_Hans: 任务成员的电话或邮箱 + human_description: + en_US: A list of member emails or phone numbers, separated by commas. + zh_Hans: 任务成员邮箱或者手机号列表,使用逗号分隔。 + llm_description: 任务成员邮箱或者手机号列表,使用逗号分隔。 + form: llm + + - name: member_role + type: select + required: true + options: + - value: assignee + label: + en_US: assignee + zh_Hans: 负责人 + - value: follower + label: + en_US: follower + zh_Hans: 关注人 + default: "follower" + label: + en_US: member_role + zh_Hans: 成员的角色 + human_description: + en_US: Member role, optional values are "assignee" (responsible person) and "follower" (observer), with a default value of "assignee". + zh_Hans: 成员的角色,可选值有 "assignee"(负责人)和 "follower"(关注人),默认值为 "assignee"。 + llm_description: 成员的角色,可选值有 "assignee"(负责人)和 "follower"(关注人),默认值为 "assignee"。 + form: form diff --git a/api/core/tools/provider/builtin/lark_task/tools/create_task.py b/api/core/tools/provider/builtin/lark_task/tools/create_task.py new file mode 100644 index 00000000000000..ff37593fbe3a12 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_task/tools/create_task.py @@ -0,0 +1,22 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class CreateTaskTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + summary = tool_parameters.get("summary") + start_time = tool_parameters.get("start_time") + end_time = tool_parameters.get("end_time") + completed_time = tool_parameters.get("completed_time") + description = tool_parameters.get("description") + + res = client.create_task(summary, start_time, end_time, completed_time, description) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_task/tools/create_task.yaml b/api/core/tools/provider/builtin/lark_task/tools/create_task.yaml new file mode 100644 index 00000000000000..4303763a1dd406 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_task/tools/create_task.yaml @@ -0,0 +1,74 @@ +identity: + name: create_task + author: Doug Lea + label: + en_US: Create Lark Task + zh_Hans: 创建 Lark 任务 +description: + human: + en_US: Create Lark Task + zh_Hans: 创建 Lark 任务 + llm: A tool for creating tasks in Lark.(创建 Lark 任务) +parameters: + - name: summary + type: string + required: true + label: + en_US: Task Title + zh_Hans: 任务标题 + human_description: + en_US: The title of the task. + zh_Hans: 任务标题 + llm_description: 任务标题 + form: llm + + - name: description + type: string + required: false + label: + en_US: Task Description + zh_Hans: 任务备注 + human_description: + en_US: The description or notes for the task. + zh_Hans: 任务备注 + llm_description: 任务备注 + form: llm + + - name: start_time + type: string + required: false + label: + en_US: Start Time + zh_Hans: 任务开始时间 + human_description: + en_US: | + The start time of the task, in the format: 2006-01-02 15:04:05 + zh_Hans: 任务开始时间,格式为:2006-01-02 15:04:05 + llm_description: 任务开始时间,格式为:2006-01-02 15:04:05 + form: llm + + - name: end_time + type: string + required: false + label: + en_US: End Time + zh_Hans: 任务结束时间 + human_description: + en_US: | + The end time of the task, in the format: 2006-01-02 15:04:05 + zh_Hans: 任务结束时间,格式为:2006-01-02 15:04:05 + llm_description: 任务结束时间,格式为:2006-01-02 15:04:05 + form: llm + + - name: completed_time + type: string + required: false + label: + en_US: Completed Time + zh_Hans: 任务完成时间 + human_description: + en_US: | + The completion time of the task, in the format: 2006-01-02 15:04:05. Leave empty to create an incomplete task; fill in a specific time to create a completed task. + zh_Hans: 任务完成时间,格式为:2006-01-02 15:04:05,不填写表示创建一个未完成任务;填写一个具体的时间表示创建一个已完成任务。 + llm_description: 任务完成时间,格式为:2006-01-02 15:04:05,不填写表示创建一个未完成任务;填写一个具体的时间表示创建一个已完成任务。 + form: llm diff --git a/api/core/tools/provider/builtin/lark_task/tools/delete_task.py b/api/core/tools/provider/builtin/lark_task/tools/delete_task.py new file mode 100644 index 00000000000000..eca381be2c185e --- /dev/null +++ b/api/core/tools/provider/builtin/lark_task/tools/delete_task.py @@ -0,0 +1,18 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class UpdateTaskTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + task_guid = tool_parameters.get("task_guid") + + res = client.delete_task(task_guid) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_task/tools/delete_task.yaml b/api/core/tools/provider/builtin/lark_task/tools/delete_task.yaml new file mode 100644 index 00000000000000..bc0154d9dc5c77 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_task/tools/delete_task.yaml @@ -0,0 +1,24 @@ +identity: + name: delete_task + author: Doug Lea + label: + en_US: Delete Lark Task + zh_Hans: 删除 Lark 任务 +description: + human: + en_US: Delete Lark Task + zh_Hans: 删除 Lark 任务 + llm: A tool for deleting tasks in Lark.(删除 Lark 任务) +parameters: + - name: task_guid + type: string + required: true + label: + en_US: Task GUID + zh_Hans: 任务 GUID + human_description: + en_US: | + The GUID of the task to be deleted, supports passing either the Task ID or the Task link URL. Example of Task ID: 8b5425ec-9f2a-43bd-a3ab-01912f50282b; Example of Task link URL: https://applink.larksuite.com/client/todo/detail?guid=1b066afa-96de-406c-90a3-dfd30159a571&suite_entity_num=t100805 + zh_Hans: 要删除的任务的 GUID,支持传任务 ID 和任务链接 URL。任务 ID 示例:8b5425ec-9f2a-43bd-a3ab-01912f50282b;任务链接 URL 示例:https://applink.larksuite.com/client/todo/detail?guid=1b066afa-96de-406c-90a3-dfd30159a571&suite_entity_num=t100805 + llm_description: 要删除的任务的 GUID,支持传任务 ID 和任务链接 URL。任务 ID 示例:8b5425ec-9f2a-43bd-a3ab-01912f50282b;任务链接 URL 示例:https://applink.larksuite.com/client/todo/detail?guid=1b066afa-96de-406c-90a3-dfd30159a571&suite_entity_num=t100805 + form: llm diff --git a/api/core/tools/provider/builtin/lark_task/tools/update_task.py b/api/core/tools/provider/builtin/lark_task/tools/update_task.py new file mode 100644 index 00000000000000..0d3469c91a01bc --- /dev/null +++ b/api/core/tools/provider/builtin/lark_task/tools/update_task.py @@ -0,0 +1,23 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class UpdateTaskTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + task_guid = tool_parameters.get("task_guid") + summary = tool_parameters.get("summary") + start_time = tool_parameters.get("start_time") + end_time = tool_parameters.get("end_time") + completed_time = tool_parameters.get("completed_time") + description = tool_parameters.get("description") + + res = client.update_task(task_guid, summary, start_time, end_time, completed_time, description) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_task/tools/update_task.yaml b/api/core/tools/provider/builtin/lark_task/tools/update_task.yaml new file mode 100644 index 00000000000000..a98f037f211c9a --- /dev/null +++ b/api/core/tools/provider/builtin/lark_task/tools/update_task.yaml @@ -0,0 +1,89 @@ +identity: + name: update_task + author: Doug Lea + label: + en_US: Update Lark Task + zh_Hans: 更新 Lark 任务 +description: + human: + en_US: Update Lark Task + zh_Hans: 更新 Lark 任务 + llm: A tool for updating tasks in Lark.(更新 Lark 任务) +parameters: + - name: task_guid + type: string + required: true + label: + en_US: Task GUID + zh_Hans: 任务 GUID + human_description: + en_US: | + The task ID, supports inputting either the Task ID or the Task link URL. Example of Task ID: 42cad8a0-f8c8-4344-9be2-d1d7e8e91b64; Example of Task link URL: https://applink.larksuite.com/client/todo/detail?guid=1b066afa-96de-406c-90a3-dfd30159a571&suite_entity_num=t100805 + zh_Hans: | + 任务ID,支持传入任务 ID 和任务链接 URL。任务 ID 示例: 42cad8a0-f8c8-4344-9be2-d1d7e8e91b64;任务链接 URL 示例: https://applink.larksuite.com/client/todo/detail?guid=1b066afa-96de-406c-90a3-dfd30159a571&suite_entity_num=t100805 + llm_description: | + 任务ID,支持传入任务 ID 和任务链接 URL。任务 ID 示例: 42cad8a0-f8c8-4344-9be2-d1d7e8e91b64;任务链接 URL 示例: https://applink.larksuite.com/client/todo/detail?guid=1b066afa-96de-406c-90a3-dfd30159a571&suite_entity_num=t100805 + form: llm + + - name: summary + type: string + required: true + label: + en_US: Task Title + zh_Hans: 任务标题 + human_description: + en_US: The title of the task. + zh_Hans: 任务标题 + llm_description: 任务标题 + form: llm + + - name: description + type: string + required: false + label: + en_US: Task Description + zh_Hans: 任务备注 + human_description: + en_US: The description or notes for the task. + zh_Hans: 任务备注 + llm_description: 任务备注 + form: llm + + - name: start_time + type: string + required: false + label: + en_US: Start Time + zh_Hans: 任务开始时间 + human_description: + en_US: | + The start time of the task, in the format: 2006-01-02 15:04:05 + zh_Hans: 任务开始时间,格式为:2006-01-02 15:04:05 + llm_description: 任务开始时间,格式为:2006-01-02 15:04:05 + form: llm + + - name: end_time + type: string + required: false + label: + en_US: End Time + zh_Hans: 任务结束时间 + human_description: + en_US: | + The end time of the task, in the format: 2006-01-02 15:04:05 + zh_Hans: 任务结束时间,格式为:2006-01-02 15:04:05 + llm_description: 任务结束时间,格式为:2006-01-02 15:04:05 + form: llm + + - name: completed_time + type: string + required: false + label: + en_US: Completed Time + zh_Hans: 任务完成时间 + human_description: + en_US: | + The completion time of the task, in the format: 2006-01-02 15:04:05 + zh_Hans: 任务完成时间,格式为:2006-01-02 15:04:05 + llm_description: 任务完成时间,格式为:2006-01-02 15:04:05 + form: llm diff --git a/api/core/tools/provider/builtin/lark_wiki/_assets/icon.png b/api/core/tools/provider/builtin/lark_wiki/_assets/icon.png new file mode 100644 index 00000000000000..47f6b8c30ea0cf Binary files /dev/null and b/api/core/tools/provider/builtin/lark_wiki/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/lark_wiki/lark_wiki.py b/api/core/tools/provider/builtin/lark_wiki/lark_wiki.py new file mode 100644 index 00000000000000..e6941206ee7618 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_wiki/lark_wiki.py @@ -0,0 +1,7 @@ +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.utils.lark_api_utils import lark_auth + + +class LarkWikiProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + lark_auth(credentials) diff --git a/api/core/tools/provider/builtin/lark_wiki/lark_wiki.yaml b/api/core/tools/provider/builtin/lark_wiki/lark_wiki.yaml new file mode 100644 index 00000000000000..86bef000868d68 --- /dev/null +++ b/api/core/tools/provider/builtin/lark_wiki/lark_wiki.yaml @@ -0,0 +1,36 @@ +identity: + author: Doug Lea + name: lark_wiki + label: + en_US: Lark Wiki + zh_Hans: Lark 知识库 + description: + en_US: | + Lark Wiki, requires the following permissions: wiki:wiki:readonly. + zh_Hans: | + Lark 知识库,需要开通以下权限: wiki:wiki:readonly。 + icon: icon.png + tags: + - social + - productivity +credentials_for_provider: + app_id: + type: text-input + required: true + label: + en_US: APP ID + placeholder: + en_US: Please input your Lark app id + zh_Hans: 请输入你的 Lark app id + help: + en_US: Get your app_id and app_secret from Lark + zh_Hans: 从 Lark 获取您的 app_id 和 app_secret + url: https://open.larksuite.com/app + app_secret: + type: secret-input + required: true + label: + en_US: APP Secret + placeholder: + en_US: Please input your app secret + zh_Hans: 请输入你的 Lark app secret diff --git a/api/core/tools/provider/builtin/lark_wiki/tools/get_wiki_nodes.py b/api/core/tools/provider/builtin/lark_wiki/tools/get_wiki_nodes.py new file mode 100644 index 00000000000000..a05f300755962f --- /dev/null +++ b/api/core/tools/provider/builtin/lark_wiki/tools/get_wiki_nodes.py @@ -0,0 +1,21 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.lark_api_utils import LarkRequest + + +class GetWikiNodesTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage: + app_id = self.runtime.credentials.get("app_id") + app_secret = self.runtime.credentials.get("app_secret") + client = LarkRequest(app_id, app_secret) + + space_id = tool_parameters.get("space_id") + parent_node_token = tool_parameters.get("parent_node_token") + page_token = tool_parameters.get("page_token") + page_size = tool_parameters.get("page_size") + + res = client.get_wiki_nodes(space_id, parent_node_token, page_token, page_size) + + return self.create_json_message(res) diff --git a/api/core/tools/provider/builtin/lark_wiki/tools/get_wiki_nodes.yaml b/api/core/tools/provider/builtin/lark_wiki/tools/get_wiki_nodes.yaml new file mode 100644 index 00000000000000..a8c242a2e909df --- /dev/null +++ b/api/core/tools/provider/builtin/lark_wiki/tools/get_wiki_nodes.yaml @@ -0,0 +1,63 @@ +identity: + name: get_wiki_nodes + author: Doug Lea + label: + en_US: Get Wiki Nodes + zh_Hans: 获取知识空间子节点列表 +description: + human: + en_US: | + Get the list of child nodes in Wiki, make sure the app/bot is a member of the wiki space. See How to add an app as a wiki base administrator (member). https://open.larksuite.com/document/server-docs/docs/wiki-v2/wiki-qa + zh_Hans: | + 获取知识库全部子节点列表,请确保应用/机器人为知识空间成员。参阅如何将应用添加为知识库管理员(成员)。https://open.larksuite.com/document/server-docs/docs/wiki-v2/wiki-qa + llm: A tool for getting all sub-nodes of a knowledge base.(获取知识空间子节点列表) +parameters: + - name: space_id + type: string + required: true + label: + en_US: Space Id + zh_Hans: 知识空间 ID + human_description: + en_US: | + The ID of the knowledge space. Supports space link URL, for example: https://lark-japan.jp.larksuite.com/wiki/settings/7431084851517718561 + zh_Hans: 知识空间 ID,支持空间链接 URL,例如:https://lark-japan.jp.larksuite.com/wiki/settings/7431084851517718561 + llm_description: 知识空间 ID,支持空间链接 URL,例如:https://lark-japan.jp.larksuite.com/wiki/settings/7431084851517718561 + form: llm + + - name: page_size + type: number + required: false + default: 10 + label: + en_US: Page Size + zh_Hans: 分页大小 + human_description: + en_US: The size of each page, with a maximum value of 50. + zh_Hans: 分页大小,最大值 50。 + llm_description: 分页大小,最大值 50。 + form: form + + - name: page_token + type: string + required: false + label: + en_US: Page Token + zh_Hans: 分页标记 + human_description: + en_US: The pagination token. Leave empty for the first request to start from the beginning; if the paginated query result has more items, a new page_token will be returned, which can be used to get the next set of results. + zh_Hans: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + llm_description: 分页标记,第一次请求不填,表示从头开始遍历;分页查询结果还有更多项时会同时返回新的 page_token,下次遍历可采用该 page_token 获取查询结果。 + form: llm + + - name: parent_node_token + type: string + required: false + label: + en_US: Parent Node Token + zh_Hans: 父节点 token + human_description: + en_US: The token of the parent node. + zh_Hans: 父节点 token + llm_description: 父节点 token + form: llm diff --git a/api/core/tools/provider/builtin/maths/maths.py b/api/core/tools/provider/builtin/maths/maths.py index 7226a5c1686feb..d4b449ec87a18a 100644 --- a/api/core/tools/provider/builtin/maths/maths.py +++ b/api/core/tools/provider/builtin/maths/maths.py @@ -9,9 +9,9 @@ class MathsProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: EvaluateExpressionTool().invoke( - user_id='', + user_id="", tool_parameters={ - 'expression': '1+(2+3)*4', + "expression": "1+(2+3)*4", }, ) except Exception as e: diff --git a/api/core/tools/provider/builtin/maths/tools/eval_expression.py b/api/core/tools/provider/builtin/maths/tools/eval_expression.py index bf73ed69181eaa..0c5b5e41cbe1e1 100644 --- a/api/core/tools/provider/builtin/maths/tools/eval_expression.py +++ b/api/core/tools/provider/builtin/maths/tools/eval_expression.py @@ -8,22 +8,23 @@ class EvaluateExpressionTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get expression - expression = tool_parameters.get('expression', '').strip() + expression = tool_parameters.get("expression", "").strip() if not expression: - return self.create_text_message('Invalid expression') + return self.create_text_message("Invalid expression") try: result = ne.evaluate(expression) result_str = str(result) except Exception as e: - logging.exception(f'Error evaluating expression: {expression}') - return self.create_text_message(f'Invalid expression: {expression}, error: {str(e)}') - return self.create_text_message(f'The result of the expression "{expression}" is {result_str}') \ No newline at end of file + logging.exception(f"Error evaluating expression: {expression}") + return self.create_text_message(f"Invalid expression: {expression}, error: {str(e)}") + return self.create_text_message(f'The result of the expression "{expression}" is {result_str}') diff --git a/api/core/tools/provider/builtin/nominatim/nominatim.py b/api/core/tools/provider/builtin/nominatim/nominatim.py index b6f29b5feb4e44..5a24bed7507eb6 100644 --- a/api/core/tools/provider/builtin/nominatim/nominatim.py +++ b/api/core/tools/provider/builtin/nominatim/nominatim.py @@ -8,16 +8,20 @@ class NominatimProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - result = NominatimSearchTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).invoke( - user_id='', - tool_parameters={ - 'query': 'London', - 'limit': 1, - }, + result = ( + NominatimSearchTool() + .fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ) + .invoke( + user_id="", + tool_parameters={ + "query": "London", + "limit": 1, + }, + ) ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/nominatim/tools/nominatim_lookup.py b/api/core/tools/provider/builtin/nominatim/tools/nominatim_lookup.py index e21ce14f542161..ffa8ad0fcc02e0 100644 --- a/api/core/tools/provider/builtin/nominatim/tools/nominatim_lookup.py +++ b/api/core/tools/provider/builtin/nominatim/tools/nominatim_lookup.py @@ -8,40 +8,33 @@ class NominatimLookupTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - osm_ids = tool_parameters.get('osm_ids', '') - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + osm_ids = tool_parameters.get("osm_ids", "") + if not osm_ids: - return self.create_text_message('Please provide OSM IDs') + return self.create_text_message("Please provide OSM IDs") + + params = {"osm_ids": osm_ids, "format": "json", "addressdetails": 1} - params = { - 'osm_ids': osm_ids, - 'format': 'json', - 'addressdetails': 1 - } - - return self._make_request(user_id, 'lookup', params) + return self._make_request(user_id, "lookup", params) def _make_request(self, user_id: str, endpoint: str, params: dict) -> ToolInvokeMessage: - base_url = self.runtime.credentials.get('base_url', 'https://nominatim.openstreetmap.org') - + base_url = self.runtime.credentials.get("base_url", "https://nominatim.openstreetmap.org") + try: - headers = { - "User-Agent": "DifyNominatimTool/1.0" - } + headers = {"User-Agent": "DifyNominatimTool/1.0"} s = requests.session() - response = s.request( - method='GET', - headers=headers, - url=f"{base_url}/{endpoint}", - params=params - ) + response = s.request(method="GET", headers=headers, url=f"{base_url}/{endpoint}", params=params) response_data = response.json() - + if response.status_code == 200: s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False)) + ) else: return self.create_text_message(f"Error: {response.status_code} - {response.text}") except Exception as e: - return self.create_text_message(f"An error occurred: {str(e)}") \ No newline at end of file + return self.create_text_message(f"An error occurred: {str(e)}") diff --git a/api/core/tools/provider/builtin/nominatim/tools/nominatim_reverse.py b/api/core/tools/provider/builtin/nominatim/tools/nominatim_reverse.py index 438d5219e97887..f46691e1a3ebb4 100644 --- a/api/core/tools/provider/builtin/nominatim/tools/nominatim_reverse.py +++ b/api/core/tools/provider/builtin/nominatim/tools/nominatim_reverse.py @@ -8,42 +8,34 @@ class NominatimReverseTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - lat = tool_parameters.get('lat') - lon = tool_parameters.get('lon') - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + lat = tool_parameters.get("lat") + lon = tool_parameters.get("lon") + if lat is None or lon is None: - return self.create_text_message('Please provide both latitude and longitude') + return self.create_text_message("Please provide both latitude and longitude") + + params = {"lat": lat, "lon": lon, "format": "json", "addressdetails": 1} - params = { - 'lat': lat, - 'lon': lon, - 'format': 'json', - 'addressdetails': 1 - } - - return self._make_request(user_id, 'reverse', params) + return self._make_request(user_id, "reverse", params) def _make_request(self, user_id: str, endpoint: str, params: dict) -> ToolInvokeMessage: - base_url = self.runtime.credentials.get('base_url', 'https://nominatim.openstreetmap.org') - + base_url = self.runtime.credentials.get("base_url", "https://nominatim.openstreetmap.org") + try: - headers = { - "User-Agent": "DifyNominatimTool/1.0" - } + headers = {"User-Agent": "DifyNominatimTool/1.0"} s = requests.session() - response = s.request( - method='GET', - headers=headers, - url=f"{base_url}/{endpoint}", - params=params - ) + response = s.request(method="GET", headers=headers, url=f"{base_url}/{endpoint}", params=params) response_data = response.json() - + if response.status_code == 200: s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False)) + ) else: return self.create_text_message(f"Error: {response.status_code} - {response.text}") except Exception as e: - return self.create_text_message(f"An error occurred: {str(e)}") \ No newline at end of file + return self.create_text_message(f"An error occurred: {str(e)}") diff --git a/api/core/tools/provider/builtin/nominatim/tools/nominatim_search.py b/api/core/tools/provider/builtin/nominatim/tools/nominatim_search.py index 983cbc0e346577..34851d86dcaa5f 100644 --- a/api/core/tools/provider/builtin/nominatim/tools/nominatim_search.py +++ b/api/core/tools/provider/builtin/nominatim/tools/nominatim_search.py @@ -8,42 +8,34 @@ class NominatimSearchTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - query = tool_parameters.get('query', '') - limit = tool_parameters.get('limit', 10) - + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + query = tool_parameters.get("query", "") + limit = tool_parameters.get("limit", 10) + if not query: - return self.create_text_message('Please input a search query') + return self.create_text_message("Please input a search query") + + params = {"q": query, "format": "json", "limit": limit, "addressdetails": 1} - params = { - 'q': query, - 'format': 'json', - 'limit': limit, - 'addressdetails': 1 - } - - return self._make_request(user_id, 'search', params) + return self._make_request(user_id, "search", params) def _make_request(self, user_id: str, endpoint: str, params: dict) -> ToolInvokeMessage: - base_url = self.runtime.credentials.get('base_url', 'https://nominatim.openstreetmap.org') - + base_url = self.runtime.credentials.get("base_url", "https://nominatim.openstreetmap.org") + try: - headers = { - "User-Agent": "DifyNominatimTool/1.0" - } + headers = {"User-Agent": "DifyNominatimTool/1.0"} s = requests.session() - response = s.request( - method='GET', - headers=headers, - url=f"{base_url}/{endpoint}", - params=params - ) + response = s.request(method="GET", headers=headers, url=f"{base_url}/{endpoint}", params=params) response_data = response.json() - + if response.status_code == 200: s.close() - return self.create_text_message(self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False))) + return self.create_text_message( + self.summary(user_id=user_id, content=json.dumps(response_data, ensure_ascii=False)) + ) else: return self.create_text_message(f"Error: {response.status_code} - {response.text}") except Exception as e: - return self.create_text_message(f"An error occurred: {str(e)}") \ No newline at end of file + return self.create_text_message(f"An error occurred: {str(e)}") diff --git a/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py new file mode 100644 index 00000000000000..762e158459cc2a --- /dev/null +++ b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py @@ -0,0 +1,69 @@ +from novita_client import ( + Txt2ImgV3Embedding, + Txt2ImgV3HiresFix, + Txt2ImgV3LoRA, + Txt2ImgV3Refiner, + V3TaskImage, +) + + +class NovitaAiToolBase: + def _extract_loras(self, loras_str: str): + if not loras_str: + return [] + + loras_ori_list = lora_str.strip().split(";") + result_list = [] + for lora_str in loras_ori_list: + lora_info = lora_str.strip().split(",") + lora = Txt2ImgV3LoRA( + model_name=lora_info[0].strip(), + strength=float(lora_info[1]), + ) + result_list.append(lora) + + return result_list + + def _extract_embeddings(self, embeddings_str: str): + if not embeddings_str: + return [] + + embeddings_ori_list = embeddings_str.strip().split(";") + result_list = [] + for embedding_str in embeddings_ori_list: + embedding = Txt2ImgV3Embedding(model_name=embedding_str.strip()) + result_list.append(embedding) + + return result_list + + def _extract_hires_fix(self, hires_fix_str: str): + hires_fix_info = hires_fix_str.strip().split(",") + if "upscaler" in hires_fix_info: + hires_fix = Txt2ImgV3HiresFix( + target_width=int(hires_fix_info[0]), + target_height=int(hires_fix_info[1]), + strength=float(hires_fix_info[2]), + upscaler=hires_fix_info[3].strip(), + ) + else: + hires_fix = Txt2ImgV3HiresFix( + target_width=int(hires_fix_info[0]), + target_height=int(hires_fix_info[1]), + strength=float(hires_fix_info[2]), + ) + + return hires_fix + + def _extract_refiner(self, switch_at: str): + refiner = Txt2ImgV3Refiner(switch_at=float(switch_at)) + return refiner + + def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool: + """ + is hit nsfw + """ + if image.nsfw_detection_result is None: + return False + if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold: + return True + return False diff --git a/api/core/tools/provider/builtin/novitaai/novitaai.py b/api/core/tools/provider/builtin/novitaai/novitaai.py index 1e7d9757c3d81e..d5e32eff29373a 100644 --- a/api/core/tools/provider/builtin/novitaai/novitaai.py +++ b/api/core/tools/provider/builtin/novitaai/novitaai.py @@ -8,23 +8,27 @@ class NovitaAIProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - result = NovitaAiTxt2ImgTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).invoke( - user_id='', - tool_parameters={ - 'model_name': 'cinenautXLATRUE_cinenautV10_392434.safetensors', - 'prompt': 'a futuristic city with flying cars', - 'negative_prompt': '', - 'width': 128, - 'height': 128, - 'image_num': 1, - 'guidance_scale': 7.5, - 'seed': -1, - 'steps': 1, - }, + result = ( + NovitaAiTxt2ImgTool() + .fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ) + .invoke( + user_id="", + tool_parameters={ + "model_name": "cinenautXLATRUE_cinenautV10_392434.safetensors", + "prompt": "a futuristic city with flying cars", + "negative_prompt": "", + "width": 128, + "height": 128, + "image_num": 1, + "guidance_scale": 7.5, + "seed": -1, + "steps": 1, + }, + ) ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py index e63c8919575620..0b4f2edff3607f 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_createtile.py @@ -12,17 +12,18 @@ class NovitaAiCreateTileTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): raise ToolProviderCredentialValidationError("Novita AI API Key is required.") - api_key = self.runtime.credentials.get('api_key') + api_key = self.runtime.credentials.get("api_key") client = NovitaClient(api_key=api_key) param = self._process_parameters(tool_parameters) @@ -30,21 +31,23 @@ def _invoke(self, results = [] results.append( - self.create_blob_message(blob=b64decode(client_result.image_file), - meta={'mime_type': f'image/{client_result.image_type}'}, - save_as=self.VARIABLE_KEY.IMAGE.value) + self.create_blob_message( + blob=b64decode(client_result.image_file), + meta={"mime_type": f"image/{client_result.image_type}"}, + save_as=self.VariableKey.IMAGE.value, + ) ) return results def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ - process parameters + process parameters """ res_parameters = deepcopy(parameters) # delete none and empty - keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ''] + keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ""] for k in keys_to_delete: del res_parameters[k] diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py index ec2927675e14b0..a200ee81231f00 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py @@ -12,127 +12,137 @@ class NovitaAiModelQueryTool(BuiltinTool): - _model_query_endpoint = 'https://api.novita.ai/v3/model' + _model_query_endpoint = "https://api.novita.ai/v3/model" - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): raise ToolProviderCredentialValidationError("Novita AI API Key is required.") - api_key = self.runtime.credentials.get('api_key') - headers = { - 'Content-Type': 'application/json', - 'Authorization': "Bearer " + api_key - } + api_key = self.runtime.credentials.get("api_key") + headers = {"Content-Type": "application/json", "Authorization": "Bearer " + api_key} params = self._process_parameters(tool_parameters) - result_type = params.get('result_type') - del params['result_type'] + result_type = params.get("result_type") + del params["result_type"] models_data = self._query_models( models_data=[], headers=headers, params=params, - recursive=False if result_type == 'first sd_name' or result_type == 'first name sd_name pair' else True + recursive=result_type not in {"first sd_name", "first name sd_name pair"}, ) - result_str = '' - if result_type == 'first sd_name': - result_str = models_data[0]['sd_name_in_api'] if len(models_data) > 0 else '' - elif result_type == 'first name sd_name pair': - result_str = json.dumps({'name': models_data[0]['name'], 'sd_name': models_data[0]['sd_name_in_api']}) if len(models_data) > 0 else '' - elif result_type == 'sd_name array': - sd_name_array = [model['sd_name_in_api'] for model in models_data] if len(models_data) > 0 else [] + result_str = "" + if result_type == "first sd_name": + result_str = models_data[0]["sd_name_in_api"] if len(models_data) > 0 else "" + elif result_type == "first name sd_name pair": + result_str = ( + json.dumps({"name": models_data[0]["name"], "sd_name": models_data[0]["sd_name_in_api"]}) + if len(models_data) > 0 + else "" + ) + elif result_type == "sd_name array": + sd_name_array = [model["sd_name_in_api"] for model in models_data] if len(models_data) > 0 else [] result_str = json.dumps(sd_name_array) - elif result_type == 'name array': - name_array = [model['name'] for model in models_data] if len(models_data) > 0 else [] + elif result_type == "name array": + name_array = [model["name"] for model in models_data] if len(models_data) > 0 else [] result_str = json.dumps(name_array) - elif result_type == 'name sd_name pair array': - name_sd_name_pair_array = [{'name': model['name'], 'sd_name': model['sd_name_in_api']} - for model in models_data] if len(models_data) > 0 else [] + elif result_type == "name sd_name pair array": + name_sd_name_pair_array = ( + [{"name": model["name"], "sd_name": model["sd_name_in_api"]} for model in models_data] + if len(models_data) > 0 + else [] + ) result_str = json.dumps(name_sd_name_pair_array) - elif result_type == 'whole info array': + elif result_type == "whole info array": result_str = json.dumps(models_data) else: raise NotImplementedError return self.create_text_message(result_str) - def _query_models(self, models_data: list, headers: dict[str, Any], - params: dict[str, Any], pagination_cursor: str = '', recursive: bool = True) -> list: + def _query_models( + self, + models_data: list, + headers: dict[str, Any], + params: dict[str, Any], + pagination_cursor: str = "", + recursive: bool = True, + ) -> list: """ - query models + query models """ inside_params = deepcopy(params) - if pagination_cursor != '': - inside_params['pagination.cursor'] = pagination_cursor + if pagination_cursor != "": + inside_params["pagination.cursor"] = pagination_cursor response = ssrf_proxy.get( - url=str(URL(self._model_query_endpoint)), - headers=headers, - params=params, - timeout=(10, 60) + url=str(URL(self._model_query_endpoint)), headers=headers, params=params, timeout=(10, 60) ) res_data = response.json() - models_data.extend(res_data['models']) + models_data.extend(res_data["models"]) - res_data_len = len(res_data['models']) - if res_data_len == 0 or res_data_len < int(params['pagination.limit']) or recursive is False: + res_data_len = len(res_data["models"]) + if res_data_len == 0 or res_data_len < int(params["pagination.limit"]) or recursive is False: # deduplicate df = DataFrame.from_dict(models_data) - df_unique = df.drop_duplicates(subset=['id']) - models_data = df_unique.to_dict('records') + df_unique = df.drop_duplicates(subset=["id"]) + models_data = df_unique.to_dict("records") return models_data return self._query_models( models_data=models_data, headers=headers, params=inside_params, - pagination_cursor=res_data['pagination']['next_cursor'] + pagination_cursor=res_data["pagination"]["next_cursor"], ) def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ - process parameters + process parameters """ process_parameters = deepcopy(parameters) res_parameters = {} # delete none or empty - keys_to_delete = [k for k, v in process_parameters.items() if v is None or v == ''] + keys_to_delete = [k for k, v in process_parameters.items() if v is None or v == ""] for k in keys_to_delete: del process_parameters[k] - if 'query' in process_parameters and process_parameters.get('query') != 'unspecified': - res_parameters['filter.query'] = process_parameters['query'] + if "query" in process_parameters and process_parameters.get("query") != "unspecified": + res_parameters["filter.query"] = process_parameters["query"] - if 'visibility' in process_parameters and process_parameters.get('visibility') != 'unspecified': - res_parameters['filter.visibility'] = process_parameters['visibility'] + if "visibility" in process_parameters and process_parameters.get("visibility") != "unspecified": + res_parameters["filter.visibility"] = process_parameters["visibility"] - if 'source' in process_parameters and process_parameters.get('source') != 'unspecified': - res_parameters['filter.source'] = process_parameters['source'] + if "source" in process_parameters and process_parameters.get("source") != "unspecified": + res_parameters["filter.source"] = process_parameters["source"] - if 'type' in process_parameters and process_parameters.get('type') != 'unspecified': - res_parameters['filter.types'] = process_parameters['type'] + if "type" in process_parameters and process_parameters.get("type") != "unspecified": + res_parameters["filter.types"] = process_parameters["type"] - if 'is_sdxl' in process_parameters: - if process_parameters['is_sdxl'] == 'true': - res_parameters['filter.is_sdxl'] = True - elif process_parameters['is_sdxl'] == 'false': - res_parameters['filter.is_sdxl'] = False + if "is_sdxl" in process_parameters: + if process_parameters["is_sdxl"] == "true": + res_parameters["filter.is_sdxl"] = True + elif process_parameters["is_sdxl"] == "false": + res_parameters["filter.is_sdxl"] = False - res_parameters['result_type'] = process_parameters.get('result_type', 'first sd_name') + res_parameters["result_type"] = process_parameters.get("result_type", "first sd_name") - res_parameters['pagination.limit'] = 1 \ - if res_parameters.get('result_type') == 'first sd_name' \ - or res_parameters.get('result_type') == 'first name sd_name pair'\ + res_parameters["pagination.limit"] = ( + 1 + if res_parameters.get("result_type") == "first sd_name" + or res_parameters.get("result_type") == "first name sd_name pair" else 100 + ) return res_parameters diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py index c9524d6a66f4d8..9c61eab9f95784 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_txt2img.py @@ -4,30 +4,27 @@ from novita_client import ( NovitaClient, - Txt2ImgV3Embedding, - Txt2ImgV3HiresFix, - Txt2ImgV3LoRA, - Txt2ImgV3Refiner, - V3TaskImage, ) from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.novitaai._novita_tool_base import NovitaAiToolBase from core.tools.tool.builtin_tool import BuiltinTool -class NovitaAiTxt2ImgTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: +class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - if 'api_key' not in self.runtime.credentials or not self.runtime.credentials.get('api_key'): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): raise ToolProviderCredentialValidationError("Novita AI API Key is required.") - api_key = self.runtime.credentials.get('api_key') + api_key = self.runtime.credentials.get("api_key") client = NovitaClient(api_key=api_key) param = self._process_parameters(tool_parameters) @@ -36,102 +33,58 @@ def _invoke(self, results = [] for image_encoded, image in zip(client_result.images_encoded, client_result.images): if self._is_hit_nsfw_detection(image, 0.8): - results = self.create_text_message(text='NSFW detected!') + results = self.create_text_message(text="NSFW detected!") break results.append( - self.create_blob_message(blob=b64decode(image_encoded), - meta={'mime_type': f'image/{image.image_type}'}, - save_as=self.VARIABLE_KEY.IMAGE.value) + self.create_blob_message( + blob=b64decode(image_encoded), + meta={"mime_type": f"image/{image.image_type}"}, + save_as=self.VariableKey.IMAGE.value, + ) ) return results def _process_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ - process parameters + process parameters """ res_parameters = deepcopy(parameters) # delete none and empty - keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ''] + keys_to_delete = [k for k, v in res_parameters.items() if v is None or v == ""] for k in keys_to_delete: del res_parameters[k] - if 'clip_skip' in res_parameters and res_parameters.get('clip_skip') == 0: - del res_parameters['clip_skip'] + if "clip_skip" in res_parameters and res_parameters.get("clip_skip") == 0: + del res_parameters["clip_skip"] - if 'refiner_switch_at' in res_parameters and res_parameters.get('refiner_switch_at') == 0: - del res_parameters['refiner_switch_at'] + if "refiner_switch_at" in res_parameters and res_parameters.get("refiner_switch_at") == 0: + del res_parameters["refiner_switch_at"] - if 'enabled_enterprise_plan' in res_parameters: - res_parameters['enterprise_plan'] = {'enabled': res_parameters['enabled_enterprise_plan']} - del res_parameters['enabled_enterprise_plan'] + if "enabled_enterprise_plan" in res_parameters: + res_parameters["enterprise_plan"] = {"enabled": res_parameters["enabled_enterprise_plan"]} + del res_parameters["enabled_enterprise_plan"] - if 'nsfw_detection_level' in res_parameters: - res_parameters['nsfw_detection_level'] = int(res_parameters['nsfw_detection_level']) + if "nsfw_detection_level" in res_parameters: + res_parameters["nsfw_detection_level"] = int(res_parameters["nsfw_detection_level"]) # process loras - if 'loras' in res_parameters: - loras_ori_list = res_parameters.get('loras').strip().split(';') - locals_list = [] - for lora_str in loras_ori_list: - lora_info = lora_str.strip().split(',') - lora = Txt2ImgV3LoRA( - model_name=lora_info[0].strip(), - strength=float(lora_info[1]), - ) - locals_list.append(lora) - - res_parameters['loras'] = locals_list + if "loras" in res_parameters: + res_parameters["loras"] = self._extract_loras(res_parameters.get("loras")) # process embeddings - if 'embeddings' in res_parameters: - embeddings_ori_list = res_parameters.get('embeddings').strip().split(';') - locals_list = [] - for embedding_str in embeddings_ori_list: - embedding = Txt2ImgV3Embedding( - model_name=embedding_str.strip() - ) - locals_list.append(embedding) - - res_parameters['embeddings'] = locals_list + if "embeddings" in res_parameters: + res_parameters["embeddings"] = self._extract_embeddings(res_parameters.get("embeddings")) # process hires_fix - if 'hires_fix' in res_parameters: - hires_fix_ori = res_parameters.get('hires_fix') - hires_fix_info = hires_fix_ori.strip().split(',') - if 'upscaler' in hires_fix_info: - hires_fix = Txt2ImgV3HiresFix( - target_width=int(hires_fix_info[0]), - target_height=int(hires_fix_info[1]), - strength=float(hires_fix_info[2]), - upscaler=hires_fix_info[3].strip() - ) - else: - hires_fix = Txt2ImgV3HiresFix( - target_width=int(hires_fix_info[0]), - target_height=int(hires_fix_info[1]), - strength=float(hires_fix_info[2]) - ) - - res_parameters['hires_fix'] = hires_fix + if "hires_fix" in res_parameters: + res_parameters["hires_fix"] = self._extract_hires_fix(res_parameters.get("hires_fix")) - if 'refiner_switch_at' in res_parameters: - refiner = Txt2ImgV3Refiner( - switch_at=float(res_parameters.get('refiner_switch_at')) - ) - del res_parameters['refiner_switch_at'] - res_parameters['refiner'] = refiner + # process refiner + if "refiner_switch_at" in res_parameters: + res_parameters["refiner"] = self._extract_refiner(res_parameters.get("refiner_switch_at")) + del res_parameters["refiner_switch_at"] return res_parameters - - def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool: - """ - is hit nsfw - """ - if image.nsfw_detection_result is None: - return False - if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold: - return True - return False diff --git a/api/core/tools/provider/builtin/onebot/_assets/icon.ico b/api/core/tools/provider/builtin/onebot/_assets/icon.ico new file mode 100644 index 00000000000000..1b07e965b9910b Binary files /dev/null and b/api/core/tools/provider/builtin/onebot/_assets/icon.ico differ diff --git a/api/core/tools/provider/builtin/onebot/onebot.py b/api/core/tools/provider/builtin/onebot/onebot.py new file mode 100644 index 00000000000000..b8e5ed24d6b43f --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/onebot.py @@ -0,0 +1,10 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class OneBotProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + if not credentials.get("ob11_http_url"): + raise ToolProviderCredentialValidationError("OneBot HTTP URL is required.") diff --git a/api/core/tools/provider/builtin/onebot/onebot.yaml b/api/core/tools/provider/builtin/onebot/onebot.yaml new file mode 100644 index 00000000000000..1922adc4de4d56 --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/onebot.yaml @@ -0,0 +1,35 @@ +identity: + author: RockChinQ + name: onebot + label: + en_US: OneBot v11 Protocol + zh_Hans: OneBot v11 协议 + description: + en_US: Unofficial OneBot v11 Protocol Tool + zh_Hans: 非官方 OneBot v11 协议工具 + icon: icon.ico +credentials_for_provider: + ob11_http_url: + type: text-input + required: true + label: + en_US: HTTP URL + zh_Hans: HTTP URL + description: + en_US: Forward HTTP URL of OneBot v11 + zh_Hans: OneBot v11 正向 HTTP URL + help: + en_US: Fill this with the HTTP URL of your OneBot server + zh_Hans: 请在你的 OneBot 协议端开启 正向 HTTP 并填写其 URL + access_token: + type: secret-input + required: false + label: + en_US: Access Token + zh_Hans: 访问令牌 + description: + en_US: Access Token for OneBot v11 Protocol + zh_Hans: OneBot 协议访问令牌 + help: + en_US: Fill this if you set a access token in your OneBot server + zh_Hans: 如果你在 OneBot 服务器中设置了 access token,请填写此项 diff --git a/api/core/tools/provider/builtin/onebot/tools/__init__.py b/api/core/tools/provider/builtin/onebot/tools/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py new file mode 100644 index 00000000000000..9c95bbc2ae8d2d --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.py @@ -0,0 +1,39 @@ +from typing import Any, Union + +import requests +from yarl import URL + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class SendGroupMsg(BuiltinTool): + """OneBot v11 Tool: Send Group Message""" + + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + # Get parameters + send_group_id = tool_parameters.get("group_id", "") + + message = tool_parameters.get("message", "") + if not message: + return self.create_json_message({"error": "Message is empty."}) + + auto_escape = tool_parameters.get("auto_escape", False) + + try: + url = URL(self.runtime.credentials["ob11_http_url"]) / "send_group_msg" + + resp = requests.post( + url, + json={"group_id": send_group_id, "message": message, "auto_escape": auto_escape}, + headers={"Authorization": "Bearer " + self.runtime.credentials["access_token"]}, + ) + + if resp.status_code != 200: + return self.create_json_message({"error": f"Failed to send group message: {resp.text}"}) + + return self.create_json_message({"response": resp.json()}) + except Exception as e: + return self.create_json_message({"error": f"Failed to send group message: {e}"}) diff --git a/api/core/tools/provider/builtin/onebot/tools/send_group_msg.yaml b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.yaml new file mode 100644 index 00000000000000..64beaa85457a3a --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/tools/send_group_msg.yaml @@ -0,0 +1,46 @@ +identity: + name: send_group_msg + author: RockChinQ + label: + en_US: Send Group Message + zh_Hans: 发送群消息 +description: + human: + en_US: Send a message to a group + zh_Hans: 发送消息到群聊 + llm: A tool for sending a message segment to a group +parameters: + - name: group_id + type: number + required: true + label: + en_US: Target Group ID + zh_Hans: 目标群 ID + human_description: + en_US: The group ID of the target group + zh_Hans: 目标群的群 ID + llm_description: The group ID of the target group + form: llm + - name: message + type: string + required: true + label: + en_US: Message + zh_Hans: 消息 + human_description: + en_US: The message to send + zh_Hans: 要发送的消息。支持 CQ码(需要同时设置 auto_escape 为 true) + llm_description: The message to send + form: llm + - name: auto_escape + type: boolean + required: false + default: false + label: + en_US: Auto Escape + zh_Hans: 自动转义 + human_description: + en_US: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending. Since Dify currently does not support passing Object-format message chains, developers can send complex message components through CQ codes. + zh_Hans: 若为 true 则会把 message 视为 CQ 码解析,否则视为 纯文本 直接发送。由于 Dify 目前不支持传入 Object格式 的消息,故开发者可以通过 CQ 码来发送复杂消息组件。 + llm_description: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending. + form: form diff --git a/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py new file mode 100644 index 00000000000000..1174c7f07d002f --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.py @@ -0,0 +1,39 @@ +from typing import Any, Union + +import requests +from yarl import URL + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class SendPrivateMsg(BuiltinTool): + """OneBot v11 Tool: Send Private Message""" + + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + # Get parameters + send_user_id = tool_parameters.get("user_id", "") + + message = tool_parameters.get("message", "") + if not message: + return self.create_json_message({"error": "Message is empty."}) + + auto_escape = tool_parameters.get("auto_escape", False) + + try: + url = URL(self.runtime.credentials["ob11_http_url"]) / "send_private_msg" + + resp = requests.post( + url, + json={"user_id": send_user_id, "message": message, "auto_escape": auto_escape}, + headers={"Authorization": "Bearer " + self.runtime.credentials["access_token"]}, + ) + + if resp.status_code != 200: + return self.create_json_message({"error": f"Failed to send private message: {resp.text}"}) + + return self.create_json_message({"response": resp.json()}) + except Exception as e: + return self.create_json_message({"error": f"Failed to send private message: {e}"}) diff --git a/api/core/tools/provider/builtin/onebot/tools/send_private_msg.yaml b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.yaml new file mode 100644 index 00000000000000..8200ce4a83f4e2 --- /dev/null +++ b/api/core/tools/provider/builtin/onebot/tools/send_private_msg.yaml @@ -0,0 +1,46 @@ +identity: + name: send_private_msg + author: RockChinQ + label: + en_US: Send Private Message + zh_Hans: 发送私聊消息 +description: + human: + en_US: Send a private message to a user + zh_Hans: 发送私聊消息给用户 + llm: A tool for sending a message segment to a user in private chat +parameters: + - name: user_id + type: number + required: true + label: + en_US: Target User ID + zh_Hans: 目标用户 ID + human_description: + en_US: The user ID of the target user + zh_Hans: 目标用户的用户 ID + llm_description: The user ID of the target user + form: llm + - name: message + type: string + required: true + label: + en_US: Message + zh_Hans: 消息 + human_description: + en_US: The message to send + zh_Hans: 要发送的消息。支持 CQ码(需要同时设置 auto_escape 为 true) + llm_description: The message to send + form: llm + - name: auto_escape + type: boolean + required: false + default: false + label: + en_US: Auto Escape + zh_Hans: 自动转义 + human_description: + en_US: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending. Since Dify currently does not support passing Object-format message chains, developers can send complex message components through CQ codes. + zh_Hans: 若为 true 则会把 message 视为 CQ 码解析,否则视为 纯文本 直接发送。由于 Dify 目前不支持传入 Object格式 的消息,故开发者可以通过 CQ 码来发送复杂消息组件。 + llm_description: If true, the message will be treated as a CQ code for parsing, otherwise it will be treated as plain text for direct sending. + form: form diff --git a/api/core/tools/provider/builtin/openweather/openweather.py b/api/core/tools/provider/builtin/openweather/openweather.py index a2827177a370c9..9e40249aba6b40 100644 --- a/api/core/tools/provider/builtin/openweather/openweather.py +++ b/api/core/tools/provider/builtin/openweather/openweather.py @@ -5,7 +5,6 @@ def query_weather(city="Beijing", units="metric", language="zh_cn", api_key=None): - url = "https://api.openweathermap.org/data/2.5/weather" params = {"q": city, "appid": api_key, "units": units, "lang": language} @@ -16,21 +15,15 @@ class OpenweatherProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: if "api_key" not in credentials or not credentials.get("api_key"): - raise ToolProviderCredentialValidationError( - "Open weather API key is required." - ) + raise ToolProviderCredentialValidationError("Open weather API key is required.") apikey = credentials.get("api_key") try: response = query_weather(api_key=apikey) if response.status_code == 200: pass else: - raise ToolProviderCredentialValidationError( - (response.json()).get("info") - ) + raise ToolProviderCredentialValidationError((response.json()).get("info")) except Exception as e: - raise ToolProviderCredentialValidationError( - "Open weather API Key is invalid. {}".format(e) - ) + raise ToolProviderCredentialValidationError("Open weather API Key is invalid. {}".format(e)) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/openweather/tools/weather.py b/api/core/tools/provider/builtin/openweather/tools/weather.py index 536a3511f463d2..ed4ec487fa984a 100644 --- a/api/core/tools/provider/builtin/openweather/tools/weather.py +++ b/api/core/tools/provider/builtin/openweather/tools/weather.py @@ -17,10 +17,7 @@ def _invoke( city = tool_parameters.get("city", "") if not city: return self.create_text_message("Please tell me your city") - if ( - "api_key" not in self.runtime.credentials - or not self.runtime.credentials.get("api_key") - ): + if "api_key" not in self.runtime.credentials or not self.runtime.credentials.get("api_key"): return self.create_text_message("OpenWeather API key is required.") units = tool_parameters.get("units", "metric") @@ -29,7 +26,7 @@ def _invoke( # request URL url = "https://api.openweathermap.org/data/2.5/weather" - # request parmas + # request params params = { "q": city, "appid": self.runtime.credentials.get("api_key"), @@ -39,12 +36,9 @@ def _invoke( response = requests.get(url, params=params) if response.status_code == 200: - data = response.json() return self.create_text_message( - self.summary( - user_id=user_id, content=json.dumps(data, ensure_ascii=False) - ) + self.summary(user_id=user_id, content=json.dumps(data, ensure_ascii=False)) ) else: error_message = { @@ -55,6 +49,4 @@ def _invoke( return json.dumps(error_message) except Exception as e: - return self.create_text_message( - "Openweather API Key is invalid. {}".format(e) - ) + return self.create_text_message("Openweather API Key is invalid. {}".format(e)) diff --git a/api/core/tools/provider/builtin/perplexity/_assets/icon.svg b/api/core/tools/provider/builtin/perplexity/_assets/icon.svg new file mode 100644 index 00000000000000..c2974c142fc622 --- /dev/null +++ b/api/core/tools/provider/builtin/perplexity/_assets/icon.svg @@ -0,0 +1,3 @@ + + + diff --git a/api/core/tools/provider/builtin/perplexity/perplexity.py b/api/core/tools/provider/builtin/perplexity/perplexity.py new file mode 100644 index 00000000000000..80518853fb4a4b --- /dev/null +++ b/api/core/tools/provider/builtin/perplexity/perplexity.py @@ -0,0 +1,38 @@ +from typing import Any + +import requests + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.perplexity.tools.perplexity_search import PERPLEXITY_API_URL +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class PerplexityProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + headers = { + "Authorization": f"Bearer {credentials.get('perplexity_api_key')}", + "Content-Type": "application/json", + } + + payload = { + "model": "llama-3.1-sonar-small-128k-online", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, + ], + "max_tokens": 5, + "temperature": 0.1, + "top_p": 0.9, + "stream": False, + } + + try: + response = requests.post(PERPLEXITY_API_URL, json=payload, headers=headers) + response.raise_for_status() + except requests.RequestException as e: + raise ToolProviderCredentialValidationError(f"Failed to validate Perplexity API key: {str(e)}") + + if response.status_code != 200: + raise ToolProviderCredentialValidationError( + f"Perplexity API key is invalid. Status code: {response.status_code}" + ) diff --git a/api/core/tools/provider/builtin/perplexity/perplexity.yaml b/api/core/tools/provider/builtin/perplexity/perplexity.yaml new file mode 100644 index 00000000000000..c0b504f300c45a --- /dev/null +++ b/api/core/tools/provider/builtin/perplexity/perplexity.yaml @@ -0,0 +1,26 @@ +identity: + author: Dify + name: perplexity + label: + en_US: Perplexity + zh_Hans: Perplexity + description: + en_US: Perplexity.AI + zh_Hans: Perplexity.AI + icon: icon.svg + tags: + - search +credentials_for_provider: + perplexity_api_key: + type: secret-input + required: true + label: + en_US: Perplexity API key + zh_Hans: Perplexity API key + placeholder: + en_US: Please input your Perplexity API key + zh_Hans: 请输入你的 Perplexity API key + help: + en_US: Get your Perplexity API key from Perplexity + zh_Hans: 从 Perplexity 获取您的 Perplexity API key + url: https://www.perplexity.ai/settings/api diff --git a/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py b/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py new file mode 100644 index 00000000000000..5ed4b9ca993483 --- /dev/null +++ b/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.py @@ -0,0 +1,67 @@ +import json +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +PERPLEXITY_API_URL = "https://api.perplexity.ai/chat/completions" + + +class PerplexityAITool(BuiltinTool): + def _parse_response(self, response: dict) -> dict: + """Parse the response from Perplexity AI API""" + if "choices" in response and len(response["choices"]) > 0: + message = response["choices"][0]["message"] + return { + "content": message.get("content", ""), + "role": message.get("role", ""), + "citations": response.get("citations", []), + } + else: + return {"content": "Unable to get a valid response", "role": "assistant", "citations": []} + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + headers = { + "Authorization": f"Bearer {self.runtime.credentials['perplexity_api_key']}", + "Content-Type": "application/json", + } + + payload = { + "model": tool_parameters.get("model", "llama-3.1-sonar-small-128k-online"), + "messages": [ + {"role": "system", "content": "Be precise and concise."}, + {"role": "user", "content": tool_parameters["query"]}, + ], + "max_tokens": tool_parameters.get("max_tokens", 4096), + "temperature": tool_parameters.get("temperature", 0.7), + "top_p": tool_parameters.get("top_p", 1), + "top_k": tool_parameters.get("top_k", 5), + "presence_penalty": tool_parameters.get("presence_penalty", 0), + "frequency_penalty": tool_parameters.get("frequency_penalty", 1), + "stream": False, + } + + if "search_recency_filter" in tool_parameters: + payload["search_recency_filter"] = tool_parameters["search_recency_filter"] + if "return_citations" in tool_parameters: + payload["return_citations"] = tool_parameters["return_citations"] + if "search_domain_filter" in tool_parameters: + if isinstance(tool_parameters["search_domain_filter"], str): + payload["search_domain_filter"] = [tool_parameters["search_domain_filter"]] + elif isinstance(tool_parameters["search_domain_filter"], list): + payload["search_domain_filter"] = tool_parameters["search_domain_filter"] + + response = requests.post(url=PERPLEXITY_API_URL, json=payload, headers=headers) + response.raise_for_status() + valuable_res = self._parse_response(response.json()) + + return [ + self.create_json_message(valuable_res), + self.create_text_message(json.dumps(valuable_res, ensure_ascii=False, indent=2)), + ] diff --git a/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.yaml b/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.yaml new file mode 100644 index 00000000000000..02a645df335aaf --- /dev/null +++ b/api/core/tools/provider/builtin/perplexity/tools/perplexity_search.yaml @@ -0,0 +1,178 @@ +identity: + name: perplexity + author: Dify + label: + en_US: Perplexity Search +description: + human: + en_US: Search information using Perplexity AI's language models. + llm: This tool is used to search information using Perplexity AI's language models. +parameters: + - name: query + type: string + required: true + label: + en_US: Query + zh_Hans: 查询 + human_description: + en_US: The text query to be processed by the AI model. + zh_Hans: 要由 AI 模型处理的文本查询。 + form: llm + - name: model + type: select + required: false + label: + en_US: Model Name + zh_Hans: 模型名称 + human_description: + en_US: The Perplexity AI model to use for generating the response. + zh_Hans: 用于生成响应的 Perplexity AI 模型。 + form: form + default: "llama-3.1-sonar-small-128k-online" + options: + - value: llama-3.1-sonar-small-128k-online + label: + en_US: llama-3.1-sonar-small-128k-online + zh_Hans: llama-3.1-sonar-small-128k-online + - value: llama-3.1-sonar-large-128k-online + label: + en_US: llama-3.1-sonar-large-128k-online + zh_Hans: llama-3.1-sonar-large-128k-online + - value: llama-3.1-sonar-huge-128k-online + label: + en_US: llama-3.1-sonar-huge-128k-online + zh_Hans: llama-3.1-sonar-huge-128k-online + - name: max_tokens + type: number + required: false + label: + en_US: Max Tokens + zh_Hans: 最大令牌数 + pt_BR: Máximo de Tokens + human_description: + en_US: The maximum number of tokens to generate in the response. + zh_Hans: 在响应中生成的最大令牌数。 + pt_BR: O número máximo de tokens a serem gerados na resposta. + form: form + default: 4096 + min: 1 + max: 4096 + - name: temperature + type: number + required: false + label: + en_US: Temperature + zh_Hans: 温度 + pt_BR: Temperatura + human_description: + en_US: Controls randomness in the output. Lower values make the output more focused and deterministic. + zh_Hans: 控制输出的随机性。较低的值使输出更加集中和确定。 + form: form + default: 0.7 + min: 0 + max: 1 + - name: top_k + type: number + required: false + label: + en_US: Top K + zh_Hans: 取样数量 + human_description: + en_US: The number of top results to consider for response generation. + zh_Hans: 用于生成响应的顶部结果数量。 + form: form + default: 5 + min: 1 + max: 100 + - name: top_p + type: number + required: false + label: + en_US: Top P + zh_Hans: Top P + human_description: + en_US: Controls diversity via nucleus sampling. + zh_Hans: 通过核心采样控制多样性。 + form: form + default: 1 + min: 0.1 + max: 1 + step: 0.1 + - name: presence_penalty + type: number + required: false + label: + en_US: Presence Penalty + zh_Hans: 存在惩罚 + human_description: + en_US: Positive values penalize new tokens based on whether they appear in the text so far. + zh_Hans: 正值会根据新词元是否已经出现在文本中来对其进行惩罚。 + form: form + default: 0 + min: -1.0 + max: 1.0 + step: 0.1 + - name: frequency_penalty + type: number + required: false + label: + en_US: Frequency Penalty + zh_Hans: 频率惩罚 + human_description: + en_US: Positive values penalize new tokens based on their existing frequency in the text so far. + zh_Hans: 正值会根据新词元在文本中已经出现的频率来对其进行惩罚。 + form: form + default: 1 + min: 0.1 + max: 1.0 + step: 0.1 + - name: return_citations + type: boolean + required: false + label: + en_US: Return Citations + zh_Hans: 返回引用 + human_description: + en_US: Whether to return citations in the response. + zh_Hans: 是否在响应中返回引用。 + form: form + default: true + - name: search_domain_filter + type: string + required: false + label: + en_US: Search Domain Filter + zh_Hans: 搜索域过滤器 + human_description: + en_US: Domain to filter the search results. + zh_Hans: 用于过滤搜索结果的域名。 + form: form + default: "" + - name: search_recency_filter + type: select + required: false + label: + en_US: Search Recency Filter + zh_Hans: 搜索时间过滤器 + human_description: + en_US: Filter for search results based on recency. + zh_Hans: 基于时间筛选搜索结果。 + form: form + default: "month" + options: + - value: day + label: + en_US: Day + zh_Hans: 天 + - value: week + label: + en_US: Week + zh_Hans: 周 + - value: month + label: + en_US: Month + zh_Hans: 月 + - value: year + label: + en_US: Year + zh_Hans: 年 diff --git a/api/core/tools/provider/builtin/podcast_generator/_assets/icon.svg b/api/core/tools/provider/builtin/podcast_generator/_assets/icon.svg new file mode 100644 index 00000000000000..01743c9cd31120 --- /dev/null +++ b/api/core/tools/provider/builtin/podcast_generator/_assets/icon.svg @@ -0,0 +1,24 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/podcast_generator/podcast_generator.py b/api/core/tools/provider/builtin/podcast_generator/podcast_generator.py new file mode 100644 index 00000000000000..0b9c025834d6f9 --- /dev/null +++ b/api/core/tools/provider/builtin/podcast_generator/podcast_generator.py @@ -0,0 +1,33 @@ +from typing import Any + +import openai + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class PodcastGeneratorProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + tts_service = credentials.get("tts_service") + api_key = credentials.get("api_key") + + if not tts_service: + raise ToolProviderCredentialValidationError("TTS service is not specified") + + if not api_key: + raise ToolProviderCredentialValidationError("API key is missing") + + if tts_service == "openai": + self._validate_openai_credentials(api_key) + else: + raise ToolProviderCredentialValidationError(f"Unsupported TTS service: {tts_service}") + + def _validate_openai_credentials(self, api_key: str) -> None: + client = openai.OpenAI(api_key=api_key) + try: + # We're using a simple API call to validate the credentials + client.models.list() + except openai.AuthenticationError: + raise ToolProviderCredentialValidationError("Invalid OpenAI API key") + except Exception as e: + raise ToolProviderCredentialValidationError(f"Error validating OpenAI API key: {str(e)}") diff --git a/api/core/tools/provider/builtin/podcast_generator/podcast_generator.yaml b/api/core/tools/provider/builtin/podcast_generator/podcast_generator.yaml new file mode 100644 index 00000000000000..bd02b32020a85e --- /dev/null +++ b/api/core/tools/provider/builtin/podcast_generator/podcast_generator.yaml @@ -0,0 +1,34 @@ +identity: + author: Dify + name: podcast_generator + label: + en_US: Podcast Generator + zh_Hans: 播客生成器 + description: + en_US: Generate podcast audio using Text-to-Speech services + zh_Hans: 使用文字转语音服务生成播客音频 + icon: icon.svg +credentials_for_provider: + tts_service: + type: select + required: true + label: + en_US: TTS Service + zh_Hans: TTS 服务 + placeholder: + en_US: Select a TTS service + zh_Hans: 选择一个 TTS 服务 + options: + - label: + en_US: OpenAI TTS + zh_Hans: OpenAI TTS + value: openai + api_key: + type: secret-input + required: true + label: + en_US: API Key + zh_Hans: API 密钥 + placeholder: + en_US: Enter your TTS service API key + zh_Hans: 输入您的 TTS 服务 API 密钥 diff --git a/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py new file mode 100644 index 00000000000000..476e2d01e1d107 --- /dev/null +++ b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py @@ -0,0 +1,104 @@ +import concurrent.futures +import io +import random +import warnings +from typing import Any, Literal, Optional, Union + +import openai + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.errors import ToolParameterValidationError, ToolProviderCredentialValidationError +from core.tools.tool.builtin_tool import BuiltinTool + +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from pydub import AudioSegment + + +class PodcastAudioGeneratorTool(BuiltinTool): + @staticmethod + def _generate_silence(duration: float): + # Generate silent WAV data using pydub + silence = AudioSegment.silent(duration=int(duration * 1000)) # pydub uses milliseconds + return silence + + @staticmethod + def _generate_audio_segment( + client: openai.OpenAI, + line: str, + voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"], + index: int, + ) -> tuple[int, Union[AudioSegment, str], Optional[AudioSegment]]: + try: + response = client.audio.speech.create(model="tts-1", voice=voice, input=line.strip(), response_format="wav") + audio = AudioSegment.from_wav(io.BytesIO(response.content)) + silence_duration = random.uniform(0.1, 1.5) + silence = PodcastAudioGeneratorTool._generate_silence(silence_duration) + return index, audio, silence + except Exception as e: + return index, f"Error generating audio: {str(e)}", None + + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + # Extract parameters + script = tool_parameters.get("script", "") + host1_voice = tool_parameters.get("host1_voice") + host2_voice = tool_parameters.get("host2_voice") + + # Split the script into lines + script_lines = [line for line in script.split("\n") if line.strip()] + + # Ensure voices are provided + if not host1_voice or not host2_voice: + raise ToolParameterValidationError("Host voices are required") + + # Get OpenAI API key from credentials + if not self.runtime or not self.runtime.credentials: + raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing") + api_key = self.runtime.credentials.get("api_key") + if not api_key: + raise ToolProviderCredentialValidationError("OpenAI API key is missing") + + # Initialize OpenAI client + client = openai.OpenAI(api_key=api_key) + + # Create a thread pool + max_workers = 5 + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for i, line in enumerate(script_lines): + voice = host1_voice if i % 2 == 0 else host2_voice + future = executor.submit(self._generate_audio_segment, client, line, voice, i) + futures.append(future) + + # Collect results + audio_segments: list[Any] = [None] * len(script_lines) + for future in concurrent.futures.as_completed(futures): + index, audio, silence = future.result() + if isinstance(audio, str): # Error occurred + return self.create_text_message(audio) + audio_segments[index] = (audio, silence) + + # Combine audio segments in the correct order + combined_audio = AudioSegment.empty() + for i, (audio, silence) in enumerate(audio_segments): + if audio: + combined_audio += audio + if i < len(audio_segments) - 1 and silence: + combined_audio += silence + + # Export the combined audio to a WAV file in memory + buffer = io.BytesIO() + combined_audio.export(buffer, format="wav") + wav_bytes = buffer.getvalue() + + # Create a blob message with the combined audio + return [ + self.create_text_message("Audio generated successfully"), + self.create_blob_message( + blob=wav_bytes, + meta={"mime_type": "audio/x-wav"}, + save_as=self.VariableKey.AUDIO, + ), + ] diff --git a/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.yaml b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.yaml new file mode 100644 index 00000000000000..d6ae98f59522c5 --- /dev/null +++ b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.yaml @@ -0,0 +1,95 @@ +identity: + name: podcast_audio_generator + author: Dify + label: + en_US: Podcast Audio Generator + zh_Hans: 播客音频生成器 +description: + human: + en_US: Generate a podcast audio file from a script with two alternating voices using OpenAI's TTS service. + zh_Hans: 使用 OpenAI 的 TTS 服务,从包含两个交替声音的脚本生成播客音频文件。 + llm: This tool converts a prepared podcast script into an audio file using OpenAI's Text-to-Speech service, with two specified voices for alternating hosts. +parameters: + - name: script + type: string + required: true + label: + en_US: Podcast Script + zh_Hans: 播客脚本 + human_description: + en_US: A string containing alternating lines for two hosts, separated by newline characters. + zh_Hans: 包含两位主持人交替台词的字符串,每行用换行符分隔。 + llm_description: A string representing the script, with alternating lines for two hosts separated by newline characters. + form: llm + - name: host1_voice + type: select + required: true + label: + en_US: Host 1 Voice + zh_Hans: 主持人1 音色 + human_description: + en_US: The voice for the first host. + zh_Hans: 第一位主持人的音色。 + llm_description: The voice identifier for the first host's voice. + options: + - label: + en_US: Alloy + zh_Hans: Alloy + value: alloy + - label: + en_US: Echo + zh_Hans: Echo + value: echo + - label: + en_US: Fable + zh_Hans: Fable + value: fable + - label: + en_US: Onyx + zh_Hans: Onyx + value: onyx + - label: + en_US: Nova + zh_Hans: Nova + value: nova + - label: + en_US: Shimmer + zh_Hans: Shimmer + value: shimmer + form: form + - name: host2_voice + type: select + required: true + label: + en_US: Host 2 Voice + zh_Hans: 主持人2 音色 + human_description: + en_US: The voice for the second host. + zh_Hans: 第二位主持人的音色。 + llm_description: The voice identifier for the second host's voice. + options: + - label: + en_US: Alloy + zh_Hans: Alloy + value: alloy + - label: + en_US: Echo + zh_Hans: Echo + value: echo + - label: + en_US: Fable + zh_Hans: Fable + value: fable + - label: + en_US: Onyx + zh_Hans: Onyx + value: onyx + - label: + en_US: Nova + zh_Hans: Nova + value: nova + - label: + en_US: Shimmer + zh_Hans: Shimmer + value: shimmer + form: form diff --git a/api/core/tools/provider/builtin/pubmed/pubmed.py b/api/core/tools/provider/builtin/pubmed/pubmed.py index 05cd171b873327..ea3a477c30178d 100644 --- a/api/core/tools/provider/builtin/pubmed/pubmed.py +++ b/api/core/tools/provider/builtin/pubmed/pubmed.py @@ -11,11 +11,10 @@ def _validate_credentials(self, credentials: dict) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "John Doe", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py b/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py index 58811d65e6fe9a..3a4f374ea0b0bc 100644 --- a/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py +++ b/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py @@ -51,17 +51,12 @@ def run(self, query: str) -> str: try: # Retrieve the top-k results for the query docs = [ - f"Published: {result['pub_date']}\nTitle: {result['title']}\n" - f"Summary: {result['summary']}" + f"Published: {result['pub_date']}\nTitle: {result['title']}\nSummary: {result['summary']}" for result in self.load(query[: self.ARXIV_MAX_QUERY_LENGTH]) ] # Join the results and limit the character count - return ( - "\n\n".join(docs)[:self.doc_content_chars_max] - if docs - else "No good PubMed Result was found" - ) + return "\n\n".join(docs)[: self.doc_content_chars_max] if docs else "No good PubMed Result was found" except Exception as ex: return f"PubMed exception: {ex}" @@ -91,13 +86,7 @@ def load(self, query: str) -> list[dict]: return articles def retrieve_article(self, uid: str, webenv: str) -> dict: - url = ( - self.base_url_efetch - + "db=pubmed&retmode=xml&id=" - + uid - + "&webenv=" - + webenv - ) + url = self.base_url_efetch + "db=pubmed&retmode=xml&id=" + uid + "&webenv=" + webenv retry = 0 while True: @@ -108,10 +97,7 @@ def retrieve_article(self, uid: str, webenv: str) -> dict: if e.code == 429 and retry < self.max_retry: # Too Many Requests error # wait for an exponentially increasing amount of time - print( - f"Too Many Requests, " - f"waiting for {self.sleep_time:.2f} seconds..." - ) + print(f"Too Many Requests, waiting for {self.sleep_time:.2f} seconds...") time.sleep(self.sleep_time) self.sleep_time *= 2 retry += 1 @@ -125,27 +111,21 @@ def retrieve_article(self, uid: str, webenv: str) -> dict: if "" in xml_text and "" in xml_text: start_tag = "" end_tag = "" - title = xml_text[ - xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag) - ] + title = xml_text[xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)] # Get abstract abstract = "" if "" in xml_text and "" in xml_text: start_tag = "" end_tag = "" - abstract = xml_text[ - xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag) - ] + abstract = xml_text[xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)] # Get publication date pub_date = "" if "" in xml_text and "" in xml_text: start_tag = "" end_tag = "" - pub_date = xml_text[ - xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag) - ] + pub_date = xml_text[xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag)] # Return article as dictionary article = { @@ -182,6 +162,7 @@ def _run( class PubMedInput(BaseModel): query: str = Field(..., description="Search query.") + class PubMedSearchTool(BuiltinTool): """ Tool for performing a search using PubMed search engine. @@ -198,14 +179,13 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe Returns: ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation. """ - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input query') + return self.create_text_message("Please input query") tool = PubmedQueryRun(args_schema=PubMedInput) result = tool._run(query) return self.create_text_message(self.summary(user_id=user_id, content=result)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/qrcode/qrcode.py b/api/core/tools/provider/builtin/qrcode/qrcode.py index 9fa7d012657fd8..8466b9a26b42b6 100644 --- a/api/core/tools/provider/builtin/qrcode/qrcode.py +++ b/api/core/tools/provider/builtin/qrcode/qrcode.py @@ -8,9 +8,6 @@ class QRCodeProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - QRCodeGeneratorTool().invoke(user_id='', - tool_parameters={ - 'content': 'Dify 123 😊' - }) + QRCodeGeneratorTool().invoke(user_id="", tool_parameters={"content": "Dify 123 😊"}) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py index 5eede98f5eed6c..d8ca20bde6ffc9 100644 --- a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py +++ b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py @@ -13,43 +13,44 @@ class QRCodeGeneratorTool(BuiltinTool): error_correction_levels: dict[str, int] = { - 'L': ERROR_CORRECT_L, # <=7% - 'M': ERROR_CORRECT_M, # <=15% - 'Q': ERROR_CORRECT_Q, # <=25% - 'H': ERROR_CORRECT_H, # <=30% + "L": ERROR_CORRECT_L, # <=7% + "M": ERROR_CORRECT_M, # <=15% + "Q": ERROR_CORRECT_Q, # <=25% + "H": ERROR_CORRECT_H, # <=30% } - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get text content - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") # get border size - border = tool_parameters.get('border', 0) + border = tool_parameters.get("border", 0) if border < 0 or border > 100: - return self.create_text_message('Invalid parameter border') + return self.create_text_message("Invalid parameter border") # get error_correction - error_correction = tool_parameters.get('error_correction', '') - if error_correction not in self.error_correction_levels.keys(): - return self.create_text_message('Invalid parameter error_correction') + error_correction = tool_parameters.get("error_correction", "") + if error_correction not in self.error_correction_levels: + return self.create_text_message("Invalid parameter error_correction") try: image = self._generate_qrcode(content, border, error_correction) image_bytes = self._image_to_byte_array(image) - return self.create_blob_message(blob=image_bytes, - meta={'mime_type': 'image/png'}, - save_as=self.VARIABLE_KEY.IMAGE.value) + return self.create_blob_message( + blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value + ) except Exception: - logging.exception(f'Failed to generate QR code for content: {content}') - return self.create_text_message('Failed to generate QR code') + logging.exception(f"Failed to generate QR code for content: {content}") + return self.create_text_message("Failed to generate QR code") def _generate_qrcode(self, content: str, border: int, error_correction: str) -> BaseImage: qr = QRCode( diff --git a/api/core/tools/provider/builtin/regex/regex.py b/api/core/tools/provider/builtin/regex/regex.py index d38ae1b292675f..c498105979f13e 100644 --- a/api/core/tools/provider/builtin/regex/regex.py +++ b/api/core/tools/provider/builtin/regex/regex.py @@ -9,10 +9,10 @@ class RegexProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: RegexExpressionTool().invoke( - user_id='', + user_id="", tool_parameters={ - 'content': '1+(2+3)*4', - 'expression': r'(\d+)', + "content": "1+(2+3)*4", + "expression": r"(\d+)", }, ) except Exception as e: diff --git a/api/core/tools/provider/builtin/regex/tools/regex_extract.py b/api/core/tools/provider/builtin/regex/tools/regex_extract.py index 5d8f013d0d012c..786b4694040030 100644 --- a/api/core/tools/provider/builtin/regex/tools/regex_extract.py +++ b/api/core/tools/provider/builtin/regex/tools/regex_extract.py @@ -6,22 +6,23 @@ class RegexExpressionTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get expression - content = tool_parameters.get('content', '').strip() + content = tool_parameters.get("content", "").strip() if not content: - return self.create_text_message('Invalid content') - expression = tool_parameters.get('expression', '').strip() + return self.create_text_message("Invalid content") + expression = tool_parameters.get("expression", "").strip() if not expression: - return self.create_text_message('Invalid expression') + return self.create_text_message("Invalid expression") try: result = re.findall(expression, content) return self.create_text_message(str(result)) except Exception as e: - return self.create_text_message(f'Failed to extract result, error: {str(e)}') \ No newline at end of file + return self.create_text_message(f"Failed to extract result, error: {str(e)}") diff --git a/api/core/tools/provider/builtin/searchapi/searchapi.py b/api/core/tools/provider/builtin/searchapi/searchapi.py index 6fa4f05acd7b9d..109bba8b2d8f79 100644 --- a/api/core/tools/provider/builtin/searchapi/searchapi.py +++ b/api/core/tools/provider/builtin/searchapi/searchapi.py @@ -13,11 +13,8 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "query": "SearchApi dify", - "result_type": "link" - }, + user_id="", + tool_parameters={"query": "SearchApi dify", "result_type": "link"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/searchapi/tools/google.py b/api/core/tools/provider/builtin/searchapi/tools/google.py index dd780aeadcf36c..17e2978194c6a3 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google.py @@ -7,6 +7,7 @@ SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + class SearchAPI: """ SearchAPI tool provider. @@ -37,42 +38,45 @@ def get_params(self, query: str, **kwargs: Any) -> dict[str, str]: return { "engine": "google", "q": query, - **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + **{key: value for key, value in kwargs.items() if value not in {None, ""}}, } @staticmethod def _process_response(res: dict, type: str) -> str: """Process response from SearchAPI.""" - if "error" in res.keys(): + if "error" in res: raise ValueError(f"Got error from SearchApi: {res['error']}") toret = "" if type == "text": - if "answer_box" in res.keys() and "answer" in res["answer_box"].keys(): + if "answer_box" in res and "answer" in res["answer_box"]: toret += res["answer_box"]["answer"] + "\n" - if "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): + if "answer_box" in res and "snippet" in res["answer_box"]: toret += res["answer_box"]["snippet"] + "\n" - if "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys(): + if "knowledge_graph" in res and "description" in res["knowledge_graph"]: toret += res["knowledge_graph"]["description"] + "\n" - if "organic_results" in res.keys() and "snippet" in res["organic_results"][0].keys(): + if "organic_results" in res and "snippet" in res["organic_results"][0]: for item in res["organic_results"]: toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n" if toret == "": toret = "No good search result found" elif type == "link": - if "answer_box" in res.keys() and "organic_result" in res["answer_box"].keys(): - if "title" in res["answer_box"]["organic_result"].keys(): - toret = f"[{res['answer_box']['organic_result']['title']}]({res['answer_box']['organic_result']['link']})\n" - elif "organic_results" in res.keys() and "link" in res["organic_results"][0].keys(): + if "answer_box" in res and "organic_result" in res["answer_box"]: + if "title" in res["answer_box"]["organic_result"]: + toret = ( + f"[{res['answer_box']['organic_result']['title']}]" + f"({res['answer_box']['organic_result']['link']})\n" + ) + elif "organic_results" in res and "link" in res["organic_results"][0]: toret = "" for item in res["organic_results"]: toret += f"[{item['title']}]({item['link']})\n" - elif "related_questions" in res.keys() and "link" in res["related_questions"][0].keys(): + elif "related_questions" in res and "link" in res["related_questions"][0]: toret = "" for item in res["related_questions"]: toret += f"[{item['title']}]({item['link']})\n" - elif "related_searches" in res.keys() and "link" in res["related_searches"][0].keys(): + elif "related_searches" in res and "link" in res["related_searches"][0]: toret = "" for item in res["related_searches"]: toret += f"[{item['title']}]({item['link']})\n" @@ -80,25 +84,29 @@ def _process_response(res: dict, type: str) -> str: toret = "No good search result found" return toret + class GoogleTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the SearchApi tool. """ - query = tool_parameters['query'] - result_type = tool_parameters['result_type'] + query = tool_parameters["query"] + result_type = tool_parameters["result_type"] num = tool_parameters.get("num", 10) google_domain = tool_parameters.get("google_domain", "google.com") gl = tool_parameters.get("gl", "us") hl = tool_parameters.get("hl", "en") location = tool_parameters.get("location") - api_key = self.runtime.credentials['searchapi_api_key'] - result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location) + api_key = self.runtime.credentials["searchapi_api_key"] + result = SearchAPI(api_key).run( + query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location + ) - if result_type == 'text': + if result_type == "text": return self.create_text_message(text=result) return self.create_link_message(link=result) diff --git a/api/core/tools/provider/builtin/searchapi/tools/google.yaml b/api/core/tools/provider/builtin/searchapi/tools/google.yaml index b69a0e1d3e706b..0dc1b6672436cd 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google.yaml +++ b/api/core/tools/provider/builtin/searchapi/tools/google.yaml @@ -65,206 +65,1206 @@ parameters: form: form default: US options: + - value: AF + label: + en_US: Afghanistan + zh_Hans: 阿富汗 + pt_BR: Afeganistão + - value: AL + label: + en_US: Albania + zh_Hans: 阿尔巴尼亚 + pt_BR: Albânia + - value: DZ + label: + en_US: Algeria + zh_Hans: 阿尔及利亚 + pt_BR: Argélia + - value: AS + label: + en_US: American Samoa + zh_Hans: 美属萨摩亚 + pt_BR: Samoa Americana + - value: AD + label: + en_US: Andorra + zh_Hans: 安道尔 + pt_BR: Andorra + - value: AO + label: + en_US: Angola + zh_Hans: 安哥拉 + pt_BR: Angola + - value: AI + label: + en_US: Anguilla + zh_Hans: 安圭拉 + pt_BR: Anguilla + - value: AQ + label: + en_US: Antarctica + zh_Hans: 南极洲 + pt_BR: Antártica + - value: AG + label: + en_US: Antigua and Barbuda + zh_Hans: 安提瓜和巴布达 + pt_BR: Antígua e Barbuda - value: AR label: en_US: Argentina zh_Hans: 阿根廷 pt_BR: Argentina + - value: AM + label: + en_US: Armenia + zh_Hans: 亚美尼亚 + pt_BR: Armênia + - value: AW + label: + en_US: Aruba + zh_Hans: 阿鲁巴 + pt_BR: Aruba - value: AU label: en_US: Australia zh_Hans: 澳大利亚 - pt_BR: Australia + pt_BR: Austrália - value: AT label: en_US: Austria zh_Hans: 奥地利 - pt_BR: Austria + pt_BR: Áustria + - value: AZ + label: + en_US: Azerbaijan + zh_Hans: 阿塞拜疆 + pt_BR: Azerbaijão + - value: BS + label: + en_US: Bahamas + zh_Hans: 巴哈马 + pt_BR: Bahamas + - value: BH + label: + en_US: Bahrain + zh_Hans: 巴林 + pt_BR: Bahrein + - value: BD + label: + en_US: Bangladesh + zh_Hans: 孟加拉国 + pt_BR: Bangladesh + - value: BB + label: + en_US: Barbados + zh_Hans: 巴巴多斯 + pt_BR: Barbados + - value: BY + label: + en_US: Belarus + zh_Hans: 白俄罗斯 + pt_BR: Bielorrússia - value: BE label: en_US: Belgium zh_Hans: 比利时 - pt_BR: Belgium + pt_BR: Bélgica + - value: BZ + label: + en_US: Belize + zh_Hans: 伯利兹 + pt_BR: Belize + - value: BJ + label: + en_US: Benin + zh_Hans: 贝宁 + pt_BR: Benim + - value: BM + label: + en_US: Bermuda + zh_Hans: 百慕大 + pt_BR: Bermudas + - value: BT + label: + en_US: Bhutan + zh_Hans: 不丹 + pt_BR: Butão + - value: BO + label: + en_US: Bolivia + zh_Hans: 玻利维亚 + pt_BR: Bolívia + - value: BA + label: + en_US: Bosnia and Herzegovina + zh_Hans: 波斯尼亚和黑塞哥维那 + pt_BR: Bósnia e Herzegovina + - value: BW + label: + en_US: Botswana + zh_Hans: 博茨瓦纳 + pt_BR: Botsuana + - value: BV + label: + en_US: Bouvet Island + zh_Hans: 布韦岛 + pt_BR: Ilha Bouvet - value: BR label: en_US: Brazil zh_Hans: 巴西 - pt_BR: Brazil + pt_BR: Brasil + - value: IO + label: + en_US: British Indian Ocean Territory + zh_Hans: 英属印度洋领地 + pt_BR: Território Britânico do Oceano Índico + - value: BN + label: + en_US: Brunei Darussalam + zh_Hans: 文莱 + pt_BR: Brunei Darussalam + - value: BG + label: + en_US: Bulgaria + zh_Hans: 保加利亚 + pt_BR: Bulgária + - value: BF + label: + en_US: Burkina Faso + zh_Hans: 布基纳法索 + pt_BR: Burkina Faso + - value: BI + label: + en_US: Burundi + zh_Hans: 布隆迪 + pt_BR: Burundi + - value: KH + label: + en_US: Cambodia + zh_Hans: 柬埔寨 + pt_BR: Camboja + - value: CM + label: + en_US: Cameroon + zh_Hans: 喀麦隆 + pt_BR: Camarões - value: CA label: en_US: Canada zh_Hans: 加拿大 - pt_BR: Canada + pt_BR: Canadá + - value: CV + label: + en_US: Cape Verde + zh_Hans: 佛得角 + pt_BR: Cabo Verde + - value: KY + label: + en_US: Cayman Islands + zh_Hans: 开曼群岛 + pt_BR: Ilhas Cayman + - value: CF + label: + en_US: Central African Republic + zh_Hans: 中非共和国 + pt_BR: República Centro-Africana + - value: TD + label: + en_US: Chad + zh_Hans: 乍得 + pt_BR: Chade - value: CL label: en_US: Chile zh_Hans: 智利 pt_BR: Chile - - value: CO - label: - en_US: Colombia - zh_Hans: 哥伦比亚 - pt_BR: Colombia - value: CN label: en_US: China zh_Hans: 中国 pt_BR: China + - value: CX + label: + en_US: Christmas Island + zh_Hans: 圣诞岛 + pt_BR: Ilha do Natal + - value: CC + label: + en_US: Cocos (Keeling) Islands + zh_Hans: 科科斯(基林)群岛 + pt_BR: Ilhas Cocos (Keeling) + - value: CO + label: + en_US: Colombia + zh_Hans: 哥伦比亚 + pt_BR: Colômbia + - value: KM + label: + en_US: Comoros + zh_Hans: 科摩罗 + pt_BR: Comores + - value: CG + label: + en_US: Congo + zh_Hans: 刚果 + pt_BR: Congo + - value: CD + label: + en_US: Congo, the Democratic Republic of the + zh_Hans: 刚果民主共和国 + pt_BR: Congo, República Democrática do + - value: CK + label: + en_US: Cook Islands + zh_Hans: 库克群岛 + pt_BR: Ilhas Cook + - value: CR + label: + en_US: Costa Rica + zh_Hans: 哥斯达黎加 + pt_BR: Costa Rica + - value: CI + label: + en_US: Cote D'ivoire + zh_Hans: 科特迪瓦 + pt_BR: Costa do Marfim + - value: HR + label: + en_US: Croatia + zh_Hans: 克罗地亚 + pt_BR: Croácia + - value: CU + label: + en_US: Cuba + zh_Hans: 古巴 + pt_BR: Cuba + - value: CY + label: + en_US: Cyprus + zh_Hans: 塞浦路斯 + pt_BR: Chipre - value: CZ label: en_US: Czech Republic zh_Hans: 捷克共和国 - pt_BR: Czech Republic + pt_BR: República Tcheca - value: DK label: en_US: Denmark zh_Hans: 丹麦 - pt_BR: Denmark + pt_BR: Dinamarca + - value: DJ + label: + en_US: Djibouti + zh_Hans: 吉布提 + pt_BR: Djibuti + - value: DM + label: + en_US: Dominica + zh_Hans: 多米尼克 + pt_BR: Dominica + - value: DO + label: + en_US: Dominican Republic + zh_Hans: 多米尼加共和国 + pt_BR: República Dominicana + - value: EC + label: + en_US: Ecuador + zh_Hans: 厄瓜多尔 + pt_BR: Equador + - value: EG + label: + en_US: Egypt + zh_Hans: 埃及 + pt_BR: Egito + - value: SV + label: + en_US: El Salvador + zh_Hans: 萨尔瓦多 + pt_BR: El Salvador + - value: GQ + label: + en_US: Equatorial Guinea + zh_Hans: 赤道几内亚 + pt_BR: Guiné Equatorial + - value: ER + label: + en_US: Eritrea + zh_Hans: 厄立特里亚 + pt_BR: Eritreia + - value: EE + label: + en_US: Estonia + zh_Hans: 爱沙尼亚 + pt_BR: Estônia + - value: ET + label: + en_US: Ethiopia + zh_Hans: 埃塞俄比亚 + pt_BR: Etiópia + - value: FK + label: + en_US: Falkland Islands (Malvinas) + zh_Hans: 福克兰群岛(马尔维纳斯) + pt_BR: Ilhas Falkland (Malvinas) + - value: FO + label: + en_US: Faroe Islands + zh_Hans: 法罗群岛 + pt_BR: Ilhas Faroe + - value: FJ + label: + en_US: Fiji + zh_Hans: 斐济 + pt_BR: Fiji - value: FI label: en_US: Finland zh_Hans: 芬兰 - pt_BR: Finland + pt_BR: Finlândia - value: FR label: en_US: France zh_Hans: 法国 - pt_BR: France + pt_BR: França + - value: GF + label: + en_US: French Guiana + zh_Hans: 法属圭亚那 + pt_BR: Guiana Francesa + - value: PF + label: + en_US: French Polynesia + zh_Hans: 法属波利尼西亚 + pt_BR: Polinésia Francesa + - value: TF + label: + en_US: French Southern Territories + zh_Hans: 法属南部领地 + pt_BR: Territórios Franceses do Sul + - value: GA + label: + en_US: Gabon + zh_Hans: 加蓬 + pt_BR: Gabão + - value: GM + label: + en_US: Gambia + zh_Hans: 冈比亚 + pt_BR: Gâmbia + - value: GE + label: + en_US: Georgia + zh_Hans: 格鲁吉亚 + pt_BR: Geórgia - value: DE label: en_US: Germany zh_Hans: 德国 - pt_BR: Germany + pt_BR: Alemanha + - value: GH + label: + en_US: Ghana + zh_Hans: 加纳 + pt_BR: Gana + - value: GI + label: + en_US: Gibraltar + zh_Hans: 直布罗陀 + pt_BR: Gibraltar + - value: GR + label: + en_US: Greece + zh_Hans: 希腊 + pt_BR: Grécia + - value: GL + label: + en_US: Greenland + zh_Hans: 格陵兰 + pt_BR: Groenlândia + - value: GD + label: + en_US: Grenada + zh_Hans: 格林纳达 + pt_BR: Granada + - value: GP + label: + en_US: Guadeloupe + zh_Hans: 瓜德罗普 + pt_BR: Guadalupe + - value: GU + label: + en_US: Guam + zh_Hans: 关岛 + pt_BR: Guam + - value: GT + label: + en_US: Guatemala + zh_Hans: 危地马拉 + pt_BR: Guatemala + - value: GN + label: + en_US: Guinea + zh_Hans: 几内亚 + pt_BR: Guiné + - value: GW + label: + en_US: Guinea-Bissau + zh_Hans: 几内亚比绍 + pt_BR: Guiné-Bissau + - value: GY + label: + en_US: Guyana + zh_Hans: 圭亚那 + pt_BR: Guiana + - value: HT + label: + en_US: Haiti + zh_Hans: 海地 + pt_BR: Haiti + - value: HM + label: + en_US: Heard Island and McDonald Islands + zh_Hans: 赫德岛和麦克唐纳群岛 + pt_BR: Ilha Heard e Ilhas McDonald + - value: VA + label: + en_US: Holy See (Vatican City State) + zh_Hans: 教廷(梵蒂冈城国) + pt_BR: Santa Sé (Estado da Cidade do Vaticano) + - value: HN + label: + en_US: Honduras + zh_Hans: 洪都拉斯 + pt_BR: Honduras - value: HK label: en_US: Hong Kong zh_Hans: 香港 pt_BR: Hong Kong + - value: HU + label: + en_US: Hungary + zh_Hans: 匈牙利 + pt_BR: Hungria + - value: IS + label: + en_US: Iceland + zh_Hans: 冰岛 + pt_BR: Islândia - value: IN label: en_US: India zh_Hans: 印度 - pt_BR: India + pt_BR: Índia - value: ID label: en_US: Indonesia zh_Hans: 印度尼西亚 - pt_BR: Indonesia + pt_BR: Indonésia + - value: IR + label: + en_US: Iran, Islamic Republic of + zh_Hans: 伊朗 + pt_BR: Irã + - value: IQ + label: + en_US: Iraq + zh_Hans: 伊拉克 + pt_BR: Iraque + - value: IE + label: + en_US: Ireland + zh_Hans: 爱尔兰 + pt_BR: Irlanda + - value: IL + label: + en_US: Israel + zh_Hans: 以色列 + pt_BR: Israel - value: IT label: en_US: Italy zh_Hans: 意大利 - pt_BR: Italy + pt_BR: Itália + - value: JM + label: + en_US: Jamaica + zh_Hans: 牙买加 + pt_BR: Jamaica - value: JP label: en_US: Japan zh_Hans: 日本 - pt_BR: Japan + pt_BR: Japão + - value: JO + label: + en_US: Jordan + zh_Hans: 约旦 + pt_BR: Jordânia + - value: KZ + label: + en_US: Kazakhstan + zh_Hans: 哈萨克斯坦 + pt_BR: Cazaquistão + - value: KE + label: + en_US: Kenya + zh_Hans: 肯尼亚 + pt_BR: Quênia + - value: KI + label: + en_US: Kiribati + zh_Hans: 基里巴斯 + pt_BR: Kiribati + - value: KP + label: + en_US: Korea, Democratic People's Republic of + zh_Hans: 朝鲜 + pt_BR: Coreia, República Democrática Popular da - value: KR label: - en_US: Korea + en_US: Korea, Republic of zh_Hans: 韩国 - pt_BR: Korea + pt_BR: Coreia, República da + - value: KW + label: + en_US: Kuwait + zh_Hans: 科威特 + pt_BR: Kuwait + - value: KG + label: + en_US: Kyrgyzstan + zh_Hans: 吉尔吉斯斯坦 + pt_BR: Quirguistão + - value: LA + label: + en_US: Lao People's Democratic Republic + zh_Hans: 老挝 + pt_BR: República Democrática Popular do Laos + - value: LV + label: + en_US: Latvia + zh_Hans: 拉脱维亚 + pt_BR: Letônia + - value: LB + label: + en_US: Lebanon + zh_Hans: 黎巴嫩 + pt_BR: Líbano + - value: LS + label: + en_US: Lesotho + zh_Hans: 莱索托 + pt_BR: Lesoto + - value: LR + label: + en_US: Liberia + zh_Hans: 利比里亚 + pt_BR: Libéria + - value: LY + label: + en_US: Libyan Arab Jamahiriya + zh_Hans: 利比亚 + pt_BR: Líbia + - value: LI + label: + en_US: Liechtenstein + zh_Hans: 列支敦士登 + pt_BR: Liechtenstein + - value: LT + label: + en_US: Lithuania + zh_Hans: 立陶宛 + pt_BR: Lituânia + - value: LU + label: + en_US: Luxembourg + zh_Hans: 卢森堡 + pt_BR: Luxemburgo + - value: MO + label: + en_US: Macao + zh_Hans: 澳门 + pt_BR: Macau + - value: MK + label: + en_US: Macedonia, the Former Yugosalv Republic of + zh_Hans: 前南斯拉夫马其顿共和国 + pt_BR: Macedônia, Ex-República Iugoslava da + - value: MG + label: + en_US: Madagascar + zh_Hans: 马达加斯加 + pt_BR: Madagascar + - value: MW + label: + en_US: Malawi + zh_Hans: 马拉维 + pt_BR: Malaui - value: MY label: en_US: Malaysia zh_Hans: 马来西亚 - pt_BR: Malaysia + pt_BR: Malásia + - value: MV + label: + en_US: Maldives + zh_Hans: 马尔代夫 + pt_BR: Maldivas + - value: ML + label: + en_US: Mali + zh_Hans: 马里 + pt_BR: Mali + - value: MT + label: + en_US: Malta + zh_Hans: 马耳他 + pt_BR: Malta + - value: MH + label: + en_US: Marshall Islands + zh_Hans: 马绍尔群岛 + pt_BR: Ilhas Marshall + - value: MQ + label: + en_US: Martinique + zh_Hans: 马提尼克 + pt_BR: Martinica + - value: MR + label: + en_US: Mauritania + zh_Hans: 毛里塔尼亚 + pt_BR: Mauritânia + - value: MU + label: + en_US: Mauritius + zh_Hans: 毛里求斯 + pt_BR: Maurício + - value: YT + label: + en_US: Mayotte + zh_Hans: 马约特 + pt_BR: Mayotte - value: MX label: en_US: Mexico zh_Hans: 墨西哥 - pt_BR: Mexico + pt_BR: México + - value: FM + label: + en_US: Micronesia, Federated States of + zh_Hans: 密克罗尼西亚联邦 + pt_BR: Micronésia, Estados Federados da + - value: MD + label: + en_US: Moldova, Republic of + zh_Hans: 摩尔多瓦共和国 + pt_BR: Moldávia, República da + - value: MC + label: + en_US: Monaco + zh_Hans: 摩纳哥 + pt_BR: Mônaco + - value: MN + label: + en_US: Mongolia + zh_Hans: 蒙古 + pt_BR: Mongólia + - value: MS + label: + en_US: Montserrat + zh_Hans: 蒙特塞拉特 + pt_BR: Montserrat + - value: MA + label: + en_US: Morocco + zh_Hans: 摩洛哥 + pt_BR: Marrocos + - value: MZ + label: + en_US: Mozambique + zh_Hans: 莫桑比克 + pt_BR: Moçambique + - value: MM + label: + en_US: Myanmar + zh_Hans: 缅甸 + pt_BR: Mianmar + - value: NA + label: + en_US: Namibia + zh_Hans: 纳米比亚 + pt_BR: Namíbia + - value: NR + label: + en_US: Nauru + zh_Hans: 瑙鲁 + pt_BR: Nauru + - value: NP + label: + en_US: Nepal + zh_Hans: 尼泊尔 + pt_BR: Nepal - value: NL label: en_US: Netherlands zh_Hans: 荷兰 - pt_BR: Netherlands + pt_BR: Países Baixos + - value: AN + label: + en_US: Netherlands Antilles + zh_Hans: 荷属安的列斯 + pt_BR: Antilhas Holandesas + - value: NC + label: + en_US: New Caledonia + zh_Hans: 新喀里多尼亚 + pt_BR: Nova Caledônia - value: NZ label: en_US: New Zealand zh_Hans: 新西兰 - pt_BR: New Zealand - - value: 'NO' + pt_BR: Nova Zelândia + - value: NI + label: + en_US: Nicaragua + zh_Hans: 尼加拉瓜 + pt_BR: Nicarágua + - value: NE + label: + en_US: Niger + zh_Hans: 尼日尔 + pt_BR: Níger + - value: NG + label: + en_US: Nigeria + zh_Hans: 尼日利亚 + pt_BR: Nigéria + - value: NU + label: + en_US: Niue + zh_Hans: 纽埃 + pt_BR: Niue + - value: NF + label: + en_US: Norfolk Island + zh_Hans: 诺福克岛 + pt_BR: Ilha Norfolk + - value: MP + label: + en_US: Northern Mariana Islands + zh_Hans: 北马里亚纳群岛 + pt_BR: Ilhas Marianas do Norte + - value: "NO" label: en_US: Norway zh_Hans: 挪威 - pt_BR: Norway + pt_BR: Noruega + - value: OM + label: + en_US: Oman + zh_Hans: 阿曼 + pt_BR: Omã + - value: PK + label: + en_US: Pakistan + zh_Hans: 巴基斯坦 + pt_BR: Paquistão + - value: PW + label: + en_US: Palau + zh_Hans: 帕劳 + pt_BR: Palau + - value: PS + label: + en_US: Palestinian Territory, Occupied + zh_Hans: 巴勒斯坦领土 + pt_BR: Palestina, Território Ocupado + - value: PA + label: + en_US: Panama + zh_Hans: 巴拿马 + pt_BR: Panamá + - value: PG + label: + en_US: Papua New Guinea + zh_Hans: 巴布亚新几内亚 + pt_BR: Papua Nova Guiné + - value: PY + label: + en_US: Paraguay + zh_Hans: 巴拉圭 + pt_BR: Paraguai + - value: PE + label: + en_US: Peru + zh_Hans: 秘鲁 + pt_BR: Peru - value: PH label: en_US: Philippines zh_Hans: 菲律宾 - pt_BR: Philippines + pt_BR: Filipinas + - value: PN + label: + en_US: Pitcairn + zh_Hans: 皮特凯恩岛 + pt_BR: Pitcairn - value: PL label: en_US: Poland zh_Hans: 波兰 - pt_BR: Poland + pt_BR: Polônia - value: PT label: en_US: Portugal zh_Hans: 葡萄牙 pt_BR: Portugal + - value: PR + label: + en_US: Puerto Rico + zh_Hans: 波多黎各 + pt_BR: Porto Rico + - value: QA + label: + en_US: Qatar + zh_Hans: 卡塔尔 + pt_BR: Catar + - value: RE + label: + en_US: Reunion + zh_Hans: 留尼旺 + pt_BR: Reunião + - value: RO + label: + en_US: Romania + zh_Hans: 罗马尼亚 + pt_BR: Romênia - value: RU label: - en_US: Russia - zh_Hans: 俄罗斯 - pt_BR: Russia + en_US: Russian Federation + zh_Hans: 俄罗斯联邦 + pt_BR: Rússia + - value: RW + label: + en_US: Rwanda + zh_Hans: 卢旺达 + pt_BR: Ruanda + - value: SH + label: + en_US: Saint Helena + zh_Hans: 圣赫勒拿 + pt_BR: Santa Helena + - value: KN + label: + en_US: Saint Kitts and Nevis + zh_Hans: 圣基茨和尼维斯 + pt_BR: São Cristóvão e Nevis + - value: LC + label: + en_US: Saint Lucia + zh_Hans: 圣卢西亚 + pt_BR: Santa Lúcia + - value: PM + label: + en_US: Saint Pierre and Miquelon + zh_Hans: 圣皮埃尔和密克隆 + pt_BR: São Pedro e Miquelon + - value: VC + label: + en_US: Saint Vincent and the Grenadines + zh_Hans: 圣文森特和格林纳丁斯 + pt_BR: São Vicente e Granadinas + - value: WS + label: + en_US: Samoa + zh_Hans: 萨摩亚 + pt_BR: Samoa + - value: SM + label: + en_US: San Marino + zh_Hans: 圣马力诺 + pt_BR: San Marino + - value: ST + label: + en_US: Sao Tome and Principe + zh_Hans: 圣多美和普林西比 + pt_BR: São Tomé e Príncipe - value: SA label: en_US: Saudi Arabia zh_Hans: 沙特阿拉伯 - pt_BR: Saudi Arabia + pt_BR: Arábia Saudita + - value: SN + label: + en_US: Senegal + zh_Hans: 塞内加尔 + pt_BR: Senegal + - value: RS + label: + en_US: Serbia and Montenegro + zh_Hans: 塞尔维亚和黑山 + pt_BR: Sérvia e Montenegro + - value: SC + label: + en_US: Seychelles + zh_Hans: 塞舌尔 + pt_BR: Seicheles + - value: SL + label: + en_US: Sierra Leone + zh_Hans: 塞拉利昂 + pt_BR: Serra Leoa - value: SG label: en_US: Singapore zh_Hans: 新加坡 - pt_BR: Singapore + pt_BR: Singapura + - value: SK + label: + en_US: Slovakia + zh_Hans: 斯洛伐克 + pt_BR: Eslováquia + - value: SI + label: + en_US: Slovenia + zh_Hans: 斯洛文尼亚 + pt_BR: Eslovênia + - value: SB + label: + en_US: Solomon Islands + zh_Hans: 所罗门群岛 + pt_BR: Ilhas Salomão + - value: SO + label: + en_US: Somalia + zh_Hans: 索马里 + pt_BR: Somália - value: ZA label: en_US: South Africa zh_Hans: 南非 - pt_BR: South Africa + pt_BR: África do Sul + - value: GS + label: + en_US: South Georgia and the South Sandwich Islands + zh_Hans: 南乔治亚和南桑威奇群岛 + pt_BR: Geórgia do Sul e Ilhas Sandwich do Sul - value: ES label: en_US: Spain zh_Hans: 西班牙 - pt_BR: Spain + pt_BR: Espanha + - value: LK + label: + en_US: Sri Lanka + zh_Hans: 斯里兰卡 + pt_BR: Sri Lanka + - value: SD + label: + en_US: Sudan + zh_Hans: 苏丹 + pt_BR: Sudão + - value: SR + label: + en_US: Suriname + zh_Hans: 苏里南 + pt_BR: Suriname + - value: SJ + label: + en_US: Svalbard and Jan Mayen + zh_Hans: 斯瓦尔巴特和扬马延岛 + pt_BR: Svalbard e Jan Mayen + - value: SZ + label: + en_US: Swaziland + zh_Hans: 斯威士兰 + pt_BR: Essuatíni - value: SE label: en_US: Sweden zh_Hans: 瑞典 - pt_BR: Sweden + pt_BR: Suécia - value: CH label: en_US: Switzerland zh_Hans: 瑞士 - pt_BR: Switzerland + pt_BR: Suíça + - value: SY + label: + en_US: Syrian Arab Republic + zh_Hans: 叙利亚 + pt_BR: Síria - value: TW label: - en_US: Taiwan + en_US: Taiwan, Province of China zh_Hans: 台湾 pt_BR: Taiwan + - value: TJ + label: + en_US: Tajikistan + zh_Hans: 塔吉克斯坦 + pt_BR: Tajiquistão + - value: TZ + label: + en_US: Tanzania, United Republic of + zh_Hans: 坦桑尼亚联合共和国 + pt_BR: Tanzânia - value: TH label: en_US: Thailand zh_Hans: 泰国 - pt_BR: Thailand + pt_BR: Tailândia + - value: TL + label: + en_US: Timor-Leste + zh_Hans: 东帝汶 + pt_BR: Timor-Leste + - value: TG + label: + en_US: Togo + zh_Hans: 多哥 + pt_BR: Togo + - value: TK + label: + en_US: Tokelau + zh_Hans: 托克劳 + pt_BR: Toquelau + - value: TO + label: + en_US: Tonga + zh_Hans: 汤加 + pt_BR: Tonga + - value: TT + label: + en_US: Trinidad and Tobago + zh_Hans: 特立尼达和多巴哥 + pt_BR: Trindade e Tobago + - value: TN + label: + en_US: Tunisia + zh_Hans: 突尼斯 + pt_BR: Tunísia - value: TR label: en_US: Turkey zh_Hans: 土耳其 - pt_BR: Turkey + pt_BR: Turquia + - value: TM + label: + en_US: Turkmenistan + zh_Hans: 土库曼斯坦 + pt_BR: Turcomenistão + - value: TC + label: + en_US: Turks and Caicos Islands + zh_Hans: 特克斯和凯科斯群岛 + pt_BR: Ilhas Turks e Caicos + - value: TV + label: + en_US: Tuvalu + zh_Hans: 图瓦卢 + pt_BR: Tuvalu + - value: UG + label: + en_US: Uganda + zh_Hans: 乌干达 + pt_BR: Uganda + - value: UA + label: + en_US: Ukraine + zh_Hans: 乌克兰 + pt_BR: Ucrânia + - value: AE + label: + en_US: United Arab Emirates + zh_Hans: 阿联酋 + pt_BR: Emirados Árabes Unidos + - value: UK + label: + en_US: United Kingdom + zh_Hans: 英国 + pt_BR: Reino Unido - value: GB label: en_US: United Kingdom zh_Hans: 英国 - pt_BR: United Kingdom + pt_BR: Reino Unido - value: US label: en_US: United States zh_Hans: 美国 - pt_BR: United States + pt_BR: Estados Unidos + - value: UM + label: + en_US: United States Minor Outlying Islands + zh_Hans: 美国本土外小岛屿 + pt_BR: Ilhas Menores Distantes dos Estados Unidos + - value: UY + label: + en_US: Uruguay + zh_Hans: 乌拉圭 + pt_BR: Uruguai + - value: UZ + label: + en_US: Uzbekistan + zh_Hans: 乌兹别克斯坦 + pt_BR: Uzbequistão + - value: VU + label: + en_US: Vanuatu + zh_Hans: 瓦努阿图 + pt_BR: Vanuatu + - value: VE + label: + en_US: Venezuela + zh_Hans: 委内瑞拉 + pt_BR: Venezuela + - value: VN + label: + en_US: Viet Nam + zh_Hans: 越南 + pt_BR: Vietnã + - value: VG + label: + en_US: Virgin Islands, British + zh_Hans: 英属维尔京群岛 + pt_BR: Ilhas Virgens Britânicas + - value: VI + label: + en_US: Virgin Islands, U.S. + zh_Hans: 美属维尔京群岛 + pt_BR: Ilhas Virgens dos EUA + - value: WF + label: + en_US: Wallis and Futuna + zh_Hans: 瓦利斯和富图纳群岛 + pt_BR: Wallis e Futuna + - value: EH + label: + en_US: Western Sahara + zh_Hans: 西撒哈拉 + pt_BR: Saara Ocidental + - value: YE + label: + en_US: Yemen + zh_Hans: 也门 + pt_BR: Iémen + - value: ZM + label: + en_US: Zambia + zh_Hans: 赞比亚 + pt_BR: Zâmbia + - value: ZW + label: + en_US: Zimbabwe + zh_Hans: 津巴布韦 + pt_BR: Zimbábue - name: hl type: select label: @@ -277,18 +1277,94 @@ parameters: default: en form: form options: + - value: af + label: + en_US: Afrikaans + zh_Hans: 南非语 + - value: ak + label: + en_US: Akan + zh_Hans: 阿坎语 + - value: sq + label: + en_US: Albanian + zh_Hans: 阿尔巴尼亚语 + - value: ws + label: + en_US: Samoa + zh_Hans: 萨摩亚语 + - value: am + label: + en_US: Amharic + zh_Hans: 阿姆哈拉语 - value: ar label: en_US: Arabic zh_Hans: 阿拉伯语 + - value: hy + label: + en_US: Armenian + zh_Hans: 亚美尼亚语 + - value: az + label: + en_US: Azerbaijani + zh_Hans: 阿塞拜疆语 + - value: eu + label: + en_US: Basque + zh_Hans: 巴斯克语 + - value: be + label: + en_US: Belarusian + zh_Hans: 白俄罗斯语 + - value: bem + label: + en_US: Bemba + zh_Hans: 班巴语 + - value: bn + label: + en_US: Bengali + zh_Hans: 孟加拉语 + - value: bh + label: + en_US: Bihari + zh_Hans: 比哈尔语 + - value: xx-bork + label: + en_US: Bork, bork, bork! + zh_Hans: 博克语 + - value: bs + label: + en_US: Bosnian + zh_Hans: 波斯尼亚语 + - value: br + label: + en_US: Breton + zh_Hans: 布列塔尼语 - value: bg label: en_US: Bulgarian zh_Hans: 保加利亚语 + - value: bt + label: + en_US: Bhutanese + zh_Hans: 不丹语 + - value: km + label: + en_US: Cambodian + zh_Hans: 高棉语 - value: ca label: en_US: Catalan zh_Hans: 加泰罗尼亚语 + - value: chr + label: + en_US: Cherokee + zh_Hans: 切罗基语 + - value: ny + label: + en_US: Chichewa + zh_Hans: 齐切瓦语 - value: zh-cn label: en_US: Chinese (Simplified) @@ -297,6 +1373,14 @@ parameters: label: en_US: Chinese (Traditional) zh_Hans: 中文(繁体) + - value: co + label: + en_US: Corsican + zh_Hans: 科西嘉语 + - value: hr + label: + en_US: Croatian + zh_Hans: 克罗地亚语 - value: cs label: en_US: Czech @@ -309,14 +1393,34 @@ parameters: label: en_US: Dutch zh_Hans: 荷兰语 + - value: xx-elmer + label: + en_US: Elmer Fudd + zh_Hans: 艾尔默福德语 - value: en label: en_US: English zh_Hans: 英语 + - value: eo + label: + en_US: Esperanto + zh_Hans: 世界语 - value: et label: en_US: Estonian zh_Hans: 爱沙尼亚语 + - value: ee + label: + en_US: Ewe + zh_Hans: 埃维语 + - value: fo + label: + en_US: Faroese + zh_Hans: 法罗语 + - value: tl + label: + en_US: Filipino + zh_Hans: 菲律宾语 - value: fi label: en_US: Finnish @@ -325,6 +1429,22 @@ parameters: label: en_US: French zh_Hans: 法语 + - value: fy + label: + en_US: Frisian + zh_Hans: 弗里西亚语 + - value: gaa + label: + en_US: Ga + zh_Hans: 加语 + - value: gl + label: + en_US: Galician + zh_Hans: 加利西亚语 + - value: ka + label: + en_US: Georgian + zh_Hans: 格鲁吉亚语 - value: de label: en_US: German @@ -333,6 +1453,34 @@ parameters: label: en_US: Greek zh_Hans: 希腊语 + - value: kl + label: + en_US: Greenlandic + zh_Hans: 格陵兰语 + - value: gn + label: + en_US: Guarani + zh_Hans: 瓜拉尼语 + - value: gu + label: + en_US: Gujarati + zh_Hans: 古吉拉特语 + - value: xx-hacker + label: + en_US: Hacker + zh_Hans: 黑客语 + - value: ht + label: + en_US: Haitian Creole + zh_Hans: 海地克里奥尔语 + - value: ha + label: + en_US: Hausa + zh_Hans: 豪萨语 + - value: haw + label: + en_US: Hawaiian + zh_Hans: 夏威夷语 - value: iw label: en_US: Hebrew @@ -345,10 +1493,26 @@ parameters: label: en_US: Hungarian zh_Hans: 匈牙利语 + - value: is + label: + en_US: Icelandic + zh_Hans: 冰岛语 + - value: ig + label: + en_US: Igbo + zh_Hans: 伊博语 - value: id label: en_US: Indonesian zh_Hans: 印尼语 + - value: ia + label: + en_US: Interlingua + zh_Hans: 国际语 + - value: ga + label: + en_US: Irish + zh_Hans: 爱尔兰语 - value: it label: en_US: Italian @@ -357,22 +1521,94 @@ parameters: label: en_US: Japanese zh_Hans: 日语 + - value: jw + label: + en_US: Javanese + zh_Hans: 爪哇语 - value: kn label: en_US: Kannada zh_Hans: 卡纳达语 + - value: kk + label: + en_US: Kazakh + zh_Hans: 哈萨克语 + - value: rw + label: + en_US: Kinyarwanda + zh_Hans: 基尼亚卢旺达语 + - value: rn + label: + en_US: Kirundi + zh_Hans: 基隆迪语 + - value: xx-klingon + label: + en_US: Klingon + zh_Hans: 克林贡语 + - value: kg + label: + en_US: Kongo + zh_Hans: 刚果语 - value: ko label: en_US: Korean zh_Hans: 韩语 + - value: kri + label: + en_US: Krio (Sierra Leone) + zh_Hans: 塞拉利昂克里奥尔语 + - value: ku + label: + en_US: Kurdish + zh_Hans: 库尔德语 + - value: ckb + label: + en_US: Kurdish (Soranî) + zh_Hans: 库尔德语(索拉尼) + - value: ky + label: + en_US: Kyrgyz + zh_Hans: 吉尔吉斯语 + - value: lo + label: + en_US: Laothian + zh_Hans: 老挝语 + - value: la + label: + en_US: Latin + zh_Hans: 拉丁语 - value: lv label: en_US: Latvian zh_Hans: 拉脱维亚语 + - value: ln + label: + en_US: Lingala + zh_Hans: 林加拉语 - value: lt label: en_US: Lithuanian zh_Hans: 立陶宛语 + - value: loz + label: + en_US: Lozi + zh_Hans: 洛齐语 + - value: lg + label: + en_US: Luganda + zh_Hans: 卢干达语 + - value: ach + label: + en_US: Luo + zh_Hans: 卢奥语 + - value: mk + label: + en_US: Macedonian + zh_Hans: 马其顿语 + - value: mg + label: + en_US: Malagasy + zh_Hans: 马尔加什语 - value: my label: en_US: Malay @@ -381,18 +1617,90 @@ parameters: label: en_US: Malayalam zh_Hans: 马拉雅拉姆语 + - value: mt + label: + en_US: Maltese + zh_Hans: 马耳他语 + - value: mv + label: + en_US: Maldives + zh_Hans: 马尔代夫语 + - value: mi + label: + en_US: Maori + zh_Hans: 毛利语 - value: mr label: en_US: Marathi zh_Hans: 马拉地语 + - value: mfe + label: + en_US: Mauritian Creole + zh_Hans: 毛里求斯克里奥尔语 + - value: mo + label: + en_US: Moldavian + zh_Hans: 摩尔达维亚语 + - value: mn + label: + en_US: Mongolian + zh_Hans: 蒙古语 + - value: sr-me + label: + en_US: Montenegrin + zh_Hans: 黑山语 + - value: ne + label: + en_US: Nepali + zh_Hans: 尼泊尔语 + - value: pcm + label: + en_US: Nigerian Pidgin + zh_Hans: 尼日利亚皮钦语 + - value: nso + label: + en_US: Northern Sotho + zh_Hans: 北索托语 - value: "no" label: en_US: Norwegian zh_Hans: 挪威语 + - value: nn + label: + en_US: Norwegian (Nynorsk) + zh_Hans: 挪威语(尼诺斯克语) + - value: oc + label: + en_US: Occitan + zh_Hans: 奥克语 + - value: or + label: + en_US: Oriya + zh_Hans: 奥里亚语 + - value: om + label: + en_US: Oromo + zh_Hans: 奥罗莫语 + - value: ps + label: + en_US: Pashto + zh_Hans: 普什图语 + - value: fa + label: + en_US: Persian + zh_Hans: 波斯语 + - value: xx-pirate + label: + en_US: Pirate + zh_Hans: 海盗语 - value: pl label: en_US: Polish zh_Hans: 波兰语 + - value: pt + label: + en_US: Portuguese + zh_Hans: 葡萄牙语 - value: pt-br label: en_US: Portuguese (Brazil) @@ -405,18 +1713,62 @@ parameters: label: en_US: Punjabi zh_Hans: 旁遮普语 + - value: qu + label: + en_US: Quechua + zh_Hans: 克丘亚语 - value: ro label: en_US: Romanian zh_Hans: 罗马尼亚语 + - value: rm + label: + en_US: Romansh + zh_Hans: 罗曼什语 + - value: nyn + label: + en_US: Runyakitara + zh_Hans: 卢尼亚基塔拉语 - value: ru label: en_US: Russian zh_Hans: 俄语 + - value: gd + label: + en_US: Scots Gaelic + zh_Hans: 苏格兰盖尔语 - value: sr label: en_US: Serbian zh_Hans: 塞尔维亚语 + - value: sh + label: + en_US: Serbo-Croatian + zh_Hans: 塞尔维亚-克罗地亚语 + - value: st + label: + en_US: Sesotho + zh_Hans: 塞索托语 + - value: tn + label: + en_US: Setswana + zh_Hans: 塞茨瓦纳语 + - value: crs + label: + en_US: Seychellois Creole + zh_Hans: 塞舌尔克里奥尔语 + - value: sn + label: + en_US: Shona + zh_Hans: 绍纳语 + - value: sd + label: + en_US: Sindhi + zh_Hans: 信德语 + - value: si + label: + en_US: Sinhalese + zh_Hans: 僧伽罗语 - value: sk label: en_US: Slovak @@ -425,18 +1777,42 @@ parameters: label: en_US: Slovenian zh_Hans: 斯洛文尼亚语 + - value: so + label: + en_US: Somali + zh_Hans: 索马里语 - value: es label: en_US: Spanish zh_Hans: 西班牙语 + - value: es-419 + label: + en_US: Spanish (Latin American) + zh_Hans: 西班牙语(拉丁美洲) + - value: su + label: + en_US: Sundanese + zh_Hans: 巽他语 + - value: sw + label: + en_US: Swahili + zh_Hans: 斯瓦希里语 - value: sv label: en_US: Swedish zh_Hans: 瑞典语 + - value: tg + label: + en_US: Tajik + zh_Hans: 塔吉克语 - value: ta label: en_US: Tamil zh_Hans: 泰米尔语 + - value: tt + label: + en_US: Tatar + zh_Hans: 鞑靼语 - value: te label: en_US: Telugu @@ -445,18 +1821,82 @@ parameters: label: en_US: Thai zh_Hans: 泰语 + - value: ti + label: + en_US: Tigrinya + zh_Hans: 提格利尼亚语 + - value: to + label: + en_US: Tonga + zh_Hans: 汤加语 + - value: lua + label: + en_US: Tshiluba + zh_Hans: 卢巴语 + - value: tum + label: + en_US: Tumbuka + zh_Hans: 图布卡语 - value: tr label: en_US: Turkish zh_Hans: 土耳其语 + - value: tk + label: + en_US: Turkmen + zh_Hans: 土库曼语 + - value: tw + label: + en_US: Twi + zh_Hans: 契维语 + - value: ug + label: + en_US: Uighur + zh_Hans: 维吾尔语 - value: uk label: en_US: Ukrainian zh_Hans: 乌克兰语 + - value: ur + label: + en_US: Urdu + zh_Hans: 乌尔都语 + - value: uz + label: + en_US: Uzbek + zh_Hans: 乌兹别克语 + - value: vu + label: + en_US: Vanuatu + zh_Hans: 瓦努阿图语 - value: vi label: en_US: Vietnamese zh_Hans: 越南语 + - value: cy + label: + en_US: Welsh + zh_Hans: 威尔士语 + - value: wo + label: + en_US: Wolof + zh_Hans: 沃洛夫语 + - value: xh + label: + en_US: Xhosa + zh_Hans: 科萨语 + - value: yi + label: + en_US: Yiddish + zh_Hans: 意第绪语 + - value: yo + label: + en_US: Yoruba + zh_Hans: 约鲁巴语 + - value: zu + label: + en_US: Zulu + zh_Hans: 祖鲁语 - name: google_domain type: string required: false diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py index 81c67c51a9a7ae..c478bc108b47e1 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py @@ -7,6 +7,7 @@ SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + class SearchAPI: """ SearchAPI tool provider. @@ -37,41 +38,52 @@ def get_params(self, query: str, **kwargs: Any) -> dict[str, str]: return { "engine": "google_jobs", "q": query, - **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + **{key: value for key, value in kwargs.items() if value not in {None, ""}}, } @staticmethod def _process_response(res: dict, type: str) -> str: """Process response from SearchAPI.""" - if "error" in res.keys(): + if "error" in res: raise ValueError(f"Got error from SearchApi: {res['error']}") toret = "" if type == "text": - if "jobs" in res.keys() and "title" in res["jobs"][0].keys(): + if "jobs" in res and "title" in res["jobs"][0]: for item in res["jobs"]: - toret += "title: " + item["title"] + "\n" + "company_name: " + item["company_name"] + "content: " + item["description"] + "\n" + toret += ( + "title: " + + item["title"] + + "\n" + + "company_name: " + + item["company_name"] + + "content: " + + item["description"] + + "\n" + ) if toret == "": toret = "No good search result found" elif type == "link": - if "jobs" in res.keys() and "apply_link" in res["jobs"][0].keys(): + if "jobs" in res and "apply_link" in res["jobs"][0]: for item in res["jobs"]: toret += f"[{item['title']} - {item['company_name']}]({item['apply_link']})\n" else: toret = "No good search result found" return toret + class GoogleJobsTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the SearchApi tool. """ - query = tool_parameters['query'] - result_type = tool_parameters['result_type'] + query = tool_parameters["query"] + result_type = tool_parameters["result_type"] is_remote = tool_parameters.get("is_remote") google_domain = tool_parameters.get("google_domain", "google.com") gl = tool_parameters.get("gl", "us") @@ -80,9 +92,11 @@ def _invoke(self, ltype = 1 if is_remote else None - api_key = self.runtime.credentials['searchapi_api_key'] - result = SearchAPI(api_key).run(query, result_type=result_type, google_domain=google_domain, gl=gl, hl=hl, location=location, ltype=ltype) + api_key = self.runtime.credentials["searchapi_api_key"] + result = SearchAPI(api_key).run( + query, result_type=result_type, google_domain=google_domain, gl=gl, hl=hl, location=location, ltype=ltype + ) - if result_type == 'text': + if result_type == "text": return self.create_text_message(text=result) return self.create_link_message(link=result) diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.yaml b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.yaml index 9033bc0f8784cc..3e00e20fbd6e33 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.yaml +++ b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.yaml @@ -65,36 +65,141 @@ parameters: form: form default: US options: - - value: AR - label: - en_US: Argentina - zh_Hans: 阿根廷 - pt_BR: Argentina - - value: AU - label: - en_US: Australia - zh_Hans: 澳大利亚 - pt_BR: Australia + - value: DZ + label: + en_US: Algeria + zh_Hans: 阿尔及利亚 + pt_BR: Algeria + - value: AS + label: + en_US: American Samoa + zh_Hans: 美属萨摩亚 + pt_BR: American Samoa + - value: AO + label: + en_US: Angola + zh_Hans: 安哥拉 + pt_BR: Angola + - value: AI + label: + en_US: Anguilla + zh_Hans: 安圭拉 + pt_BR: Anguilla + - value: AG + label: + en_US: Antigua and Barbuda + zh_Hans: 安提瓜和巴布达 + pt_BR: Antigua and Barbuda + - value: AW + label: + en_US: Aruba + zh_Hans: 阿鲁巴 + pt_BR: Aruba - value: AT label: en_US: Austria zh_Hans: 奥地利 pt_BR: Austria + - value: BS + label: + en_US: Bahamas + zh_Hans: 巴哈马 + pt_BR: Bahamas + - value: BH + label: + en_US: Bahrain + zh_Hans: 巴林 + pt_BR: Bahrain + - value: BD + label: + en_US: Bangladesh + zh_Hans: 孟加拉国 + pt_BR: Bangladesh + - value: BY + label: + en_US: Belarus + zh_Hans: 白俄罗斯 + pt_BR: Belarus - value: BE label: en_US: Belgium zh_Hans: 比利时 pt_BR: Belgium + - value: BZ + label: + en_US: Belize + zh_Hans: 伯利兹 + pt_BR: Belize + - value: BJ + label: + en_US: Benin + zh_Hans: 贝宁 + pt_BR: Benin + - value: BM + label: + en_US: Bermuda + zh_Hans: 百慕大 + pt_BR: Bermuda + - value: BO + label: + en_US: Bolivia + zh_Hans: 玻利维亚 + pt_BR: Bolivia + - value: BW + label: + en_US: Botswana + zh_Hans: 博茨瓦纳 + pt_BR: Botswana - value: BR label: en_US: Brazil zh_Hans: 巴西 pt_BR: Brazil + - value: IO + label: + en_US: British Indian Ocean Territory + zh_Hans: 英属印度洋领地 + pt_BR: British Indian Ocean Territory + - value: BF + label: + en_US: Burkina Faso + zh_Hans: 布基纳法索 + pt_BR: Burkina Faso + - value: BI + label: + en_US: Burundi + zh_Hans: 布隆迪 + pt_BR: Burundi + - value: CM + label: + en_US: Cameroon + zh_Hans: 喀麦隆 + pt_BR: Cameroon - value: CA label: en_US: Canada zh_Hans: 加拿大 pt_BR: Canada + - value: CV + label: + en_US: Cape Verde + zh_Hans: 佛得角 + pt_BR: Cape Verde + - value: KY + label: + en_US: Cayman Islands + zh_Hans: 开曼群岛 + pt_BR: Cayman Islands + - value: CF + label: + en_US: Central African Republic + zh_Hans: 中非共和国 + pt_BR: Central African Republic + - value: TD + label: + en_US: Chad + zh_Hans: 乍得 + pt_BR: Chad - value: CL label: en_US: Chile @@ -105,36 +210,141 @@ parameters: en_US: Colombia zh_Hans: 哥伦比亚 pt_BR: Colombia - - value: CN - label: - en_US: China - zh_Hans: 中国 - pt_BR: China - - value: CZ - label: - en_US: Czech Republic - zh_Hans: 捷克共和国 - pt_BR: Czech Republic + - value: CD + label: + en_US: Congo, the Democratic Republic of the + zh_Hans: 刚果民主共和国 + pt_BR: Congo, the Democratic Republic of the + - value: CR + label: + en_US: Costa Rica + zh_Hans: 哥斯达黎加 + pt_BR: Costa Rica + - value: CI + label: + en_US: Cote D'ivoire + zh_Hans: 科特迪瓦 + pt_BR: Cote D'ivoire + - value: CU + label: + en_US: Cuba + zh_Hans: 古巴 + pt_BR: Cuba - value: DK label: en_US: Denmark zh_Hans: 丹麦 pt_BR: Denmark - - value: FI - label: - en_US: Finland - zh_Hans: 芬兰 - pt_BR: Finland + - value: DJ + label: + en_US: Djibouti + zh_Hans: 吉布提 + pt_BR: Djibouti + - value: DM + label: + en_US: Dominica + zh_Hans: 多米尼克 + pt_BR: Dominica + - value: DO + label: + en_US: Dominican Republic + zh_Hans: 多米尼加共和国 + pt_BR: Dominican Republic + - value: EC + label: + en_US: Ecuador + zh_Hans: 厄瓜多尔 + pt_BR: Ecuador + - value: EG + label: + en_US: Egypt + zh_Hans: 埃及 + pt_BR: Egypt + - value: SV + label: + en_US: El Salvador + zh_Hans: 萨尔瓦多 + pt_BR: El Salvador + - value: ET + label: + en_US: Ethiopia + zh_Hans: 埃塞俄比亚 + pt_BR: Ethiopia + - value: FK + label: + en_US: Falkland Islands (Malvinas) + zh_Hans: 福克兰群岛(马尔维纳斯) + pt_BR: Falkland Islands (Malvinas) - value: FR label: en_US: France zh_Hans: 法国 pt_BR: France + - value: GF + label: + en_US: French Guiana + zh_Hans: 法属圭亚那 + pt_BR: French Guiana + - value: PF + label: + en_US: French Polynesia + zh_Hans: 法属波利尼西亚 + pt_BR: French Polynesia + - value: TF + label: + en_US: French Southern Territories + zh_Hans: 法属南部领地 + pt_BR: French Southern Territories + - value: GA + label: + en_US: Gabon + zh_Hans: 加蓬 + pt_BR: Gabon + - value: GM + label: + en_US: Gambia + zh_Hans: 冈比亚 + pt_BR: Gambia - value: DE label: en_US: Germany zh_Hans: 德国 pt_BR: Germany + - value: GH + label: + en_US: Ghana + zh_Hans: 加纳 + pt_BR: Ghana + - value: GR + label: + en_US: Greece + zh_Hans: 希腊 + pt_BR: Greece + - value: GP + label: + en_US: Guadeloupe + zh_Hans: 瓜德罗普 + pt_BR: Guadeloupe + - value: GT + label: + en_US: Guatemala + zh_Hans: 危地马拉 + pt_BR: Guatemala + - value: GY + label: + en_US: Guyana + zh_Hans: 圭亚那 + pt_BR: Guyana + - value: HT + label: + en_US: Haiti + zh_Hans: 海地 + pt_BR: Haiti + - value: HN + label: + en_US: Honduras + zh_Hans: 洪都拉斯 + pt_BR: Honduras - value: HK label: en_US: Hong Kong @@ -150,91 +360,291 @@ parameters: en_US: Indonesia zh_Hans: 印度尼西亚 pt_BR: Indonesia + - value: IQ + label: + en_US: Iraq + zh_Hans: 伊拉克 + pt_BR: Iraq - value: IT label: en_US: Italy zh_Hans: 意大利 pt_BR: Italy + - value: JM + label: + en_US: Jamaica + zh_Hans: 牙买加 + pt_BR: Jamaica - value: JP label: en_US: Japan zh_Hans: 日本 pt_BR: Japan - - value: KR - label: - en_US: Korea - zh_Hans: 韩国 - pt_BR: Korea + - value: JO + label: + en_US: Jordan + zh_Hans: 约旦 + pt_BR: Jordan + - value: KZ + label: + en_US: Kazakhstan + zh_Hans: 哈萨克斯坦 + pt_BR: Kazakhstan + - value: KE + label: + en_US: Kenya + zh_Hans: 肯尼亚 + pt_BR: Kenya + - value: KW + label: + en_US: Kuwait + zh_Hans: 科威特 + pt_BR: Kuwait + - value: KG + label: + en_US: Kyrgyzstan + zh_Hans: 吉尔吉斯斯坦 + pt_BR: Kyrgyzstan + - value: LB + label: + en_US: Lebanon + zh_Hans: 黎巴嫩 + pt_BR: Lebanon + - value: LS + label: + en_US: Lesotho + zh_Hans: 莱索托 + pt_BR: Lesotho + - value: LY + label: + en_US: Libyan Arab Jamahiriya + zh_Hans: 利比亚 + pt_BR: Libyan Arab Jamahiriya + - value: MG + label: + en_US: Madagascar + zh_Hans: 马达加斯加 + pt_BR: Madagascar + - value: MW + label: + en_US: Malawi + zh_Hans: 马拉维 + pt_BR: Malawi - value: MY label: en_US: Malaysia zh_Hans: 马来西亚 pt_BR: Malaysia + - value: ML + label: + en_US: Mali + zh_Hans: 马里 + pt_BR: Mali + - value: MQ + label: + en_US: Martinique + zh_Hans: 马提尼克 + pt_BR: Martinique + - value: MU + label: + en_US: Mauritius + zh_Hans: 毛里求斯 + pt_BR: Mauritius + - value: YT + label: + en_US: Mayotte + zh_Hans: 马约特 + pt_BR: Mayotte - value: MX label: en_US: Mexico zh_Hans: 墨西哥 pt_BR: Mexico + - value: MS + label: + en_US: Montserrat + zh_Hans: 蒙特塞拉特 + pt_BR: Montserrat + - value: MA + label: + en_US: Morocco + zh_Hans: 摩洛哥 + pt_BR: Morocco + - value: MZ + label: + en_US: Mozambique + zh_Hans: 莫桑比克 + pt_BR: Mozambique + - value: NA + label: + en_US: Namibia + zh_Hans: 纳米比亚 + pt_BR: Namibia - value: NL label: en_US: Netherlands zh_Hans: 荷兰 pt_BR: Netherlands - - value: NZ - label: - en_US: New Zealand - zh_Hans: 新西兰 - pt_BR: New Zealand - - value: 'NO' - label: - en_US: Norway - zh_Hans: 挪威 - pt_BR: Norway + - value: NC + label: + en_US: New Caledonia + zh_Hans: 新喀里多尼亚 + pt_BR: New Caledonia + - value: NI + label: + en_US: Nicaragua + zh_Hans: 尼加拉瓜 + pt_BR: Nicaragua + - value: NE + label: + en_US: Niger + zh_Hans: 尼日尔 + pt_BR: Niger + - value: NG + label: + en_US: Nigeria + zh_Hans: 尼日利亚 + pt_BR: Nigeria + - value: OM + label: + en_US: Oman + zh_Hans: 阿曼 + pt_BR: Oman + - value: PK + label: + en_US: Pakistan + zh_Hans: 巴基斯坦 + pt_BR: Pakistan + - value: PS + label: + en_US: Palestinian Territory, Occupied + zh_Hans: 巴勒斯坦领土 + pt_BR: Palestinian Territory, Occupied + - value: PA + label: + en_US: Panama + zh_Hans: 巴拿马 + pt_BR: Panama + - value: PY + label: + en_US: Paraguay + zh_Hans: 巴拉圭 + pt_BR: Paraguay + - value: PE + label: + en_US: Peru + zh_Hans: 秘鲁 + pt_BR: Peru - value: PH label: en_US: Philippines zh_Hans: 菲律宾 pt_BR: Philippines - - value: PL - label: - en_US: Poland - zh_Hans: 波兰 - pt_BR: Poland - value: PT label: en_US: Portugal zh_Hans: 葡萄牙 pt_BR: Portugal + - value: PR + label: + en_US: Puerto Rico + zh_Hans: 波多黎各 + pt_BR: Puerto Rico + - value: QA + label: + en_US: Qatar + zh_Hans: 卡塔尔 + pt_BR: Qatar + - value: RE + label: + en_US: Reunion + zh_Hans: 留尼旺 + pt_BR: Reunion - value: RU label: - en_US: Russia - zh_Hans: 俄罗斯 - pt_BR: Russia + en_US: Russian Federation + zh_Hans: 俄罗斯联邦 + pt_BR: Russian Federation + - value: RW + label: + en_US: Rwanda + zh_Hans: 卢旺达 + pt_BR: Rwanda + - value: SH + label: + en_US: Saint Helena + zh_Hans: 圣赫勒拿 + pt_BR: Saint Helena + - value: PM + label: + en_US: Saint Pierre and Miquelon + zh_Hans: 圣皮埃尔和密克隆 + pt_BR: Saint Pierre and Miquelon + - value: VC + label: + en_US: Saint Vincent and the Grenadines + zh_Hans: 圣文森特和格林纳丁斯 + pt_BR: Saint Vincent and the Grenadines + - value: ST + label: + en_US: Sao Tome and Principe + zh_Hans: 圣多美和普林西比 + pt_BR: Sao Tome and Principe - value: SA label: en_US: Saudi Arabia zh_Hans: 沙特阿拉伯 pt_BR: Saudi Arabia + - value: SN + label: + en_US: Senegal + zh_Hans: 塞内加尔 + pt_BR: Senegal + - value: SC + label: + en_US: Seychelles + zh_Hans: 塞舌尔 + pt_BR: Seychelles + - value: SL + label: + en_US: Sierra Leone + zh_Hans: 塞拉利昂 + pt_BR: Sierra Leone - value: SG label: en_US: Singapore zh_Hans: 新加坡 pt_BR: Singapore + - value: SO + label: + en_US: Somalia + zh_Hans: 索马里 + pt_BR: Somalia - value: ZA label: en_US: South Africa zh_Hans: 南非 pt_BR: South Africa + - value: GS + label: + en_US: South Georgia and the South Sandwich Islands + zh_Hans: 南乔治亚和南桑威奇群岛 + pt_BR: South Georgia and the South Sandwich Islands - value: ES label: en_US: Spain zh_Hans: 西班牙 pt_BR: Spain - - value: SE + - value: LK + label: + en_US: Sri Lanka + zh_Hans: 斯里兰卡 + pt_BR: Sri Lanka + - value: SR label: - en_US: Sweden - zh_Hans: 瑞典 - pt_BR: Sweden + en_US: Suriname + zh_Hans: 苏里南 + pt_BR: Suriname - value: CH label: en_US: Switzerland @@ -242,19 +652,54 @@ parameters: pt_BR: Switzerland - value: TW label: - en_US: Taiwan - zh_Hans: 台湾 - pt_BR: Taiwan + en_US: Taiwan, Province of China + zh_Hans: 中国台湾省 + pt_BR: Taiwan, Province of China + - value: TZ + label: + en_US: Tanzania, United Republic of + zh_Hans: 坦桑尼亚联合共和国 + pt_BR: Tanzania, United Republic of - value: TH label: en_US: Thailand zh_Hans: 泰国 pt_BR: Thailand - - value: TR + - value: TG + label: + en_US: Togo + zh_Hans: 多哥 + pt_BR: Togo + - value: TT + label: + en_US: Trinidad and Tobago + zh_Hans: 特立尼达和多巴哥 + pt_BR: Trinidad and Tobago + - value: TN + label: + en_US: Tunisia + zh_Hans: 突尼斯 + pt_BR: Tunisia + - value: TC + label: + en_US: Turks and Caicos Islands + zh_Hans: 特克斯和凯科斯群岛 + pt_BR: Turks and Caicos Islands + - value: UG + label: + en_US: Uganda + zh_Hans: 乌干达 + pt_BR: Uganda + - value: AE + label: + en_US: United Arab Emirates + zh_Hans: 阿联酋 + pt_BR: United Arab Emirates + - value: UK label: - en_US: Turkey - zh_Hans: 土耳其 - pt_BR: Turkey + en_US: United Kingdom + zh_Hans: 英国 + pt_BR: United Kingdom - value: GB label: en_US: United Kingdom @@ -265,6 +710,46 @@ parameters: en_US: United States zh_Hans: 美国 pt_BR: United States + - value: UY + label: + en_US: Uruguay + zh_Hans: 乌拉圭 + pt_BR: Uruguay + - value: UZ + label: + en_US: Uzbekistan + zh_Hans: 乌兹别克斯坦 + pt_BR: Uzbekistan + - value: VE + label: + en_US: Venezuela + zh_Hans: 委内瑞拉 + pt_BR: Venezuela + - value: VN + label: + en_US: Viet Nam + zh_Hans: 越南 + pt_BR: Viet Nam + - value: VG + label: + en_US: Virgin Islands, British + zh_Hans: 英属维尔京群岛 + pt_BR: Virgin Islands, British + - value: VI + label: + en_US: Virgin Islands, U.S. + zh_Hans: 美属维尔京群岛 + pt_BR: Virgin Islands, U.S. + - value: ZM + label: + en_US: Zambia + zh_Hans: 赞比亚 + pt_BR: Zambia + - value: ZW + label: + en_US: Zimbabwe + zh_Hans: 津巴布韦 + pt_BR: Zimbabwe - name: hl type: select label: @@ -277,18 +762,94 @@ parameters: default: en form: form options: + - value: af + label: + en_US: Afrikaans + zh_Hans: 南非语 + - value: ak + label: + en_US: Akan + zh_Hans: 阿坎语 + - value: sq + label: + en_US: Albanian + zh_Hans: 阿尔巴尼亚语 + - value: ws + label: + en_US: Samoa + zh_Hans: 萨摩亚语 + - value: am + label: + en_US: Amharic + zh_Hans: 阿姆哈拉语 - value: ar label: en_US: Arabic zh_Hans: 阿拉伯语 + - value: hy + label: + en_US: Armenian + zh_Hans: 亚美尼亚语 + - value: az + label: + en_US: Azerbaijani + zh_Hans: 阿塞拜疆语 + - value: eu + label: + en_US: Basque + zh_Hans: 巴斯克语 + - value: be + label: + en_US: Belarusian + zh_Hans: 白俄罗斯语 + - value: bem + label: + en_US: Bemba + zh_Hans: 班巴语 + - value: bn + label: + en_US: Bengali + zh_Hans: 孟加拉语 + - value: bh + label: + en_US: Bihari + zh_Hans: 比哈尔语 + - value: xx-bork + label: + en_US: Bork, bork, bork! + zh_Hans: 博克语 + - value: bs + label: + en_US: Bosnian + zh_Hans: 波斯尼亚语 + - value: br + label: + en_US: Breton + zh_Hans: 布列塔尼语 - value: bg label: en_US: Bulgarian zh_Hans: 保加利亚语 + - value: bt + label: + en_US: Bhutanese + zh_Hans: 不丹语 + - value: km + label: + en_US: Cambodian + zh_Hans: 高棉语 - value: ca label: en_US: Catalan zh_Hans: 加泰罗尼亚语 + - value: chr + label: + en_US: Cherokee + zh_Hans: 切罗基语 + - value: ny + label: + en_US: Chichewa + zh_Hans: 齐切瓦语 - value: zh-cn label: en_US: Chinese (Simplified) @@ -297,6 +858,14 @@ parameters: label: en_US: Chinese (Traditional) zh_Hans: 中文(繁体) + - value: co + label: + en_US: Corsican + zh_Hans: 科西嘉语 + - value: hr + label: + en_US: Croatian + zh_Hans: 克罗地亚语 - value: cs label: en_US: Czech @@ -309,14 +878,34 @@ parameters: label: en_US: Dutch zh_Hans: 荷兰语 + - value: xx-elmer + label: + en_US: Elmer Fudd + zh_Hans: 艾尔默福德语 - value: en label: en_US: English zh_Hans: 英语 + - value: eo + label: + en_US: Esperanto + zh_Hans: 世界语 - value: et label: en_US: Estonian zh_Hans: 爱沙尼亚语 + - value: ee + label: + en_US: Ewe + zh_Hans: 埃维语 + - value: fo + label: + en_US: Faroese + zh_Hans: 法罗语 + - value: tl + label: + en_US: Filipino + zh_Hans: 菲律宾语 - value: fi label: en_US: Finnish @@ -325,6 +914,22 @@ parameters: label: en_US: French zh_Hans: 法语 + - value: fy + label: + en_US: Frisian + zh_Hans: 弗里西亚语 + - value: gaa + label: + en_US: Ga + zh_Hans: 加语 + - value: gl + label: + en_US: Galician + zh_Hans: 加利西亚语 + - value: ka + label: + en_US: Georgian + zh_Hans: 格鲁吉亚语 - value: de label: en_US: German @@ -333,6 +938,34 @@ parameters: label: en_US: Greek zh_Hans: 希腊语 + - value: kl + label: + en_US: Greenlandic + zh_Hans: 格陵兰语 + - value: gn + label: + en_US: Guarani + zh_Hans: 瓜拉尼语 + - value: gu + label: + en_US: Gujarati + zh_Hans: 古吉拉特语 + - value: xx-hacker + label: + en_US: Hacker + zh_Hans: 黑客语 + - value: ht + label: + en_US: Haitian Creole + zh_Hans: 海地克里奥尔语 + - value: ha + label: + en_US: Hausa + zh_Hans: 豪萨语 + - value: haw + label: + en_US: Hawaiian + zh_Hans: 夏威夷语 - value: iw label: en_US: Hebrew @@ -345,10 +978,26 @@ parameters: label: en_US: Hungarian zh_Hans: 匈牙利语 + - value: is + label: + en_US: Icelandic + zh_Hans: 冰岛语 + - value: ig + label: + en_US: Igbo + zh_Hans: 伊博语 - value: id label: en_US: Indonesian zh_Hans: 印尼语 + - value: ia + label: + en_US: Interlingua + zh_Hans: 国际语 + - value: ga + label: + en_US: Irish + zh_Hans: 爱尔兰语 - value: it label: en_US: Italian @@ -357,22 +1006,94 @@ parameters: label: en_US: Japanese zh_Hans: 日语 + - value: jw + label: + en_US: Javanese + zh_Hans: 爪哇语 - value: kn label: en_US: Kannada zh_Hans: 卡纳达语 + - value: kk + label: + en_US: Kazakh + zh_Hans: 哈萨克语 + - value: rw + label: + en_US: Kinyarwanda + zh_Hans: 基尼亚卢旺达语 + - value: rn + label: + en_US: Kirundi + zh_Hans: 基隆迪语 + - value: xx-klingon + label: + en_US: Klingon + zh_Hans: 克林贡语 + - value: kg + label: + en_US: Kongo + zh_Hans: 刚果语 - value: ko label: en_US: Korean zh_Hans: 韩语 + - value: kri + label: + en_US: Krio (Sierra Leone) + zh_Hans: 塞拉利昂克里奥尔语 + - value: ku + label: + en_US: Kurdish + zh_Hans: 库尔德语 + - value: ckb + label: + en_US: Kurdish (Soranî) + zh_Hans: 库尔德语(索拉尼) + - value: ky + label: + en_US: Kyrgyz + zh_Hans: 吉尔吉斯语 + - value: lo + label: + en_US: Laothian + zh_Hans: 老挝语 + - value: la + label: + en_US: Latin + zh_Hans: 拉丁语 - value: lv label: en_US: Latvian zh_Hans: 拉脱维亚语 + - value: ln + label: + en_US: Lingala + zh_Hans: 林加拉语 - value: lt label: en_US: Lithuanian zh_Hans: 立陶宛语 + - value: loz + label: + en_US: Lozi + zh_Hans: 洛齐语 + - value: lg + label: + en_US: Luganda + zh_Hans: 卢干达语 + - value: ach + label: + en_US: Luo + zh_Hans: 卢奥语 + - value: mk + label: + en_US: Macedonian + zh_Hans: 马其顿语 + - value: mg + label: + en_US: Malagasy + zh_Hans: 马尔加什语 - value: my label: en_US: Malay @@ -381,18 +1102,90 @@ parameters: label: en_US: Malayalam zh_Hans: 马拉雅拉姆语 + - value: mt + label: + en_US: Maltese + zh_Hans: 马耳他语 + - value: mv + label: + en_US: Maldives + zh_Hans: 马尔代夫语 + - value: mi + label: + en_US: Maori + zh_Hans: 毛利语 - value: mr label: en_US: Marathi zh_Hans: 马拉地语 + - value: mfe + label: + en_US: Mauritian Creole + zh_Hans: 毛里求斯克里奥尔语 + - value: mo + label: + en_US: Moldavian + zh_Hans: 摩尔达维亚语 + - value: mn + label: + en_US: Mongolian + zh_Hans: 蒙古语 + - value: sr-me + label: + en_US: Montenegrin + zh_Hans: 黑山语 + - value: ne + label: + en_US: Nepali + zh_Hans: 尼泊尔语 + - value: pcm + label: + en_US: Nigerian Pidgin + zh_Hans: 尼日利亚皮钦语 + - value: nso + label: + en_US: Northern Sotho + zh_Hans: 北索托语 - value: "no" label: en_US: Norwegian zh_Hans: 挪威语 + - value: nn + label: + en_US: Norwegian (Nynorsk) + zh_Hans: 挪威语(尼诺斯克语) + - value: oc + label: + en_US: Occitan + zh_Hans: 奥克语 + - value: or + label: + en_US: Oriya + zh_Hans: 奥里亚语 + - value: om + label: + en_US: Oromo + zh_Hans: 奥罗莫语 + - value: ps + label: + en_US: Pashto + zh_Hans: 普什图语 + - value: fa + label: + en_US: Persian + zh_Hans: 波斯语 + - value: xx-pirate + label: + en_US: Pirate + zh_Hans: 海盗语 - value: pl label: en_US: Polish zh_Hans: 波兰语 + - value: pt + label: + en_US: Portuguese + zh_Hans: 葡萄牙语 - value: pt-br label: en_US: Portuguese (Brazil) @@ -405,18 +1198,62 @@ parameters: label: en_US: Punjabi zh_Hans: 旁遮普语 + - value: qu + label: + en_US: Quechua + zh_Hans: 克丘亚语 - value: ro label: en_US: Romanian zh_Hans: 罗马尼亚语 + - value: rm + label: + en_US: Romansh + zh_Hans: 罗曼什语 + - value: nyn + label: + en_US: Runyakitara + zh_Hans: 卢尼亚基塔拉语 - value: ru label: en_US: Russian zh_Hans: 俄语 + - value: gd + label: + en_US: Scots Gaelic + zh_Hans: 苏格兰盖尔语 - value: sr label: en_US: Serbian zh_Hans: 塞尔维亚语 + - value: sh + label: + en_US: Serbo-Croatian + zh_Hans: 塞尔维亚-克罗地亚语 + - value: st + label: + en_US: Sesotho + zh_Hans: 塞索托语 + - value: tn + label: + en_US: Setswana + zh_Hans: 塞茨瓦纳语 + - value: crs + label: + en_US: Seychellois Creole + zh_Hans: 塞舌尔克里奥尔语 + - value: sn + label: + en_US: Shona + zh_Hans: 绍纳语 + - value: sd + label: + en_US: Sindhi + zh_Hans: 信德语 + - value: si + label: + en_US: Sinhalese + zh_Hans: 僧伽罗语 - value: sk label: en_US: Slovak @@ -425,18 +1262,42 @@ parameters: label: en_US: Slovenian zh_Hans: 斯洛文尼亚语 + - value: so + label: + en_US: Somali + zh_Hans: 索马里语 - value: es label: en_US: Spanish zh_Hans: 西班牙语 + - value: es-419 + label: + en_US: Spanish (Latin American) + zh_Hans: 西班牙语(拉丁美洲) + - value: su + label: + en_US: Sundanese + zh_Hans: 巽他语 + - value: sw + label: + en_US: Swahili + zh_Hans: 斯瓦希里语 - value: sv label: en_US: Swedish zh_Hans: 瑞典语 + - value: tg + label: + en_US: Tajik + zh_Hans: 塔吉克语 - value: ta label: en_US: Tamil zh_Hans: 泰米尔语 + - value: tt + label: + en_US: Tatar + zh_Hans: 鞑靼语 - value: te label: en_US: Telugu @@ -445,18 +1306,82 @@ parameters: label: en_US: Thai zh_Hans: 泰语 + - value: ti + label: + en_US: Tigrinya + zh_Hans: 提格利尼亚语 + - value: to + label: + en_US: Tonga + zh_Hans: 汤加语 + - value: lua + label: + en_US: Tshiluba + zh_Hans: 卢巴语 + - value: tum + label: + en_US: Tumbuka + zh_Hans: 图布卡语 - value: tr label: en_US: Turkish zh_Hans: 土耳其语 + - value: tk + label: + en_US: Turkmen + zh_Hans: 土库曼语 + - value: tw + label: + en_US: Twi + zh_Hans: 契维语 + - value: ug + label: + en_US: Uighur + zh_Hans: 维吾尔语 - value: uk label: en_US: Ukrainian zh_Hans: 乌克兰语 + - value: ur + label: + en_US: Urdu + zh_Hans: 乌尔都语 + - value: uz + label: + en_US: Uzbek + zh_Hans: 乌兹别克语 + - value: vu + label: + en_US: Vanuatu + zh_Hans: 瓦努阿图语 - value: vi label: en_US: Vietnamese zh_Hans: 越南语 + - value: cy + label: + en_US: Welsh + zh_Hans: 威尔士语 + - value: wo + label: + en_US: Wolof + zh_Hans: 沃洛夫语 + - value: xh + label: + en_US: Xhosa + zh_Hans: 科萨语 + - value: yi + label: + en_US: Yiddish + zh_Hans: 意第绪语 + - value: yo + label: + en_US: Yoruba + zh_Hans: 约鲁巴语 + - value: zu + label: + en_US: Zulu + zh_Hans: 祖鲁语 - name: is_remote type: select label: diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_news.py b/api/core/tools/provider/builtin/searchapi/tools/google_news.py index 5d2657dddd1972..562bc01964b4c3 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_news.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_news.py @@ -7,6 +7,7 @@ SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + class SearchAPI: """ SearchAPI tool provider. @@ -37,56 +38,60 @@ def get_params(self, query: str, **kwargs: Any) -> dict[str, str]: return { "engine": "google_news", "q": query, - **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + **{key: value for key, value in kwargs.items() if value not in {None, ""}}, } @staticmethod def _process_response(res: dict, type: str) -> str: """Process response from SearchAPI.""" - if "error" in res.keys(): + if "error" in res: raise ValueError(f"Got error from SearchApi: {res['error']}") toret = "" if type == "text": - if "organic_results" in res.keys() and "snippet" in res["organic_results"][0].keys(): + if "organic_results" in res and "snippet" in res["organic_results"][0]: for item in res["organic_results"]: toret += "content: " + item["snippet"] + "\n" + "link: " + item["link"] + "\n" - if "top_stories" in res.keys() and "title" in res["top_stories"][0].keys(): + if "top_stories" in res and "title" in res["top_stories"][0]: for item in res["top_stories"]: toret += "title: " + item["title"] + "\n" + "link: " + item["link"] + "\n" if toret == "": toret = "No good search result found" elif type == "link": - if "organic_results" in res.keys() and "title" in res["organic_results"][0].keys(): + if "organic_results" in res and "title" in res["organic_results"][0]: for item in res["organic_results"]: toret += f"[{item['title']}]({item['link']})\n" - elif "top_stories" in res.keys() and "title" in res["top_stories"][0].keys(): + elif "top_stories" in res and "title" in res["top_stories"][0]: for item in res["top_stories"]: toret += f"[{item['title']}]({item['link']})\n" else: toret = "No good search result found" return toret + class GoogleNewsTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the SearchApi tool. """ - query = tool_parameters['query'] - result_type = tool_parameters['result_type'] + query = tool_parameters["query"] + result_type = tool_parameters["result_type"] num = tool_parameters.get("num", 10) google_domain = tool_parameters.get("google_domain", "google.com") gl = tool_parameters.get("gl", "us") hl = tool_parameters.get("hl", "en") location = tool_parameters.get("location") - api_key = self.runtime.credentials['searchapi_api_key'] - result = SearchAPI(api_key).run(query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location) + api_key = self.runtime.credentials["searchapi_api_key"] + result = SearchAPI(api_key).run( + query, result_type=result_type, num=num, google_domain=google_domain, gl=gl, hl=hl, location=location + ) - if result_type == 'text': + if result_type == "text": return self.create_text_message(text=result) return self.create_link_message(link=result) diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_news.yaml b/api/core/tools/provider/builtin/searchapi/tools/google_news.yaml index cbb0edf9829595..ff34af34cc9f5c 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_news.yaml +++ b/api/core/tools/provider/builtin/searchapi/tools/google_news.yaml @@ -65,206 +65,1206 @@ parameters: form: form default: US options: + - value: AF + label: + en_US: Afghanistan + zh_Hans: 阿富汗 + pt_BR: Afeganistão + - value: AL + label: + en_US: Albania + zh_Hans: 阿尔巴尼亚 + pt_BR: Albânia + - value: DZ + label: + en_US: Algeria + zh_Hans: 阿尔及利亚 + pt_BR: Argélia + - value: AS + label: + en_US: American Samoa + zh_Hans: 美属萨摩亚 + pt_BR: Samoa Americana + - value: AD + label: + en_US: Andorra + zh_Hans: 安道尔 + pt_BR: Andorra + - value: AO + label: + en_US: Angola + zh_Hans: 安哥拉 + pt_BR: Angola + - value: AI + label: + en_US: Anguilla + zh_Hans: 安圭拉 + pt_BR: Anguilla + - value: AQ + label: + en_US: Antarctica + zh_Hans: 南极洲 + pt_BR: Antártica + - value: AG + label: + en_US: Antigua and Barbuda + zh_Hans: 安提瓜和巴布达 + pt_BR: Antígua e Barbuda - value: AR label: en_US: Argentina zh_Hans: 阿根廷 pt_BR: Argentina + - value: AM + label: + en_US: Armenia + zh_Hans: 亚美尼亚 + pt_BR: Armênia + - value: AW + label: + en_US: Aruba + zh_Hans: 阿鲁巴 + pt_BR: Aruba - value: AU label: en_US: Australia zh_Hans: 澳大利亚 - pt_BR: Australia + pt_BR: Austrália - value: AT label: en_US: Austria zh_Hans: 奥地利 - pt_BR: Austria + pt_BR: Áustria + - value: AZ + label: + en_US: Azerbaijan + zh_Hans: 阿塞拜疆 + pt_BR: Azerbaijão + - value: BS + label: + en_US: Bahamas + zh_Hans: 巴哈马 + pt_BR: Bahamas + - value: BH + label: + en_US: Bahrain + zh_Hans: 巴林 + pt_BR: Bahrein + - value: BD + label: + en_US: Bangladesh + zh_Hans: 孟加拉国 + pt_BR: Bangladesh + - value: BB + label: + en_US: Barbados + zh_Hans: 巴巴多斯 + pt_BR: Barbados + - value: BY + label: + en_US: Belarus + zh_Hans: 白俄罗斯 + pt_BR: Bielorrússia - value: BE label: en_US: Belgium zh_Hans: 比利时 - pt_BR: Belgium + pt_BR: Bélgica + - value: BZ + label: + en_US: Belize + zh_Hans: 伯利兹 + pt_BR: Belize + - value: BJ + label: + en_US: Benin + zh_Hans: 贝宁 + pt_BR: Benim + - value: BM + label: + en_US: Bermuda + zh_Hans: 百慕大 + pt_BR: Bermudas + - value: BT + label: + en_US: Bhutan + zh_Hans: 不丹 + pt_BR: Butão + - value: BO + label: + en_US: Bolivia + zh_Hans: 玻利维亚 + pt_BR: Bolívia + - value: BA + label: + en_US: Bosnia and Herzegovina + zh_Hans: 波斯尼亚和黑塞哥维那 + pt_BR: Bósnia e Herzegovina + - value: BW + label: + en_US: Botswana + zh_Hans: 博茨瓦纳 + pt_BR: Botsuana + - value: BV + label: + en_US: Bouvet Island + zh_Hans: 布韦岛 + pt_BR: Ilha Bouvet - value: BR label: en_US: Brazil zh_Hans: 巴西 - pt_BR: Brazil + pt_BR: Brasil + - value: IO + label: + en_US: British Indian Ocean Territory + zh_Hans: 英属印度洋领地 + pt_BR: Território Britânico do Oceano Índico + - value: BN + label: + en_US: Brunei Darussalam + zh_Hans: 文莱 + pt_BR: Brunei Darussalam + - value: BG + label: + en_US: Bulgaria + zh_Hans: 保加利亚 + pt_BR: Bulgária + - value: BF + label: + en_US: Burkina Faso + zh_Hans: 布基纳法索 + pt_BR: Burkina Faso + - value: BI + label: + en_US: Burundi + zh_Hans: 布隆迪 + pt_BR: Burundi + - value: KH + label: + en_US: Cambodia + zh_Hans: 柬埔寨 + pt_BR: Camboja + - value: CM + label: + en_US: Cameroon + zh_Hans: 喀麦隆 + pt_BR: Camarões - value: CA label: en_US: Canada zh_Hans: 加拿大 - pt_BR: Canada + pt_BR: Canadá + - value: CV + label: + en_US: Cape Verde + zh_Hans: 佛得角 + pt_BR: Cabo Verde + - value: KY + label: + en_US: Cayman Islands + zh_Hans: 开曼群岛 + pt_BR: Ilhas Cayman + - value: CF + label: + en_US: Central African Republic + zh_Hans: 中非共和国 + pt_BR: República Centro-Africana + - value: TD + label: + en_US: Chad + zh_Hans: 乍得 + pt_BR: Chade - value: CL label: en_US: Chile zh_Hans: 智利 pt_BR: Chile - - value: CO - label: - en_US: Colombia - zh_Hans: 哥伦比亚 - pt_BR: Colombia - value: CN label: en_US: China zh_Hans: 中国 pt_BR: China + - value: CX + label: + en_US: Christmas Island + zh_Hans: 圣诞岛 + pt_BR: Ilha do Natal + - value: CC + label: + en_US: Cocos (Keeling) Islands + zh_Hans: 科科斯(基林)群岛 + pt_BR: Ilhas Cocos (Keeling) + - value: CO + label: + en_US: Colombia + zh_Hans: 哥伦比亚 + pt_BR: Colômbia + - value: KM + label: + en_US: Comoros + zh_Hans: 科摩罗 + pt_BR: Comores + - value: CG + label: + en_US: Congo + zh_Hans: 刚果 + pt_BR: Congo + - value: CD + label: + en_US: Congo, the Democratic Republic of the + zh_Hans: 刚果民主共和国 + pt_BR: Congo, República Democrática do + - value: CK + label: + en_US: Cook Islands + zh_Hans: 库克群岛 + pt_BR: Ilhas Cook + - value: CR + label: + en_US: Costa Rica + zh_Hans: 哥斯达黎加 + pt_BR: Costa Rica + - value: CI + label: + en_US: Cote D'ivoire + zh_Hans: 科特迪瓦 + pt_BR: Costa do Marfim + - value: HR + label: + en_US: Croatia + zh_Hans: 克罗地亚 + pt_BR: Croácia + - value: CU + label: + en_US: Cuba + zh_Hans: 古巴 + pt_BR: Cuba + - value: CY + label: + en_US: Cyprus + zh_Hans: 塞浦路斯 + pt_BR: Chipre - value: CZ label: en_US: Czech Republic zh_Hans: 捷克共和国 - pt_BR: Czech Republic + pt_BR: República Tcheca - value: DK label: en_US: Denmark zh_Hans: 丹麦 - pt_BR: Denmark + pt_BR: Dinamarca + - value: DJ + label: + en_US: Djibouti + zh_Hans: 吉布提 + pt_BR: Djibuti + - value: DM + label: + en_US: Dominica + zh_Hans: 多米尼克 + pt_BR: Dominica + - value: DO + label: + en_US: Dominican Republic + zh_Hans: 多米尼加共和国 + pt_BR: República Dominicana + - value: EC + label: + en_US: Ecuador + zh_Hans: 厄瓜多尔 + pt_BR: Equador + - value: EG + label: + en_US: Egypt + zh_Hans: 埃及 + pt_BR: Egito + - value: SV + label: + en_US: El Salvador + zh_Hans: 萨尔瓦多 + pt_BR: El Salvador + - value: GQ + label: + en_US: Equatorial Guinea + zh_Hans: 赤道几内亚 + pt_BR: Guiné Equatorial + - value: ER + label: + en_US: Eritrea + zh_Hans: 厄立特里亚 + pt_BR: Eritreia + - value: EE + label: + en_US: Estonia + zh_Hans: 爱沙尼亚 + pt_BR: Estônia + - value: ET + label: + en_US: Ethiopia + zh_Hans: 埃塞俄比亚 + pt_BR: Etiópia + - value: FK + label: + en_US: Falkland Islands (Malvinas) + zh_Hans: 福克兰群岛(马尔维纳斯) + pt_BR: Ilhas Falkland (Malvinas) + - value: FO + label: + en_US: Faroe Islands + zh_Hans: 法罗群岛 + pt_BR: Ilhas Faroe + - value: FJ + label: + en_US: Fiji + zh_Hans: 斐济 + pt_BR: Fiji - value: FI label: en_US: Finland zh_Hans: 芬兰 - pt_BR: Finland + pt_BR: Finlândia - value: FR label: en_US: France zh_Hans: 法国 - pt_BR: France + pt_BR: França + - value: GF + label: + en_US: French Guiana + zh_Hans: 法属圭亚那 + pt_BR: Guiana Francesa + - value: PF + label: + en_US: French Polynesia + zh_Hans: 法属波利尼西亚 + pt_BR: Polinésia Francesa + - value: TF + label: + en_US: French Southern Territories + zh_Hans: 法属南部领地 + pt_BR: Territórios Franceses do Sul + - value: GA + label: + en_US: Gabon + zh_Hans: 加蓬 + pt_BR: Gabão + - value: GM + label: + en_US: Gambia + zh_Hans: 冈比亚 + pt_BR: Gâmbia + - value: GE + label: + en_US: Georgia + zh_Hans: 格鲁吉亚 + pt_BR: Geórgia - value: DE label: en_US: Germany zh_Hans: 德国 - pt_BR: Germany + pt_BR: Alemanha + - value: GH + label: + en_US: Ghana + zh_Hans: 加纳 + pt_BR: Gana + - value: GI + label: + en_US: Gibraltar + zh_Hans: 直布罗陀 + pt_BR: Gibraltar + - value: GR + label: + en_US: Greece + zh_Hans: 希腊 + pt_BR: Grécia + - value: GL + label: + en_US: Greenland + zh_Hans: 格陵兰 + pt_BR: Groenlândia + - value: GD + label: + en_US: Grenada + zh_Hans: 格林纳达 + pt_BR: Granada + - value: GP + label: + en_US: Guadeloupe + zh_Hans: 瓜德罗普 + pt_BR: Guadalupe + - value: GU + label: + en_US: Guam + zh_Hans: 关岛 + pt_BR: Guam + - value: GT + label: + en_US: Guatemala + zh_Hans: 危地马拉 + pt_BR: Guatemala + - value: GN + label: + en_US: Guinea + zh_Hans: 几内亚 + pt_BR: Guiné + - value: GW + label: + en_US: Guinea-Bissau + zh_Hans: 几内亚比绍 + pt_BR: Guiné-Bissau + - value: GY + label: + en_US: Guyana + zh_Hans: 圭亚那 + pt_BR: Guiana + - value: HT + label: + en_US: Haiti + zh_Hans: 海地 + pt_BR: Haiti + - value: HM + label: + en_US: Heard Island and McDonald Islands + zh_Hans: 赫德岛和麦克唐纳群岛 + pt_BR: Ilha Heard e Ilhas McDonald + - value: VA + label: + en_US: Holy See (Vatican City State) + zh_Hans: 教廷(梵蒂冈城国) + pt_BR: Santa Sé (Estado da Cidade do Vaticano) + - value: HN + label: + en_US: Honduras + zh_Hans: 洪都拉斯 + pt_BR: Honduras - value: HK label: en_US: Hong Kong zh_Hans: 香港 pt_BR: Hong Kong + - value: HU + label: + en_US: Hungary + zh_Hans: 匈牙利 + pt_BR: Hungria + - value: IS + label: + en_US: Iceland + zh_Hans: 冰岛 + pt_BR: Islândia - value: IN label: en_US: India zh_Hans: 印度 - pt_BR: India + pt_BR: Índia - value: ID label: en_US: Indonesia zh_Hans: 印度尼西亚 - pt_BR: Indonesia + pt_BR: Indonésia + - value: IR + label: + en_US: Iran, Islamic Republic of + zh_Hans: 伊朗 + pt_BR: Irã + - value: IQ + label: + en_US: Iraq + zh_Hans: 伊拉克 + pt_BR: Iraque + - value: IE + label: + en_US: Ireland + zh_Hans: 爱尔兰 + pt_BR: Irlanda + - value: IL + label: + en_US: Israel + zh_Hans: 以色列 + pt_BR: Israel - value: IT label: en_US: Italy zh_Hans: 意大利 - pt_BR: Italy + pt_BR: Itália + - value: JM + label: + en_US: Jamaica + zh_Hans: 牙买加 + pt_BR: Jamaica - value: JP label: en_US: Japan zh_Hans: 日本 - pt_BR: Japan + pt_BR: Japão + - value: JO + label: + en_US: Jordan + zh_Hans: 约旦 + pt_BR: Jordânia + - value: KZ + label: + en_US: Kazakhstan + zh_Hans: 哈萨克斯坦 + pt_BR: Cazaquistão + - value: KE + label: + en_US: Kenya + zh_Hans: 肯尼亚 + pt_BR: Quênia + - value: KI + label: + en_US: Kiribati + zh_Hans: 基里巴斯 + pt_BR: Kiribati + - value: KP + label: + en_US: Korea, Democratic People's Republic of + zh_Hans: 朝鲜 + pt_BR: Coreia, República Democrática Popular da - value: KR label: - en_US: Korea + en_US: Korea, Republic of zh_Hans: 韩国 - pt_BR: Korea + pt_BR: Coreia, República da + - value: KW + label: + en_US: Kuwait + zh_Hans: 科威特 + pt_BR: Kuwait + - value: KG + label: + en_US: Kyrgyzstan + zh_Hans: 吉尔吉斯斯坦 + pt_BR: Quirguistão + - value: LA + label: + en_US: Lao People's Democratic Republic + zh_Hans: 老挝 + pt_BR: República Democrática Popular do Laos + - value: LV + label: + en_US: Latvia + zh_Hans: 拉脱维亚 + pt_BR: Letônia + - value: LB + label: + en_US: Lebanon + zh_Hans: 黎巴嫩 + pt_BR: Líbano + - value: LS + label: + en_US: Lesotho + zh_Hans: 莱索托 + pt_BR: Lesoto + - value: LR + label: + en_US: Liberia + zh_Hans: 利比里亚 + pt_BR: Libéria + - value: LY + label: + en_US: Libyan Arab Jamahiriya + zh_Hans: 利比亚 + pt_BR: Líbia + - value: LI + label: + en_US: Liechtenstein + zh_Hans: 列支敦士登 + pt_BR: Liechtenstein + - value: LT + label: + en_US: Lithuania + zh_Hans: 立陶宛 + pt_BR: Lituânia + - value: LU + label: + en_US: Luxembourg + zh_Hans: 卢森堡 + pt_BR: Luxemburgo + - value: MO + label: + en_US: Macao + zh_Hans: 澳门 + pt_BR: Macau + - value: MK + label: + en_US: Macedonia, the Former Yugosalv Republic of + zh_Hans: 前南斯拉夫马其顿共和国 + pt_BR: Macedônia, Ex-República Iugoslava da + - value: MG + label: + en_US: Madagascar + zh_Hans: 马达加斯加 + pt_BR: Madagascar + - value: MW + label: + en_US: Malawi + zh_Hans: 马拉维 + pt_BR: Malaui - value: MY label: en_US: Malaysia zh_Hans: 马来西亚 - pt_BR: Malaysia + pt_BR: Malásia + - value: MV + label: + en_US: Maldives + zh_Hans: 马尔代夫 + pt_BR: Maldivas + - value: ML + label: + en_US: Mali + zh_Hans: 马里 + pt_BR: Mali + - value: MT + label: + en_US: Malta + zh_Hans: 马耳他 + pt_BR: Malta + - value: MH + label: + en_US: Marshall Islands + zh_Hans: 马绍尔群岛 + pt_BR: Ilhas Marshall + - value: MQ + label: + en_US: Martinique + zh_Hans: 马提尼克 + pt_BR: Martinica + - value: MR + label: + en_US: Mauritania + zh_Hans: 毛里塔尼亚 + pt_BR: Mauritânia + - value: MU + label: + en_US: Mauritius + zh_Hans: 毛里求斯 + pt_BR: Maurício + - value: YT + label: + en_US: Mayotte + zh_Hans: 马约特 + pt_BR: Mayotte - value: MX label: en_US: Mexico zh_Hans: 墨西哥 - pt_BR: Mexico + pt_BR: México + - value: FM + label: + en_US: Micronesia, Federated States of + zh_Hans: 密克罗尼西亚联邦 + pt_BR: Micronésia, Estados Federados da + - value: MD + label: + en_US: Moldova, Republic of + zh_Hans: 摩尔多瓦共和国 + pt_BR: Moldávia, República da + - value: MC + label: + en_US: Monaco + zh_Hans: 摩纳哥 + pt_BR: Mônaco + - value: MN + label: + en_US: Mongolia + zh_Hans: 蒙古 + pt_BR: Mongólia + - value: MS + label: + en_US: Montserrat + zh_Hans: 蒙特塞拉特 + pt_BR: Montserrat + - value: MA + label: + en_US: Morocco + zh_Hans: 摩洛哥 + pt_BR: Marrocos + - value: MZ + label: + en_US: Mozambique + zh_Hans: 莫桑比克 + pt_BR: Moçambique + - value: MM + label: + en_US: Myanmar + zh_Hans: 缅甸 + pt_BR: Mianmar + - value: NA + label: + en_US: Namibia + zh_Hans: 纳米比亚 + pt_BR: Namíbia + - value: NR + label: + en_US: Nauru + zh_Hans: 瑙鲁 + pt_BR: Nauru + - value: NP + label: + en_US: Nepal + zh_Hans: 尼泊尔 + pt_BR: Nepal - value: NL label: en_US: Netherlands zh_Hans: 荷兰 - pt_BR: Netherlands + pt_BR: Países Baixos + - value: AN + label: + en_US: Netherlands Antilles + zh_Hans: 荷属安的列斯 + pt_BR: Antilhas Holandesas + - value: NC + label: + en_US: New Caledonia + zh_Hans: 新喀里多尼亚 + pt_BR: Nova Caledônia - value: NZ label: en_US: New Zealand zh_Hans: 新西兰 - pt_BR: New Zealand - - value: 'NO' + pt_BR: Nova Zelândia + - value: NI + label: + en_US: Nicaragua + zh_Hans: 尼加拉瓜 + pt_BR: Nicarágua + - value: NE + label: + en_US: Niger + zh_Hans: 尼日尔 + pt_BR: Níger + - value: NG + label: + en_US: Nigeria + zh_Hans: 尼日利亚 + pt_BR: Nigéria + - value: NU + label: + en_US: Niue + zh_Hans: 纽埃 + pt_BR: Niue + - value: NF + label: + en_US: Norfolk Island + zh_Hans: 诺福克岛 + pt_BR: Ilha Norfolk + - value: MP + label: + en_US: Northern Mariana Islands + zh_Hans: 北马里亚纳群岛 + pt_BR: Ilhas Marianas do Norte + - value: "NO" label: en_US: Norway zh_Hans: 挪威 - pt_BR: Norway + pt_BR: Noruega + - value: OM + label: + en_US: Oman + zh_Hans: 阿曼 + pt_BR: Omã + - value: PK + label: + en_US: Pakistan + zh_Hans: 巴基斯坦 + pt_BR: Paquistão + - value: PW + label: + en_US: Palau + zh_Hans: 帕劳 + pt_BR: Palau + - value: PS + label: + en_US: Palestinian Territory, Occupied + zh_Hans: 巴勒斯坦领土 + pt_BR: Palestina, Território Ocupado + - value: PA + label: + en_US: Panama + zh_Hans: 巴拿马 + pt_BR: Panamá + - value: PG + label: + en_US: Papua New Guinea + zh_Hans: 巴布亚新几内亚 + pt_BR: Papua Nova Guiné + - value: PY + label: + en_US: Paraguay + zh_Hans: 巴拉圭 + pt_BR: Paraguai + - value: PE + label: + en_US: Peru + zh_Hans: 秘鲁 + pt_BR: Peru - value: PH label: en_US: Philippines zh_Hans: 菲律宾 - pt_BR: Philippines + pt_BR: Filipinas + - value: PN + label: + en_US: Pitcairn + zh_Hans: 皮特凯恩岛 + pt_BR: Pitcairn - value: PL label: en_US: Poland zh_Hans: 波兰 - pt_BR: Poland + pt_BR: Polônia - value: PT label: en_US: Portugal zh_Hans: 葡萄牙 pt_BR: Portugal + - value: PR + label: + en_US: Puerto Rico + zh_Hans: 波多黎各 + pt_BR: Porto Rico + - value: QA + label: + en_US: Qatar + zh_Hans: 卡塔尔 + pt_BR: Catar + - value: RE + label: + en_US: Reunion + zh_Hans: 留尼旺 + pt_BR: Reunião + - value: RO + label: + en_US: Romania + zh_Hans: 罗马尼亚 + pt_BR: Romênia - value: RU label: - en_US: Russia - zh_Hans: 俄罗斯 - pt_BR: Russia + en_US: Russian Federation + zh_Hans: 俄罗斯联邦 + pt_BR: Rússia + - value: RW + label: + en_US: Rwanda + zh_Hans: 卢旺达 + pt_BR: Ruanda + - value: SH + label: + en_US: Saint Helena + zh_Hans: 圣赫勒拿 + pt_BR: Santa Helena + - value: KN + label: + en_US: Saint Kitts and Nevis + zh_Hans: 圣基茨和尼维斯 + pt_BR: São Cristóvão e Nevis + - value: LC + label: + en_US: Saint Lucia + zh_Hans: 圣卢西亚 + pt_BR: Santa Lúcia + - value: PM + label: + en_US: Saint Pierre and Miquelon + zh_Hans: 圣皮埃尔和密克隆 + pt_BR: São Pedro e Miquelon + - value: VC + label: + en_US: Saint Vincent and the Grenadines + zh_Hans: 圣文森特和格林纳丁斯 + pt_BR: São Vicente e Granadinas + - value: WS + label: + en_US: Samoa + zh_Hans: 萨摩亚 + pt_BR: Samoa + - value: SM + label: + en_US: San Marino + zh_Hans: 圣马力诺 + pt_BR: San Marino + - value: ST + label: + en_US: Sao Tome and Principe + zh_Hans: 圣多美和普林西比 + pt_BR: São Tomé e Príncipe - value: SA label: en_US: Saudi Arabia zh_Hans: 沙特阿拉伯 - pt_BR: Saudi Arabia + pt_BR: Arábia Saudita + - value: SN + label: + en_US: Senegal + zh_Hans: 塞内加尔 + pt_BR: Senegal + - value: RS + label: + en_US: Serbia and Montenegro + zh_Hans: 塞尔维亚和黑山 + pt_BR: Sérvia e Montenegro + - value: SC + label: + en_US: Seychelles + zh_Hans: 塞舌尔 + pt_BR: Seicheles + - value: SL + label: + en_US: Sierra Leone + zh_Hans: 塞拉利昂 + pt_BR: Serra Leoa - value: SG label: en_US: Singapore zh_Hans: 新加坡 - pt_BR: Singapore + pt_BR: Singapura + - value: SK + label: + en_US: Slovakia + zh_Hans: 斯洛伐克 + pt_BR: Eslováquia + - value: SI + label: + en_US: Slovenia + zh_Hans: 斯洛文尼亚 + pt_BR: Eslovênia + - value: SB + label: + en_US: Solomon Islands + zh_Hans: 所罗门群岛 + pt_BR: Ilhas Salomão + - value: SO + label: + en_US: Somalia + zh_Hans: 索马里 + pt_BR: Somália - value: ZA label: en_US: South Africa zh_Hans: 南非 - pt_BR: South Africa + pt_BR: África do Sul + - value: GS + label: + en_US: South Georgia and the South Sandwich Islands + zh_Hans: 南乔治亚和南桑威奇群岛 + pt_BR: Geórgia do Sul e Ilhas Sandwich do Sul - value: ES label: en_US: Spain zh_Hans: 西班牙 - pt_BR: Spain + pt_BR: Espanha + - value: LK + label: + en_US: Sri Lanka + zh_Hans: 斯里兰卡 + pt_BR: Sri Lanka + - value: SD + label: + en_US: Sudan + zh_Hans: 苏丹 + pt_BR: Sudão + - value: SR + label: + en_US: Suriname + zh_Hans: 苏里南 + pt_BR: Suriname + - value: SJ + label: + en_US: Svalbard and Jan Mayen + zh_Hans: 斯瓦尔巴特和扬马延岛 + pt_BR: Svalbard e Jan Mayen + - value: SZ + label: + en_US: Swaziland + zh_Hans: 斯威士兰 + pt_BR: Essuatíni - value: SE label: en_US: Sweden zh_Hans: 瑞典 - pt_BR: Sweden + pt_BR: Suécia - value: CH label: en_US: Switzerland zh_Hans: 瑞士 - pt_BR: Switzerland + pt_BR: Suíça + - value: SY + label: + en_US: Syrian Arab Republic + zh_Hans: 叙利亚 + pt_BR: Síria - value: TW label: - en_US: Taiwan + en_US: Taiwan, Province of China zh_Hans: 台湾 pt_BR: Taiwan + - value: TJ + label: + en_US: Tajikistan + zh_Hans: 塔吉克斯坦 + pt_BR: Tajiquistão + - value: TZ + label: + en_US: Tanzania, United Republic of + zh_Hans: 坦桑尼亚联合共和国 + pt_BR: Tanzânia - value: TH label: en_US: Thailand zh_Hans: 泰国 - pt_BR: Thailand + pt_BR: Tailândia + - value: TL + label: + en_US: Timor-Leste + zh_Hans: 东帝汶 + pt_BR: Timor-Leste + - value: TG + label: + en_US: Togo + zh_Hans: 多哥 + pt_BR: Togo + - value: TK + label: + en_US: Tokelau + zh_Hans: 托克劳 + pt_BR: Toquelau + - value: TO + label: + en_US: Tonga + zh_Hans: 汤加 + pt_BR: Tonga + - value: TT + label: + en_US: Trinidad and Tobago + zh_Hans: 特立尼达和多巴哥 + pt_BR: Trindade e Tobago + - value: TN + label: + en_US: Tunisia + zh_Hans: 突尼斯 + pt_BR: Tunísia - value: TR label: en_US: Turkey zh_Hans: 土耳其 - pt_BR: Turkey + pt_BR: Turquia + - value: TM + label: + en_US: Turkmenistan + zh_Hans: 土库曼斯坦 + pt_BR: Turcomenistão + - value: TC + label: + en_US: Turks and Caicos Islands + zh_Hans: 特克斯和凯科斯群岛 + pt_BR: Ilhas Turks e Caicos + - value: TV + label: + en_US: Tuvalu + zh_Hans: 图瓦卢 + pt_BR: Tuvalu + - value: UG + label: + en_US: Uganda + zh_Hans: 乌干达 + pt_BR: Uganda + - value: UA + label: + en_US: Ukraine + zh_Hans: 乌克兰 + pt_BR: Ucrânia + - value: AE + label: + en_US: United Arab Emirates + zh_Hans: 阿联酋 + pt_BR: Emirados Árabes Unidos + - value: UK + label: + en_US: United Kingdom + zh_Hans: 英国 + pt_BR: Reino Unido - value: GB label: en_US: United Kingdom zh_Hans: 英国 - pt_BR: United Kingdom + pt_BR: Reino Unido - value: US label: en_US: United States zh_Hans: 美国 - pt_BR: United States + pt_BR: Estados Unidos + - value: UM + label: + en_US: United States Minor Outlying Islands + zh_Hans: 美国本土外小岛屿 + pt_BR: Ilhas Menores Distantes dos Estados Unidos + - value: UY + label: + en_US: Uruguay + zh_Hans: 乌拉圭 + pt_BR: Uruguai + - value: UZ + label: + en_US: Uzbekistan + zh_Hans: 乌兹别克斯坦 + pt_BR: Uzbequistão + - value: VU + label: + en_US: Vanuatu + zh_Hans: 瓦努阿图 + pt_BR: Vanuatu + - value: VE + label: + en_US: Venezuela + zh_Hans: 委内瑞拉 + pt_BR: Venezuela + - value: VN + label: + en_US: Viet Nam + zh_Hans: 越南 + pt_BR: Vietnã + - value: VG + label: + en_US: Virgin Islands, British + zh_Hans: 英属维尔京群岛 + pt_BR: Ilhas Virgens Britânicas + - value: VI + label: + en_US: Virgin Islands, U.S. + zh_Hans: 美属维尔京群岛 + pt_BR: Ilhas Virgens dos EUA + - value: WF + label: + en_US: Wallis and Futuna + zh_Hans: 瓦利斯和富图纳群岛 + pt_BR: Wallis e Futuna + - value: EH + label: + en_US: Western Sahara + zh_Hans: 西撒哈拉 + pt_BR: Saara Ocidental + - value: YE + label: + en_US: Yemen + zh_Hans: 也门 + pt_BR: Iémen + - value: ZM + label: + en_US: Zambia + zh_Hans: 赞比亚 + pt_BR: Zâmbia + - value: ZW + label: + en_US: Zimbabwe + zh_Hans: 津巴布韦 + pt_BR: Zimbábue - name: hl type: select label: @@ -277,18 +1277,94 @@ parameters: default: en form: form options: + - value: af + label: + en_US: Afrikaans + zh_Hans: 南非语 + - value: ak + label: + en_US: Akan + zh_Hans: 阿坎语 + - value: sq + label: + en_US: Albanian + zh_Hans: 阿尔巴尼亚语 + - value: ws + label: + en_US: Samoa + zh_Hans: 萨摩亚语 + - value: am + label: + en_US: Amharic + zh_Hans: 阿姆哈拉语 - value: ar label: en_US: Arabic zh_Hans: 阿拉伯语 + - value: hy + label: + en_US: Armenian + zh_Hans: 亚美尼亚语 + - value: az + label: + en_US: Azerbaijani + zh_Hans: 阿塞拜疆语 + - value: eu + label: + en_US: Basque + zh_Hans: 巴斯克语 + - value: be + label: + en_US: Belarusian + zh_Hans: 白俄罗斯语 + - value: bem + label: + en_US: Bemba + zh_Hans: 班巴语 + - value: bn + label: + en_US: Bengali + zh_Hans: 孟加拉语 + - value: bh + label: + en_US: Bihari + zh_Hans: 比哈尔语 + - value: xx-bork + label: + en_US: Bork, bork, bork! + zh_Hans: 博克语 + - value: bs + label: + en_US: Bosnian + zh_Hans: 波斯尼亚语 + - value: br + label: + en_US: Breton + zh_Hans: 布列塔尼语 - value: bg label: en_US: Bulgarian zh_Hans: 保加利亚语 + - value: bt + label: + en_US: Bhutanese + zh_Hans: 不丹语 + - value: km + label: + en_US: Cambodian + zh_Hans: 高棉语 - value: ca label: en_US: Catalan zh_Hans: 加泰罗尼亚语 + - value: chr + label: + en_US: Cherokee + zh_Hans: 切罗基语 + - value: ny + label: + en_US: Chichewa + zh_Hans: 齐切瓦语 - value: zh-cn label: en_US: Chinese (Simplified) @@ -297,6 +1373,14 @@ parameters: label: en_US: Chinese (Traditional) zh_Hans: 中文(繁体) + - value: co + label: + en_US: Corsican + zh_Hans: 科西嘉语 + - value: hr + label: + en_US: Croatian + zh_Hans: 克罗地亚语 - value: cs label: en_US: Czech @@ -309,14 +1393,34 @@ parameters: label: en_US: Dutch zh_Hans: 荷兰语 + - value: xx-elmer + label: + en_US: Elmer Fudd + zh_Hans: 艾尔默福德语 - value: en label: en_US: English zh_Hans: 英语 + - value: eo + label: + en_US: Esperanto + zh_Hans: 世界语 - value: et label: en_US: Estonian zh_Hans: 爱沙尼亚语 + - value: ee + label: + en_US: Ewe + zh_Hans: 埃维语 + - value: fo + label: + en_US: Faroese + zh_Hans: 法罗语 + - value: tl + label: + en_US: Filipino + zh_Hans: 菲律宾语 - value: fi label: en_US: Finnish @@ -325,6 +1429,22 @@ parameters: label: en_US: French zh_Hans: 法语 + - value: fy + label: + en_US: Frisian + zh_Hans: 弗里西亚语 + - value: gaa + label: + en_US: Ga + zh_Hans: 加语 + - value: gl + label: + en_US: Galician + zh_Hans: 加利西亚语 + - value: ka + label: + en_US: Georgian + zh_Hans: 格鲁吉亚语 - value: de label: en_US: German @@ -333,6 +1453,34 @@ parameters: label: en_US: Greek zh_Hans: 希腊语 + - value: kl + label: + en_US: Greenlandic + zh_Hans: 格陵兰语 + - value: gn + label: + en_US: Guarani + zh_Hans: 瓜拉尼语 + - value: gu + label: + en_US: Gujarati + zh_Hans: 古吉拉特语 + - value: xx-hacker + label: + en_US: Hacker + zh_Hans: 黑客语 + - value: ht + label: + en_US: Haitian Creole + zh_Hans: 海地克里奥尔语 + - value: ha + label: + en_US: Hausa + zh_Hans: 豪萨语 + - value: haw + label: + en_US: Hawaiian + zh_Hans: 夏威夷语 - value: iw label: en_US: Hebrew @@ -345,10 +1493,26 @@ parameters: label: en_US: Hungarian zh_Hans: 匈牙利语 + - value: is + label: + en_US: Icelandic + zh_Hans: 冰岛语 + - value: ig + label: + en_US: Igbo + zh_Hans: 伊博语 - value: id label: en_US: Indonesian zh_Hans: 印尼语 + - value: ia + label: + en_US: Interlingua + zh_Hans: 国际语 + - value: ga + label: + en_US: Irish + zh_Hans: 爱尔兰语 - value: it label: en_US: Italian @@ -357,22 +1521,94 @@ parameters: label: en_US: Japanese zh_Hans: 日语 + - value: jw + label: + en_US: Javanese + zh_Hans: 爪哇语 - value: kn label: en_US: Kannada zh_Hans: 卡纳达语 + - value: kk + label: + en_US: Kazakh + zh_Hans: 哈萨克语 + - value: rw + label: + en_US: Kinyarwanda + zh_Hans: 基尼亚卢旺达语 + - value: rn + label: + en_US: Kirundi + zh_Hans: 基隆迪语 + - value: xx-klingon + label: + en_US: Klingon + zh_Hans: 克林贡语 + - value: kg + label: + en_US: Kongo + zh_Hans: 刚果语 - value: ko label: en_US: Korean zh_Hans: 韩语 + - value: kri + label: + en_US: Krio (Sierra Leone) + zh_Hans: 塞拉利昂克里奥尔语 + - value: ku + label: + en_US: Kurdish + zh_Hans: 库尔德语 + - value: ckb + label: + en_US: Kurdish (Soranî) + zh_Hans: 库尔德语(索拉尼) + - value: ky + label: + en_US: Kyrgyz + zh_Hans: 吉尔吉斯语 + - value: lo + label: + en_US: Laothian + zh_Hans: 老挝语 + - value: la + label: + en_US: Latin + zh_Hans: 拉丁语 - value: lv label: en_US: Latvian zh_Hans: 拉脱维亚语 + - value: ln + label: + en_US: Lingala + zh_Hans: 林加拉语 - value: lt label: en_US: Lithuanian zh_Hans: 立陶宛语 + - value: loz + label: + en_US: Lozi + zh_Hans: 洛齐语 + - value: lg + label: + en_US: Luganda + zh_Hans: 卢干达语 + - value: ach + label: + en_US: Luo + zh_Hans: 卢奥语 + - value: mk + label: + en_US: Macedonian + zh_Hans: 马其顿语 + - value: mg + label: + en_US: Malagasy + zh_Hans: 马尔加什语 - value: my label: en_US: Malay @@ -381,18 +1617,90 @@ parameters: label: en_US: Malayalam zh_Hans: 马拉雅拉姆语 + - value: mt + label: + en_US: Maltese + zh_Hans: 马耳他语 + - value: mv + label: + en_US: Maldives + zh_Hans: 马尔代夫语 + - value: mi + label: + en_US: Maori + zh_Hans: 毛利语 - value: mr label: en_US: Marathi zh_Hans: 马拉地语 + - value: mfe + label: + en_US: Mauritian Creole + zh_Hans: 毛里求斯克里奥尔语 + - value: mo + label: + en_US: Moldavian + zh_Hans: 摩尔达维亚语 + - value: mn + label: + en_US: Mongolian + zh_Hans: 蒙古语 + - value: sr-me + label: + en_US: Montenegrin + zh_Hans: 黑山语 + - value: ne + label: + en_US: Nepali + zh_Hans: 尼泊尔语 + - value: pcm + label: + en_US: Nigerian Pidgin + zh_Hans: 尼日利亚皮钦语 + - value: nso + label: + en_US: Northern Sotho + zh_Hans: 北索托语 - value: "no" label: en_US: Norwegian zh_Hans: 挪威语 + - value: nn + label: + en_US: Norwegian (Nynorsk) + zh_Hans: 挪威语(尼诺斯克语) + - value: oc + label: + en_US: Occitan + zh_Hans: 奥克语 + - value: or + label: + en_US: Oriya + zh_Hans: 奥里亚语 + - value: om + label: + en_US: Oromo + zh_Hans: 奥罗莫语 + - value: ps + label: + en_US: Pashto + zh_Hans: 普什图语 + - value: fa + label: + en_US: Persian + zh_Hans: 波斯语 + - value: xx-pirate + label: + en_US: Pirate + zh_Hans: 海盗语 - value: pl label: en_US: Polish zh_Hans: 波兰语 + - value: pt + label: + en_US: Portuguese + zh_Hans: 葡萄牙语 - value: pt-br label: en_US: Portuguese (Brazil) @@ -405,18 +1713,62 @@ parameters: label: en_US: Punjabi zh_Hans: 旁遮普语 + - value: qu + label: + en_US: Quechua + zh_Hans: 克丘亚语 - value: ro label: en_US: Romanian zh_Hans: 罗马尼亚语 + - value: rm + label: + en_US: Romansh + zh_Hans: 罗曼什语 + - value: nyn + label: + en_US: Runyakitara + zh_Hans: 卢尼亚基塔拉语 - value: ru label: en_US: Russian zh_Hans: 俄语 + - value: gd + label: + en_US: Scots Gaelic + zh_Hans: 苏格兰盖尔语 - value: sr label: en_US: Serbian zh_Hans: 塞尔维亚语 + - value: sh + label: + en_US: Serbo-Croatian + zh_Hans: 塞尔维亚-克罗地亚语 + - value: st + label: + en_US: Sesotho + zh_Hans: 塞索托语 + - value: tn + label: + en_US: Setswana + zh_Hans: 塞茨瓦纳语 + - value: crs + label: + en_US: Seychellois Creole + zh_Hans: 塞舌尔克里奥尔语 + - value: sn + label: + en_US: Shona + zh_Hans: 绍纳语 + - value: sd + label: + en_US: Sindhi + zh_Hans: 信德语 + - value: si + label: + en_US: Sinhalese + zh_Hans: 僧伽罗语 - value: sk label: en_US: Slovak @@ -425,18 +1777,42 @@ parameters: label: en_US: Slovenian zh_Hans: 斯洛文尼亚语 + - value: so + label: + en_US: Somali + zh_Hans: 索马里语 - value: es label: en_US: Spanish zh_Hans: 西班牙语 + - value: es-419 + label: + en_US: Spanish (Latin American) + zh_Hans: 西班牙语(拉丁美洲) + - value: su + label: + en_US: Sundanese + zh_Hans: 巽他语 + - value: sw + label: + en_US: Swahili + zh_Hans: 斯瓦希里语 - value: sv label: en_US: Swedish zh_Hans: 瑞典语 + - value: tg + label: + en_US: Tajik + zh_Hans: 塔吉克语 - value: ta label: en_US: Tamil zh_Hans: 泰米尔语 + - value: tt + label: + en_US: Tatar + zh_Hans: 鞑靼语 - value: te label: en_US: Telugu @@ -445,18 +1821,82 @@ parameters: label: en_US: Thai zh_Hans: 泰语 + - value: ti + label: + en_US: Tigrinya + zh_Hans: 提格利尼亚语 + - value: to + label: + en_US: Tonga + zh_Hans: 汤加语 + - value: lua + label: + en_US: Tshiluba + zh_Hans: 卢巴语 + - value: tum + label: + en_US: Tumbuka + zh_Hans: 图布卡语 - value: tr label: en_US: Turkish zh_Hans: 土耳其语 + - value: tk + label: + en_US: Turkmen + zh_Hans: 土库曼语 + - value: tw + label: + en_US: Twi + zh_Hans: 契维语 + - value: ug + label: + en_US: Uighur + zh_Hans: 维吾尔语 - value: uk label: en_US: Ukrainian zh_Hans: 乌克兰语 + - value: ur + label: + en_US: Urdu + zh_Hans: 乌尔都语 + - value: uz + label: + en_US: Uzbek + zh_Hans: 乌兹别克语 + - value: vu + label: + en_US: Vanuatu + zh_Hans: 瓦努阿图语 - value: vi label: en_US: Vietnamese zh_Hans: 越南语 + - value: cy + label: + en_US: Welsh + zh_Hans: 威尔士语 + - value: wo + label: + en_US: Wolof + zh_Hans: 沃洛夫语 + - value: xh + label: + en_US: Xhosa + zh_Hans: 科萨语 + - value: yi + label: + en_US: Yiddish + zh_Hans: 意第绪语 + - value: yo + label: + en_US: Yoruba + zh_Hans: 约鲁巴语 + - value: zu + label: + en_US: Zulu + zh_Hans: 祖鲁语 - name: google_domain type: string required: false diff --git a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py index 6345b338011e7a..1867cf7be79be5 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py +++ b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py @@ -7,6 +7,7 @@ SEARCH_API_URL = "https://www.searchapi.io/api/v1/search" + class SearchAPI: """ SearchAPI tool provider. @@ -36,18 +37,18 @@ def get_params(self, video_id: str, language: str, **kwargs: Any) -> dict[str, s return { "engine": "youtube_transcripts", "video_id": video_id, - "lang": language if language else "en", - **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + "lang": language or "en", + **{key: value for key, value in kwargs.items() if value not in {None, ""}}, } @staticmethod def _process_response(res: dict) -> str: """Process response from SearchAPI.""" - if "error" in res.keys(): + if "error" in res: raise ValueError(f"Got error from SearchApi: {res['error']}") toret = "" - if "transcripts" in res.keys() and "text" in res["transcripts"][0].keys(): + if "transcripts" in res and "text" in res["transcripts"][0]: for item in res["transcripts"]: toret += item["text"] + " " if toret == "": @@ -55,18 +56,20 @@ def _process_response(res: dict) -> str: return toret + class YoutubeTranscriptsTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the SearchApi tool. """ - video_id = tool_parameters['video_id'] - language = tool_parameters.get('language', "en") + video_id = tool_parameters["video_id"] + language = tool_parameters.get("language", "en") - api_key = self.runtime.credentials['searchapi_api_key'] + api_key = self.runtime.credentials["searchapi_api_key"] result = SearchAPI(api_key).run(video_id, language=language) return self.create_text_message(text=result) diff --git a/api/core/tools/provider/builtin/searxng/searxng.py b/api/core/tools/provider/builtin/searxng/searxng.py index ab354003e6f567..b7bbcc60b1ed26 100644 --- a/api/core/tools/provider/builtin/searxng/searxng.py +++ b/api/core/tools/provider/builtin/searxng/searxng.py @@ -13,12 +13,8 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "query": "SearXNG", - "limit": 1, - "search_type": "general" - }, + user_id="", + tool_parameters={"query": "SearXNG", "limit": 1, "search_type": "general"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/searxng/tools/searxng_search.py b/api/core/tools/provider/builtin/searxng/tools/searxng_search.py index dc835a8e8cbd5b..c5e339a108e5b2 100644 --- a/api/core/tools/provider/builtin/searxng/tools/searxng_search.py +++ b/api/core/tools/provider/builtin/searxng/tools/searxng_search.py @@ -23,18 +23,21 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation. """ - host = self.runtime.credentials.get('searxng_base_url') + host = self.runtime.credentials.get("searxng_base_url") if not host: - raise Exception('SearXNG api is required') + raise Exception("SearXNG api is required") - response = requests.get(host, params={ - "q": tool_parameters.get('query'), - "format": "json", - "categories": tool_parameters.get('search_type', 'general') - }) + response = requests.get( + host, + params={ + "q": tool_parameters.get("query"), + "format": "json", + "categories": tool_parameters.get("search_type", "general"), + }, + ) if response.status_code != 200: - raise Exception(f'Error {response.status_code}: {response.text}') + raise Exception(f"Error {response.status_code}: {response.text}") res = response.json().get("results", []) if not res: diff --git a/api/core/tools/provider/builtin/serper/serper.py b/api/core/tools/provider/builtin/serper/serper.py index 2a421093731477..cb1d090a9dd4b0 100644 --- a/api/core/tools/provider/builtin/serper/serper.py +++ b/api/core/tools/provider/builtin/serper/serper.py @@ -13,11 +13,8 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "query": "test", - "result_type": "link" - }, + user_id="", + tool_parameters={"query": "test", "result_type": "link"}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/serper/tools/serper_search.py b/api/core/tools/provider/builtin/serper/tools/serper_search.py index 24facaf4ec3ae9..7baebbf95855e0 100644 --- a/api/core/tools/provider/builtin/serper/tools/serper_search.py +++ b/api/core/tools/provider/builtin/serper/tools/serper_search.py @@ -9,7 +9,6 @@ class SerperSearchTool(BuiltinTool): - def _parse_response(self, response: dict) -> dict: result = {} if "knowledgeGraph" in response: @@ -17,28 +16,19 @@ def _parse_response(self, response: dict) -> dict: result["description"] = response["knowledgeGraph"].get("description", "") if "organic" in response: result["organic"] = [ - { - "title": item.get("title", ""), - "link": item.get("link", ""), - "snippet": item.get("snippet", "") - } + {"title": item.get("title", ""), "link": item.get("link", ""), "snippet": item.get("snippet", "")} for item in response["organic"] ] return result - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - params = { - "q": tool_parameters['query'], - "gl": "us", - "hl": "en" - } - headers = { - 'X-API-KEY': self.runtime.credentials['serperapi_api_key'], - 'Content-Type': 'application/json' - } - response = requests.get(url=SERPER_API_URL, params=params,headers=headers) + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + params = {"q": tool_parameters["query"], "gl": "us", "hl": "en"} + headers = {"X-API-KEY": self.runtime.credentials["serperapi_api_key"], "Content-Type": "application/json"} + response = requests.get(url=SERPER_API_URL, params=params, headers=headers) response.raise_for_status() valuable_res = self._parse_response(response.json()) return self.create_json_message(valuable_res) diff --git a/api/core/tools/provider/builtin/siliconflow/_assets/icon.svg b/api/core/tools/provider/builtin/siliconflow/_assets/icon.svg new file mode 100644 index 00000000000000..ad6b384f7acd21 --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/_assets/icon.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/siliconflow/siliconflow.py b/api/core/tools/provider/builtin/siliconflow/siliconflow.py new file mode 100644 index 00000000000000..37a0b0755b1d39 --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/siliconflow.py @@ -0,0 +1,17 @@ +import requests + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class SiliconflowProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + url = "https://api.siliconflow.cn/v1/models" + headers = { + "accept": "application/json", + "authorization": f"Bearer {credentials.get('siliconFlow_api_key')}", + } + + response = requests.get(url, headers=headers) + if response.status_code != 200: + raise ToolProviderCredentialValidationError("SiliconFlow API key is invalid") diff --git a/api/core/tools/provider/builtin/siliconflow/siliconflow.yaml b/api/core/tools/provider/builtin/siliconflow/siliconflow.yaml new file mode 100644 index 00000000000000..46be99f262f211 --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/siliconflow.yaml @@ -0,0 +1,21 @@ +identity: + author: hjlarry + name: siliconflow + label: + en_US: SiliconFlow + zh_CN: 硅基流动 + description: + en_US: The image generation API provided by SiliconFlow includes Flux and Stable Diffusion models. + zh_CN: 硅基流动提供的图片生成 API,包含 Flux 和 Stable Diffusion 模型。 + icon: icon.svg + tags: + - image +credentials_for_provider: + siliconFlow_api_key: + type: secret-input + required: true + label: + en_US: SiliconFlow API Key + placeholder: + en_US: Please input your SiliconFlow API key + url: https://cloud.siliconflow.cn/account/ak diff --git a/api/core/tools/provider/builtin/siliconflow/tools/flux.py b/api/core/tools/provider/builtin/siliconflow/tools/flux.py new file mode 100644 index 00000000000000..0d16ff385eb30d --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/tools/flux.py @@ -0,0 +1,43 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +FLUX_URL = { + "schnell": "https://api.siliconflow.cn/v1/black-forest-labs/FLUX.1-schnell/text-to-image", + "dev": "https://api.siliconflow.cn/v1/image/generations", +} + + +class FluxTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + headers = { + "accept": "application/json", + "content-type": "application/json", + "authorization": f"Bearer {self.runtime.credentials['siliconFlow_api_key']}", + } + + payload = { + "prompt": tool_parameters.get("prompt"), + "image_size": tool_parameters.get("image_size", "1024x1024"), + "seed": tool_parameters.get("seed"), + "num_inference_steps": tool_parameters.get("num_inference_steps", 20), + } + model = tool_parameters.get("model", "schnell") + url = FLUX_URL.get(model) + if model == "dev": + payload["model"] = "black-forest-labs/FLUX.1-dev" + + response = requests.post(url, json=payload, headers=headers) + if response.status_code != 200: + return self.create_text_message(f"Got Error Response:{response.text}") + + res = response.json() + result = [self.create_json_message(res)] + for image in res.get("images", []): + result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value)) + return result diff --git a/api/core/tools/provider/builtin/siliconflow/tools/flux.yaml b/api/core/tools/provider/builtin/siliconflow/tools/flux.yaml new file mode 100644 index 00000000000000..d06b9bf3e1f489 --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/tools/flux.yaml @@ -0,0 +1,88 @@ +identity: + name: flux + author: hjlarry + label: + en_US: Flux + icon: icon.svg +description: + human: + en_US: Generate image via SiliconFlow's flux model. + llm: This tool is used to generate image from prompt via SiliconFlow's flux model. +parameters: + - name: prompt + type: string + required: true + label: + en_US: prompt + zh_Hans: 提示词 + human_description: + en_US: The text prompt used to generate the image. + zh_Hans: 建议用英文的生成图片提示词以获得更好的生成效果。 + llm_description: this prompt text will be used to generate image. + form: llm + - name: model + type: select + required: true + options: + - value: schnell + label: + en_US: Flux.1-schnell + - value: dev + label: + en_US: Flux.1-dev + default: schnell + label: + en_US: Choose Image Model + zh_Hans: 选择生成图片的模型 + form: form + - name: image_size + type: select + required: true + options: + - value: 1024x1024 + label: + en_US: 1024x1024 + - value: 768x1024 + label: + en_US: 768x1024 + - value: 576x1024 + label: + en_US: 576x1024 + - value: 512x1024 + label: + en_US: 512x1024 + - value: 1024x576 + label: + en_US: 1024x576 + - value: 768x512 + label: + en_US: 768x512 + default: 1024x1024 + label: + en_US: Choose Image Size + zh_Hans: 选择生成的图片大小 + form: form + - name: num_inference_steps + type: number + required: true + default: 20 + min: 1 + max: 100 + label: + en_US: Num Inference Steps + zh_Hans: 生成图片的步数 + form: form + human_description: + en_US: The number of inference steps to perform. More steps produce higher quality but take longer. + zh_Hans: 执行的推理步骤数量。更多的步骤可以产生更高质量的结果,但需要更长的时间。 + - name: seed + type: number + min: 0 + max: 9999999999 + label: + en_US: Seed + zh_Hans: 种子 + human_description: + en_US: The same seed and prompt can produce similar images. + zh_Hans: 相同的种子和提示可以产生相似的图像。 + form: form diff --git a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py new file mode 100644 index 00000000000000..db43790c06aaa6 --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.py @@ -0,0 +1,49 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +SILICONFLOW_API_URL = "https://api.siliconflow.cn/v1/image/generations" + +SD_MODELS = { + "sd_3": "stabilityai/stable-diffusion-3-medium", + "sd_xl": "stabilityai/stable-diffusion-xl-base-1.0", + "sd_3.5_large": "stabilityai/stable-diffusion-3-5-large", +} + + +class StableDiffusionTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + headers = { + "accept": "application/json", + "content-type": "application/json", + "authorization": f"Bearer {self.runtime.credentials['siliconFlow_api_key']}", + } + + model = tool_parameters.get("model", "sd_3") + sd_model = SD_MODELS.get(model) + + payload = { + "model": sd_model, + "prompt": tool_parameters.get("prompt"), + "negative_prompt": tool_parameters.get("negative_prompt", ""), + "image_size": tool_parameters.get("image_size", "1024x1024"), + "batch_size": tool_parameters.get("batch_size", 1), + "seed": tool_parameters.get("seed"), + "guidance_scale": tool_parameters.get("guidance_scale", 7.5), + "num_inference_steps": tool_parameters.get("num_inference_steps", 20), + } + + response = requests.post(SILICONFLOW_API_URL, json=payload, headers=headers) + if response.status_code != 200: + return self.create_text_message(f"Got Error Response:{response.text}") + + res = response.json() + result = [self.create_json_message(res)] + for image in res.get("images", []): + result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value)) + return result diff --git a/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.yaml b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.yaml new file mode 100644 index 00000000000000..b330c92e163a38 --- /dev/null +++ b/api/core/tools/provider/builtin/siliconflow/tools/stable_diffusion.yaml @@ -0,0 +1,124 @@ +identity: + name: stable_diffusion + author: hjlarry + label: + en_US: Stable Diffusion + icon: icon.svg +description: + human: + en_US: Generate image via SiliconFlow's stable diffusion model. + llm: This tool is used to generate image from prompt via SiliconFlow's stable diffusion model. +parameters: + - name: prompt + type: string + required: true + label: + en_US: prompt + zh_Hans: 提示词 + human_description: + en_US: The text prompt used to generate the image. + zh_Hans: 用于生成图片的文字提示词 + llm_description: this prompt text will be used to generate image. + form: llm + - name: negative_prompt + type: string + label: + en_US: negative prompt + zh_Hans: 负面提示词 + human_description: + en_US: Describe what you don't want included in the image. + zh_Hans: 描述您不希望包含在图片中的内容。 + llm_description: Describe what you don't want included in the image. + form: llm + - name: model + type: select + required: true + options: + - value: sd_3 + label: + en_US: Stable Diffusion 3 + - value: sd_xl + label: + en_US: Stable Diffusion XL + - value: sd_3.5_large + label: + en_US: Stable Diffusion 3.5 Large + default: sd_3 + label: + en_US: Choose Image Model + zh_Hans: 选择生成图片的模型 + form: form + - name: image_size + type: select + required: true + options: + - value: 1024x1024 + label: + en_US: 1024x1024 + - value: 1024x2048 + label: + en_US: 1024x2048 + - value: 1152x2048 + label: + en_US: 1152x2048 + - value: 1536x1024 + label: + en_US: 1536x1024 + - value: 1536x2048 + label: + en_US: 1536x2048 + - value: 2048x1152 + label: + en_US: 2048x1152 + default: 1024x1024 + label: + en_US: Choose Image Size + zh_Hans: 选择生成图片的大小 + form: form + - name: batch_size + type: number + required: true + default: 1 + min: 1 + max: 4 + label: + en_US: Number Images + zh_Hans: 生成图片的数量 + form: form + - name: guidance_scale + type: number + required: true + default: 7.5 + min: 0 + max: 100 + label: + en_US: Guidance Scale + zh_Hans: 与提示词紧密性 + human_description: + en_US: Classifier Free Guidance. How close you want the model to stick to your prompt when looking for a related image to show you. + zh_Hans: 无分类器引导。您希望模型在寻找相关图片向您展示时,与您的提示保持多紧密的关联度。 + form: form + - name: num_inference_steps + type: number + required: true + default: 20 + min: 1 + max: 100 + label: + en_US: Num Inference Steps + zh_Hans: 生成图片的步数 + human_description: + en_US: The number of inference steps to perform. More steps produce higher quality but take longer. + zh_Hans: 执行的推理步骤数量。更多的步骤可以产生更高质量的结果,但需要更长的时间。 + form: form + - name: seed + type: number + min: 0 + max: 9999999999 + label: + en_US: Seed + zh_Hans: 种子 + human_description: + en_US: The same seed and prompt can produce similar images. + zh_Hans: 相同的种子和提示可以产生相似的图像。 + form: form diff --git a/api/core/tools/provider/builtin/slack/tools/slack_webhook.py b/api/core/tools/provider/builtin/slack/tools/slack_webhook.py index f47557f2ef5852..85e0de76755898 100644 --- a/api/core/tools/provider/builtin/slack/tools/slack_webhook.py +++ b/api/core/tools/provider/builtin/slack/tools/slack_webhook.py @@ -7,25 +7,27 @@ class SlackWebhookTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - Incoming Webhooks - API Document: https://api.slack.com/messaging/webhooks + Incoming Webhooks + API Document: https://api.slack.com/messaging/webhooks """ - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - webhook_url = tool_parameters.get('webhook_url', '') + webhook_url = tool_parameters.get("webhook_url", "") - if not webhook_url.startswith('https://hooks.slack.com/'): + if not webhook_url.startswith("https://hooks.slack.com/"): return self.create_text_message( - f'Invalid parameter webhook_url ${webhook_url}, not a valid Slack webhook URL') + f"Invalid parameter webhook_url ${webhook_url}, not a valid Slack webhook URL" + ) headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = {} payload = { @@ -38,6 +40,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any] return self.create_text_message("Text message was sent successfully") else: return self.create_text_message( - f"Failed to send the text message, status code: {res.status_code}, response: {res.text}") + f"Failed to send the text message, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: - return self.create_text_message("Failed to send message through webhook. {}".format(e)) \ No newline at end of file + return self.create_text_message("Failed to send message through webhook. {}".format(e)) diff --git a/api/core/tools/provider/builtin/spark/spark.py b/api/core/tools/provider/builtin/spark/spark.py index cb8e69a59f8e3e..e0b1a58a3f679a 100644 --- a/api/core/tools/provider/builtin/spark/spark.py +++ b/api/core/tools/provider/builtin/spark/spark.py @@ -29,12 +29,8 @@ def _validate_credentials(self, credentials: dict) -> None: # 0 success, pass else: - raise ToolProviderCredentialValidationError( - "image generate error, code:{}".format(code) - ) + raise ToolProviderCredentialValidationError("image generate error, code:{}".format(code)) except Exception as e: - raise ToolProviderCredentialValidationError( - "APPID APISecret APIKey is invalid. {}".format(e) - ) + raise ToolProviderCredentialValidationError("APPID APISecret APIKey is invalid. {}".format(e)) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py index a977af2b765067..81d9e8d94185f7 100644 --- a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py +++ b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py @@ -15,16 +15,16 @@ from core.tools.tool.builtin_tool import BuiltinTool -class AssembleHeaderException(Exception): +class AssembleHeaderError(Exception): def __init__(self, msg): self.message = msg class Url: - def __init__(this, host, path, schema): - this.host = host - this.path = path - this.schema = schema + def __init__(self, host, path, schema): + self.host = host + self.path = path + self.schema = schema # calculate sha256 and encode to base64 @@ -35,49 +35,46 @@ def sha256base64(data): return digest -def parse_url(requset_url): - stidx = requset_url.index("://") - host = requset_url[stidx + 3 :] - schema = requset_url[: stidx + 3] +def parse_url(request_url): + stidx = request_url.index("://") + host = request_url[stidx + 3 :] + schema = request_url[: stidx + 3] edidx = host.index("/") if edidx <= 0: - raise AssembleHeaderException("invalid request url:" + requset_url) + raise AssembleHeaderError("invalid request url:" + request_url) path = host[edidx:] host = host[:edidx] u = Url(host, path, schema) return u -def assemble_ws_auth_url(requset_url, method="GET", api_key="", api_secret=""): - u = parse_url(requset_url) + +def assemble_ws_auth_url(request_url, method="GET", api_key="", api_secret=""): + u = parse_url(request_url) host = u.host path = u.path now = datetime.now() date = format_date_time(mktime(now.timetuple())) - signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format( - host, date, method, path - ) + signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(host, date, method, path) signature_sha = hmac.new( api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256, ).digest() signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8") - authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha}"' - - authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode( - encoding="utf-8" + authorization_origin = ( + f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha}"' ) + + authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8") values = {"host": host, "date": date, "authorization": authorization} - return requset_url + "?" + urlencode(values) + return request_url + "?" + urlencode(values) def get_body(appid, text): body = { "header": {"app_id": appid, "uid": "123456789"}, - "parameter": { - "chat": {"domain": "general", "temperature": 0.5, "max_tokens": 4096} - }, + "parameter": {"chat": {"domain": "general", "temperature": 0.5, "max_tokens": 4096}}, "payload": {"message": {"text": [{"role": "user", "content": text}]}}, } return body @@ -85,13 +82,9 @@ def get_body(appid, text): def spark_response(text, appid, apikey, apisecret): host = "http://spark-api.cn-huabei-1.xf-yun.com/v2.1/tti" - url = assemble_ws_auth_url( - host, method="POST", api_key=apikey, api_secret=apisecret - ) + url = assemble_ws_auth_url(host, method="POST", api_key=apikey, api_secret=apisecret) content = get_body(appid, text) - response = requests.post( - url, json=content, headers={"content-type": "application/json"} - ).text + response = requests.post(url, json=content, headers={"content-type": "application/json"}).text return response @@ -105,19 +98,11 @@ def _invoke( invoke tools """ - if "APPID" not in self.runtime.credentials or not self.runtime.credentials.get( - "APPID" - ): + if "APPID" not in self.runtime.credentials or not self.runtime.credentials.get("APPID"): return self.create_text_message("APPID is required.") - if ( - "APISecret" not in self.runtime.credentials - or not self.runtime.credentials.get("APISecret") - ): + if "APISecret" not in self.runtime.credentials or not self.runtime.credentials.get("APISecret"): return self.create_text_message("APISecret is required.") - if ( - "APIKey" not in self.runtime.credentials - or not self.runtime.credentials.get("APIKey") - ): + if "APIKey" not in self.runtime.credentials or not self.runtime.credentials.get("APIKey"): return self.create_text_message("APIKey is required.") prompt = tool_parameters.get("prompt", "") @@ -130,7 +115,7 @@ def _invoke( self.create_blob_message( blob=b64decode(image["base64_image"]), meta={"mime_type": "image/png"}, - save_as=self.VARIABLE_KEY.IMAGE.value, + save_as=self.VariableKey.IMAGE.value, ) ) return result diff --git a/api/core/tools/provider/builtin/spider/spider.py b/api/core/tools/provider/builtin/spider/spider.py index 5bcc56a7248c1d..5959555318722e 100644 --- a/api/core/tools/provider/builtin/spider/spider.py +++ b/api/core/tools/provider/builtin/spider/spider.py @@ -8,13 +8,13 @@ class SpiderProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - app = Spider(api_key=credentials['spider_api_key']) - app.scrape_url(url='https://spider.cloud') + app = Spider(api_key=credentials["spider_api_key"]) + app.scrape_url(url="https://spider.cloud") except AttributeError as e: # Handle cases where NoneType is not iterable, which might indicate API issues - if 'NoneType' in str(e) and 'not iterable' in str(e): - raise ToolProviderCredentialValidationError('API is currently down, try again in 15 minutes', str(e)) + if "NoneType" in str(e) and "not iterable" in str(e): + raise ToolProviderCredentialValidationError("API is currently down, try again in 15 minutes", str(e)) else: - raise ToolProviderCredentialValidationError('An unexpected error occurred.', str(e)) + raise ToolProviderCredentialValidationError("An unexpected error occurred.", str(e)) except Exception as e: - raise ToolProviderCredentialValidationError('An unexpected error occurred.', str(e)) + raise ToolProviderCredentialValidationError("An unexpected error occurred.", str(e)) diff --git a/api/core/tools/provider/builtin/spider/spiderApp.py b/api/core/tools/provider/builtin/spider/spiderApp.py index f0ed64867a18a1..4bc446a1a092a3 100644 --- a/api/core/tools/provider/builtin/spider/spiderApp.py +++ b/api/core/tools/provider/builtin/spider/spiderApp.py @@ -65,9 +65,7 @@ def api_post( :return: The JSON response or the raw response stream if stream is True. """ headers = self._prepare_headers(content_type) - response = self._post_request( - f"https://api.spider.cloud/v1/{endpoint}", data, headers, stream - ) + response = self._post_request(f"https://api.spider.cloud/v1/{endpoint}", data, headers, stream) if stream: return response @@ -76,9 +74,7 @@ def api_post( else: self._handle_error(response, f"post to {endpoint}") - def api_get( - self, endpoint: str, stream: bool, content_type: str = "application/json" - ): + def api_get(self, endpoint: str, stream: bool, content_type: str = "application/json"): """ Send a GET request to the specified endpoint. @@ -86,9 +82,7 @@ def api_get( :return: The JSON decoded response. """ headers = self._prepare_headers(content_type) - response = self._get_request( - f"https://api.spider.cloud/v1/{endpoint}", headers, stream - ) + response = self._get_request(f"https://api.spider.cloud/v1/{endpoint}", headers, stream) if response.status_code == 200: return response.json() else: @@ -120,14 +114,12 @@ def scrape_url( # Add { "return_format": "markdown" } to the params if not already present if "return_format" not in params: - params["return_format"] = "markdown" + params["return_format"] = "markdown" # Set limit to 1 params["limit"] = 1 - return self.api_post( - "crawl", {"url": url, **(params or {})}, stream, content_type - ) + return self.api_post("crawl", {"url": url, **(params or {})}, stream, content_type) def crawl_url( self, @@ -150,9 +142,7 @@ def crawl_url( if "return_format" not in params: params["return_format"] = "markdown" - return self.api_post( - "crawl", {"url": url, **(params or {})}, stream, content_type - ) + return self.api_post("crawl", {"url": url, **(params or {})}, stream, content_type) def links( self, @@ -168,9 +158,7 @@ def links( :param params: Optional parameters for the link retrieval request. :return: JSON response containing the links. """ - return self.api_post( - "links", {"url": url, **(params or {})}, stream, content_type - ) + return self.api_post("links", {"url": url, **(params or {})}, stream, content_type) def extract_contacts( self, @@ -207,9 +195,7 @@ def label( :param params: Optional parameters to guide the labeling process. :return: JSON response with labeled data. """ - return self.api_post( - "pipeline/label", {"url": url, **(params or {})}, stream, content_type - ) + return self.api_post("pipeline/label", {"url": url, **(params or {})}, stream, content_type) def _prepare_headers(self, content_type: str = "application/json"): return { @@ -228,12 +214,8 @@ def _delete_request(self, url: str, headers, stream=False): return requests.delete(url, headers=headers, stream=stream) def _handle_error(self, response, action): - if response.status_code in [402, 409, 500]: + if response.status_code in {402, 409, 500}: error_message = response.json().get("error", "Unknown error occurred") - raise Exception( - f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}" - ) + raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") else: - raise Exception( - f"Unexpected error occurred while trying to {action}. Status code: {response.status_code}" - ) + raise Exception(f"Unexpected error occurred while trying to {action}. Status code: {response.status_code}") diff --git a/api/core/tools/provider/builtin/spider/tools/scraper_crawler.py b/api/core/tools/provider/builtin/spider/tools/scraper_crawler.py index 64bbcc10ccc13e..20d2daef550de1 100644 --- a/api/core/tools/provider/builtin/spider/tools/scraper_crawler.py +++ b/api/core/tools/provider/builtin/spider/tools/scraper_crawler.py @@ -6,42 +6,44 @@ class ScrapeTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: # initialize the app object with the api key - app = Spider(api_key=self.runtime.credentials['spider_api_key']) + app = Spider(api_key=self.runtime.credentials["spider_api_key"]) + + url = tool_parameters["url"] + mode = tool_parameters["mode"] - url = tool_parameters['url'] - mode = tool_parameters['mode'] - options = { - 'limit': tool_parameters.get('limit', 0), - 'depth': tool_parameters.get('depth', 0), - 'blacklist': tool_parameters.get('blacklist', '').split(',') if tool_parameters.get('blacklist') else [], - 'whitelist': tool_parameters.get('whitelist', '').split(',') if tool_parameters.get('whitelist') else [], - 'readability': tool_parameters.get('readability', False), + "limit": tool_parameters.get("limit", 0), + "depth": tool_parameters.get("depth", 0), + "blacklist": tool_parameters.get("blacklist", "").split(",") if tool_parameters.get("blacklist") else [], + "whitelist": tool_parameters.get("whitelist", "").split(",") if tool_parameters.get("whitelist") else [], + "readability": tool_parameters.get("readability", False), } result = "" try: - if mode == 'scrape': + if mode == "scrape": scrape_result = app.scrape_url( - url=url, + url=url, params=options, ) for i in scrape_result: - result += "URL: " + i.get('url', '') + "\n" - result += "CONTENT: " + i.get('content', '') + "\n\n" - elif mode == 'crawl': + result += "URL: " + i.get("url", "") + "\n" + result += "CONTENT: " + i.get("content", "") + "\n\n" + elif mode == "crawl": crawl_result = app.crawl_url( - url=tool_parameters['url'], + url=tool_parameters["url"], params=options, ) for i in crawl_result: - result += "URL: " + i.get('url', '') + "\n" - result += "CONTENT: " + i.get('content', '') + "\n\n" + result += "URL: " + i.get("url", "") + "\n" + result += "CONTENT: " + i.get("content", "") + "\n\n" except Exception as e: - return self.create_text_message("An error occured", str(e)) + return self.create_text_message("An error occurred", str(e)) return self.create_text_message(result) diff --git a/api/core/tools/provider/builtin/stability/stability.py b/api/core/tools/provider/builtin/stability/stability.py index b31d786178dd63..f09d81ac270288 100644 --- a/api/core/tools/provider/builtin/stability/stability.py +++ b/api/core/tools/provider/builtin/stability/stability.py @@ -8,6 +8,7 @@ class StabilityToolProvider(BuiltinToolProviderController, BaseStabilityAuthoriz """ This class is responsible for providing the stability tool. """ + def _validate_credentials(self, credentials: dict[str, Any]) -> None: """ This method is responsible for validating the credentials. diff --git a/api/core/tools/provider/builtin/stability/tools/base.py b/api/core/tools/provider/builtin/stability/tools/base.py index a4788fd869ce1b..c3b7edbefa2447 100644 --- a/api/core/tools/provider/builtin/stability/tools/base.py +++ b/api/core/tools/provider/builtin/stability/tools/base.py @@ -9,26 +9,23 @@ def sd_validate_credentials(self, credentials: dict): """ This method is responsible for validating the credentials. """ - api_key = credentials.get('api_key', '') + api_key = credentials.get("api_key", "") if not api_key: - raise ToolProviderCredentialValidationError('API key is required.') - + raise ToolProviderCredentialValidationError("API key is required.") + response = requests.get( - URL('https://api.stability.ai') / 'v1' / 'user' / 'account', + URL("https://api.stability.ai") / "v1" / "user" / "account", headers=self.generate_authorization_headers(credentials), - timeout=(5, 30) + timeout=(5, 30), ) if not response.ok: - raise ToolProviderCredentialValidationError('Invalid API key.') + raise ToolProviderCredentialValidationError("Invalid API key.") return True - + def generate_authorization_headers(self, credentials: dict) -> dict[str, str]: """ This method is responsible for generating the authorization headers. """ - return { - 'Authorization': f'Bearer {credentials.get("api_key", "")}' - } - \ No newline at end of file + return {"Authorization": f'Bearer {credentials.get("api_key", "")}'} diff --git a/api/core/tools/provider/builtin/stability/tools/text2image.py b/api/core/tools/provider/builtin/stability/tools/text2image.py index 41236f7b433cef..6bcf315484ad50 100644 --- a/api/core/tools/provider/builtin/stability/tools/text2image.py +++ b/api/core/tools/provider/builtin/stability/tools/text2image.py @@ -11,10 +11,11 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization): """ This class is responsible for providing the stable diffusion tool. """ + model_endpoint_map: dict[str, str] = { - 'sd3': 'https://api.stability.ai/v2beta/stable-image/generate/sd3', - 'sd3-turbo': 'https://api.stability.ai/v2beta/stable-image/generate/sd3', - 'core': 'https://api.stability.ai/v2beta/stable-image/generate/core', + "sd3": "https://api.stability.ai/v2beta/stable-image/generate/sd3", + "sd3-turbo": "https://api.stability.ai/v2beta/stable-image/generate/sd3", + "core": "https://api.stability.ai/v2beta/stable-image/generate/core", } def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: @@ -22,39 +23,34 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe Invoke the tool. """ payload = { - 'prompt': tool_parameters.get('prompt', ''), - 'aspect_ratio': tool_parameters.get('aspect_ratio', '16:9') or tool_parameters.get('aspect_radio', '16:9'), - 'mode': 'text-to-image', - 'seed': tool_parameters.get('seed', 0), - 'output_format': 'png', + "prompt": tool_parameters.get("prompt", ""), + "aspect_ratio": tool_parameters.get("aspect_ratio", "16:9") or tool_parameters.get("aspect_radio", "16:9"), + "mode": "text-to-image", + "seed": tool_parameters.get("seed", 0), + "output_format": "png", } - model = tool_parameters.get('model', 'core') + model = tool_parameters.get("model", "core") - if model in ['sd3', 'sd3-turbo']: - payload['model'] = tool_parameters.get('model') + if model in {"sd3", "sd3-turbo"}: + payload["model"] = tool_parameters.get("model") - if not model == 'sd3-turbo': - payload['negative_prompt'] = tool_parameters.get('negative_prompt', '') + if model != "sd3-turbo": + payload["negative_prompt"] = tool_parameters.get("negative_prompt", "") response = post( - self.model_endpoint_map[tool_parameters.get('model', 'core')], + self.model_endpoint_map[tool_parameters.get("model", "core")], headers={ - 'accept': 'image/*', + "accept": "image/*", **self.generate_authorization_headers(self.runtime.credentials), }, - files={ - key: (None, str(value)) for key, value in payload.items() - }, - timeout=(5, 30) + files={key: (None, str(value)) for key, value in payload.items()}, + timeout=(5, 30), ) if not response.status_code == 200: raise Exception(response.text) - + return self.create_blob_message( - blob=response.content, meta={ - 'mime_type': 'image/png' - }, - save_as=self.VARIABLE_KEY.IMAGE.value + blob=response.content, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value ) diff --git a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py index 317d705f7c2c7c..abaa297cf36eb1 100644 --- a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py @@ -15,4 +15,3 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: ).validate_models() except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py index 0c5ebc23ac5c51..64fdc961b4c5db 100644 --- a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py @@ -18,19 +18,17 @@ # Prompts "prompt": "", "negative_prompt": "", - # "styles": [], - # Seeds + # "styles": [], + # Seeds "seed": -1, "subseed": -1, "subseed_strength": 0, "seed_resize_from_h": -1, "seed_resize_from_w": -1, - # Samplers - # "sampler_name": "DPM++ 2M", + "sampler_name": "DPM++ 2M", # "scheduler": "", # "sampler_index": "Automatic", - # Latent Space Options "batch_size": 1, "n_iter": 1, @@ -42,9 +40,9 @@ # "tiling": True, "do_not_save_samples": False, "do_not_save_grid": False, - # "eta": 0, - # "denoising_strength": 0.75, - # "s_min_uncond": 0, + # "eta": 0, + # "denoising_strength": 0.75, + # "s_min_uncond": 0, # "s_churn": 0, # "s_tmax": 0, # "s_tmin": 0, @@ -73,7 +71,6 @@ "hr_negative_prompt": "", # Task Options # "force_task_id": "", - # Script Options # "script_name": "", "script_args": [], @@ -82,135 +79,150 @@ "save_images": False, "alwayson_scripts": {}, # "infotext": "", - } class StableDiffusionTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # base url - base_url = self.runtime.credentials.get('base_url', None) + base_url = self.runtime.credentials.get("base_url", None) if not base_url: - return self.create_text_message('Please input base_url') + return self.create_text_message("Please input base_url") - if tool_parameters.get('model'): - self.runtime.credentials['model'] = tool_parameters['model'] + if tool_parameters.get("model"): + self.runtime.credentials["model"] = tool_parameters["model"] - model = self.runtime.credentials.get('model', None) + model = self.runtime.credentials.get("model", None) if not model: - return self.create_text_message('Please input model') - + return self.create_text_message("Please input model") + # set model try: - url = str(URL(base_url) / 'sdapi' / 'v1' / 'options') - response = post(url, data=json.dumps({ - 'sd_model_checkpoint': model - })) + url = str(URL(base_url) / "sdapi" / "v1" / "options") + response = post(url, data=json.dumps({"sd_model_checkpoint": model})) if response.status_code != 200: - raise ToolProviderCredentialValidationError('Failed to set model, please tell user to set model') + raise ToolProviderCredentialValidationError("Failed to set model, please tell user to set model") except Exception as e: - raise ToolProviderCredentialValidationError('Failed to set model, please tell user to set model') + raise ToolProviderCredentialValidationError("Failed to set model, please tell user to set model") # get image id and image variable - image_id = tool_parameters.get('image_id', '') + image_id = tool_parameters.get("image_id", "") image_variable = self.get_default_image_variable() # Return text2img if there's no image ID or no image variable if not image_id or not image_variable: - return self.text2img(base_url=base_url,tool_parameters=tool_parameters) + return self.text2img(base_url=base_url, tool_parameters=tool_parameters) # Proceed with image-to-image generation - return self.img2img(base_url=base_url,tool_parameters=tool_parameters) + return self.img2img(base_url=base_url, tool_parameters=tool_parameters) def validate_models(self) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - validate models + validate models """ try: - base_url = self.runtime.credentials.get('base_url', None) + base_url = self.runtime.credentials.get("base_url", None) if not base_url: - raise ToolProviderCredentialValidationError('Please input base_url') - model = self.runtime.credentials.get('model', None) + raise ToolProviderCredentialValidationError("Please input base_url") + model = self.runtime.credentials.get("model", None) if not model: - raise ToolProviderCredentialValidationError('Please input model') + raise ToolProviderCredentialValidationError("Please input model") - api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models') + api_url = str(URL(base_url) / "sdapi" / "v1" / "sd-models") response = get(url=api_url, timeout=10) if response.status_code == 404: # try draw a picture self._invoke( - user_id='test', + user_id="test", tool_parameters={ - 'prompt': 'a cat', - 'width': 1024, - 'height': 1024, - 'steps': 1, - 'lora': '', - } + "prompt": "a cat", + "width": 1024, + "height": 1024, + "steps": 1, + "lora": "", + }, ) elif response.status_code != 200: - raise ToolProviderCredentialValidationError('Failed to get models') + raise ToolProviderCredentialValidationError("Failed to get models") else: - models = [d['model_name'] for d in response.json()] + models = [d["model_name"] for d in response.json()] if len([d for d in models if d == model]) > 0: return self.create_text_message(json.dumps(models)) else: - raise ToolProviderCredentialValidationError(f'model {model} does not exist') + raise ToolProviderCredentialValidationError(f"model {model} does not exist") except Exception as e: - raise ToolProviderCredentialValidationError(f'Failed to get models, {e}') + raise ToolProviderCredentialValidationError(f"Failed to get models, {e}") def get_sd_models(self) -> list[str]: """ - get sd models + get sd models """ try: - base_url = self.runtime.credentials.get('base_url', None) + base_url = self.runtime.credentials.get("base_url", None) if not base_url: return [] - api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models') + api_url = str(URL(base_url) / "sdapi" / "v1" / "sd-models") response = get(url=api_url, timeout=(2, 10)) if response.status_code != 200: return [] else: - return [d['model_name'] for d in response.json()] + return [d["model_name"] for d in response.json()] except Exception as e: return [] - def img2img(self, base_url: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def get_sample_methods(self) -> list[str]: """ - generate image + get sample method + """ + try: + base_url = self.runtime.credentials.get("base_url", None) + if not base_url: + return [] + api_url = str(URL(base_url) / "sdapi" / "v1" / "samplers") + response = get(url=api_url, timeout=(2, 10)) + if response.status_code != 200: + return [] + else: + return [d["name"] for d in response.json()] + except Exception as e: + return [] + + def img2img( + self, base_url: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + generate image """ # Fetch the binary data of the image image_variable = self.get_default_image_variable() image_binary = self.get_variable_file(image_variable.name) if not image_binary: - return self.create_text_message('Image not found, please request user to generate image firstly.') + return self.create_text_message("Image not found, please request user to generate image firstly.") # Convert image to RGB and save as PNG try: - with Image.open(io.BytesIO(image_binary)) as image: - with io.BytesIO() as buffer: - image.convert("RGB").save(buffer, format="PNG") - image_binary = buffer.getvalue() + with Image.open(io.BytesIO(image_binary)) as image, io.BytesIO() as buffer: + image.convert("RGB").save(buffer, format="PNG") + image_binary = buffer.getvalue() except Exception as e: return self.create_text_message(f"Failed to process the image: {str(e)}") # copy draw options draw_options = deepcopy(DRAW_TEXT_OPTIONS) # set image options - model = tool_parameters.get('model', '') + model = tool_parameters.get("model", "") draw_options_image = { - "init_images": [b64encode(image_binary).decode('utf-8')], + "init_images": [b64encode(image_binary).decode("utf-8")], "denoising_strength": 0.9, "restore_faces": False, "script_args": [], "override_settings": {"sd_model_checkpoint": model}, - "resize_mode":0, + "resize_mode": 0, "image_cfg_scale": 0, # "mask": None, "mask_blur_x": 4, @@ -230,116 +242,149 @@ def img2img(self, base_url: str, tool_parameters: dict[str, Any]) \ draw_options.update(tool_parameters) # get prompt lora model - prompt = tool_parameters.get('prompt', '') - lora = tool_parameters.get('lora', '') - model = tool_parameters.get('model', '') + prompt = tool_parameters.get("prompt", "") + lora = tool_parameters.get("lora", "") + model = tool_parameters.get("model", "") if lora: - draw_options['prompt'] = f'{lora},{prompt}' + draw_options["prompt"] = f"{lora},{prompt}" else: - draw_options['prompt'] = prompt + draw_options["prompt"] = prompt try: - url = str(URL(base_url) / 'sdapi' / 'v1' / 'img2img') + url = str(URL(base_url) / "sdapi" / "v1" / "img2img") response = post(url, data=json.dumps(draw_options), timeout=120) if response.status_code != 200: - return self.create_text_message('Failed to generate image') - - image = response.json()['images'][0] - - return self.create_blob_message(blob=b64decode(image), - meta={ 'mime_type': 'image/png' }, - save_as=self.VARIABLE_KEY.IMAGE.value) - + return self.create_text_message("Failed to generate image") + + image = response.json()["images"][0] + + return self.create_blob_message( + blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value + ) + except Exception as e: - return self.create_text_message('Failed to generate image') + return self.create_text_message("Failed to generate image") - def text2img(self, base_url: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def text2img( + self, base_url: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - generate image + generate image """ # copy draw options draw_options = deepcopy(DRAW_TEXT_OPTIONS) draw_options.update(tool_parameters) # get prompt lora model - prompt = tool_parameters.get('prompt', '') - lora = tool_parameters.get('lora', '') - model = tool_parameters.get('model', '') + prompt = tool_parameters.get("prompt", "") + lora = tool_parameters.get("lora", "") + model = tool_parameters.get("model", "") if lora: - draw_options['prompt'] = f'{lora},{prompt}' + draw_options["prompt"] = f"{lora},{prompt}" else: - draw_options['prompt'] = prompt - draw_options['override_settings']['sd_model_checkpoint'] = model + draw_options["prompt"] = prompt + draw_options["override_settings"]["sd_model_checkpoint"] = model - try: - url = str(URL(base_url) / 'sdapi' / 'v1' / 'txt2img') + url = str(URL(base_url) / "sdapi" / "v1" / "txt2img") response = post(url, data=json.dumps(draw_options), timeout=120) if response.status_code != 200: - return self.create_text_message('Failed to generate image') - - image = response.json()['images'][0] - - return self.create_blob_message(blob=b64decode(image), - meta={ 'mime_type': 'image/png' }, - save_as=self.VARIABLE_KEY.IMAGE.value) - + return self.create_text_message("Failed to generate image") + + image = response.json()["images"][0] + + return self.create_blob_message( + blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value + ) + except Exception as e: - return self.create_text_message('Failed to generate image') + return self.create_text_message("Failed to generate image") def get_runtime_parameters(self) -> list[ToolParameter]: parameters = [ - ToolParameter(name='prompt', - label=I18nObject(en_US='Prompt', zh_Hans='Prompt'), - human_description=I18nObject( - en_US='Image prompt, you can check the official documentation of Stable Diffusion', - zh_Hans='图像提示词,您可以查看 Stable Diffusion 的官方文档', - ), - type=ToolParameter.ToolParameterType.STRING, - form=ToolParameter.ToolParameterForm.LLM, - llm_description='Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.', - required=True), + ToolParameter( + name="prompt", + label=I18nObject(en_US="Prompt", zh_Hans="Prompt"), + human_description=I18nObject( + en_US="Image prompt, you can check the official documentation of Stable Diffusion", + zh_Hans="图像提示词,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Image prompt of Stable Diffusion, you should describe the image you want to generate" + " as a list of words as possible as detailed, the prompt must be written in English.", + required=True, + ), ] if len(self.list_default_image_variables()) != 0: parameters.append( - ToolParameter(name='image_id', - label=I18nObject(en_US='image_id', zh_Hans='image_id'), - human_description=I18nObject( - en_US='Image id of the image you want to generate based on, if you want to generate image based on the default image, you can leave this field empty.', - zh_Hans='您想要生成的图像的图像 ID,如果您想要基于默认图像生成图像,则可以将此字段留空。', - ), - type=ToolParameter.ToolParameterType.STRING, - form=ToolParameter.ToolParameterForm.LLM, - llm_description='Image id of the original image, you can leave this field empty if you want to generate a new image.', - required=True, - options=[ToolParameterOption( - value=i.name, - label=I18nObject(en_US=i.name, zh_Hans=i.name) - ) for i in self.list_default_image_variables()]) + ToolParameter( + name="image_id", + label=I18nObject(en_US="image_id", zh_Hans="image_id"), + human_description=I18nObject( + en_US="Image id of the image you want to generate based on, if you want to generate image based" + " on the default image, you can leave this field empty.", + zh_Hans="您想要生成的图像的图像 ID,如果您想要基于默认图像生成图像,则可以将此字段留空。", + ), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Image id of the original image, you can leave this field empty if you want to" + " generate a new image.", + required=True, + options=[ + ToolParameterOption(value=i.name, label=I18nObject(en_US=i.name, zh_Hans=i.name)) + for i in self.list_default_image_variables() + ], + ) ) - + if self.runtime.credentials: try: models = self.get_sd_models() if len(models) != 0: parameters.append( - ToolParameter(name='model', - label=I18nObject(en_US='Model', zh_Hans='Model'), - human_description=I18nObject( - en_US='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion', - zh_Hans='Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档', - ), - type=ToolParameter.ToolParameterType.SELECT, - form=ToolParameter.ToolParameterForm.FORM, - llm_description='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion', - required=True, - default=models[0], - options=[ToolParameterOption( - value=i, - label=I18nObject(en_US=i, zh_Hans=i) - ) for i in models]) + ToolParameter( + name="model", + label=I18nObject(en_US="Model", zh_Hans="Model"), + human_description=I18nObject( + en_US="Model of Stable Diffusion, you can check the official documentation" + " of Stable Diffusion", + zh_Hans="Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Model of Stable Diffusion, you can check the official documentation" + " of Stable Diffusion", + required=True, + default=models[0], + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in models + ], + ) ) + except: pass + sample_methods = self.get_sample_methods() + if len(sample_methods) != 0: + parameters.append( + ToolParameter( + name="sampler_name", + label=I18nObject(en_US="Sampling method", zh_Hans="Sampling method"), + human_description=I18nObject( + en_US="Sampling method of Stable Diffusion, you can check the official documentation" + " of Stable Diffusion", + zh_Hans="Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Sampling method of Stable Diffusion, you can check the official documentation" + " of Stable Diffusion", + required=True, + default=sample_methods[0], + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in sample_methods + ], + ) + ) return parameters diff --git a/api/core/tools/provider/builtin/stackexchange/stackexchange.py b/api/core/tools/provider/builtin/stackexchange/stackexchange.py index de64c84997f7ca..9680c633cc701c 100644 --- a/api/core/tools/provider/builtin/stackexchange/stackexchange.py +++ b/api/core/tools/provider/builtin/stackexchange/stackexchange.py @@ -11,16 +11,15 @@ def _validate_credentials(self, credentials: dict) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "intitle": "Test", - "sort": "relevance", + "sort": "relevance", "order": "desc", "site": "stackoverflow", "accepted": True, - "pagesize": 1 + "pagesize": 1, }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stackexchange/tools/fetchAnsByStackExQuesID.py b/api/core/tools/provider/builtin/stackexchange/tools/fetchAnsByStackExQuesID.py index f8e17108444084..534532009501f5 100644 --- a/api/core/tools/provider/builtin/stackexchange/tools/fetchAnsByStackExQuesID.py +++ b/api/core/tools/provider/builtin/stackexchange/tools/fetchAnsByStackExQuesID.py @@ -17,7 +17,9 @@ class FetchAnsByStackExQuesIDInput(BaseModel): class FetchAnsByStackExQuesIDTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: input = FetchAnsByStackExQuesIDInput(**tool_parameters) params = { @@ -26,7 +28,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolIn "order": input.order, "sort": input.sort, "pagesize": input.pagesize, - "page": input.page + "page": input.page, } response = requests.get(f"https://api.stackexchange.com/2.3/questions/{input.id}/answers", params=params) @@ -34,4 +36,4 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolIn if response.status_code == 200: return self.create_text_message(self.summary(user_id=user_id, content=response.text)) else: - return self.create_text_message(f"API request failed with status code {response.status_code}") \ No newline at end of file + return self.create_text_message(f"API request failed with status code {response.status_code}") diff --git a/api/core/tools/provider/builtin/stackexchange/tools/searchStackExQuestions.py b/api/core/tools/provider/builtin/stackexchange/tools/searchStackExQuestions.py index 8436433c323cd1..4a25a808adf26a 100644 --- a/api/core/tools/provider/builtin/stackexchange/tools/searchStackExQuestions.py +++ b/api/core/tools/provider/builtin/stackexchange/tools/searchStackExQuestions.py @@ -9,26 +9,28 @@ class SearchStackExQuestionsInput(BaseModel): intitle: str = Field(..., description="The search query.") - sort: str = Field(..., description="The sort order - relevance, activity, votes, creation.") + sort: str = Field(..., description="The sort order - relevance, activity, votes, creation.") order: str = Field(..., description="asc or desc") site: str = Field(..., description="The Stack Exchange site.") tagged: str = Field(None, description="Semicolon-separated tags to include.") nottagged: str = Field(None, description="Semicolon-separated tags to exclude.") - accepted: bool = Field(..., description="true for only accepted answers, false otherwise") + accepted: bool = Field(..., description="true for only accepted answers, false otherwise") pagesize: int = Field(..., description="Number of results per page") class SearchStackExQuestionsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: input = SearchStackExQuestionsInput(**tool_parameters) params = { "intitle": input.intitle, "sort": input.sort, - "order": input.order, + "order": input.order, "site": input.site, "accepted": input.accepted, - "pagesize": input.pagesize + "pagesize": input.pagesize, } if input.tagged: params["tagged"] = input.tagged @@ -40,4 +42,4 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolIn if response.status_code == 200: return self.create_text_message(self.summary(user_id=user_id, content=response.text)) else: - return self.create_text_message(f"API request failed with status code {response.status_code}") \ No newline at end of file + return self.create_text_message(f"API request failed with status code {response.status_code}") diff --git a/api/core/tools/provider/builtin/stepfun/stepfun.py b/api/core/tools/provider/builtin/stepfun/stepfun.py index e809b04546aef5..239db85b1118b0 100644 --- a/api/core/tools/provider/builtin/stepfun/stepfun.py +++ b/api/core/tools/provider/builtin/stepfun/stepfun.py @@ -13,13 +13,12 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "prompt": "cute girl, blue eyes, white hair, anime style", - "size": "1024x1024", - "n": 1 + "size": "256x256", + "n": 1, }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/stepfun/stepfun.yaml b/api/core/tools/provider/builtin/stepfun/stepfun.yaml index 1f841ec369b5c3..e8139a4d7d6cfd 100644 --- a/api/core/tools/provider/builtin/stepfun/stepfun.yaml +++ b/api/core/tools/provider/builtin/stepfun/stepfun.yaml @@ -4,11 +4,9 @@ identity: label: en_US: Image-1X zh_Hans: 阶跃星辰绘画 - pt_BR: Image-1X description: en_US: Image-1X zh_Hans: 阶跃星辰绘画 - pt_BR: Image-1X icon: icon.png tags: - image @@ -20,27 +18,16 @@ credentials_for_provider: label: en_US: Stepfun API key zh_Hans: 阶跃星辰API key - pt_BR: Stepfun API key - help: - en_US: Please input your stepfun API key - zh_Hans: 请输入你的阶跃星辰 API key - pt_BR: Please input your stepfun API key placeholder: - en_US: Please input your stepfun API key + en_US: Please input your Stepfun API key zh_Hans: 请输入你的阶跃星辰 API key - pt_BR: Please input your stepfun API key + url: https://platform.stepfun.com/interface-key stepfun_base_url: type: text-input required: false label: en_US: Stepfun base URL zh_Hans: 阶跃星辰 base URL - pt_BR: Stepfun base URL - help: - en_US: Please input your Stepfun base URL - zh_Hans: 请输入你的阶跃星辰 base URL - pt_BR: Please input your Stepfun base URL placeholder: en_US: Please input your Stepfun base URL zh_Hans: 请输入你的阶跃星辰 base URL - pt_BR: Please input your Stepfun base URL diff --git a/api/core/tools/provider/builtin/stepfun/tools/image.py b/api/core/tools/provider/builtin/stepfun/tools/image.py index 5e544aada63b40..61cc14fac6ca93 100644 --- a/api/core/tools/provider/builtin/stepfun/tools/image.py +++ b/api/core/tools/provider/builtin/stepfun/tools/image.py @@ -1,4 +1,3 @@ -import random from typing import Any, Union from openai import OpenAI @@ -9,64 +8,59 @@ class StepfunTool(BuiltinTool): - """ Stepfun Image Generation Tool """ - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """Stepfun Image Generation Tool""" + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - base_url = self.runtime.credentials.get('stepfun_base_url', None) - if not base_url: - base_url = None - else: - base_url = str(URL(base_url) / 'v1') + base_url = self.runtime.credentials.get("stepfun_base_url") or "https://api.stepfun.com" + base_url = str(URL(base_url) / "v1") client = OpenAI( - api_key=self.runtime.credentials['stepfun_api_key'], + api_key=self.runtime.credentials["stepfun_api_key"], base_url=base_url, ) extra_body = {} - model = tool_parameters.get('model', 'step-1x-medium') - if not model: - return self.create_text_message('Please input model name') + model = "step-1x-medium" # prompt - prompt = tool_parameters.get('prompt', '') + prompt = tool_parameters.get("prompt", "") if not prompt: - return self.create_text_message('Please input prompt') - - seed = tool_parameters.get('seed', 0) + return self.create_text_message("Please input prompt") + if len(prompt) > 1024: + return self.create_text_message("The prompt length should less than 1024") + seed = tool_parameters.get("seed", 0) if seed > 0: - extra_body['seed'] = seed - steps = tool_parameters.get('steps', 0) + extra_body["seed"] = seed + steps = tool_parameters.get("steps", 50) if steps > 0: - extra_body['steps'] = steps - negative_prompt = tool_parameters.get('negative_prompt', '') - if negative_prompt: - extra_body['negative_prompt'] = negative_prompt + extra_body["steps"] = steps + cfg_scale = tool_parameters.get("cfg_scale", 7.5) + if cfg_scale > 0: + extra_body["cfg_scale"] = cfg_scale # call openapi stepfun model response = client.images.generate( prompt=prompt, model=model, - size=tool_parameters.get('size', '1024x1024'), - n=tool_parameters.get('n', 1), - extra_body= extra_body + size=tool_parameters.get("size", "1024x1024"), + n=tool_parameters.get("n", 1), + extra_body=extra_body, ) - print(response) result = [] for image in response.data: result.append(self.create_image_message(image=image.url)) - result.append(self.create_json_message({ - "url": image.url, - })) + result.append( + self.create_json_message( + { + "url": image.url, + } + ) + ) return result - - @staticmethod - def _generate_random_id(length=8): - characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' - random_id = ''.join(random.choices(characters, k=length)) - return random_id diff --git a/api/core/tools/provider/builtin/stepfun/tools/image.yaml b/api/core/tools/provider/builtin/stepfun/tools/image.yaml index 1e20b157aa131f..dfda6ed1914848 100644 --- a/api/core/tools/provider/builtin/stepfun/tools/image.yaml +++ b/api/core/tools/provider/builtin/stepfun/tools/image.yaml @@ -25,46 +25,17 @@ parameters: pt_BR: Prompt human_description: en_US: Image prompt, you can check the official documentation of step-1x - zh_Hans: 图像提示词,您可以查看step-1x 的官方文档 + zh_Hans: 图像提示词,您可以查看 step-1x 的官方文档 pt_BR: Image prompt, you can check the official documentation of step-1x llm_description: Image prompt of step-1x you should describe the image you want to generate as a list of words as possible as detailed form: llm - - name: model - type: select - required: false - human_description: - en_US: used for selecting the model name - zh_Hans: 用于选择模型的名字 - pt_BR: used for selecting the model name - label: - en_US: Model Name - zh_Hans: 模型名字 - pt_BR: Model Name - form: form - options: - - value: step-1x-turbo - label: - en_US: turbo - zh_Hans: turbo - pt_BR: turbo - - value: step-1x-medium - label: - en_US: medium - zh_Hans: medium - pt_BR: medium - - value: step-1x-large - label: - en_US: large - zh_Hans: large - pt_BR: large - default: step-1x-medium - name: size type: select required: false human_description: - en_US: used for selecting the image size - zh_Hans: 用于选择图像大小 - pt_BR: used for selecting the image size + en_US: The size of the generated image + zh_Hans: 生成的图片大小 + pt_BR: The size of the generated image label: en_US: Image size zh_Hans: 图像大小 @@ -106,17 +77,17 @@ parameters: type: number required: true human_description: - en_US: used for selecting the number of images - zh_Hans: 用于选择图像数量 - pt_BR: used for selecting the number of images + en_US: Number of generated images, now only one image can be generated at a time + zh_Hans: 生成的图像数量,当前仅支持每次生成一张图片 + pt_BR: Number of generated images, now only one image can be generated at a time label: - en_US: Number of images - zh_Hans: 图像数量 - pt_BR: Number of images + en_US: Number of generated images + zh_Hans: 生成的图像数量 + pt_BR: Number of generated images form: form default: 1 min: 1 - max: 10 + max: 1 - name: seed type: number required: false @@ -138,21 +109,25 @@ parameters: zh_Hans: Steps pt_BR: Steps human_description: - en_US: Steps - zh_Hans: Steps - pt_BR: Steps + en_US: Steps, now support integers between 1 and 100 + zh_Hans: Steps, 当前支持 1~100 之间整数 + pt_BR: Steps, now support integers between 1 and 100 form: form - default: 10 - - name: negative_prompt - type: string + default: 50 + min: 1 + max: 100 + - name: cfg_scale + type: number required: false label: - en_US: Negative prompt - zh_Hans: Negative prompt - pt_BR: Negative prompt + en_US: classifier-free guidance scale + zh_Hans: classifier-free guidance scale + pt_BR: classifier-free guidance scale human_description: - en_US: Negative prompt - zh_Hans: Negative prompt - pt_BR: Negative prompt + en_US: classifier-free guidance scale + zh_Hans: classifier-free guidance scale + pt_BR: classifier-free guidance scale form: form - default: (worst quality:1.3), (nsfw), low quality + default: 7.5 + min: 1 + max: 10 diff --git a/api/core/tools/provider/builtin/tavily/tavily.py b/api/core/tools/provider/builtin/tavily/tavily.py index e376d99d6bb951..a702b0a74e6131 100644 --- a/api/core/tools/provider/builtin/tavily/tavily.py +++ b/api/core/tools/provider/builtin/tavily/tavily.py @@ -13,7 +13,7 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "Sachin Tendulkar", "search_depth": "basic", @@ -22,9 +22,8 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "include_raw_content": False, "max_results": 5, "include_domains": "", - "exclude_domains": "" + "exclude_domains": "", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/tavily/tavily.yaml b/api/core/tools/provider/builtin/tavily/tavily.yaml index 7b25a8184857ca..95820f4d18b051 100644 --- a/api/core/tools/provider/builtin/tavily/tavily.yaml +++ b/api/core/tools/provider/builtin/tavily/tavily.yaml @@ -28,4 +28,4 @@ credentials_for_provider: en_US: Get your Tavily API key from Tavily zh_Hans: 从 TavilyApi 获取您的 Tavily API key pt_BR: Get your Tavily API key from Tavily - url: https://docs.tavily.com/docs/tavily-api/introduction + url: https://docs.tavily.com/docs/welcome diff --git a/api/core/tools/provider/builtin/tavily/tools/tavily_search.py b/api/core/tools/provider/builtin/tavily/tools/tavily_search.py index 0200df3c8a4c31..ca6d8633e4b0af 100644 --- a/api/core/tools/provider/builtin/tavily/tools/tavily_search.py +++ b/api/core/tools/provider/builtin/tavily/tools/tavily_search.py @@ -36,15 +36,23 @@ def raw_results(self, params: dict[str, Any]) -> dict: """ params["api_key"] = self.api_key - if 'exclude_domains' in params and isinstance(params['exclude_domains'], str) and params['exclude_domains'] != 'None': - params['exclude_domains'] = params['exclude_domains'].split() + if ( + "exclude_domains" in params + and isinstance(params["exclude_domains"], str) + and params["exclude_domains"] != "None" + ): + params["exclude_domains"] = params["exclude_domains"].split() else: - params['exclude_domains'] = [] - if 'include_domains' in params and isinstance(params['include_domains'], str) and params['include_domains'] != 'None': - params['include_domains'] = params['include_domains'].split() + params["exclude_domains"] = [] + if ( + "include_domains" in params + and isinstance(params["include_domains"], str) + and params["include_domains"] != "None" + ): + params["include_domains"] = params["include_domains"].split() else: - params['include_domains'] = [] - + params["include_domains"] = [] + response = requests.post(f"{TAVILY_API_URL}/search", json=params) response.raise_for_status() return response.json() @@ -91,9 +99,7 @@ class TavilySearchTool(BuiltinTool): A tool for searching Tavily using a given query. """ - def _invoke( - self, user_id: str, tool_parameters: dict[str, Any] - ) -> ToolInvokeMessage | list[ToolInvokeMessage]: + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: """ Invokes the Tavily search tool with the given user ID and tool parameters. @@ -115,4 +121,4 @@ def _invoke( if not results: return self.create_text_message(f"No results found for '{query}' in Tavily") else: - return self.create_text_message(text=results) \ No newline at end of file + return self.create_text_message(text=results) diff --git a/api/core/tools/provider/builtin/tianditu/tianditu.py b/api/core/tools/provider/builtin/tianditu/tianditu.py index 1f96be06b0200d..cb7d7bd8bb2c41 100644 --- a/api/core/tools/provider/builtin/tianditu/tianditu.py +++ b/api/core/tools/provider/builtin/tianditu/tianditu.py @@ -12,10 +12,12 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: runtime={ "credentials": credentials, } - ).invoke(user_id='', - tool_parameters={ - 'content': '北京', - 'specify': '156110000', - }) + ).invoke( + user_id="", + tool_parameters={ + "content": "北京", + "specify": "156110000", + }, + ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/tianditu/tools/geocoder.py b/api/core/tools/provider/builtin/tianditu/tools/geocoder.py index 484a3768c851df..690a0aed6f5aff 100644 --- a/api/core/tools/provider/builtin/tianditu/tools/geocoder.py +++ b/api/core/tools/provider/builtin/tianditu/tools/geocoder.py @@ -8,26 +8,26 @@ class GeocoderTool(BuiltinTool): - - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - base_url = 'http://api.tianditu.gov.cn/geocoder' - - keyword = tool_parameters.get('keyword', '') + base_url = "http://api.tianditu.gov.cn/geocoder" + + keyword = tool_parameters.get("keyword", "") if not keyword: - return self.create_text_message('Invalid parameter keyword') - - tk = self.runtime.credentials['tianditu_api_key'] - + return self.create_text_message("Invalid parameter keyword") + + tk = self.runtime.credentials["tianditu_api_key"] + params = { - 'keyWord': keyword, + "keyWord": keyword, } - - result = requests.get(base_url + '?ds=' + json.dumps(params, ensure_ascii=False) + '&tk=' + tk).json() + + result = requests.get(base_url + "?ds=" + json.dumps(params, ensure_ascii=False) + "&tk=" + tk).json() return self.create_json_message(result) diff --git a/api/core/tools/provider/builtin/tianditu/tools/poisearch.py b/api/core/tools/provider/builtin/tianditu/tools/poisearch.py index 08a5b8ef42a8c8..798dd94d335654 100644 --- a/api/core/tools/provider/builtin/tianditu/tools/poisearch.py +++ b/api/core/tools/provider/builtin/tianditu/tools/poisearch.py @@ -8,38 +8,51 @@ class PoiSearchTool(BuiltinTool): - - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - geocoder_base_url = 'http://api.tianditu.gov.cn/geocoder' - base_url = 'http://api.tianditu.gov.cn/v2/search' - - keyword = tool_parameters.get('keyword', '') + geocoder_base_url = "http://api.tianditu.gov.cn/geocoder" + base_url = "http://api.tianditu.gov.cn/v2/search" + + keyword = tool_parameters.get("keyword", "") if not keyword: - return self.create_text_message('Invalid parameter keyword') - - baseAddress = tool_parameters.get('baseAddress', '') + return self.create_text_message("Invalid parameter keyword") + + baseAddress = tool_parameters.get("baseAddress", "") if not baseAddress: - return self.create_text_message('Invalid parameter baseAddress') - - tk = self.runtime.credentials['tianditu_api_key'] - - base_coords = requests.get(geocoder_base_url + '?ds=' + json.dumps({'keyWord': baseAddress,}, ensure_ascii=False) + '&tk=' + tk).json() - + return self.create_text_message("Invalid parameter baseAddress") + + tk = self.runtime.credentials["tianditu_api_key"] + + base_coords = requests.get( + geocoder_base_url + + "?ds=" + + json.dumps( + { + "keyWord": baseAddress, + }, + ensure_ascii=False, + ) + + "&tk=" + + tk + ).json() + params = { - 'keyWord': keyword, - 'queryRadius': 5000, - 'queryType': 3, - 'pointLonlat': base_coords['location']['lon'] + ',' + base_coords['location']['lat'], - 'start': 0, - 'count': 100, + "keyWord": keyword, + "queryRadius": 5000, + "queryType": 3, + "pointLonlat": base_coords["location"]["lon"] + "," + base_coords["location"]["lat"], + "start": 0, + "count": 100, } - - result = requests.get(base_url + '?postStr=' + json.dumps(params, ensure_ascii=False) + '&type=query&tk=' + tk).json() + + result = requests.get( + base_url + "?postStr=" + json.dumps(params, ensure_ascii=False) + "&type=query&tk=" + tk + ).json() return self.create_json_message(result) diff --git a/api/core/tools/provider/builtin/tianditu/tools/staticmap.py b/api/core/tools/provider/builtin/tianditu/tools/staticmap.py index ecac4404ca28b0..aeaef088057686 100644 --- a/api/core/tools/provider/builtin/tianditu/tools/staticmap.py +++ b/api/core/tools/provider/builtin/tianditu/tools/staticmap.py @@ -8,29 +8,42 @@ class PoiSearchTool(BuiltinTool): - - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - - geocoder_base_url = 'http://api.tianditu.gov.cn/geocoder' - base_url = 'http://api.tianditu.gov.cn/staticimage' - - keyword = tool_parameters.get('keyword', '') + + geocoder_base_url = "http://api.tianditu.gov.cn/geocoder" + base_url = "http://api.tianditu.gov.cn/staticimage" + + keyword = tool_parameters.get("keyword", "") if not keyword: - return self.create_text_message('Invalid parameter keyword') - - tk = self.runtime.credentials['tianditu_api_key'] - - keyword_coords = requests.get(geocoder_base_url + '?ds=' + json.dumps({'keyWord': keyword,}, ensure_ascii=False) + '&tk=' + tk).json() - coords = keyword_coords['location']['lon'] + ',' + keyword_coords['location']['lat'] - - result = requests.get(base_url + '?center=' + coords + '&markers=' + coords + '&width=400&height=300&zoom=14&tk=' + tk).content - - return self.create_blob_message(blob=result, - meta={'mime_type': 'image/png'}, - save_as=self.VARIABLE_KEY.IMAGE.value) + return self.create_text_message("Invalid parameter keyword") + + tk = self.runtime.credentials["tianditu_api_key"] + + keyword_coords = requests.get( + geocoder_base_url + + "?ds=" + + json.dumps( + { + "keyWord": keyword, + }, + ensure_ascii=False, + ) + + "&tk=" + + tk + ).json() + coords = keyword_coords["location"]["lon"] + "," + keyword_coords["location"]["lat"] + + result = requests.get( + base_url + "?center=" + coords + "&markers=" + coords + "&width=400&height=300&zoom=14&tk=" + tk + ).content + + return self.create_blob_message( + blob=result, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value + ) diff --git a/api/core/tools/provider/builtin/time/time.py b/api/core/tools/provider/builtin/time/time.py index 833ae194ef840c..e4df8d616cba38 100644 --- a/api/core/tools/provider/builtin/time/time.py +++ b/api/core/tools/provider/builtin/time/time.py @@ -9,9 +9,8 @@ class WikiPediaProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: CurrentTimeTool().invoke( - user_id='', + user_id="", tool_parameters={}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/time/tools/current_time.py b/api/core/tools/provider/builtin/time/tools/current_time.py index 90c01665e6e6a1..cc38739c16f04b 100644 --- a/api/core/tools/provider/builtin/time/tools/current_time.py +++ b/api/core/tools/provider/builtin/time/tools/current_time.py @@ -8,21 +8,22 @@ class CurrentTimeTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ # get timezone - tz = tool_parameters.get('timezone', 'UTC') - fm = tool_parameters.get('format') or '%Y-%m-%d %H:%M:%S %Z' - if tz == 'UTC': - return self.create_text_message(f'{datetime.now(timezone.utc).strftime(fm)}') - + tz = tool_parameters.get("timezone", "UTC") + fm = tool_parameters.get("format") or "%Y-%m-%d %H:%M:%S %Z" + if tz == "UTC": + return self.create_text_message(f"{datetime.now(timezone.utc).strftime(fm)}") + try: tz = pytz_timezone(tz) except: - return self.create_text_message(f'Invalid timezone: {tz}') - return self.create_text_message(f'{datetime.now(tz).strftime(fm)}') \ No newline at end of file + return self.create_text_message(f"Invalid timezone: {tz}") + return self.create_text_message(f"{datetime.now(tz).strftime(fm)}") diff --git a/api/core/tools/provider/builtin/time/tools/localtime_to_timestamp.py b/api/core/tools/provider/builtin/time/tools/localtime_to_timestamp.py new file mode 100644 index 00000000000000..e16b732d0242db --- /dev/null +++ b/api/core/tools/provider/builtin/time/tools/localtime_to_timestamp.py @@ -0,0 +1,44 @@ +from datetime import datetime +from typing import Any, Union + +import pytz + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.errors import ToolInvokeError +from core.tools.tool.builtin_tool import BuiltinTool + + +class LocaltimeToTimestampTool(BuiltinTool): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + Convert localtime to timestamp + """ + localtime = tool_parameters.get("localtime") + timezone = tool_parameters.get("timezone", "Asia/Shanghai") + if not timezone: + timezone = None + time_format = "%Y-%m-%d %H:%M:%S" + + timestamp = self.localtime_to_timestamp(localtime, time_format, timezone) + if not timestamp: + return self.create_text_message(f"Invalid localtime: {localtime}") + + return self.create_text_message(f"{timestamp}") + + @staticmethod + def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None: + try: + if local_tz is None: + local_tz = datetime.now().astimezone().tzinfo + if isinstance(local_tz, str): + local_tz = pytz.timezone(local_tz) + local_time = datetime.strptime(localtime, time_format) + localtime = local_tz.localize(local_time) + timestamp = int(localtime.timestamp()) + return timestamp + except Exception as e: + raise ToolInvokeError(str(e)) diff --git a/api/core/tools/provider/builtin/time/tools/localtime_to_timestamp.yaml b/api/core/tools/provider/builtin/time/tools/localtime_to_timestamp.yaml new file mode 100644 index 00000000000000..6a3b90595fd3fd --- /dev/null +++ b/api/core/tools/provider/builtin/time/tools/localtime_to_timestamp.yaml @@ -0,0 +1,33 @@ +identity: + name: localtime_to_timestamp + author: zhuhao + label: + en_US: localtime to timestamp + zh_Hans: 获取时间戳 +description: + human: + en_US: A tool for localtime convert to timestamp + zh_Hans: 获取时间戳 + llm: A tool for localtime convert to timestamp +parameters: + - name: localtime + type: string + required: true + form: llm + label: + en_US: localtime + zh_Hans: 本地时间 + human_description: + en_US: localtime, such as 2024-1-1 0:0:0 + zh_Hans: 本地时间, 比如2024-1-1 0:0:0 + - name: timezone + type: string + required: false + form: llm + label: + en_US: Timezone + zh_Hans: 时区 + human_description: + en_US: Timezone, such as Asia/Shanghai + zh_Hans: 时区, 比如Asia/Shanghai + default: Asia/Shanghai diff --git a/api/core/tools/provider/builtin/time/tools/timestamp_to_localtime.py b/api/core/tools/provider/builtin/time/tools/timestamp_to_localtime.py new file mode 100644 index 00000000000000..bcdd34fd4ec54d --- /dev/null +++ b/api/core/tools/provider/builtin/time/tools/timestamp_to_localtime.py @@ -0,0 +1,44 @@ +from datetime import datetime +from typing import Any, Union + +import pytz + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.errors import ToolInvokeError +from core.tools.tool.builtin_tool import BuiltinTool + + +class TimestampToLocaltimeTool(BuiltinTool): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + Convert timestamp to localtime + """ + timestamp = tool_parameters.get("timestamp") + timezone = tool_parameters.get("timezone", "Asia/Shanghai") + if not timezone: + timezone = None + time_format = "%Y-%m-%d %H:%M:%S" + + locatime = self.timestamp_to_localtime(timestamp, timezone) + if not locatime: + return self.create_text_message(f"Invalid timestamp: {timestamp}") + + localtime_format = locatime.strftime(time_format) + + return self.create_text_message(f"{localtime_format}") + + @staticmethod + def timestamp_to_localtime(timestamp: int, local_tz=None) -> datetime | None: + try: + if local_tz is None: + local_tz = datetime.now().astimezone().tzinfo + if isinstance(local_tz, str): + local_tz = pytz.timezone(local_tz) + local_time = datetime.fromtimestamp(timestamp, local_tz) + return local_time + except Exception as e: + raise ToolInvokeError(str(e)) diff --git a/api/core/tools/provider/builtin/time/tools/timestamp_to_localtime.yaml b/api/core/tools/provider/builtin/time/tools/timestamp_to_localtime.yaml new file mode 100644 index 00000000000000..3794e717b4dc85 --- /dev/null +++ b/api/core/tools/provider/builtin/time/tools/timestamp_to_localtime.yaml @@ -0,0 +1,33 @@ +identity: + name: timestamp_to_localtime + author: zhuhao + label: + en_US: Timestamp to localtime + zh_Hans: 时间戳转换 +description: + human: + en_US: A tool for timestamp convert to localtime + zh_Hans: 时间戳转换 + llm: A tool for timestamp convert to localtime +parameters: + - name: timestamp + type: number + required: true + form: llm + label: + en_US: Timestamp + zh_Hans: 时间戳 + human_description: + en_US: Timestamp + zh_Hans: 时间戳 + - name: timezone + type: string + required: false + form: llm + label: + en_US: Timezone + zh_Hans: 时区 + human_description: + en_US: Timezone, such as Asia/Shanghai + zh_Hans: 时区, 比如Asia/Shanghai + default: Asia/Shanghai diff --git a/api/core/tools/provider/builtin/time/tools/timezone_conversion.py b/api/core/tools/provider/builtin/time/tools/timezone_conversion.py new file mode 100644 index 00000000000000..28e70db5328527 --- /dev/null +++ b/api/core/tools/provider/builtin/time/tools/timezone_conversion.py @@ -0,0 +1,48 @@ +from datetime import datetime +from typing import Any, Union + +import pytz + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.errors import ToolInvokeError +from core.tools.tool.builtin_tool import BuiltinTool + + +class TimezoneConversionTool(BuiltinTool): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + Convert time to equivalent time zone + """ + current_time = tool_parameters.get("current_time") + current_timezone = tool_parameters.get("current_timezone", "Asia/Shanghai") + target_timezone = tool_parameters.get("target_timezone", "Asia/Tokyo") + target_time = self.timezone_convert(current_time, current_timezone, target_timezone) + if not target_time: + return self.create_text_message( + f"Invalid datatime and timezone: {current_time},{current_timezone},{target_timezone}" + ) + + return self.create_text_message(f"{target_time}") + + @staticmethod + def timezone_convert(current_time: str, source_timezone: str, target_timezone: str) -> str: + """ + Convert a time string from source timezone to target timezone. + """ + time_format = "%Y-%m-%d %H:%M:%S" + try: + # get source timezone + input_timezone = pytz.timezone(source_timezone) + # get target timezone + output_timezone = pytz.timezone(target_timezone) + local_time = datetime.strptime(current_time, time_format) + datetime_with_tz = input_timezone.localize(local_time) + # timezone convert + converted_datetime = datetime_with_tz.astimezone(output_timezone) + return converted_datetime.strftime(format=time_format) + except Exception as e: + raise ToolInvokeError(str(e)) diff --git a/api/core/tools/provider/builtin/time/tools/timezone_conversion.yaml b/api/core/tools/provider/builtin/time/tools/timezone_conversion.yaml new file mode 100644 index 00000000000000..4c221c2e512208 --- /dev/null +++ b/api/core/tools/provider/builtin/time/tools/timezone_conversion.yaml @@ -0,0 +1,44 @@ +identity: + name: timezone_conversion + author: zhuhao + label: + en_US: convert time to equivalent time zone + zh_Hans: 时区转换 +description: + human: + en_US: A tool to convert time to equivalent time zone + zh_Hans: 时区转换 + llm: A tool to convert time to equivalent time zone +parameters: + - name: current_time + type: string + required: true + form: llm + label: + en_US: current time + zh_Hans: 当前时间 + human_description: + en_US: current time, such as 2024-1-1 0:0:0 + zh_Hans: 当前时间, 比如2024-1-1 0:0:0 + - name: current_timezone + type: string + required: true + form: llm + label: + en_US: Current Timezone + zh_Hans: 当前时区 + human_description: + en_US: Current Timezone, such as Asia/Shanghai + zh_Hans: 当前时区, 比如Asia/Shanghai + default: Asia/Shanghai + - name: target_timezone + type: string + required: true + form: llm + label: + en_US: Target Timezone + zh_Hans: 目标时区 + human_description: + en_US: Target Timezone, such as Asia/Tokyo + zh_Hans: 目标时区, 比如Asia/Tokyo + default: Asia/Tokyo diff --git a/api/core/tools/provider/builtin/time/tools/weekday.py b/api/core/tools/provider/builtin/time/tools/weekday.py index 4461cb5a32d14e..b327e54e171048 100644 --- a/api/core/tools/provider/builtin/time/tools/weekday.py +++ b/api/core/tools/provider/builtin/time/tools/weekday.py @@ -7,25 +7,26 @@ class WeekdayTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - Calculate the day of the week for a given date + Calculate the day of the week for a given date """ - year = tool_parameters.get('year') - month = tool_parameters.get('month') - day = tool_parameters.get('day') + year = tool_parameters.get("year") + month = tool_parameters.get("month") + day = tool_parameters.get("day") date_obj = self.convert_datetime(year, month, day) if not date_obj: - return self.create_text_message(f'Invalid date: Year {year}, Month {month}, Day {day}.') + return self.create_text_message(f"Invalid date: Year {year}, Month {month}, Day {day}.") weekday_name = calendar.day_name[date_obj.weekday()] month_name = calendar.month_name[month] readable_date = f"{month_name} {date_obj.day}, {date_obj.year}" - return self.create_text_message(f'{readable_date} is {weekday_name}.') + return self.create_text_message(f"{readable_date} is {weekday_name}.") @staticmethod def convert_datetime(year, month, day) -> datetime | None: diff --git a/api/core/tools/provider/builtin/trello/tools/create_board.py b/api/core/tools/provider/builtin/trello/tools/create_board.py index 2655602afa82d3..5a61d221578995 100644 --- a/api/core/tools/provider/builtin/trello/tools/create_board.py +++ b/api/core/tools/provider/builtin/trello/tools/create_board.py @@ -22,19 +22,15 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_name = tool_parameters.get('name') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_name = tool_parameters.get("name") if not (api_key and token and board_name): return self.create_text_message("Missing required parameters: API key, token, or board name.") url = "https://api.trello.com/1/boards/" - query_params = { - 'name': board_name, - 'key': api_key, - 'token': token - } + query_params = {"name": board_name, "key": api_key, "token": token} try: response = requests.post(url, params=query_params) @@ -43,5 +39,6 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] return self.create_text_message("Failed to create board") board = response.json() - return self.create_text_message(text=f"Board created successfully! Board name: {board['name']}, ID: {board['id']}") - + return self.create_text_message( + text=f"Board created successfully! Board name: {board['name']}, ID: {board['id']}" + ) diff --git a/api/core/tools/provider/builtin/trello/tools/create_list_on_board.py b/api/core/tools/provider/builtin/trello/tools/create_list_on_board.py index f5b156cb44c2ee..b32b0124dd31da 100644 --- a/api/core/tools/provider/builtin/trello/tools/create_list_on_board.py +++ b/api/core/tools/provider/builtin/trello/tools/create_list_on_board.py @@ -17,25 +17,22 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] Args: user_id (str): The ID of the user invoking the tool. - tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, including the board ID and list name. + tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, + including the board ID and list name. Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('id') - list_name = tool_parameters.get('name') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("id") + list_name = tool_parameters.get("name") if not (api_key and token and board_id and list_name): return self.create_text_message("Missing required parameters: API key, token, board ID, or list name.") url = f"https://api.trello.com/1/boards/{board_id}/lists" - params = { - 'name': list_name, - 'key': api_key, - 'token': token - } + params = {"name": list_name, "key": api_key, "token": token} try: response = requests.post(url, params=params) @@ -44,5 +41,6 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] return self.create_text_message("Failed to create list") new_list = response.json() - return self.create_text_message(text=f"List '{new_list['name']}' created successfully with Id {new_list['id']} on board {board_id}.") - + return self.create_text_message( + text=f"List '{new_list['name']}' created successfully with Id {new_list['id']} on board {board_id}." + ) diff --git a/api/core/tools/provider/builtin/trello/tools/create_new_card_on_board.py b/api/core/tools/provider/builtin/trello/tools/create_new_card_on_board.py index 74b73b40e54f5d..e98efb81ca673e 100644 --- a/api/core/tools/provider/builtin/trello/tools/create_new_card_on_board.py +++ b/api/core/tools/provider/builtin/trello/tools/create_new_card_on_board.py @@ -17,20 +17,21 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool, Args: user_id (str): The ID of the user invoking the tool. - tool_parameters (dict[str, Union[str, int, bool, None]]): The parameters for the tool invocation, including details for the new card. + tool_parameters (dict[str, Union[str, int, bool, None]]): The parameters for the tool invocation, + including details for the new card. Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") # Ensure required parameters are present - if 'name' not in tool_parameters or 'idList' not in tool_parameters: + if "name" not in tool_parameters or "idList" not in tool_parameters: return self.create_text_message("Missing required parameters: name or idList.") url = "https://api.trello.com/1/cards" - params = {**tool_parameters, 'key': api_key, 'token': token} + params = {**tool_parameters, "key": api_key, "token": token} try: response = requests.post(url, params=params) @@ -39,5 +40,6 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool, except requests.exceptions.RequestException as e: return self.create_text_message("Failed to create card") - return self.create_text_message(text=f"New card '{new_card['name']}' created successfully with ID {new_card['id']}.") - + return self.create_text_message( + text=f"New card '{new_card['name']}' created successfully with ID {new_card['id']}." + ) diff --git a/api/core/tools/provider/builtin/trello/tools/delete_board.py b/api/core/tools/provider/builtin/trello/tools/delete_board.py index 29df3fda2d23ec..7fc9d1f13c2664 100644 --- a/api/core/tools/provider/builtin/trello/tools/delete_board.py +++ b/api/core/tools/provider/builtin/trello/tools/delete_board.py @@ -17,14 +17,15 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] Args: user_id (str): The ID of the user invoking the tool. - tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, including the board ID. + tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, + including the board ID. Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -38,4 +39,3 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] return self.create_text_message("Failed to delete board") return self.create_text_message(text=f"Board with ID {board_id} deleted successfully.") - diff --git a/api/core/tools/provider/builtin/trello/tools/delete_card.py b/api/core/tools/provider/builtin/trello/tools/delete_card.py index 2ced5f6c14f9f7..1de98d639ebb7d 100644 --- a/api/core/tools/provider/builtin/trello/tools/delete_card.py +++ b/api/core/tools/provider/builtin/trello/tools/delete_card.py @@ -17,14 +17,15 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] Args: user_id (str): The ID of the user invoking the tool. - tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, including the card ID. + tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, + including the card ID. Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - card_id = tool_parameters.get('id') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + card_id = tool_parameters.get("id") if not (api_key and token and card_id): return self.create_text_message("Missing required parameters: API key, token, or card ID.") @@ -38,4 +39,3 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] return self.create_text_message("Failed to delete card") return self.create_text_message(text=f"Card with ID {card_id} has been successfully deleted.") - diff --git a/api/core/tools/provider/builtin/trello/tools/fetch_all_boards.py b/api/core/tools/provider/builtin/trello/tools/fetch_all_boards.py index f9d554c6fb0478..0c5ed9ea8533ff 100644 --- a/api/core/tools/provider/builtin/trello/tools/fetch_all_boards.py +++ b/api/core/tools/provider/builtin/trello/tools/fetch_all_boards.py @@ -28,9 +28,7 @@ def _invoke( token = self.runtime.credentials.get("trello_api_token") if not (api_key and token): - return self.create_text_message( - "Missing Trello API key or token in credentials." - ) + return self.create_text_message("Missing Trello API key or token in credentials.") # Including board filter in the request if provided board_filter = tool_parameters.get("boards", "open") @@ -48,7 +46,5 @@ def _invoke( return self.create_text_message("No boards found in Trello.") # Creating a string with both board names and IDs - boards_info = ", ".join( - [f"{board['name']} (ID: {board['id']})" for board in boards] - ) + boards_info = ", ".join([f"{board['name']} (ID: {board['id']})" for board in boards]) return self.create_text_message(text=f"Boards: {boards_info}") diff --git a/api/core/tools/provider/builtin/trello/tools/get_board_actions.py b/api/core/tools/provider/builtin/trello/tools/get_board_actions.py index 5678d8f8d76d7c..cabc7ce09359d5 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_board_actions.py +++ b/api/core/tools/provider/builtin/trello/tools/get_board_actions.py @@ -17,14 +17,15 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] Args: user_id (str): The ID of the user invoking the tool. - tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, including the board ID. + tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, + including the board ID. Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -38,6 +39,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] except requests.exceptions.RequestException as e: return self.create_text_message("Failed to retrieve board actions") - actions_summary = "\n".join([f"{action['type']}: {action.get('data', {}).get('text', 'No details available')}" for action in actions]) + actions_summary = "\n".join( + [f"{action['type']}: {action.get('data', {}).get('text', 'No details available')}" for action in actions] + ) return self.create_text_message(text=f"Actions for Board ID {board_id}:\n{actions_summary}") - diff --git a/api/core/tools/provider/builtin/trello/tools/get_board_by_id.py b/api/core/tools/provider/builtin/trello/tools/get_board_by_id.py index ee6cb065e5a9fa..fe42cd9c5cbf86 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_board_by_id.py +++ b/api/core/tools/provider/builtin/trello/tools/get_board_by_id.py @@ -17,14 +17,15 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] Args: user_id (str): The ID of the user invoking the tool. - tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, including the board ID. + tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, + including the board ID. Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -63,4 +64,3 @@ def format_board_details(self, board: dict) -> str: f"Background Color: {board['prefs']['backgroundColor']}" ) return details - diff --git a/api/core/tools/provider/builtin/trello/tools/get_board_cards.py b/api/core/tools/provider/builtin/trello/tools/get_board_cards.py index 1abb688750af53..ff2b1221e767de 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_board_cards.py +++ b/api/core/tools/provider/builtin/trello/tools/get_board_cards.py @@ -17,14 +17,15 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] Args: user_id (str): The ID of the user invoking the tool. - tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, including the board ID. + tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, + including the board ID. Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -40,4 +41,3 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] cards_summary = "\n".join([f"{card['name']} (ID: {card['id']})" for card in cards]) return self.create_text_message(text=f"Cards for Board ID {board_id}:\n{cards_summary}") - diff --git a/api/core/tools/provider/builtin/trello/tools/get_filterd_board_cards.py b/api/core/tools/provider/builtin/trello/tools/get_filterd_board_cards.py index 375ead5b1d232a..3d7f9f4ad1c996 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_filterd_board_cards.py +++ b/api/core/tools/provider/builtin/trello/tools/get_filterd_board_cards.py @@ -17,15 +17,16 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] Args: user_id (str): The ID of the user invoking the tool. - tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, including the board ID and filter. + tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, + including the board ID and filter. Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') - filter = tool_parameters.get('filter') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") + filter = tool_parameters.get("filter") if not (api_key and token and board_id and filter): return self.create_text_message("Missing required parameters: API key, token, board ID, or filter.") @@ -40,5 +41,6 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] return self.create_text_message("Failed to retrieve filtered cards") card_details = "\n".join([f"{card['name']} (ID: {card['id']})" for card in filtered_cards]) - return self.create_text_message(text=f"Filtered Cards for Board ID {board_id} with Filter '{filter}':\n{card_details}") - + return self.create_text_message( + text=f"Filtered Cards for Board ID {board_id} with Filter '{filter}':\n{card_details}" + ) diff --git a/api/core/tools/provider/builtin/trello/tools/get_lists_on_board.py b/api/core/tools/provider/builtin/trello/tools/get_lists_on_board.py index 7b9b9cf24b7543..ccf404068f225e 100644 --- a/api/core/tools/provider/builtin/trello/tools/get_lists_on_board.py +++ b/api/core/tools/provider/builtin/trello/tools/get_lists_on_board.py @@ -17,14 +17,15 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] Args: user_id (str): The ID of the user invoking the tool. - tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, including the board ID. + tool_parameters (dict[str, Union[str, int, bool]]): The parameters for the tool invocation, + including the board ID. Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.get('boardId') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.get("boardId") if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -40,4 +41,3 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool] lists_info = "\n".join([f"{list['name']} (ID: {list['id']})" for list in lists]) return self.create_text_message(text=f"Lists on Board ID {board_id}:\n{lists_info}") - diff --git a/api/core/tools/provider/builtin/trello/tools/update_board.py b/api/core/tools/provider/builtin/trello/tools/update_board.py index 7ad6ac2e64ef46..1e358b00f49add 100644 --- a/api/core/tools/provider/builtin/trello/tools/update_board.py +++ b/api/core/tools/provider/builtin/trello/tools/update_board.py @@ -17,14 +17,15 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool, Args: user_id (str): The ID of the user invoking the tool. - tool_parameters (dict[str, Union[str, int, bool, None]]): The parameters for the tool invocation, including board ID and updates. + tool_parameters (dict[str, Union[str, int, bool, None]]): The parameters for the tool invocation, + including board ID and updates. Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - board_id = tool_parameters.pop('boardId', None) + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + board_id = tool_parameters.pop("boardId", None) if not (api_key and token and board_id): return self.create_text_message("Missing required parameters: API key, token, or board ID.") @@ -33,8 +34,8 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool, # Removing parameters not intended for update action or with None value params = {k: v for k, v in tool_parameters.items() if v is not None} - params['key'] = api_key - params['token'] = token + params["key"] = api_key + params["token"] = token try: response = requests.put(url, params=params) @@ -44,4 +45,3 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool, updated_board = response.json() return self.create_text_message(text=f"Board '{updated_board['name']}' updated successfully.") - diff --git a/api/core/tools/provider/builtin/trello/tools/update_card.py b/api/core/tools/provider/builtin/trello/tools/update_card.py index 417344350cbc18..d25fcbafaa6326 100644 --- a/api/core/tools/provider/builtin/trello/tools/update_card.py +++ b/api/core/tools/provider/builtin/trello/tools/update_card.py @@ -17,22 +17,23 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Union[str, int, bool, Args: user_id (str): The ID of the user invoking the tool. - tool_parameters (dict[str, Union[str, int, bool, None]]): The parameters for the tool invocation, including the card ID and updates. + tool_parameters (dict[str, Union[str, int, bool, None]]): The parameters for the tool invocation, + including the card ID and updates. Returns: ToolInvokeMessage: The result of the tool invocation. """ - api_key = self.runtime.credentials.get('trello_api_key') - token = self.runtime.credentials.get('trello_api_token') - card_id = tool_parameters.get('id') + api_key = self.runtime.credentials.get("trello_api_key") + token = self.runtime.credentials.get("trello_api_token") + card_id = tool_parameters.get("id") if not (api_key and token and card_id): return self.create_text_message("Missing required parameters: API key, token, or card ID.") # Constructing the URL and the payload for the PUT request url = f"https://api.trello.com/1/cards/{card_id}" - params = {k: v for k, v in tool_parameters.items() if v is not None and k != 'id'} - params.update({'key': api_key, 'token': token}) + params = {k: v for k, v in tool_parameters.items() if v is not None and k != "id"} + params.update({"key": api_key, "token": token}) try: response = requests.put(url, params=params) diff --git a/api/core/tools/provider/builtin/trello/trello.py b/api/core/tools/provider/builtin/trello/trello.py index 84ecd208037037..e0dca50ec99aee 100644 --- a/api/core/tools/provider/builtin/trello/trello.py +++ b/api/core/tools/provider/builtin/trello/trello.py @@ -9,17 +9,17 @@ class TrelloProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: """Validate Trello API credentials by making a test API call. - + Args: credentials (dict[str, Any]): The Trello API credentials to validate. - + Raises: ToolProviderCredentialValidationError: If the credentials are invalid. """ api_key = credentials.get("trello_api_key") token = credentials.get("trello_api_token") url = f"https://api.trello.com/1/members/me?key={api_key}&token={token}" - + try: response = requests.get(url) response.raise_for_status() # Raises an HTTPError for bad responses @@ -32,4 +32,3 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: except requests.exceptions.RequestException as e: # Handle other exceptions, such as connection errors raise ToolProviderCredentialValidationError("Error validating Trello credentials") - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/twilio/tools/send_message.py b/api/core/tools/provider/builtin/twilio/tools/send_message.py index 1c52589956c708..5ee839baa56f02 100644 --- a/api/core/tools/provider/builtin/twilio/tools/send_message.py +++ b/api/core/tools/provider/builtin/twilio/tools/send_message.py @@ -32,17 +32,14 @@ class TwilioAPIWrapper(BaseModel): must be empty. """ - @field_validator('client', mode='before') + @field_validator("client", mode="before") @classmethod def set_validator(cls, values: dict) -> dict: """Validate that api key and python package exists in environment.""" try: from twilio.rest import Client except ImportError: - raise ImportError( - "Could not import twilio python package. " - "Please install it with `pip install twilio`." - ) + raise ImportError("Could not import twilio python package. Please install it with `pip install twilio`.") account_sid = values.get("account_sid") auth_token = values.get("auth_token") values["from_number"] = values.get("from_number") @@ -75,7 +72,8 @@ class SendMessageTool(BuiltinTool): tool_parameters (Dict[str, Any]): The parameters required for sending the message. Returns: - Union[ToolInvokeMessage, List[ToolInvokeMessage]]: The result of invoking the tool, which includes the status of the message sending operation. + Union[ToolInvokeMessage, List[ToolInvokeMessage]]: The result of invoking the tool, + which includes the status of the message sending operation. """ def _invoke( @@ -91,9 +89,7 @@ def _invoke( if to_number.startswith("whatsapp:"): from_number = f"whatsapp: {from_number}" - twilio = TwilioAPIWrapper( - account_sid=account_sid, auth_token=auth_token, from_number=from_number - ) + twilio = TwilioAPIWrapper(account_sid=account_sid, auth_token=auth_token, from_number=from_number) # Sending the message through Twilio result = twilio.run(message, to_number) diff --git a/api/core/tools/provider/builtin/twilio/twilio.py b/api/core/tools/provider/builtin/twilio/twilio.py index 06f276053a9c63..b1d100aad93dba 100644 --- a/api/core/tools/provider/builtin/twilio/twilio.py +++ b/api/core/tools/provider/builtin/twilio/twilio.py @@ -14,7 +14,7 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: account_sid = credentials["account_sid"] auth_token = credentials["auth_token"] from_number = credentials["from_number"] - + # Initialize twilio client client = Client(account_sid, auth_token) @@ -27,4 +27,3 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: raise ToolProviderCredentialValidationError(f"Missing required credential: {e}") from e except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/vanna/tools/vanna.py b/api/core/tools/provider/builtin/vanna/tools/vanna.py index a6efb0f79a5d93..1c7cb39c92b40b 100644 --- a/api/core/tools/provider/builtin/vanna/tools/vanna.py +++ b/api/core/tools/provider/builtin/vanna/tools/vanna.py @@ -35,10 +35,11 @@ def _invoke( password = tool_parameters.get("password", "") port = tool_parameters.get("port", 0) - vn = VannaDefault(model=model, api_key=api_key) + base_url = self.runtime.credentials.get("base_url", None) + vn = VannaDefault(model=model, api_key=api_key, config={"endpoint": base_url}) db_type = tool_parameters.get("db_type", "") - if db_type in ["Postgres", "MySQL", "Hive", "ClickHouse"]: + if db_type in {"Postgres", "MySQL", "Hive", "ClickHouse"}: if not db_name: return self.create_text_message("Please input database name") if not username: @@ -111,9 +112,10 @@ def _invoke( # with "visualize" set to True (default behavior) leads to remote code execution. # Affected versions: <= 0.5.5 ######################################################################################### - generate_chart = False - # generate_chart = tool_parameters.get("generate_chart", True) - res = vn.ask(prompt, False, True, generate_chart) + allow_llm_to_see_data = tool_parameters.get("allow_llm_to_see_data", False) + res = vn.ask( + prompt, print_results=False, auto_train=True, visualize=False, allow_llm_to_see_data=allow_llm_to_see_data + ) result = [] diff --git a/api/core/tools/provider/builtin/vanna/tools/vanna.yaml b/api/core/tools/provider/builtin/vanna/tools/vanna.yaml index ae2eae94c4dbc4..12ca8a862e966f 100644 --- a/api/core/tools/provider/builtin/vanna/tools/vanna.yaml +++ b/api/core/tools/provider/builtin/vanna/tools/vanna.yaml @@ -200,14 +200,14 @@ parameters: en_US: If enabled, it will attempt to train on the metadata of that database zh_Hans: 是否自动从数据库获取元数据来训练 form: form - - name: generate_chart + - name: allow_llm_to_see_data type: boolean required: false - default: True + default: false label: - en_US: Generate Charts - zh_Hans: 生成图表 + en_US: Whether to allow the LLM to see the data + zh_Hans: 是否允许LLM查看数据 human_description: - en_US: Generate Charts - zh_Hans: 是否生成图表 + en_US: Whether to allow the LLM to see the data + zh_Hans: 是否允许LLM查看数据 form: form diff --git a/api/core/tools/provider/builtin/vanna/vanna.py b/api/core/tools/provider/builtin/vanna/vanna.py index ab1fd71df5e191..1d71414bf33252 100644 --- a/api/core/tools/provider/builtin/vanna/vanna.py +++ b/api/core/tools/provider/builtin/vanna/vanna.py @@ -1,4 +1,6 @@ +import re from typing import Any +from urllib.parse import urlparse from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.vanna.tools.vanna import VannaTool @@ -6,20 +8,39 @@ class VannaProvider(BuiltinToolProviderController): + def _get_protocol_and_main_domain(self, url): + parsed_url = urlparse(url) + protocol = parsed_url.scheme + hostname = parsed_url.hostname + port = f":{parsed_url.port}" if parsed_url.port else "" + + # Check if the hostname is an IP address + is_ip = re.match(r"^\d{1,3}(\.\d{1,3}){3}$", hostname) is not None + + # Return the full hostname (with port if present) for IP addresses, otherwise return the main domain + main_domain = f"{hostname}{port}" if is_ip else ".".join(hostname.split(".")[-2:]) + port + return f"{protocol}://{main_domain}" + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + base_url = credentials.get("base_url") + if not base_url: + base_url = "https://ask.vanna.ai/rpc" + else: + base_url = base_url.removesuffix("/") + credentials["base_url"] = base_url try: VannaTool().fork_tool_runtime( runtime={ "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "model": "chinook", "db_type": "SQLite", - "url": "https://vanna.ai/Chinook.sqlite", - "query": "What are the top 10 customers by sales?" + "url": f'{self._get_protocol_and_main_domain(credentials["base_url"])}/Chinook.sqlite', + "query": "What are the top 10 customers by sales?", }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/vanna/vanna.yaml b/api/core/tools/provider/builtin/vanna/vanna.yaml index b29fa103e1d8c9..cf3fdca562c0b3 100644 --- a/api/core/tools/provider/builtin/vanna/vanna.yaml +++ b/api/core/tools/provider/builtin/vanna/vanna.yaml @@ -8,6 +8,9 @@ identity: en_US: The fastest way to get actionable insights from your database just by asking questions. zh_Hans: 一个基于大模型和RAG的Text2SQL工具。 icon: icon.png + tags: + - utilities + - productivity credentials_for_provider: api_key: type: secret-input @@ -23,3 +26,10 @@ credentials_for_provider: en_US: Get your API key from Vanna.AI zh_Hans: 从 Vanna.AI 获取你的 API key url: https://vanna.ai/account/profile + base_url: + type: text-input + required: false + label: + en_US: Vanna.AI Endpoint Base URL + placeholder: + en_US: https://ask.vanna.ai/rpc diff --git a/api/core/tools/provider/builtin/vectorizer/tools/test_data.py b/api/core/tools/provider/builtin/vectorizer/tools/test_data.py deleted file mode 100644 index 1506ac0c9ded93..00000000000000 --- a/api/core/tools/provider/builtin/vectorizer/tools/test_data.py +++ /dev/null @@ -1 +0,0 @@ -VECTORIZER_ICON_PNG = 'iVBORw0KGgoAAAANSUhEUgAAAGAAAABgCAYAAADimHc4AAAACXBIWXMAACxLAAAsSwGlPZapAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAboSURBVHgB7Z09bBxFFMffRoAvcQqbguBUxu4wCUikMCZ0TmQK4NLQJCJOlQIkokgEGhQ7NCFIKEhQuIqNnIaGMxRY2GVwmlggDHS+pIHELmIXMTEULPP3eeXz7e7szO7MvE1ufpKV03nuNn7/mfcxH7tEHo/H42lXgqwG1bGw65+/aTQM6K0gpJdCoi7ypCIMui5s9Qv9R1OVTqrVxoL1jPbpvH4hrIp/rnmj5+YOhTQ++1kwmdZgT9ovRi6EF4Xhv/XGL0Sv6OLXYMu0BokjYOSDcBQfJI8xhKFP/HAlqCW8v5vqubBr8yn6maCexxiIDR376LnWmBBzQZtPEvx+L3mMAleOZKb1/XgM2EOnyWMFZJKt78UEQKpJHisk2TYmgM967JFk2z3kYcULwIwXgBkvADNeAGa8AMw8Qcwc6N55/eAh0cYmGaOzQtR/kOhQX+M6+/c23r+3RlT/i2ipTrSyRqw4F+CwMMbgANHQwG7jRywLw/wqDDNzI79xYPjqa2L262jjtYzaT0QT3xEbsck4MXUakgWOvUx08liy0ZPYEKNhel4Y6AZpgR7/8Tvq1wEQ+sMJN6Nh9kqwy+bWYwAM8elZovNv6xmlU7iLs280RNO9ls51os/h/8eBVQEig8Dt5OXUsNrno2tluZw0cI3qUXKONQHy9sYkVHqnjntLA2LnFTAv1gSA+zBhfIDvkfVO/B4xRgWZn4fbe2WAnGJFAAxn03+I7PtUXdzE90Sjl4ne+6L4d5nCigAyYyHPn7tFdPN30uJwX/qI6jtISkQZFVLdhd9SrtNPTrFSB6QZBAaYntsptpAyfvk+KYOCamVR/XrNtLqepduiFnkh3g4iIw6YLAhlOJmKwB9zaarhApr/MPREjAZVisSU1s/KYsGzhmKXClYEWLm/8xpV7btXhcv5I7lt2vtJFA3q/T07r1HopdG5l5xhxQVdn28YFn8kBJCBOZmiPHio1m5QuJzlu9ntXApgZwSsNYJslvGjtjrfm8Sq4neceFUtz3dZCzwW09Gqo2hreuPN7HZRnNqa1BP1x8lhczVNK+zT0TqkjYAF4e7Okxoo2PZX5K4IrhNpb/P8FTK2S1+TcUq1HpBFmquJYo1qEYU6RVarJE0c2ooL7C5IRwBZ5nJ9joyRtk5hA3YBdHqWzG1gBKgE/bzMaK5LqMIugKrbUDHu59/YWVRBsWhrsYZdANV5HBUXYGNlC9dFBW8LdgH6FQVYUnQvkQgm3NH8YuO7bM4LsWZBfT3qRY9OxRyJgJRz+Ij+FDPEQ1C3GVMiWAVQ7f31u/ncytxi4wdZTbRGgdcHnpYLD/FcwSrAoOKizfKfVAiIF4kBMPK+Opfe1iWsMUB1BJh2BRgBabSNAOiFqkXYbcNFUF9P+u82FGdWTcEmgGrvh0FUppB1kC073muXEaDq/21kIjLxV9tFAC7/n5X6tkUM0PH/dcP+P0v41fvkFBYBVHs/MD0CDmVsOzEdb7JgEYDT/8uq4rpj44NSjwDTc/CyzV1gxbH7Ac4F0PH/S4ZHAOaFZLiY+2nFuQA6/t9kQMTCz1CG66tbWvWS4VwAVf9vugAbel6efqrsYbKBcwFeVNz8ajobyTppw2F84FQAnfl/kwER6wJZcWdBc7e2KZwKoOP/TVakWb0f7md+kVhwOwI0BDCFyq42rt4PSiuAiRGAEXdK4ZQlV+8HTgVwefwHvR7nhbOA0FwBGDgTIM/Z3SLXUj2hOW1wR10eSrs7Ou9eTB3jo/dzuh/gTABdn35c8dhpM3BxOmeTuXs/cDoCdDY4qe7l32pbaZxL1jF+GXo/cLotBcWVTiZU3T7RMn8rHiijW9FgauP4Ef1TLdhHWgacCgAj6tYCqGKjU/DNbqxIkMYZNs7MpxmnLuhmwYJna1dbdzHjY42hDL4/wqkA6HWuDkAngRH0iYVjRkVwnoZO/0gsuLwpkw7OBcAtwlwvfESHxctmfMBSiOG0oStj4HCF7T3+RWARwIU7QK/HbWlqls52mYJtezqMj3v34C5VOveFy8Ll4QoTsJ8Txp0RsW8/Os2im2LCtSC1RIqLw3RldTVplOKkPEYDhMAPqttnune2rzTv5Y+WKdEem2ixkWqZYSeDSUp3qwIYNOrR7cBjcbOORxkvADNeAGa8AMx4AZjxAjATf5Ab0Tp5rJBk2/iD3PAwYo8Vkmyb9CjDGfLYIaCp1rdiAnT8S5PeDVkgoDuVCsWeJxwToHZ163m3Z8hjloDGk54vn5gFbT/5eZw8phifvZz8XPlA9qmRj8JRCumi+OkljzbbrvxM0qPMm9rIqY6FXZubVBUinMbzcP3jbuXA6Mh2kMx07KPJJLfj8Xg8Hg/4H+KfFYb2WM4MAAAAAElFTkSuQmCC' \ No newline at end of file diff --git a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py index c6ec1980342d75..c722cd36c84e15 100644 --- a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py @@ -1,74 +1,82 @@ -from base64 import b64decode from typing import Any, Union from httpx import post +from core.file.enums import FileType +from core.file.file_manager import download +from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from core.tools.errors import ToolProviderCredentialValidationError -from core.tools.provider.builtin.vectorizer.tools.test_data import VECTORIZER_ICON_PNG +from core.tools.errors import ToolParameterValidationError from core.tools.tool.builtin_tool import BuiltinTool class VectorizerTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - api_key_name = self.runtime.credentials.get('api_key_name', None) - api_key_value = self.runtime.credentials.get('api_key_value', None) - mode = tool_parameters.get('mode', 'test') - if mode == 'production': - mode = 'preview' + api_key_name = self.runtime.credentials.get("api_key_name") + api_key_value = self.runtime.credentials.get("api_key_value") + mode = tool_parameters.get("mode", "test") - if not api_key_name or not api_key_value: - raise ToolProviderCredentialValidationError('Please input api key name and value') + # image file for workflow mode + image = tool_parameters.get("image") + if image and image.type != FileType.IMAGE: + raise ToolParameterValidationError("Not a valid image") + # image_id for agent mode + image_id = tool_parameters.get("image_id", "") - image_id = tool_parameters.get('image_id', '') - if not image_id: - return self.create_text_message('Please input image id') - - if image_id.startswith('__test_'): - image_binary = b64decode(VECTORIZER_ICON_PNG) - else: - image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE) + if image_id: + image_binary = self.get_variable_file(self.VariableKey.IMAGE) if not image_binary: - return self.create_text_message('Image not found, please request user to generate image firstly.') + return self.create_text_message("Image not found, please request user to generate image firstly.") + elif image: + image_binary = download(image) + else: + raise ToolParameterValidationError("Please provide either image or image_id") response = post( - 'https://vectorizer.ai/api/v1/vectorize', - files={ - 'image': image_binary - }, - data={ - 'mode': mode - } if mode == 'test' else {}, - auth=(api_key_name, api_key_value), - timeout=30 + "https://vectorizer.ai/api/v1/vectorize", + data={"mode": mode}, + files={"image": image_binary}, + auth=(api_key_name, api_key_value), + timeout=30, ) if response.status_code != 200: raise Exception(response.text) - + return [ - self.create_text_message('the vectorized svg is saved as an image.'), - self.create_blob_message(blob=response.content, - meta={'mime_type': 'image/svg+xml'}) + self.create_text_message("the vectorized svg is saved as an image."), + self.create_blob_message(blob=response.content, meta={"mime_type": "image/svg+xml"}), ] - + def get_runtime_parameters(self) -> list[ToolParameter]: """ override the runtime parameters """ return [ ToolParameter.get_simple_instance( - name='image_id', - llm_description=f'the image id that you want to vectorize, \ - and the image id should be specified in \ - {[i.name for i in self.list_default_image_variables()]}', + name="image_id", + llm_description=f"the image_id that you want to vectorize, \ + and the image_id should be specified in \ + {[i.name for i in self.list_default_image_variables()]}", type=ToolParameter.ToolParameterType.SELECT, - required=True, - options=[i.name for i in self.list_default_image_variables()] - ) + required=False, + options=[i.name for i in self.list_default_image_variables()], + ), + ToolParameter( + name="image", + label=I18nObject(en_US="image", zh_Hans="image"), + human_description=I18nObject( + en_US="The image to be converted.", + zh_Hans="要转换的图片。", + ), + type=ToolParameter.ToolParameterType.FILE, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="you should not input this parameter. just input the image_id.", + required=False, + ), ] - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.yaml b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.yaml index 4b4fb9e2452c3c..0afd1c201f9126 100644 --- a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.yaml +++ b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.yaml @@ -4,14 +4,21 @@ identity: label: en_US: Vectorizer.AI zh_Hans: Vectorizer.AI - pt_BR: Vectorizer.AI description: human: en_US: Convert your PNG and JPG images to SVG vectors quickly and easily. Fully automatically. Using AI. zh_Hans: 一个将 PNG 和 JPG 图像快速轻松地转换为 SVG 矢量图的工具。 - pt_BR: Convert your PNG and JPG images to SVG vectors quickly and easily. Fully automatically. Using AI. llm: A tool for converting images to SVG vectors. you should input the image id as the input of this tool. the image id can be got from parameters. parameters: + - name: image + type: file + label: + en_US: image + human_description: + en_US: The image to be converted. + zh_Hans: 要转换的图片。 + llm_description: you should not input this parameter. just input the image_id. + form: llm - name: mode type: select required: true @@ -20,19 +27,15 @@ parameters: label: en_US: production zh_Hans: 生产模式 - pt_BR: production - value: test label: en_US: test zh_Hans: 测试模式 - pt_BR: test default: test label: en_US: Mode zh_Hans: 模式 - pt_BR: Mode human_description: en_US: It is free to integrate with and test out the API in test mode, no subscription required. zh_Hans: 在测试模式下,可以免费测试API。 - pt_BR: It is free to integrate with and test out the API in test mode, no subscription required. form: form diff --git a/api/core/tools/provider/builtin/vectorizer/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/vectorizer.py index 3f89a83500da9f..211ec78f4d6a58 100644 --- a/api/core/tools/provider/builtin/vectorizer/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/vectorizer.py @@ -1,24 +1,32 @@ from typing import Any +from core.file import FileTransferMethod, FileType from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.vectorizer.tools.vectorizer import VectorizerTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from factories import file_factory class VectorizerProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: + mapping = { + "transfer_method": FileTransferMethod.TOOL_FILE, + "type": FileType.IMAGE, + "id": "test_id", + "url": "https://cloud.dify.ai/logo/logo-site.png", + } + test_img = file_factory.build_from_mapping( + mapping=mapping, + tenant_id="__test_123", + ) try: VectorizerTool().fork_tool_runtime( runtime={ "credentials": credentials, } ).invoke( - user_id='', - tool_parameters={ - "mode": "test", - "image_id": "__test_123" - }, + user_id="", + tool_parameters={"mode": "test", "image": test_img}, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/vectorizer/vectorizer.yaml b/api/core/tools/provider/builtin/vectorizer/vectorizer.yaml index 1257f8d285c986..94dae2087609d4 100644 --- a/api/core/tools/provider/builtin/vectorizer/vectorizer.yaml +++ b/api/core/tools/provider/builtin/vectorizer/vectorizer.yaml @@ -4,11 +4,9 @@ identity: label: en_US: Vectorizer.AI zh_Hans: Vectorizer.AI - pt_BR: Vectorizer.AI description: en_US: Convert your PNG and JPG images to SVG vectors quickly and easily. Fully automatically. Using AI. zh_Hans: 一个将 PNG 和 JPG 图像快速轻松地转换为 SVG 矢量图的工具。 - pt_BR: Convert your PNG and JPG images to SVG vectors quickly and easily. Fully automatically. Using AI. icon: icon.png tags: - productivity @@ -20,15 +18,12 @@ credentials_for_provider: label: en_US: Vectorizer.AI API Key name zh_Hans: Vectorizer.AI API Key name - pt_BR: Vectorizer.AI API Key name placeholder: en_US: Please input your Vectorizer.AI ApiKey name zh_Hans: 请输入你的 Vectorizer.AI ApiKey name - pt_BR: Please input your Vectorizer.AI ApiKey name help: en_US: Get your Vectorizer.AI API Key from Vectorizer.AI. zh_Hans: 从 Vectorizer.AI 获取您的 Vectorizer.AI API Key。 - pt_BR: Get your Vectorizer.AI API Key from Vectorizer.AI. url: https://vectorizer.ai/api api_key_value: type: secret-input @@ -36,12 +31,9 @@ credentials_for_provider: label: en_US: Vectorizer.AI API Key zh_Hans: Vectorizer.AI API Key - pt_BR: Vectorizer.AI API Key placeholder: en_US: Please input your Vectorizer.AI ApiKey zh_Hans: 请输入你的 Vectorizer.AI ApiKey - pt_BR: Please input your Vectorizer.AI ApiKey help: en_US: Get your Vectorizer.AI API Key from Vectorizer.AI. zh_Hans: 从 Vectorizer.AI 获取您的 Vectorizer.AI API Key。 - pt_BR: Get your Vectorizer.AI API Key from Vectorizer.AI. diff --git a/api/core/tools/provider/builtin/webscraper/tools/webscraper.py b/api/core/tools/provider/builtin/webscraper/tools/webscraper.py index 3d098e6768a8e7..12670b4b8b9289 100644 --- a/api/core/tools/provider/builtin/webscraper/tools/webscraper.py +++ b/api/core/tools/provider/builtin/webscraper/tools/webscraper.py @@ -6,23 +6,24 @@ class WebscraperTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ try: - url = tool_parameters.get('url', '') - user_agent = tool_parameters.get('user_agent', '') + url = tool_parameters.get("url", "") + user_agent = tool_parameters.get("user_agent", "") if not url: - return self.create_text_message('Please input url') + return self.create_text_message("Please input url") # get webpage result = self.get_url(url, user_agent=user_agent) - if tool_parameters.get('generate_summary'): + if tool_parameters.get("generate_summary"): # summarize and return return self.create_text_message(self.summary(user_id=user_id, content=result)) else: diff --git a/api/core/tools/provider/builtin/webscraper/webscraper.py b/api/core/tools/provider/builtin/webscraper/webscraper.py index 1e60fdb2939d93..3c51393ac64cc4 100644 --- a/api/core/tools/provider/builtin/webscraper/webscraper.py +++ b/api/core/tools/provider/builtin/webscraper/webscraper.py @@ -13,12 +13,11 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ - 'url': 'https://www.google.com', - 'user_agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ' + "url": "https://www.google.com", + "user_agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/websearch/tools/job_search.py b/api/core/tools/provider/builtin/websearch/tools/job_search.py index 91283059229645..293f4f63297120 100644 --- a/api/core/tools/provider/builtin/websearch/tools/job_search.py +++ b/api/core/tools/provider/builtin/websearch/tools/job_search.py @@ -50,14 +50,16 @@ def parse_results(res: dict) -> str: for job in jobs[:10]: try: string.append( - "\n".join([ - f"Position: {job['position']}", - f"Employer: {job['employer']}", - f"Location: {job['location']}", - f"Link: {job['link']}", - f"""Highest: {", ".join(list(job["highlights"]))}""", - "---", - ]) + "\n".join( + [ + f"Position: {job['position']}", + f"Employer: {job['employer']}", + f"Location: {job['location']}", + f"Link: {job['link']}", + f"""Highest: {", ".join(list(job["highlights"]))}""", + "---", + ] + ) ) except KeyError: continue diff --git a/api/core/tools/provider/builtin/websearch/tools/news_search.py b/api/core/tools/provider/builtin/websearch/tools/news_search.py index e9c0744f054aa3..9b5482fe183e18 100644 --- a/api/core/tools/provider/builtin/websearch/tools/news_search.py +++ b/api/core/tools/provider/builtin/websearch/tools/news_search.py @@ -53,13 +53,15 @@ def parse_results(res: dict) -> str: r = requests.get(entry["link"]) final_link = r.history[-1].headers["Location"] string.append( - "\n".join([ - f"Title: {entry['title']}", - f"Link: {final_link}", - f"Source: {entry['source']['title']}", - f"Published: {entry['published']}", - "---", - ]) + "\n".join( + [ + f"Title: {entry['title']}", + f"Link: {final_link}", + f"Source: {entry['source']['title']}", + f"Published: {entry['published']}", + "---", + ] + ) ) except KeyError: continue diff --git a/api/core/tools/provider/builtin/websearch/tools/scholar_search.py b/api/core/tools/provider/builtin/websearch/tools/scholar_search.py index 0030a03c06a5d8..798d059b512edf 100644 --- a/api/core/tools/provider/builtin/websearch/tools/scholar_search.py +++ b/api/core/tools/provider/builtin/websearch/tools/scholar_search.py @@ -55,14 +55,16 @@ def parse_results(res: dict) -> str: link = article["link"] authors = [author["name"] for author in article["author"]["authors"]] string.append( - "\n".join([ - f"Title: {article['title']}", - f"Link: {link}", - f"Description: {article['description']}", - f"Cite: {article['cite']}", - f"Authors: {', '.join(authors)}", - "---", - ]) + "\n".join( + [ + f"Title: {article['title']}", + f"Link: {link}", + f"Description: {article['description']}", + f"Cite: {article['cite']}", + f"Authors: {', '.join(authors)}", + "---", + ] + ) ) except KeyError: continue diff --git a/api/core/tools/provider/builtin/websearch/tools/web_search.py b/api/core/tools/provider/builtin/websearch/tools/web_search.py index 4f57c27caf5257..fe363ac7a4d5d0 100644 --- a/api/core/tools/provider/builtin/websearch/tools/web_search.py +++ b/api/core/tools/provider/builtin/websearch/tools/web_search.py @@ -49,12 +49,14 @@ def parse_results(res: dict) -> str: for result in results: try: string.append( - "\n".join([ - f"Title: {result['title']}", - f"Link: {result['link']}", - f"Description: {result['description'].strip()}", - "---", - ]) + "\n".join( + [ + f"Title: {result['title']}", + f"Link: {result['link']}", + f"Description: {result['description'].strip()}", + "---", + ] + ) ) except KeyError: continue diff --git a/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py b/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py index fb44b70f4ec6b6..545d9f4f8d6335 100644 --- a/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py +++ b/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py @@ -8,41 +8,41 @@ class WecomGroupBotTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any] - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - content = tool_parameters.get('content', '') + content = tool_parameters.get("content", "") if not content: - return self.create_text_message('Invalid parameter content') + return self.create_text_message("Invalid parameter content") - hook_key = tool_parameters.get('hook_key', '') + hook_key = tool_parameters.get("hook_key", "") if not is_valid_uuid(hook_key): - return self.create_text_message( - f'Invalid parameter hook_key ${hook_key}, not a valid UUID') + return self.create_text_message(f"Invalid parameter hook_key ${hook_key}, not a valid UUID") - message_type = tool_parameters.get('message_type', 'text') - if message_type == 'markdown': + message_type = tool_parameters.get("message_type", "text") + if message_type == "markdown": payload = { - "msgtype": 'markdown', + "msgtype": "markdown", "markdown": { "content": content, - } + }, } else: payload = { - "msgtype": 'text', + "msgtype": "text", "text": { "content": content, - } + }, } - api_url = 'https://qyapi.weixin.qq.com/cgi-bin/webhook/send' + api_url = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send" headers = { - 'Content-Type': 'application/json', + "Content-Type": "application/json", } params = { - 'key': hook_key, + "key": hook_key, } try: @@ -51,6 +51,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any] return self.create_text_message("Text message sent successfully") else: return self.create_text_message( - f"Failed to send the text message, status code: {res.status_code}, response: {res.text}") + f"Failed to send the text message, status code: {res.status_code}, response: {res.text}" + ) except Exception as e: return self.create_text_message("Failed to send message to group chat bot. {}".format(e)) diff --git a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py index 0796cd2392a68b..cb88e9519a4346 100644 --- a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py +++ b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py @@ -27,7 +27,7 @@ def __init__(self, doc_content_chars_max: int = 4000): self.doc_content_chars_max = doc_content_chars_max def run(self, query: str, lang: str = "") -> str: - if lang in wikipedia.languages().keys(): + if lang in wikipedia.languages(): self.lang = lang wikipedia.set_lang(self.lang) @@ -83,7 +83,6 @@ def _run( class WikiPediaSearchTool(BuiltinTool): - def _invoke( self, user_id: str, diff --git a/api/core/tools/provider/builtin/wikipedia/wikipedia.py b/api/core/tools/provider/builtin/wikipedia/wikipedia.py index f8038714a5f524..178bf7b0ceb2e9 100644 --- a/api/core/tools/provider/builtin/wikipedia/wikipedia.py +++ b/api/core/tools/provider/builtin/wikipedia/wikipedia.py @@ -11,11 +11,10 @@ def _validate_credentials(self, credentials: dict) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "misaka mikoto", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py b/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py index 8cb9c10ddf499a..9dc5bed824d715 100644 --- a/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py +++ b/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py @@ -8,29 +8,24 @@ class WolframAlphaTool(BuiltinTool): - _base_url = 'https://api.wolframalpha.com/v2/query' + _base_url = "https://api.wolframalpha.com/v2/query" - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], - ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - query = tool_parameters.get('query', '') + query = tool_parameters.get("query", "") if not query: - return self.create_text_message('Please input query') - appid = self.runtime.credentials.get('appid', '') + return self.create_text_message("Please input query") + appid = self.runtime.credentials.get("appid", "") if not appid: - raise ToolProviderCredentialValidationError('Please input appid') - - params = { - 'appid': appid, - 'input': query, - 'includepodid': 'Result', - 'format': 'plaintext', - 'output': 'json' - } + raise ToolProviderCredentialValidationError("Please input appid") + + params = {"appid": appid, "input": query, "includepodid": "Result", "format": "plaintext", "output": "json"} finished = False result = None @@ -45,34 +40,33 @@ def _invoke(self, response_data = response.json() except Exception as e: raise ToolInvokeError(str(e)) - - if 'success' not in response_data['queryresult'] or response_data['queryresult']['success'] != True: - query_result = response_data.get('queryresult', {}) - if query_result.get('error'): - if 'msg' in query_result['error']: - if query_result['error']['msg'] == 'Invalid appid': - raise ToolProviderCredentialValidationError('Invalid appid') - raise ToolInvokeError('Failed to invoke tool') - - if 'didyoumeans' in response_data['queryresult']: + + if "success" not in response_data["queryresult"] or response_data["queryresult"]["success"] != True: + query_result = response_data.get("queryresult", {}) + if query_result.get("error"): + if "msg" in query_result["error"]: + if query_result["error"]["msg"] == "Invalid appid": + raise ToolProviderCredentialValidationError("Invalid appid") + raise ToolInvokeError("Failed to invoke tool") + + if "didyoumeans" in response_data["queryresult"]: # get the most likely interpretation - query = '' + query = "" max_score = 0 - for didyoumean in response_data['queryresult']['didyoumeans']: - if float(didyoumean['score']) > max_score: - query = didyoumean['val'] - max_score = float(didyoumean['score']) + for didyoumean in response_data["queryresult"]["didyoumeans"]: + if float(didyoumean["score"]) > max_score: + query = didyoumean["val"] + max_score = float(didyoumean["score"]) - params['input'] = query + params["input"] = query else: finished = True - if 'souces' in response_data['queryresult']: - return self.create_link_message(response_data['queryresult']['sources']['url']) - elif 'pods' in response_data['queryresult']: - result = response_data['queryresult']['pods'][0]['subpods'][0]['plaintext'] + if "souces" in response_data["queryresult"]: + return self.create_link_message(response_data["queryresult"]["sources"]["url"]) + elif "pods" in response_data["queryresult"]: + result = response_data["queryresult"]["pods"][0]["subpods"][0]["plaintext"] if not finished or not result: - return self.create_text_message('No result found') + return self.create_text_message("No result found") return self.create_text_message(result) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py index ef1aac7ff272c5..7be288b5387f34 100644 --- a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py +++ b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py @@ -13,11 +13,10 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "query": "1+2+....+111", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/xinference/_assets/icon.png b/api/core/tools/provider/builtin/xinference/_assets/icon.png new file mode 100644 index 00000000000000..e58cacbd123b58 Binary files /dev/null and b/api/core/tools/provider/builtin/xinference/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/xinference/tools/stable_diffusion.py b/api/core/tools/provider/builtin/xinference/tools/stable_diffusion.py new file mode 100644 index 00000000000000..a44d3b730a84f9 --- /dev/null +++ b/api/core/tools/provider/builtin/xinference/tools/stable_diffusion.py @@ -0,0 +1,415 @@ +import io +import json +from base64 import b64decode, b64encode +from copy import deepcopy +from typing import Any, Union + +from httpx import get, post +from PIL import Image +from yarl import URL + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolInvokeMessage, + ToolParameter, + ToolParameterOption, +) +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.tool.builtin_tool import BuiltinTool + +# All commented out parameters default to null +DRAW_TEXT_OPTIONS = { + # Prompts + "prompt": "", + "negative_prompt": "", + # "styles": [], + # Seeds + "seed": -1, + "subseed": -1, + "subseed_strength": 0, + "seed_resize_from_h": -1, + "seed_resize_from_w": -1, + # Samplers + "sampler_name": "DPM++ 2M", + # "scheduler": "", + # "sampler_index": "Automatic", + # Latent Space Options + "batch_size": 1, + "n_iter": 1, + "steps": 10, + "cfg_scale": 7, + "width": 512, + "height": 512, + # "restore_faces": True, + # "tiling": True, + "do_not_save_samples": False, + "do_not_save_grid": False, + # "eta": 0, + # "denoising_strength": 0.75, + # "s_min_uncond": 0, + # "s_churn": 0, + # "s_tmax": 0, + # "s_tmin": 0, + # "s_noise": 0, + "override_settings": {}, + "override_settings_restore_afterwards": True, + # Refinement Options + "refiner_checkpoint": "", + "refiner_switch_at": 0, + "disable_extra_networks": False, + # "firstpass_image": "", + # "comments": "", + # High-Resolution Options + "enable_hr": False, + "firstphase_width": 0, + "firstphase_height": 0, + "hr_scale": 2, + # "hr_upscaler": "", + "hr_second_pass_steps": 0, + "hr_resize_x": 0, + "hr_resize_y": 0, + # "hr_checkpoint_name": "", + # "hr_sampler_name": "", + # "hr_scheduler": "", + "hr_prompt": "", + "hr_negative_prompt": "", + # Task Options + # "force_task_id": "", + # Script Options + # "script_name": "", + "script_args": [], + # Output Options + "send_images": True, + "save_images": False, + "alwayson_scripts": {}, + # "infotext": "", +} + + +class StableDiffusionTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + # base url + base_url = self.runtime.credentials.get("base_url", None) + if not base_url: + return self.create_text_message("Please input base_url") + + if tool_parameters.get("model"): + self.runtime.credentials["model"] = tool_parameters["model"] + + model = self.runtime.credentials.get("model", None) + if not model: + return self.create_text_message("Please input model") + api_key = self.runtime.credentials.get("api_key") or "abc" + headers = {"Authorization": f"Bearer {api_key}"} + # set model + try: + url = str(URL(base_url) / "sdapi" / "v1" / "options") + response = post( + url, + json={"sd_model_checkpoint": model}, + headers=headers, + ) + if response.status_code != 200: + raise ToolProviderCredentialValidationError("Failed to set model, please tell user to set model") + except Exception as e: + raise ToolProviderCredentialValidationError("Failed to set model, please tell user to set model") + + # get image id and image variable + image_id = tool_parameters.get("image_id", "") + image_variable = self.get_default_image_variable() + # Return text2img if there's no image ID or no image variable + if not image_id or not image_variable: + return self.text2img(base_url=base_url, tool_parameters=tool_parameters) + + # Proceed with image-to-image generation + return self.img2img(base_url=base_url, tool_parameters=tool_parameters) + + def validate_models(self): + """ + validate models + """ + try: + base_url = self.runtime.credentials.get("base_url", None) + if not base_url: + raise ToolProviderCredentialValidationError("Please input base_url") + model = self.runtime.credentials.get("model", None) + if not model: + raise ToolProviderCredentialValidationError("Please input model") + + api_url = str(URL(base_url) / "sdapi" / "v1" / "sd-models") + response = get(url=api_url, timeout=10) + if response.status_code == 404: + # try draw a picture + self._invoke( + user_id="test", + tool_parameters={ + "prompt": "a cat", + "width": 1024, + "height": 1024, + "steps": 1, + "lora": "", + }, + ) + elif response.status_code != 200: + raise ToolProviderCredentialValidationError("Failed to get models") + else: + models = [d["model_name"] for d in response.json()] + if len([d for d in models if d == model]) > 0: + return self.create_text_message(json.dumps(models)) + else: + raise ToolProviderCredentialValidationError(f"model {model} does not exist") + except Exception as e: + raise ToolProviderCredentialValidationError(f"Failed to get models, {e}") + + def get_sd_models(self) -> list[str]: + """ + get sd models + """ + try: + base_url = self.runtime.credentials.get("base_url", None) + if not base_url: + return [] + api_url = str(URL(base_url) / "sdapi" / "v1" / "sd-models") + response = get(url=api_url, timeout=120) + if response.status_code != 200: + return [] + else: + return [d["model_name"] for d in response.json()] + except Exception as e: + return [] + + def get_sample_methods(self) -> list[str]: + """ + get sample method + """ + try: + base_url = self.runtime.credentials.get("base_url", None) + if not base_url: + return [] + api_url = str(URL(base_url) / "sdapi" / "v1" / "samplers") + response = get(url=api_url, timeout=120) + if response.status_code != 200: + return [] + else: + return [d["name"] for d in response.json()] + except Exception as e: + return [] + + def img2img( + self, base_url: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + generate image + """ + + # Fetch the binary data of the image + image_variable = self.get_default_image_variable() + image_binary = self.get_variable_file(image_variable.name) + if not image_binary: + return self.create_text_message("Image not found, please request user to generate image firstly.") + + # Convert image to RGB and save as PNG + try: + with Image.open(io.BytesIO(image_binary)) as image, io.BytesIO() as buffer: + image.convert("RGB").save(buffer, format="PNG") + image_binary = buffer.getvalue() + except Exception as e: + return self.create_text_message(f"Failed to process the image: {str(e)}") + + # copy draw options + draw_options = deepcopy(DRAW_TEXT_OPTIONS) + # set image options + model = tool_parameters.get("model", "") + draw_options_image = { + "init_images": [b64encode(image_binary).decode("utf-8")], + "denoising_strength": 0.9, + "restore_faces": False, + "script_args": [], + "override_settings": {"sd_model_checkpoint": model}, + "resize_mode": 0, + "image_cfg_scale": 0, + # "mask": None, + "mask_blur_x": 4, + "mask_blur_y": 4, + "mask_blur": 0, + "mask_round": True, + "inpainting_fill": 0, + "inpaint_full_res": True, + "inpaint_full_res_padding": 0, + "inpainting_mask_invert": 0, + "initial_noise_multiplier": 0, + # "latent_mask": None, + "include_init_images": True, + } + # update key and values + draw_options.update(draw_options_image) + draw_options.update(tool_parameters) + + # get prompt lora model + prompt = tool_parameters.get("prompt", "") + lora = tool_parameters.get("lora", "") + model = tool_parameters.get("model", "") + if lora: + draw_options["prompt"] = f"{lora},{prompt}" + else: + draw_options["prompt"] = prompt + api_key = self.runtime.credentials.get("api_key") or "abc" + headers = {"Authorization": f"Bearer {api_key}"} + try: + url = str(URL(base_url) / "sdapi" / "v1" / "img2img") + response = post( + url, + json=draw_options, + timeout=120, + headers=headers, + ) + if response.status_code != 200: + return self.create_text_message("Failed to generate image") + + image = response.json()["images"][0] + + return self.create_blob_message( + blob=b64decode(image), + meta={"mime_type": "image/png"}, + save_as=self.VariableKey.IMAGE.value, + ) + + except Exception as e: + return self.create_text_message("Failed to generate image") + + def text2img( + self, base_url: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + generate image + """ + # copy draw options + draw_options = deepcopy(DRAW_TEXT_OPTIONS) + draw_options.update(tool_parameters) + # get prompt lora model + prompt = tool_parameters.get("prompt", "") + lora = tool_parameters.get("lora", "") + model = tool_parameters.get("model", "") + if lora: + draw_options["prompt"] = f"{lora},{prompt}" + else: + draw_options["prompt"] = prompt + draw_options["override_settings"]["sd_model_checkpoint"] = model + api_key = self.runtime.credentials.get("api_key") or "abc" + headers = {"Authorization": f"Bearer {api_key}"} + try: + url = str(URL(base_url) / "sdapi" / "v1" / "txt2img") + response = post( + url, + json=draw_options, + timeout=120, + headers=headers, + ) + if response.status_code != 200: + return self.create_text_message("Failed to generate image") + + image = response.json()["images"][0] + + return self.create_blob_message( + blob=b64decode(image), + meta={"mime_type": "image/png"}, + save_as=self.VariableKey.IMAGE.value, + ) + + except Exception as e: + return self.create_text_message("Failed to generate image") + + def get_runtime_parameters(self) -> list[ToolParameter]: + parameters = [ + ToolParameter( + name="prompt", + label=I18nObject(en_US="Prompt", zh_Hans="Prompt"), + human_description=I18nObject( + en_US="Image prompt, you can check the official documentation of Stable Diffusion", + zh_Hans="图像提示词,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Image prompt of Stable Diffusion, you should describe the image you want to generate" + " as a list of words as possible as detailed, the prompt must be written in English.", + required=True, + ), + ] + if len(self.list_default_image_variables()) != 0: + parameters.append( + ToolParameter( + name="image_id", + label=I18nObject(en_US="image_id", zh_Hans="image_id"), + human_description=I18nObject( + en_US="Image id of the image you want to generate based on, if you want to generate image based" + " on the default image, you can leave this field empty.", + zh_Hans="您想要生成的图像的图像 ID,如果您想要基于默认图像生成图像,则可以将此字段留空。", + ), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Image id of the original image, you can leave this field empty if you want to" + " generate a new image.", + required=True, + options=[ + ToolParameterOption(value=i.name, label=I18nObject(en_US=i.name, zh_Hans=i.name)) + for i in self.list_default_image_variables() + ], + ) + ) + + if self.runtime.credentials: + try: + models = self.get_sd_models() + if len(models) != 0: + parameters.append( + ToolParameter( + name="model", + label=I18nObject(en_US="Model", zh_Hans="Model"), + human_description=I18nObject( + en_US="Model of Stable Diffusion, you can check the official documentation" + " of Stable Diffusion", + zh_Hans="Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Model of Stable Diffusion, you can check the official documentation" + " of Stable Diffusion", + required=True, + default=models[0], + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in models + ], + ) + ) + + except: + pass + + sample_methods = self.get_sample_methods() + if len(sample_methods) != 0: + parameters.append( + ToolParameter( + name="sampler_name", + label=I18nObject(en_US="Sampling method", zh_Hans="Sampling method"), + human_description=I18nObject( + en_US="Sampling method of Stable Diffusion, you can check the official documentation" + " of Stable Diffusion", + zh_Hans="Stable Diffusion 的Sampling method,您可以查看 Stable Diffusion 的官方文档", + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description="Sampling method of Stable Diffusion, you can check the official documentation" + " of Stable Diffusion", + required=True, + default=sample_methods[0], + options=[ + ToolParameterOption(value=i, label=I18nObject(en_US=i, zh_Hans=i)) for i in sample_methods + ], + ) + ) + return parameters diff --git a/api/core/tools/provider/builtin/xinference/tools/stable_diffusion.yaml b/api/core/tools/provider/builtin/xinference/tools/stable_diffusion.yaml new file mode 100644 index 00000000000000..4f1d17f175c567 --- /dev/null +++ b/api/core/tools/provider/builtin/xinference/tools/stable_diffusion.yaml @@ -0,0 +1,87 @@ +identity: + name: stable_diffusion + author: xinference + label: + en_US: Stable Diffusion + zh_Hans: Stable Diffusion +description: + human: + en_US: Generate images using Stable Diffusion models. + zh_Hans: 使用 Stable Diffusion 模型生成图片。 + llm: draw the image you want based on your prompt. +parameters: + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + human_description: + en_US: Image prompt + zh_Hans: 图像提示词 + llm_description: Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English. + form: llm + - name: model + type: string + required: false + label: + en_US: Model Name + zh_Hans: 模型名称 + human_description: + en_US: Model Name + zh_Hans: 模型名称 + form: form + - name: lora + type: string + required: false + label: + en_US: Lora + zh_Hans: Lora + human_description: + en_US: Lora + zh_Hans: Lora + form: form + - name: steps + type: number + required: false + label: + en_US: Steps + zh_Hans: Steps + human_description: + en_US: Steps + zh_Hans: Steps + form: form + default: 10 + - name: width + type: number + required: false + label: + en_US: Width + zh_Hans: Width + human_description: + en_US: Width + zh_Hans: Width + form: form + default: 1024 + - name: height + type: number + required: false + label: + en_US: Height + zh_Hans: Height + human_description: + en_US: Height + zh_Hans: Height + form: form + default: 1024 + - name: negative_prompt + type: string + required: false + label: + en_US: Negative prompt + zh_Hans: Negative prompt + human_description: + en_US: Negative prompt + zh_Hans: Negative prompt + form: form + default: bad art, ugly, deformed, watermark, duplicated, discontinuous lines diff --git a/api/core/tools/provider/builtin/xinference/xinference.py b/api/core/tools/provider/builtin/xinference/xinference.py new file mode 100644 index 00000000000000..9692e4060e8a87 --- /dev/null +++ b/api/core/tools/provider/builtin/xinference/xinference.py @@ -0,0 +1,24 @@ +import requests + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class XinferenceProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + base_url = credentials.get("base_url", "").removesuffix("/") + api_key = credentials.get("api_key", "") + if not api_key: + api_key = "abc" + credentials["api_key"] = api_key + model = credentials.get("model", "") + if not base_url or not model: + raise ToolProviderCredentialValidationError("Xinference base_url and model is required") + headers = {"Authorization": f"Bearer {api_key}"} + res = requests.post( + f"{base_url}/sdapi/v1/options", + headers=headers, + json={"sd_model_checkpoint": model}, + ) + if res.status_code != 200: + raise ToolProviderCredentialValidationError("Xinference API key is invalid") diff --git a/api/core/tools/provider/builtin/xinference/xinference.yaml b/api/core/tools/provider/builtin/xinference/xinference.yaml new file mode 100644 index 00000000000000..b0c02b9cbcb01a --- /dev/null +++ b/api/core/tools/provider/builtin/xinference/xinference.yaml @@ -0,0 +1,40 @@ +identity: + author: xinference + name: xinference + label: + en_US: Xinference + zh_Hans: Xinference + description: + zh_Hans: Xinference 提供的兼容 Stable Diffusion web ui 的图片生成 API。 + en_US: Stable Diffusion web ui compatible API provided by Xinference. + icon: icon.png + tags: + - image +credentials_for_provider: + base_url: + type: secret-input + required: true + label: + en_US: Base URL + zh_Hans: Xinference 服务器的 Base URL + placeholder: + en_US: Please input Xinference server's Base URL + zh_Hans: 请输入 Xinference 服务器的 Base URL + model: + type: text-input + required: true + label: + en_US: Model + zh_Hans: 模型 + placeholder: + en_US: Please input your model name + zh_Hans: 请输入你的模型名称 + api_key: + type: secret-input + required: false + label: + en_US: API Key + zh_Hans: Xinference 服务器的 API Key + placeholder: + en_US: Please input Xinference server's API Key + zh_Hans: 请输入 Xinference 服务器的 API Key diff --git a/api/core/tools/provider/builtin/yahoo/tools/analytics.py b/api/core/tools/provider/builtin/yahoo/tools/analytics.py index cf511ea8940082..f044fbe5404b0a 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/analytics.py +++ b/api/core/tools/provider/builtin/yahoo/tools/analytics.py @@ -10,27 +10,28 @@ class YahooFinanceAnalyticsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - symbol = tool_parameters.get('symbol', '') + symbol = tool_parameters.get("symbol", "") if not symbol: - return self.create_text_message('Please input symbol') - + return self.create_text_message("Please input symbol") + time_range = [None, None] - start_date = tool_parameters.get('start_date', '') + start_date = tool_parameters.get("start_date", "") if start_date: time_range[0] = start_date else: - time_range[0] = '1800-01-01' + time_range[0] = "1800-01-01" - end_date = tool_parameters.get('end_date', '') + end_date = tool_parameters.get("end_date", "") if end_date: time_range[1] = end_date else: - time_range[1] = datetime.now().strftime('%Y-%m-%d') + time_range[1] = datetime.now().strftime("%Y-%m-%d") stock_data = download(symbol, start=time_range[0], end=time_range[1]) max_segments = min(15, len(stock_data)) @@ -41,30 +42,29 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ end_idx = (i + 1) * rows_per_segment if i < max_segments - 1 else len(stock_data) segment_data = stock_data.iloc[start_idx:end_idx] segment_summary = { - 'Start Date': segment_data.index[0], - 'End Date': segment_data.index[-1], - 'Average Close': segment_data['Close'].mean(), - 'Average Volume': segment_data['Volume'].mean(), - 'Average Open': segment_data['Open'].mean(), - 'Average High': segment_data['High'].mean(), - 'Average Low': segment_data['Low'].mean(), - 'Average Adj Close': segment_data['Adj Close'].mean(), - 'Max Close': segment_data['Close'].max(), - 'Min Close': segment_data['Close'].min(), - 'Max Volume': segment_data['Volume'].max(), - 'Min Volume': segment_data['Volume'].min(), - 'Max Open': segment_data['Open'].max(), - 'Min Open': segment_data['Open'].min(), - 'Max High': segment_data['High'].max(), - 'Min High': segment_data['High'].min(), + "Start Date": segment_data.index[0], + "End Date": segment_data.index[-1], + "Average Close": segment_data["Close"].mean(), + "Average Volume": segment_data["Volume"].mean(), + "Average Open": segment_data["Open"].mean(), + "Average High": segment_data["High"].mean(), + "Average Low": segment_data["Low"].mean(), + "Average Adj Close": segment_data["Adj Close"].mean(), + "Max Close": segment_data["Close"].max(), + "Min Close": segment_data["Close"].min(), + "Max Volume": segment_data["Volume"].max(), + "Min Volume": segment_data["Volume"].min(), + "Max Open": segment_data["Open"].max(), + "Min Open": segment_data["Open"].min(), + "Max High": segment_data["High"].max(), + "Min High": segment_data["High"].min(), } - + summary_data.append(segment_summary) summary_df = pd.DataFrame(summary_data) - + try: return self.create_text_message(str(summary_df.to_dict())) except (HTTPError, ReadTimeout): - return self.create_text_message('There is a internet connection problem. Please try again later.') - \ No newline at end of file + return self.create_text_message("There is a internet connection problem. Please try again later.") diff --git a/api/core/tools/provider/builtin/yahoo/tools/news.py b/api/core/tools/provider/builtin/yahoo/tools/news.py index 4f2922ef3ec1de..ff820430f9f366 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/news.py +++ b/api/core/tools/provider/builtin/yahoo/tools/news.py @@ -8,40 +8,39 @@ class YahooFinanceSearchTickerTool(BuiltinTool): - def _invoke(self,user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: - ''' - invoke tools - ''' - - query = tool_parameters.get('symbol', '') + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + + query = tool_parameters.get("symbol", "") if not query: - return self.create_text_message('Please input symbol') - + return self.create_text_message("Please input symbol") + try: return self.run(ticker=query, user_id=user_id) except (HTTPError, ReadTimeout): - return self.create_text_message('There is a internet connection problem. Please try again later.') + return self.create_text_message("There is a internet connection problem. Please try again later.") def run(self, ticker: str, user_id: str) -> ToolInvokeMessage: company = yfinance.Ticker(ticker) try: if company.isin is None: - return self.create_text_message(f'Company ticker {ticker} not found.') + return self.create_text_message(f"Company ticker {ticker} not found.") except (HTTPError, ReadTimeout, ConnectionError): - return self.create_text_message(f'Company ticker {ticker} not found.') + return self.create_text_message(f"Company ticker {ticker} not found.") links = [] try: - links = [n['link'] for n in company.news if n['type'] == 'STORY'] + links = [n["link"] for n in company.news if n["type"] == "STORY"] except (HTTPError, ReadTimeout, ConnectionError): if not links: - return self.create_text_message(f'There is nothing about {ticker} ticker') + return self.create_text_message(f"There is nothing about {ticker} ticker") if not links: - return self.create_text_message(f'No news found for company that searched with {ticker} ticker.') - - result = '\n\n'.join([ - self.get_url(link) for link in links - ]) + return self.create_text_message(f"No news found for company that searched with {ticker} ticker.") + + result = "\n\n".join([self.get_url(link) for link in links]) return self.create_text_message(self.summary(user_id=user_id, content=result)) diff --git a/api/core/tools/provider/builtin/yahoo/tools/ticker.py b/api/core/tools/provider/builtin/yahoo/tools/ticker.py index 262fff3b25ba93..dfc7e460473c33 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/ticker.py +++ b/api/core/tools/provider/builtin/yahoo/tools/ticker.py @@ -8,19 +8,20 @@ class YahooFinanceSearchTickerTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - query = tool_parameters.get('symbol', '') + query = tool_parameters.get("symbol", "") if not query: - return self.create_text_message('Please input symbol') - + return self.create_text_message("Please input symbol") + try: return self.create_text_message(self.run(ticker=query)) except (HTTPError, ReadTimeout): - return self.create_text_message('There is a internet connection problem. Please try again later.') - + return self.create_text_message("There is a internet connection problem. Please try again later.") + def run(self, ticker: str) -> str: - return str(Ticker(ticker).info) \ No newline at end of file + return str(Ticker(ticker).info) diff --git a/api/core/tools/provider/builtin/yahoo/yahoo.py b/api/core/tools/provider/builtin/yahoo/yahoo.py index 96dbc6c3d0d8e9..8d82084e769703 100644 --- a/api/core/tools/provider/builtin/yahoo/yahoo.py +++ b/api/core/tools/provider/builtin/yahoo/yahoo.py @@ -11,11 +11,10 @@ def _validate_credentials(self, credentials: dict) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ "ticker": "MSFT", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/youtube/tools/videos.py b/api/core/tools/provider/builtin/youtube/tools/videos.py index 7a9b9fce4a921f..95dec2eac9a752 100644 --- a/api/core/tools/provider/builtin/youtube/tools/videos.py +++ b/api/core/tools/provider/builtin/youtube/tools/videos.py @@ -8,60 +8,67 @@ class YoutubeVideosAnalyticsTool(BuiltinTool): - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ - -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke tools + invoke tools """ - channel = tool_parameters.get('channel', '') + channel = tool_parameters.get("channel", "") if not channel: - return self.create_text_message('Please input symbol') - + return self.create_text_message("Please input symbol") + time_range = [None, None] - start_date = tool_parameters.get('start_date', '') + start_date = tool_parameters.get("start_date", "") if start_date: time_range[0] = start_date else: - time_range[0] = '1800-01-01' + time_range[0] = "1800-01-01" - end_date = tool_parameters.get('end_date', '') + end_date = tool_parameters.get("end_date", "") if end_date: time_range[1] = end_date else: - time_range[1] = datetime.now().strftime('%Y-%m-%d') + time_range[1] = datetime.now().strftime("%Y-%m-%d") - if 'google_api_key' not in self.runtime.credentials or not self.runtime.credentials['google_api_key']: - return self.create_text_message('Please input api key') + if "google_api_key" not in self.runtime.credentials or not self.runtime.credentials["google_api_key"]: + return self.create_text_message("Please input api key") - youtube = build('youtube', 'v3', developerKey=self.runtime.credentials['google_api_key']) + youtube = build("youtube", "v3", developerKey=self.runtime.credentials["google_api_key"]) # try to get channel id - search_results = youtube.search().list(q=channel, type='channel', order='relevance', part='id').execute() - channel_id = search_results['items'][0]['id']['channelId'] + search_results = youtube.search().list(q=channel, type="channel", order="relevance", part="id").execute() + channel_id = search_results["items"][0]["id"]["channelId"] start_date, end_date = time_range - start_date = datetime.strptime(start_date, '%Y-%m-%d').strftime('%Y-%m-%dT%H:%M:%SZ') - end_date = datetime.strptime(end_date, '%Y-%m-%d').strftime('%Y-%m-%dT%H:%M:%SZ') + start_date = datetime.strptime(start_date, "%Y-%m-%d").strftime("%Y-%m-%dT%H:%M:%SZ") + end_date = datetime.strptime(end_date, "%Y-%m-%d").strftime("%Y-%m-%dT%H:%M:%SZ") # get videos - time_range_videos = youtube.search().list( - part='snippet', channelId=channel_id, order='date', type='video', - publishedAfter=start_date, - publishedBefore=end_date - ).execute() + time_range_videos = ( + youtube.search() + .list( + part="snippet", + channelId=channel_id, + order="date", + type="video", + publishedAfter=start_date, + publishedBefore=end_date, + ) + .execute() + ) def extract_video_data(video_list): data = [] - for video in video_list['items']: - video_id = video['id']['videoId'] - video_info = youtube.videos().list(part='snippet,statistics', id=video_id).execute() - title = video_info['items'][0]['snippet']['title'] - views = video_info['items'][0]['statistics']['viewCount'] - data.append({'Title': title, 'Views': views}) + for video in video_list["items"]: + video_id = video["id"]["videoId"] + video_info = youtube.videos().list(part="snippet,statistics", id=video_id).execute() + title = video_info["items"][0]["snippet"]["title"] + views = video_info["items"][0]["statistics"]["viewCount"] + data.append({"Title": title, "Views": views}) return data summary = extract_video_data(time_range_videos) - + return self.create_text_message(str(summary)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/youtube/youtube.py b/api/core/tools/provider/builtin/youtube/youtube.py index 83a4fccb3247d0..07e430bcbf27e1 100644 --- a/api/core/tools/provider/builtin/youtube/youtube.py +++ b/api/core/tools/provider/builtin/youtube/youtube.py @@ -11,13 +11,12 @@ def _validate_credentials(self, credentials: dict) -> None: "credentials": credentials, } ).invoke( - user_id='', + user_id="", tool_parameters={ - "channel": "TOKYO GIRLS COLLECTION", + "channel": "UC2JZCsZSOudXA08cMMRCL9g", "start_date": "2020-01-01", "end_date": "2024-12-31", }, ) except Exception as e: raise ToolProviderCredentialValidationError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index bcf41c90edbfcd..955a0add3b4513 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -13,43 +13,44 @@ from core.tools.provider.tool_provider import ToolProviderController from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.tool import Tool -from core.tools.utils.tool_parameter_converter import ToolParameterConverter from core.tools.utils.yaml_utils import load_yaml_file class BuiltinToolProviderController(ToolProviderController): def __init__(self, **data: Any) -> None: - if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP: + if self.provider_type in {ToolProviderType.API, ToolProviderType.APP}: super().__init__(**data) return - + # load provider yaml - provider = self.__class__.__module__.split('.')[-1] - yaml_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.yaml') + provider = self.__class__.__module__.split(".")[-1] + yaml_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, f"{provider}.yaml") try: provider_yaml = load_yaml_file(yaml_path, ignore_error=False) except Exception as e: - raise ToolProviderNotFoundError(f'can not load provider yaml for {provider}: {e}') + raise ToolProviderNotFoundError(f"can not load provider yaml for {provider}: {e}") - if 'credentials_for_provider' in provider_yaml and provider_yaml['credentials_for_provider'] is not None: + if "credentials_for_provider" in provider_yaml and provider_yaml["credentials_for_provider"] is not None: # set credentials name - for credential_name in provider_yaml['credentials_for_provider']: - provider_yaml['credentials_for_provider'][credential_name]['name'] = credential_name + for credential_name in provider_yaml["credentials_for_provider"]: + provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name - super().__init__(**{ - 'identity': provider_yaml['identity'], - 'credentials_schema': provider_yaml.get('credentials_for_provider', None), - }) + super().__init__( + **{ + "identity": provider_yaml["identity"], + "credentials_schema": provider_yaml.get("credentials_for_provider", None), + } + ) def _get_builtin_tools(self) -> list[Tool]: """ - returns a list of tools that the provider can provide + returns a list of tools that the provider can provide - :return: list of tools + :return: list of tools """ if self.tools: return self.tools - + provider = self.identity.name tool_path = path.join(path.dirname(path.realpath(__file__)), "builtin", provider, "tools") # get all the yaml files in the tool path @@ -62,155 +63,159 @@ def _get_builtin_tools(self) -> list[Tool]: # get tool class, import the module assistant_tool_class = load_single_subclass_from_source( - module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}', - script_path=path.join(path.dirname(path.realpath(__file__)), - 'builtin', provider, 'tools', f'{tool_name}.py'), - parent_type=BuiltinTool) + module_name=f"core.tools.provider.builtin.{provider}.tools.{tool_name}", + script_path=path.join( + path.dirname(path.realpath(__file__)), "builtin", provider, "tools", f"{tool_name}.py" + ), + parent_type=BuiltinTool, + ) tool["identity"]["provider"] = provider tools.append(assistant_tool_class(**tool)) self.tools = tools return tools - + def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: """ - returns the credentials schema of the provider + returns the credentials schema of the provider - :return: the credentials schema + :return: the credentials schema """ if not self.credentials_schema: return {} - + return self.credentials_schema.copy() def get_tools(self) -> list[Tool]: """ - returns a list of tools that the provider can provide + returns a list of tools that the provider can provide - :return: list of tools + :return: list of tools """ return self._get_builtin_tools() - + def get_tool(self, tool_name: str) -> Tool: """ - returns the tool that the provider can provide + returns the tool that the provider can provide """ return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) def get_parameters(self, tool_name: str) -> list[ToolParameter]: """ - returns the parameters of the tool + returns the parameters of the tool - :param tool_name: the name of the tool, defined in `get_tools` - :return: list of parameters + :param tool_name: the name of the tool, defined in `get_tools` + :return: list of parameters """ tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) if tool is None: - raise ToolNotFoundError(f'tool {tool_name} not found') + raise ToolNotFoundError(f"tool {tool_name} not found") return tool.parameters @property def need_credentials(self) -> bool: """ - returns whether the provider needs credentials + returns whether the provider needs credentials - :return: whether the provider needs credentials + :return: whether the provider needs credentials """ - return self.credentials_schema is not None and \ - len(self.credentials_schema) != 0 + return self.credentials_schema is not None and len(self.credentials_schema) != 0 @property def provider_type(self) -> ToolProviderType: """ - returns the type of the provider + returns the type of the provider - :return: type of the provider + :return: type of the provider """ return ToolProviderType.BUILT_IN @property def tool_labels(self) -> list[str]: """ - returns the labels of the provider + returns the labels of the provider - :return: labels of the provider + :return: labels of the provider """ label_enums = self._get_tool_labels() return [default_tool_label_dict[label].name for label in label_enums] def _get_tool_labels(self) -> list[ToolLabelEnum]: """ - returns the labels of the provider + returns the labels of the provider """ return self.identity.tags or [] def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None: """ - validate the parameters of the tool and set the default value if needed + validate the parameters of the tool and set the default value if needed - :param tool_name: the name of the tool, defined in `get_tools` - :param tool_parameters: the parameters of the tool + :param tool_name: the name of the tool, defined in `get_tools` + :param tool_parameters: the parameters of the tool """ tool_parameters_schema = self.get_parameters(tool_name) - + tool_parameters_need_to_validate: dict[str, ToolParameter] = {} for parameter in tool_parameters_schema: tool_parameters_need_to_validate[parameter.name] = parameter for parameter in tool_parameters: if parameter not in tool_parameters_need_to_validate: - raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}') - + raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}") + # check type parameter_schema = tool_parameters_need_to_validate[parameter] if parameter_schema.type == ToolParameter.ToolParameterType.STRING: if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f'parameter {parameter} should be string') - + raise ToolParameterValidationError(f"parameter {parameter} should be string") + elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: if not isinstance(tool_parameters[parameter], int | float): - raise ToolParameterValidationError(f'parameter {parameter} should be number') - + raise ToolParameterValidationError(f"parameter {parameter} should be number") + if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: - raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}') - + raise ToolParameterValidationError( + f"parameter {parameter} should be greater than {parameter_schema.min}" + ) + if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: - raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}') - + raise ToolParameterValidationError( + f"parameter {parameter} should be less than {parameter_schema.max}" + ) + elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: if not isinstance(tool_parameters[parameter], bool): - raise ToolParameterValidationError(f'parameter {parameter} should be boolean') - + raise ToolParameterValidationError(f"parameter {parameter} should be boolean") + elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT: if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f'parameter {parameter} should be string') - + raise ToolParameterValidationError(f"parameter {parameter} should be string") + options = parameter_schema.options if not isinstance(options, list): - raise ToolParameterValidationError(f'parameter {parameter} options should be list') - + raise ToolParameterValidationError(f"parameter {parameter} options should be list") + if tool_parameters[parameter] not in [x.value for x in options]: - raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}') - + raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}") + tool_parameters_need_to_validate.pop(parameter) for parameter in tool_parameters_need_to_validate: parameter_schema = tool_parameters_need_to_validate[parameter] if parameter_schema.required: - raise ToolParameterValidationError(f'parameter {parameter} is required') - + raise ToolParameterValidationError(f"parameter {parameter} is required") + # the parameter is not set currently, set the default value if needed if parameter_schema.default is not None: - default_value = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default, - parameter_schema.type) + default_value = parameter_schema.type.cast_value(parameter_schema.default) tool_parameters[parameter] = default_value - + def validate_credentials(self, credentials: dict[str, Any]) -> None: """ - validate the credentials of the provider + validate the credentials of the provider - :param tool_name: the name of the tool, defined in `get_tools` - :param credentials: the credentials of the tool + :param tool_name: the name of the tool, defined in `get_tools` + :param credentials: the credentials of the tool """ # validate credentials format self.validate_credentials_format(credentials) @@ -221,9 +226,9 @@ def validate_credentials(self, credentials: dict[str, Any]) -> None: @abstractmethod def _validate_credentials(self, credentials: dict[str, Any]) -> None: """ - validate the credentials of the provider + validate the credentials of the provider - :param tool_name: the name of the tool, defined in `get_tools` - :param credentials: the credentials of the tool + :param tool_name: the name of the tool, defined in `get_tools` + :param credentials: the credentials of the tool """ pass diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py index ef1ace9c7c31e7..bc05a11562b717 100644 --- a/api/core/tools/provider/tool_provider.py +++ b/api/core/tools/provider/tool_provider.py @@ -11,7 +11,6 @@ ) from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError from core.tools.tool.tool import Tool -from core.tools.utils.tool_parameter_converter import ToolParameterConverter class ToolProviderController(BaseModel, ABC): @@ -21,162 +20,175 @@ class ToolProviderController(BaseModel, ABC): def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: """ - returns the credentials schema of the provider + returns the credentials schema of the provider - :return: the credentials schema + :return: the credentials schema """ return self.credentials_schema.copy() - + @abstractmethod def get_tools(self) -> list[Tool]: """ - returns a list of tools that the provider can provide + returns a list of tools that the provider can provide - :return: list of tools + :return: list of tools """ pass @abstractmethod def get_tool(self, tool_name: str) -> Tool: """ - returns a tool that the provider can provide + returns a tool that the provider can provide - :return: tool + :return: tool """ pass def get_parameters(self, tool_name: str) -> list[ToolParameter]: """ - returns the parameters of the tool + returns the parameters of the tool - :param tool_name: the name of the tool, defined in `get_tools` - :return: list of parameters + :param tool_name: the name of the tool, defined in `get_tools` + :return: list of parameters """ tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) if tool is None: - raise ToolNotFoundError(f'tool {tool_name} not found') + raise ToolNotFoundError(f"tool {tool_name} not found") return tool.parameters @property def provider_type(self) -> ToolProviderType: """ - returns the type of the provider + returns the type of the provider - :return: type of the provider + :return: type of the provider """ return ToolProviderType.BUILT_IN def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None: """ - validate the parameters of the tool and set the default value if needed + validate the parameters of the tool and set the default value if needed - :param tool_name: the name of the tool, defined in `get_tools` - :param tool_parameters: the parameters of the tool + :param tool_name: the name of the tool, defined in `get_tools` + :param tool_parameters: the parameters of the tool """ tool_parameters_schema = self.get_parameters(tool_name) - + tool_parameters_need_to_validate: dict[str, ToolParameter] = {} for parameter in tool_parameters_schema: tool_parameters_need_to_validate[parameter.name] = parameter for parameter in tool_parameters: if parameter not in tool_parameters_need_to_validate: - raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}') - + raise ToolParameterValidationError(f"parameter {parameter} not found in tool {tool_name}") + # check type parameter_schema = tool_parameters_need_to_validate[parameter] if parameter_schema.type == ToolParameter.ToolParameterType.STRING: if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f'parameter {parameter} should be string') - + raise ToolParameterValidationError(f"parameter {parameter} should be string") + elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: if not isinstance(tool_parameters[parameter], int | float): - raise ToolParameterValidationError(f'parameter {parameter} should be number') - + raise ToolParameterValidationError(f"parameter {parameter} should be number") + if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min: - raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}') - + raise ToolParameterValidationError( + f"parameter {parameter} should be greater than {parameter_schema.min}" + ) + if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max: - raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}') - + raise ToolParameterValidationError( + f"parameter {parameter} should be less than {parameter_schema.max}" + ) + elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: if not isinstance(tool_parameters[parameter], bool): - raise ToolParameterValidationError(f'parameter {parameter} should be boolean') - + raise ToolParameterValidationError(f"parameter {parameter} should be boolean") + elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT: if not isinstance(tool_parameters[parameter], str): - raise ToolParameterValidationError(f'parameter {parameter} should be string') - + raise ToolParameterValidationError(f"parameter {parameter} should be string") + options = parameter_schema.options if not isinstance(options, list): - raise ToolParameterValidationError(f'parameter {parameter} options should be list') - + raise ToolParameterValidationError(f"parameter {parameter} options should be list") + if tool_parameters[parameter] not in [x.value for x in options]: - raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}') - + raise ToolParameterValidationError(f"parameter {parameter} should be one of {options}") + tool_parameters_need_to_validate.pop(parameter) for parameter in tool_parameters_need_to_validate: parameter_schema = tool_parameters_need_to_validate[parameter] if parameter_schema.required: - raise ToolParameterValidationError(f'parameter {parameter} is required') - + raise ToolParameterValidationError(f"parameter {parameter} is required") + # the parameter is not set currently, set the default value if needed if parameter_schema.default is not None: - tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default, - parameter_schema.type) + tool_parameters[parameter] = parameter_schema.type.cast_value(parameter_schema.default) def validate_credentials_format(self, credentials: dict[str, Any]) -> None: """ - validate the format of the credentials of the provider and set the default value if needed + validate the format of the credentials of the provider and set the default value if needed - :param credentials: the credentials of the tool + :param credentials: the credentials of the tool """ credentials_schema = self.credentials_schema if credentials_schema is None: return - + credentials_need_to_validate: dict[str, ToolProviderCredentials] = {} for credential_name in credentials_schema: credentials_need_to_validate[credential_name] = credentials_schema[credential_name] for credential_name in credentials: if credential_name not in credentials_need_to_validate: - raise ToolProviderCredentialValidationError(f'credential {credential_name} not found in provider {self.identity.name}') - + raise ToolProviderCredentialValidationError( + f"credential {credential_name} not found in provider {self.identity.name}" + ) + # check type credential_schema = credentials_need_to_validate[credential_name] - if credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \ - credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT: + if not credential_schema.required and credentials[credential_name] is None: + continue + + if credential_schema.type in { + ToolProviderCredentials.CredentialsType.SECRET_INPUT, + ToolProviderCredentials.CredentialsType.TEXT_INPUT, + }: if not isinstance(credentials[credential_name], str): - raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string') - + raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") + elif credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT: if not isinstance(credentials[credential_name], str): - raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string') - + raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") + options = credential_schema.options if not isinstance(options, list): - raise ToolProviderCredentialValidationError(f'credential {credential_name} options should be list') - + raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list") + if credentials[credential_name] not in [x.value for x in options]: - raise ToolProviderCredentialValidationError(f'credential {credential_name} should be one of {options}') - + raise ToolProviderCredentialValidationError( + f"credential {credential_name} should be one of {options}" + ) + credentials_need_to_validate.pop(credential_name) for credential_name in credentials_need_to_validate: credential_schema = credentials_need_to_validate[credential_name] if credential_schema.required: - raise ToolProviderCredentialValidationError(f'credential {credential_name} is required') - + raise ToolProviderCredentialValidationError(f"credential {credential_name} is required") + # the credential is not set currently, set the default value if needed if credential_schema.default is not None: default_value = credential_schema.default # parse default value into the correct type - if credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \ - credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT or \ - credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT: + if credential_schema.type in { + ToolProviderCredentials.CredentialsType.SECRET_INPUT, + ToolProviderCredentials.CredentialsType.TEXT_INPUT, + ToolProviderCredentials.CredentialsType.SELECT, + }: default_value = str(default_value) credentials[credential_name] = default_value - \ No newline at end of file diff --git a/api/core/tools/provider/workflow_tool_provider.py b/api/core/tools/provider/workflow_tool_provider.py index f7911fea1db18d..5656dd09ab8c94 100644 --- a/api/core/tools/provider/workflow_tool_provider.py +++ b/api/core/tools/provider/workflow_tool_provider.py @@ -1,6 +1,6 @@ from typing import Optional -from core.app.app_config.entities import VariableEntity +from core.app.app_config.entities import VariableEntityType from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( @@ -18,35 +18,40 @@ from models.tools import WorkflowToolProvider from models.workflow import Workflow +VARIABLE_TO_PARAMETER_TYPE_MAPPING = { + VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING, + VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING, + VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT, + VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER, + VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE, + VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES, +} + class WorkflowToolProviderController(ToolProviderController): provider_id: str @classmethod - def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderController': + def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController": app = db_provider.app if not app: - raise ValueError('app not found') - - controller = WorkflowToolProviderController(**{ - 'identity': { - 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '', - 'name': db_provider.label, - 'label': { - 'en_US': db_provider.label, - 'zh_Hans': db_provider.label - }, - 'description': { - 'en_US': db_provider.description, - 'zh_Hans': db_provider.description + raise ValueError("app not found") + + controller = WorkflowToolProviderController.model_validate( + { + "identity": { + "author": db_provider.user.name if db_provider.user_id and db_provider.user else "", + "name": db_provider.label, + "label": {"en_US": db_provider.label, "zh_Hans": db_provider.label}, + "description": {"en_US": db_provider.description, "zh_Hans": db_provider.description}, + "icon": db_provider.icon, }, - 'icon': db_provider.icon, - }, - 'credentials_schema': {}, - 'provider_id': db_provider.id or '', - }) - + "credentials_schema": {}, + "provider_id": db_provider.id or "", + } + ) + # init tools controller.tools = [controller._get_db_provider_tool(db_provider, app)] @@ -56,33 +61,31 @@ def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderCont @property def provider_type(self) -> ToolProviderType: return ToolProviderType.WORKFLOW - + def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool: """ - get db provider tool - :param db_provider: the db provider - :param app: the app - :return: the tool + get db provider tool + :param db_provider: the db provider + :param app: the app + :return: the tool """ - workflow: Workflow = db.session.query(Workflow).filter( - Workflow.app_id == db_provider.app_id, - Workflow.version == db_provider.version - ).first() + workflow = ( + db.session.query(Workflow) + .filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version) + .first() + ) if not workflow: - raise ValueError('workflow not found') + raise ValueError("workflow not found") # fetch start node - graph: dict = workflow.graph_dict - features_dict: dict = workflow.features_dict - features = WorkflowAppConfigManager.convert_features( - config_dict=features_dict, - app_mode=AppMode.WORKFLOW - ) + graph = workflow.graph_dict + features_dict = workflow.features_dict + features = WorkflowAppConfigManager.convert_features(config_dict=features_dict, app_mode=AppMode.WORKFLOW) parameters = db_provider.parameter_configurations variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph) - def fetch_workflow_variable(variable_name: str) -> VariableEntity: + def fetch_workflow_variable(variable_name: str): return next(filter(lambda x: x.variable == variable_name, variables), None) user = db_provider.user @@ -93,132 +96,100 @@ def fetch_workflow_variable(variable_name: str) -> VariableEntity: if variable: parameter_type = None options = None - if variable.type in [ - VariableEntity.Type.TEXT_INPUT, - VariableEntity.Type.PARAGRAPH, - ]: - parameter_type = ToolParameter.ToolParameterType.STRING - elif variable.type in [ - VariableEntity.Type.SELECT - ]: - parameter_type = ToolParameter.ToolParameterType.SELECT - elif variable.type in [ - VariableEntity.Type.NUMBER - ]: - parameter_type = ToolParameter.ToolParameterType.NUMBER - else: - raise ValueError(f'unsupported variable type {variable.type}') - - if variable.type == VariableEntity.Type.SELECT and variable.options: + if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING: + raise ValueError(f"unsupported variable type {variable.type}") + parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type] + + if variable.type == VariableEntityType.SELECT and variable.options: options = [ - ToolParameterOption( - value=option, - label=I18nObject( - en_US=option, - zh_Hans=option - ) - ) for option in variable.options + ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) + for option in variable.options ] workflow_tool_parameters.append( ToolParameter( name=parameter.name, - label=I18nObject( - en_US=variable.label, - zh_Hans=variable.label - ), - human_description=I18nObject( - en_US=parameter.description, - zh_Hans=parameter.description - ), + label=I18nObject(en_US=variable.label, zh_Hans=variable.label), + human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description), type=parameter_type, form=parameter.form, llm_description=parameter.description, required=variable.required, options=options, - default=variable.default ) ) elif features.file_upload: workflow_tool_parameters.append( ToolParameter( name=parameter.name, - label=I18nObject( - en_US=parameter.name, - zh_Hans=parameter.name - ), - human_description=I18nObject( - en_US=parameter.description, - zh_Hans=parameter.description - ), - type=ToolParameter.ToolParameterType.FILE, + label=I18nObject(en_US=parameter.name, zh_Hans=parameter.name), + human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description), + type=ToolParameter.ToolParameterType.SYSTEM_FILES, llm_description=parameter.description, required=False, form=parameter.form, ) ) else: - raise ValueError('variable not found') + raise ValueError("variable not found") return WorkflowTool( identity=ToolIdentity( - author=user.name if user else '', + author=user.name if user else "", name=db_provider.name, - label=I18nObject( - en_US=db_provider.label, - zh_Hans=db_provider.label - ), + label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label), provider=self.provider_id, icon=db_provider.icon, ), description=ToolDescription( - human=I18nObject( - en_US=db_provider.description, - zh_Hans=db_provider.description - ), + human=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description), llm=db_provider.description, ), parameters=workflow_tool_parameters, is_team_authorization=True, workflow_app_id=app.id, workflow_entities={ - 'app': app, - 'workflow': workflow, + "app": app, + "workflow": workflow, }, version=db_provider.version, workflow_call_depth=0, - label=db_provider.label + label=db_provider.label, ) def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]: """ - fetch tools from database + fetch tools from database - :param user_id: the user id - :param tenant_id: the tenant id - :return: the tools + :param user_id: the user id + :param tenant_id: the tenant id + :return: the tools """ if self.tools is not None: return self.tools - - db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.app_id == self.provider_id, - ).first() + + db_providers: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.app_id == self.provider_id, + ) + .first() + ) if not db_providers: return [] - + self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)] return self.tools - + def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: """ - get tool by name + get tool by name - :param tool_name: the name of the tool - :return: the tool + :param tool_name: the name of the tool + :return: the tool """ if self.tools is None: return None @@ -226,5 +197,5 @@ def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: for tool in self.tools: if tool.identity.name == tool_name: return tool - + return None diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index 69e3dfa0612e8a..c779d704c368e5 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -5,15 +5,15 @@ import httpx -import core.helper.ssrf_proxy as ssrf_proxy +from core.helper import ssrf_proxy from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError from core.tools.tool.tool import Tool API_TOOL_DEFAULT_TIMEOUT = ( - int(getenv('API_TOOL_DEFAULT_CONNECT_TIMEOUT', '10')), - int(getenv('API_TOOL_DEFAULT_READ_TIMEOUT', '60')) + int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), + int(getenv("API_TOOL_DEFAULT_READ_TIMEOUT", "60")), ) @@ -24,31 +24,32 @@ class ApiTool(Tool): Api tool """ - def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool': + def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": """ - fork a new tool with meta data + fork a new tool with meta data - :param meta: the meta data of a tool call processing, tenant_id is required - :return: the new tool + :param meta: the meta data of a tool call processing, tenant_id is required + :return: the new tool """ return self.__class__( identity=self.identity.model_copy() if self.identity else None, parameters=self.parameters.copy() if self.parameters else None, description=self.description.model_copy() if self.description else None, api_bundle=self.api_bundle.model_copy() if self.api_bundle else None, - runtime=Tool.Runtime(**runtime) + runtime=Tool.Runtime(**runtime), ) - def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], - format_only: bool = False) -> str: + def validate_credentials( + self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False + ) -> str: """ - validate the credentials for Api tool + validate the credentials for Api tool """ - # assemble validate request and request parameters + # assemble validate request and request parameters headers = self.assembling_request(parameters) if format_only: - return '' + return "" response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters) # validate response @@ -61,30 +62,30 @@ def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]: headers = {} credentials = self.runtime.credentials or {} - if 'auth_type' not in credentials: - raise ToolProviderCredentialValidationError('Missing auth_type') + if "auth_type" not in credentials: + raise ToolProviderCredentialValidationError("Missing auth_type") - if credentials['auth_type'] == 'api_key': - api_key_header = 'api_key' + if credentials["auth_type"] == "api_key": + api_key_header = "api_key" - if 'api_key_header' in credentials: - api_key_header = credentials['api_key_header'] + if "api_key_header" in credentials: + api_key_header = credentials["api_key_header"] - if 'api_key_value' not in credentials: - raise ToolProviderCredentialValidationError('Missing api_key_value') - elif not isinstance(credentials['api_key_value'], str): - raise ToolProviderCredentialValidationError('api_key_value must be a string') + if "api_key_value" not in credentials: + raise ToolProviderCredentialValidationError("Missing api_key_value") + elif not isinstance(credentials["api_key_value"], str): + raise ToolProviderCredentialValidationError("api_key_value must be a string") - if 'api_key_header_prefix' in credentials: - api_key_header_prefix = credentials['api_key_header_prefix'] - if api_key_header_prefix == 'basic' and credentials['api_key_value']: - credentials['api_key_value'] = f'Basic {credentials["api_key_value"]}' - elif api_key_header_prefix == 'bearer' and credentials['api_key_value']: - credentials['api_key_value'] = f'Bearer {credentials["api_key_value"]}' - elif api_key_header_prefix == 'custom': + if "api_key_header_prefix" in credentials: + api_key_header_prefix = credentials["api_key_header_prefix"] + if api_key_header_prefix == "basic" and credentials["api_key_value"]: + credentials["api_key_value"] = f'Basic {credentials["api_key_value"]}' + elif api_key_header_prefix == "bearer" and credentials["api_key_value"]: + credentials["api_key_value"] = f'Bearer {credentials["api_key_value"]}' + elif api_key_header_prefix == "custom": pass - headers[api_key_header] = credentials['api_key_value'] + headers[api_key_header] = credentials["api_key_value"] needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required] for parameter in needed_parameters: @@ -98,13 +99,13 @@ def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]: def validate_and_parse_response(self, response: httpx.Response) -> str: """ - validate the response + validate the response """ if isinstance(response, httpx.Response): if response.status_code >= 400: raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}") if not response.content: - return 'Empty response from the tool, please check your parameters and try again.' + return "Empty response from the tool, please check your parameters and try again." try: response = response.json() try: @@ -114,21 +115,22 @@ def validate_and_parse_response(self, response: httpx.Response) -> str: except Exception as e: return response.text else: - raise ValueError(f'Invalid response type {type(response)}') + raise ValueError(f"Invalid response type {type(response)}") @staticmethod def get_parameter_value(parameter, parameters): - if parameter['name'] in parameters: - return parameters[parameter['name']] - elif parameter.get('required', False): + if parameter["name"] in parameters: + return parameters[parameter["name"]] + elif parameter.get("required", False): raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}") else: - return (parameter.get('schema', {}) or {}).get('default', '') + return (parameter.get("schema", {}) or {}).get("default", "") - def do_http_request(self, url: str, method: str, headers: dict[str, Any], - parameters: dict[str, Any]) -> httpx.Response: + def do_http_request( + self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any] + ) -> httpx.Response: """ - do http request depending on api bundle + do http request depending on api bundle """ method = method.lower() @@ -138,29 +140,30 @@ def do_http_request(self, url: str, method: str, headers: dict[str, Any], cookies = {} # check parameters - for parameter in self.api_bundle.openapi.get('parameters', []): + for parameter in self.api_bundle.openapi.get("parameters", []): value = self.get_parameter_value(parameter, parameters) - if parameter['in'] == 'path': - path_params[parameter['name']] = value + if parameter["in"] == "path": + path_params[parameter["name"]] = value - elif parameter['in'] == 'query': - params[parameter['name']] = value + elif parameter["in"] == "query": + if value != "": + params[parameter["name"]] = value - elif parameter['in'] == 'cookie': - cookies[parameter['name']] = value + elif parameter["in"] == "cookie": + cookies[parameter["name"]] = value - elif parameter['in'] == 'header': - headers[parameter['name']] = value + elif parameter["in"] == "header": + headers[parameter["name"]] = value # check if there is a request body and handle it - if 'requestBody' in self.api_bundle.openapi and self.api_bundle.openapi['requestBody'] is not None: + if "requestBody" in self.api_bundle.openapi and self.api_bundle.openapi["requestBody"] is not None: # handle json request body - if 'content' in self.api_bundle.openapi['requestBody']: - for content_type in self.api_bundle.openapi['requestBody']['content']: - headers['Content-Type'] = content_type - body_schema = self.api_bundle.openapi['requestBody']['content'][content_type]['schema'] - required = body_schema.get('required', []) - properties = body_schema.get('properties', {}) + if "content" in self.api_bundle.openapi["requestBody"]: + for content_type in self.api_bundle.openapi["requestBody"]["content"]: + headers["Content-Type"] = content_type + body_schema = self.api_bundle.openapi["requestBody"]["content"][content_type]["schema"] + required = body_schema.get("required", []) + properties = body_schema.get("properties", {}) for name, property in properties.items(): if name in parameters: # convert type @@ -169,63 +172,71 @@ def do_http_request(self, url: str, method: str, headers: dict[str, Any], raise ToolParameterValidationError( f"Missing required parameter {name} in operation {self.api_bundle.operation_id}" ) - elif 'default' in property: - body[name] = property['default'] + elif "default" in property: + body[name] = property["default"] else: body[name] = None break # replace path parameters for name, value in path_params.items(): - url = url.replace(f'{{{name}}}', f'{value}') + url = url.replace(f"{{{name}}}", f"{value}") # parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored - if 'Content-Type' in headers: - if headers['Content-Type'] == 'application/json': + if "Content-Type" in headers: + if headers["Content-Type"] == "application/json": body = json.dumps(body) - elif headers['Content-Type'] == 'application/x-www-form-urlencoded': + elif headers["Content-Type"] == "application/x-www-form-urlencoded": body = urlencode(body) else: body = body - if method in ('get', 'head', 'post', 'put', 'delete', 'patch'): - response = getattr(ssrf_proxy, method)(url, params=params, headers=headers, cookies=cookies, data=body, - timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True) + if method in {"get", "head", "post", "put", "delete", "patch"}: + response = getattr(ssrf_proxy, method)( + url, + params=params, + headers=headers, + cookies=cookies, + data=body, + timeout=API_TOOL_DEFAULT_TIMEOUT, + follow_redirects=True, + ) return response else: - raise ValueError(f'Invalid http method {self.method}') + raise ValueError(f"Invalid http method {self.method}") - def _convert_body_property_any_of(self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], - max_recursive=10) -> Any: + def _convert_body_property_any_of( + self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10 + ) -> Any: if max_recursive <= 0: raise Exception("Max recursion depth reached") for option in any_of or []: try: - if 'type' in option: + if "type" in option: # Attempt to convert the value based on the type. - if option['type'] == 'integer' or option['type'] == 'int': + if option["type"] == "integer" or option["type"] == "int": return int(value) - elif option['type'] == 'number': - if '.' in str(value): + elif option["type"] == "number": + if "." in str(value): return float(value) else: return int(value) - elif option['type'] == 'string': + elif option["type"] == "string": return str(value) - elif option['type'] == 'boolean': - if str(value).lower() in ['true', '1']: + elif option["type"] == "boolean": + if str(value).lower() in {"true", "1"}: return True - elif str(value).lower() in ['false', '0']: + elif str(value).lower() in {"false", "0"}: return False else: continue # Not a boolean, try next option - elif option['type'] == 'null' and not value: + elif option["type"] == "null" and not value: return None else: continue # Unsupported type, try next option - elif 'anyOf' in option and isinstance(option['anyOf'], list): + elif "anyOf" in option and isinstance(option["anyOf"], list): # Recursive call to handle nested anyOf - return self._convert_body_property_any_of(property, value, option['anyOf'], max_recursive - 1) + return self._convert_body_property_any_of(property, value, option["anyOf"], max_recursive - 1) except ValueError: continue # Conversion failed, try next option # If no option succeeded, you might want to return the value as is or raise an error @@ -233,23 +244,23 @@ def _convert_body_property_any_of(self, property: dict[str, Any], value: Any, an def _convert_body_property_type(self, property: dict[str, Any], value: Any) -> Any: try: - if 'type' in property: - if property['type'] == 'integer' or property['type'] == 'int': + if "type" in property: + if property["type"] == "integer" or property["type"] == "int": return int(value) - elif property['type'] == 'number': + elif property["type"] == "number": # check if it is a float - if '.' in str(value): + if "." in str(value): return float(value) else: return int(value) - elif property['type'] == 'string': + elif property["type"] == "string": return str(value) - elif property['type'] == 'boolean': + elif property["type"] == "boolean": return bool(value) - elif property['type'] == 'null': + elif property["type"] == "null": if value is None: return None - elif property['type'] == 'object' or property['type'] == 'array': + elif property["type"] == "object" or property["type"] == "array": if isinstance(value, str): try: # an array str like '[1,2]' also can convert to list [1,2] through json.loads @@ -264,8 +275,8 @@ def _convert_body_property_type(self, property: dict[str, Any], value: Any) -> A return value else: raise ValueError(f"Invalid type {property['type']} for property {property}") - elif 'anyOf' in property and isinstance(property['anyOf'], list): - return self._convert_body_property_any_of(property, value, property['anyOf']) + elif "anyOf" in property and isinstance(property["anyOf"], list): + return self._convert_body_property_any_of(property, value, property["anyOf"]) except ValueError as e: return value diff --git a/api/core/tools/tool/builtin_tool.py b/api/core/tools/tool/builtin_tool.py index ad7a88838b1a24..e2a81ed0a36edd 100644 --- a/api/core/tools/tool/builtin_tool.py +++ b/api/core/tools/tool/builtin_tool.py @@ -1,3 +1,4 @@ +from typing import Optional from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage @@ -16,40 +17,38 @@ class BuiltinTool(Tool): """ - Builtin tool + Builtin tool - :param meta: the meta data of a tool call processing + :param meta: the meta data of a tool call processing """ - def invoke_model( - self, user_id: str, prompt_messages: list[PromptMessage], stop: list[str] - ) -> LLMResult: + def invoke_model(self, user_id: str, prompt_messages: list[PromptMessage], stop: list[str]) -> LLMResult: """ - invoke model + invoke model - :param model_config: the model config - :param prompt_messages: the prompt messages - :param stop: the stop words - :return: the model result + :param model_config: the model config + :param prompt_messages: the prompt messages + :param stop: the stop words + :return: the model result """ # invoke model return ModelInvocationUtils.invoke( user_id=user_id, tenant_id=self.runtime.tenant_id, - tool_type='builtin', + tool_type="builtin", tool_name=self.identity.name, prompt_messages=prompt_messages, ) - + def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.BUILT_IN - + def get_max_tokens(self) -> int: """ - get max tokens + get max tokens - :param model_config: the model config - :return: the max tokens + :param model_config: the model config + :return: the max tokens """ return ModelInvocationUtils.get_max_llm_context_tokens( tenant_id=self.runtime.tenant_id, @@ -57,39 +56,34 @@ def get_max_tokens(self) -> int: def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int: """ - get prompt tokens + get prompt tokens - :param prompt_messages: the prompt messages - :return: the tokens + :param prompt_messages: the prompt messages + :return: the tokens """ - return ModelInvocationUtils.calculate_tokens( - tenant_id=self.runtime.tenant_id, - prompt_messages=prompt_messages - ) + return ModelInvocationUtils.calculate_tokens(tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages) def summary(self, user_id: str, content: str) -> str: max_tokens = self.get_max_tokens() - if self.get_prompt_tokens(prompt_messages=[ - UserPromptMessage(content=content) - ]) < max_tokens * 0.6: + if self.get_prompt_tokens(prompt_messages=[UserPromptMessage(content=content)]) < max_tokens * 0.6: return content - + def get_prompt_tokens(content: str) -> int: - return self.get_prompt_tokens(prompt_messages=[ - SystemPromptMessage(content=_SUMMARY_PROMPT), - UserPromptMessage(content=content) - ]) - + return self.get_prompt_tokens( + prompt_messages=[SystemPromptMessage(content=_SUMMARY_PROMPT), UserPromptMessage(content=content)] + ) + def summarize(content: str) -> str: - summary = self.invoke_model(user_id=user_id, prompt_messages=[ - SystemPromptMessage(content=_SUMMARY_PROMPT), - UserPromptMessage(content=content) - ], stop=[]) + summary = self.invoke_model( + user_id=user_id, + prompt_messages=[SystemPromptMessage(content=_SUMMARY_PROMPT), UserPromptMessage(content=content)], + stop=[], + ) return summary.message.content - lines = content.split('\n') + lines = content.split("\n") new_lines = [] # split long line into multiple lines for i in range(len(lines)): @@ -100,8 +94,8 @@ def summarize(content: str) -> str: new_lines.append(line) elif get_prompt_tokens(line) > max_tokens * 0.7: while get_prompt_tokens(line) > max_tokens * 0.7: - new_lines.append(line[:int(max_tokens * 0.5)]) - line = line[int(max_tokens * 0.5):] + new_lines.append(line[: int(max_tokens * 0.5)]) + line = line[int(max_tokens * 0.5) :] new_lines.append(line) else: new_lines.append(line) @@ -125,17 +119,15 @@ def summarize(content: str) -> str: summary = summarize(message) summaries.append(summary) - result = '\n'.join(summaries) + result = "\n".join(summaries) - if self.get_prompt_tokens(prompt_messages=[ - UserPromptMessage(content=result) - ]) > max_tokens * 0.7: + if self.get_prompt_tokens(prompt_messages=[UserPromptMessage(content=result)]) > max_tokens * 0.7: return self.summary(user_id=user_id, content=result) - + return result - - def get_url(self, url: str, user_agent: str = None) -> str: + + def get_url(self, url: str, user_agent: Optional[str] = None) -> str: """ - get url + get url """ - return get_url(url, user_agent=user_agent) \ No newline at end of file + return get_url(url, user_agent=user_agent) diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index 7cb7c033bbe9f1..ab7b40a2536db8 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -8,20 +8,17 @@ from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.retrieval_service import RetrievalService from core.rag.rerank.rerank_model import RerankModelRunner -from core.rag.retrieval.retrival_methods import RetrievalMethod +from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } @@ -31,6 +28,7 @@ class DatasetMultiRetrieverToolInput(BaseModel): class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): """Tool for querying multi dataset.""" + name: str = "dataset_" args_schema: type[BaseModel] = DatasetMultiRetrieverToolInput description: str = "dataset multi retriever and rerank. " @@ -38,27 +36,26 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): reranking_provider_name: str reranking_model_name: str - @classmethod def from_dataset(cls, dataset_ids: list[str], tenant_id: str, **kwargs): return cls( - name=f"dataset_{tenant_id.replace('-', '_')}", - tenant_id=tenant_id, - dataset_ids=dataset_ids, - **kwargs + name=f"dataset_{tenant_id.replace('-', '_')}", tenant_id=tenant_id, dataset_ids=dataset_ids, **kwargs ) def _run(self, query: str) -> str: threads = [] all_documents = [] for dataset_id in self.dataset_ids: - retrieval_thread = threading.Thread(target=self._retriever, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'all_documents': all_documents, - 'hit_callbacks': self.hit_callbacks - }) + retrieval_thread = threading.Thread( + target=self._retriever, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset_id, + "query": query, + "all_documents": all_documents, + "hit_callbacks": self.hit_callbacks, + }, + ) threads.append(retrieval_thread) retrieval_thread.start() for thread in threads: @@ -69,7 +66,7 @@ def _run(self, query: str) -> str: tenant_id=self.tenant_id, provider=self.reranking_provider_name, model_type=ModelType.RERANK, - model=self.reranking_model_name + model=self.reranking_model_name, ) rerank_runner = RerankModelRunner(rerank_model_instance) @@ -80,62 +77,61 @@ def _run(self, query: str) -> str: document_score_list = {} for item in all_documents: - if item.metadata.get('score'): - document_score_list[item.metadata['doc_id']] = item.metadata['score'] + if item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_context_list = [] - index_node_ids = [document.metadata['doc_id'] for document in all_documents] + index_node_ids = [document.metadata["doc_id"] for document in all_documents] segments = DocumentSegment.query.filter( DocumentSegment.dataset_id.in_(self.dataset_ids), DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == 'completed', + DocumentSegment.status == "completed", DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids) + DocumentSegment.index_node_id.in_(index_node_ids), ).all() if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted(segments, - key=lambda segment: index_node_id_to_position.get(segment.index_node_id, - float('inf'))) + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) for segment in sorted_segments: if segment.answer: - document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}') + document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") else: document_context_list.append(segment.get_sign_content()) if self.return_resource: context_list = [] resource_number = 1 for segment in sorted_segments: - dataset = Dataset.query.filter_by( - id=segment.dataset_id + dataset = Dataset.query.filter_by(id=segment.dataset_id).first() + document = Document.query.filter( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, ).first() - document = Document.query.filter(Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ).first() if dataset and document: source = { - 'position': resource_number, - 'dataset_id': dataset.id, - 'dataset_name': dataset.name, - 'document_id': document.id, - 'document_name': document.name, - 'data_source_type': document.data_source_type, - 'segment_id': segment.id, - 'retriever_from': self.retriever_from, - 'score': document_score_list.get(segment.index_node_id, None) + "position": resource_number, + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "data_source_type": document.data_source_type, + "segment_id": segment.id, + "retriever_from": self.retriever_from, + "score": document_score_list.get(segment.index_node_id, None), } - if self.retriever_from == 'dev': - source['hit_count'] = segment.hit_count - source['word_count'] = segment.word_count - source['segment_position'] = segment.position - source['index_node_hash'] = segment.index_node_hash + if self.retriever_from == "dev": + source["hit_count"] = segment.hit_count + source["word_count"] = segment.word_count + source["segment_position"] = segment.position + source["index_node_hash"] = segment.index_node_hash if segment.answer: - source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' + source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" else: - source['content'] = segment.content + source["content"] = segment.content context_list.append(source) resource_number += 1 @@ -144,13 +140,18 @@ def _run(self, query: str) -> str: return str("\n".join(document_context_list)) - def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: list, - hit_callbacks: list[DatasetIndexToolCallbackHandler]): + def _retriever( + self, + flask_app: Flask, + dataset_id: str, + query: str, + all_documents: list, + hit_callbacks: list[DatasetIndexToolCallbackHandler], + ): with flask_app.app_context(): - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == self.tenant_id, - Dataset.id == dataset_id - ).first() + dataset = ( + db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first() + ) if not dataset: return [] @@ -159,31 +160,34 @@ def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_document hit_callback.on_query(query, dataset.id) # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model + retrieval_model = dataset.retrieval_model or default_retrieval_model if dataset.indexing_technique == "economy": # use keyword table query - documents = RetrievalService.retrieve(retrival_method='keyword_search', - dataset_id=dataset.id, - query=query, - top_k=self.top_k - ) + documents = RetrievalService.retrieve( + retrieval_method="keyword_search", + dataset_id=dataset.id, + query=query, + top_k=retrieval_model.get("top_k") or 2, + ) if documents: all_documents.extend(documents) else: if self.top_k > 0: # retrieval source - documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], - dataset_id=dataset.id, - query=query, - top_k=self.top_k, - score_threshold=retrieval_model.get('score_threshold', .0) - if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model.get('reranking_model', None) - if retrieval_model['reranking_enable'] else None, - reranking_mode=retrieval_model.get('reranking_mode') - if retrieval_model.get('reranking_mode') else 'reranking_model', - weights=retrieval_model.get('weights', None), - ) - - all_documents.extend(documents) \ No newline at end of file + documents = RetrievalService.retrieve( + retrieval_method=retrieval_model["search_method"], + dataset_id=dataset.id, + query=query, + top_k=retrieval_model.get("top_k") or 2, + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else 0.0, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", + weights=retrieval_model.get("weights", None), + ) + + all_documents.extend(documents) diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py index 62e97a02306e58..dad8c773579099 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_base_tool.py @@ -9,6 +9,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC): """Tool for querying a Dataset.""" + name: str = "dataset" description: str = "use this to retrieve a dataset. " tenant_id: str diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index a7e70af6286544..987f94a35046e9 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -1,22 +1,20 @@ - from pydantic import BaseModel, Field from core.rag.datasource.retrieval_service import RetrievalService -from core.rag.retrieval.retrival_methods import RetrievalMethod +from core.rag.models.document import Document as RetrievalDocument +from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment +from services.external_knowledge_service import ExternalDatasetService default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'reranking_mode': 'reranking_model', - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "reranking_mode": "reranking_model", + "top_k": 2, + "score_threshold_enabled": False, } @@ -26,128 +24,168 @@ class DatasetRetrieverToolInput(BaseModel): class DatasetRetrieverTool(DatasetRetrieverBaseTool): """Tool for querying a Dataset.""" + name: str = "dataset" args_schema: type[BaseModel] = DatasetRetrieverToolInput description: str = "use this to retrieve a dataset. " dataset_id: str - @classmethod def from_dataset(cls, dataset: Dataset, **kwargs): description = dataset.description if not description: - description = 'useful for when you want to answer queries about the ' + dataset.name + description = "useful for when you want to answer queries about the " + dataset.name - description = description.replace('\n', '').replace('\r', '') + description = description.replace("\n", "").replace("\r", "") return cls( name=f"dataset_{dataset.id.replace('-', '_')}", tenant_id=dataset.tenant_id, dataset_id=dataset.id, description=description, - **kwargs + **kwargs, ) def _run(self, query: str) -> str: - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == self.tenant_id, - Dataset.id == self.dataset_id - ).first() + dataset = ( + db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first() + ) if not dataset: - return '' + return "" for hit_callback in self.hit_callbacks: hit_callback.on_query(query, dataset.id) + if dataset.provider == "external": + results = [] + external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + query=query, + external_retrieval_parameters=dataset.retrieval_model, + ) + for external_document in external_documents: + document = RetrievalDocument( + page_content=external_document.get("content"), + metadata=external_document.get("metadata"), + provider="external", + ) + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset.id + document.metadata["dataset_name"] = dataset.name + results.append(document) + # deal with external documents + context_list = [] + for position, item in enumerate(results, start=1): + source = { + "position": position, + "dataset_id": item.metadata.get("dataset_id"), + "dataset_name": item.metadata.get("dataset_name"), + "document_name": item.metadata.get("title"), + "data_source_type": "external", + "retriever_from": self.retriever_from, + "score": item.metadata.get("score"), + "title": item.metadata.get("title"), + "content": item.page_content, + } + context_list.append(source) + for hit_callback in self.hit_callbacks: + hit_callback.return_retriever_resource_info(context_list) - # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model - if dataset.indexing_technique == "economy": - # use keyword table query - documents = RetrievalService.retrieve(retrival_method='keyword_search', - dataset_id=dataset.id, - query=query, - top_k=self.top_k - ) - return str("\n".join([document.page_content for document in documents])) + return str("\n".join([item.page_content for item in results])) else: - if self.top_k > 0: - # retrieval source - documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'), - dataset_id=dataset.id, - query=query, - top_k=self.top_k, - score_threshold=retrieval_model.get('score_threshold', .0) - if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model.get('reranking_model', None) - if retrieval_model['reranking_enable'] else None, - reranking_mode=retrieval_model.get('reranking_mode') - if retrieval_model.get('reranking_mode') else 'reranking_model', - weights=retrieval_model.get('weights', None), - ) + # get retrieval model , if the model is not setting , using default + retrieval_model = dataset.retrieval_model or default_retrieval_model + if dataset.indexing_technique == "economy": + # use keyword table query + documents = RetrievalService.retrieve( + retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k + ) + return str("\n".join([document.page_content for document in documents])) else: - documents = [] - - for hit_callback in self.hit_callbacks: - hit_callback.on_tool_end(documents) - document_score_list = {} - if dataset.indexing_technique != "economy": - for item in documents: - if item.metadata.get('score'): - document_score_list[item.metadata['doc_id']] = item.metadata['score'] - document_context_list = [] - index_node_ids = [document.metadata['doc_id'] for document in documents] - segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id, - DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids) - ).all() - - if segments: - index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted(segments, - key=lambda segment: index_node_id_to_position.get(segment.index_node_id, - float('inf'))) - for segment in sorted_segments: - if segment.answer: - document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}') - else: - document_context_list.append(segment.get_sign_content()) - if self.return_resource: - context_list = [] - resource_number = 1 + if self.top_k > 0: + # retrieval source + documents = RetrievalService.retrieve( + retrieval_method=retrieval_model.get("search_method", "semantic_search"), + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else 0.0, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", + weights=retrieval_model.get("weights", None), + ) + else: + documents = [] + + for hit_callback in self.hit_callbacks: + hit_callback.on_tool_end(documents) + document_score_list = {} + if dataset.indexing_technique != "economy": + for item in documents: + if item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] + document_context_list = [] + index_node_ids = [document.metadata["doc_id"] for document in documents] + segments = DocumentSegment.query.filter( + DocumentSegment.dataset_id == self.dataset_id, + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids), + ).all() + + if segments: + index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) for segment in sorted_segments: - context = {} - document = Document.query.filter(Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ).first() - if dataset and document: - source = { - 'position': resource_number, - 'dataset_id': dataset.id, - 'dataset_name': dataset.name, - 'document_id': document.id, - 'document_name': document.name, - 'data_source_type': document.data_source_type, - 'segment_id': segment.id, - 'retriever_from': self.retriever_from, - 'score': document_score_list.get(segment.index_node_id, None) - - } - if self.retriever_from == 'dev': - source['hit_count'] = segment.hit_count - source['word_count'] = segment.word_count - source['segment_position'] = segment.position - source['index_node_hash'] = segment.index_node_hash - if segment.answer: - source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' - else: - source['content'] = segment.content - context_list.append(source) - resource_number += 1 - - for hit_callback in self.hit_callbacks: - hit_callback.return_retriever_resource_info(context_list) - - return str("\n".join(document_context_list)) \ No newline at end of file + if segment.answer: + document_context_list.append( + f"question:{segment.get_sign_content()} answer:{segment.answer}" + ) + else: + document_context_list.append(segment.get_sign_content()) + if self.return_resource: + context_list = [] + resource_number = 1 + for segment in sorted_segments: + context = {} + document = Document.query.filter( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ).first() + if dataset and document: + source = { + "position": resource_number, + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "data_source_type": document.data_source_type, + "segment_id": segment.id, + "retriever_from": self.retriever_from, + "score": document_score_list.get(segment.index_node_id, None), + } + if self.retriever_from == "dev": + source["hit_count"] = segment.hit_count + source["word_count"] = segment.word_count + source["segment_position"] = segment.position + source["index_node_hash"] = segment.index_node_hash + if segment.answer: + source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" + else: + source["content"] = segment.content + context_list.append(source) + resource_number += 1 + + for hit_callback in self.hit_callbacks: + hit_callback.return_retriever_resource_info(context_list) + + return str("\n".join(document_context_list)) diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py index 1170e1b7a5f065..3c9295c493c470 100644 --- a/api/core/tools/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -17,16 +17,17 @@ class DatasetRetrieverTool(Tool): - retrival_tool: DatasetRetrieverBaseTool + retrieval_tool: DatasetRetrieverBaseTool @staticmethod - def get_dataset_tools(tenant_id: str, - dataset_ids: list[str], - retrieve_config: DatasetRetrieveConfigEntity, - return_resource: bool, - invoke_from: InvokeFrom, - hit_callback: DatasetIndexToolCallbackHandler - ) -> list['DatasetRetrieverTool']: + def get_dataset_tools( + tenant_id: str, + dataset_ids: list[str], + retrieve_config: DatasetRetrieveConfigEntity, + return_resource: bool, + invoke_from: InvokeFrom, + hit_callback: DatasetIndexToolCallbackHandler, + ) -> list["DatasetRetrieverTool"]: """ get dataset tool """ @@ -42,29 +43,29 @@ def get_dataset_tools(tenant_id: str, # Agent only support SINGLE mode original_retriever_mode = retrieve_config.retrieve_strategy retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE - retrival_tools = feature.to_dataset_retriever_tool( + retrieval_tools = feature.to_dataset_retriever_tool( tenant_id=tenant_id, dataset_ids=dataset_ids, retrieve_config=retrieve_config, return_resource=return_resource, invoke_from=invoke_from, - hit_callback=hit_callback + hit_callback=hit_callback, ) # restore retrieve strategy retrieve_config.retrieve_strategy = original_retriever_mode - # convert retrival tools to Tools + # convert retrieval tools to Tools tools = [] - for retrival_tool in retrival_tools: + for retrieval_tool in retrieval_tools: tool = DatasetRetrieverTool( - retrival_tool=retrival_tool, - identity=ToolIdentity(provider='', author='', name=retrival_tool.name, label=I18nObject(en_US='', zh_Hans='')), + retrieval_tool=retrieval_tool, + identity=ToolIdentity( + provider="", author="", name=retrieval_tool.name, label=I18nObject(en_US="", zh_Hans="") + ), parameters=[], is_team_authorization=True, - description=ToolDescription( - human=I18nObject(en_US='', zh_Hans=''), - llm=retrival_tool.description), - runtime=DatasetRetrieverTool.Runtime() + description=ToolDescription(human=I18nObject(en_US="", zh_Hans=""), llm=retrieval_tool.description), + runtime=DatasetRetrieverTool.Runtime(), ) tools.append(tool) @@ -73,16 +74,18 @@ def get_dataset_tools(tenant_id: str, def get_runtime_parameters(self) -> list[ToolParameter]: return [ - ToolParameter(name='query', - label=I18nObject(en_US='', zh_Hans=''), - human_description=I18nObject(en_US='', zh_Hans=''), - type=ToolParameter.ToolParameterType.STRING, - form=ToolParameter.ToolParameterForm.LLM, - llm_description='Query for the dataset to be used to retrieve the dataset.', - required=True, - default=''), + ToolParameter( + name="query", + label=I18nObject(en_US="", zh_Hans=""), + human_description=I18nObject(en_US="", zh_Hans=""), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description="Query for the dataset to be used to retrieve the dataset.", + required=True, + default="", + ), ] - + def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.DATASET_RETRIEVAL @@ -90,12 +93,12 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe """ invoke dataset retriever tool """ - query = tool_parameters.get('query') + query = tool_parameters.get("query") if not query: - return self.create_text_message(text='please input query') + return self.create_text_message(text="please input query") # invoke dataset retriever tool - result = self.retrival_tool._run(query=query) + result = self.retrieval_tool._run(query=query) return self.create_text_message(text=result) diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index d990131b5fbbfd..6cb6e18b6d4e84 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -20,10 +20,9 @@ ToolRuntimeVariablePool, ) from core.tools.tool_file_manager import ToolFileManager -from core.tools.utils.tool_parameter_converter import ToolParameterConverter if TYPE_CHECKING: - from core.file.file_obj import FileVar + from core.file.models import File class Tool(BaseModel, ABC): @@ -35,15 +34,16 @@ class Tool(BaseModel, ABC): # pydantic configs model_config = ConfigDict(protected_namespaces=()) - @field_validator('parameters', mode='before') + @field_validator("parameters", mode="before") @classmethod def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]: return v or [] class Runtime(BaseModel): """ - Meta data of a tool call processing + Meta data of a tool call processing """ + def __init__(self, **data: Any): super().__init__(**data) if not self.runtime_parameters: @@ -62,15 +62,19 @@ def __init__(self, **data: Any): def __init__(self, **data: Any): super().__init__(**data) - class VARIABLE_KEY(Enum): - IMAGE = 'image' + class VariableKey(str, Enum): + IMAGE = "image" + DOCUMENT = "document" + VIDEO = "video" + AUDIO = "audio" + CUSTOM = "custom" - def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool': + def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": """ - fork a new tool with meta data + fork a new tool with meta data - :param meta: the meta data of a tool call processing, tenant_id is required - :return: the new tool + :param meta: the meta data of a tool call processing, tenant_id is required + :return: the new tool """ return self.__class__( identity=self.identity.model_copy() if self.identity else None, @@ -82,22 +86,22 @@ def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'Tool': @abstractmethod def tool_provider_type(self) -> ToolProviderType: """ - get the tool provider type + get the tool provider type - :return: the tool provider type + :return: the tool provider type """ def load_variables(self, variables: ToolRuntimeVariablePool): """ - load variables from database + load variables from database - :param conversation_id: the conversation id + :param conversation_id: the conversation id """ self.variables = variables def set_image_variable(self, variable_name: str, image_key: str) -> None: """ - set an image variable + set an image variable """ if not self.variables: return @@ -106,7 +110,7 @@ def set_image_variable(self, variable_name: str, image_key: str) -> None: def set_text_variable(self, variable_name: str, text: str) -> None: """ - set a text variable + set a text variable """ if not self.variables: return @@ -115,10 +119,10 @@ def set_text_variable(self, variable_name: str, text: str) -> None: def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]: """ - get a variable + get a variable - :param name: the name of the variable - :return: the variable + :param name: the name of the variable + :return: the variable """ if not self.variables: return None @@ -134,21 +138,21 @@ def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]: def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]: """ - get the default image variable + get the default image variable - :return: the image variable + :return: the image variable """ if not self.variables: return None - return self.get_variable(self.VARIABLE_KEY.IMAGE) + return self.get_variable(self.VariableKey.IMAGE) def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: """ - get a variable file + get a variable file - :param name: the name of the variable - :return: the variable file + :param name: the name of the variable + :return: the variable file """ variable = self.get_variable(name) if not variable: @@ -167,9 +171,9 @@ def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: def list_variables(self) -> list[ToolRuntimeVariable]: """ - list all variables + list all variables - :return: the variables + :return: the variables """ if not self.variables: return [] @@ -178,9 +182,9 @@ def list_variables(self) -> list[ToolRuntimeVariable]: def list_default_image_variables(self) -> list[ToolRuntimeVariable]: """ - list all image variables + list all image variables - :return: the image variables + :return: the image variables """ if not self.variables: return [] @@ -188,7 +192,7 @@ def list_default_image_variables(self) -> list[ToolRuntimeVariable]: result = [] for variable in self.variables.pool: - if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value): + if variable.name.startswith(self.VariableKey.IMAGE.value): result.append(variable) return result @@ -220,38 +224,40 @@ def _transform_tool_parameters_type(self, tool_parameters: Mapping[str, Any]) -> result = deepcopy(tool_parameters) for parameter in self.parameters or []: if parameter.name in tool_parameters: - result[parameter.name] = ToolParameterConverter.cast_parameter_by_type(tool_parameters[parameter.name], parameter.type) + result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name]) return result @abstractmethod - def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: pass def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: """ - validate the credentials + validate the credentials - :param credentials: the credentials - :param parameters: the parameters + :param credentials: the credentials + :param parameters: the parameters """ pass def get_runtime_parameters(self) -> list[ToolParameter]: """ - get the runtime parameters + get the runtime parameters - interface for developer to dynamic change the parameters of a tool depends on the variables pool + interface for developer to dynamic change the parameters of a tool depends on the variables pool - :return: the runtime parameters + :return: the runtime parameters """ return self.parameters or [] def get_all_runtime_parameters(self) -> list[ToolParameter]: """ - get all runtime parameters + get all runtime parameters - :return: all runtime parameters + :return: all runtime parameters """ parameters = self.parameters or [] parameters = parameters.copy() @@ -281,67 +287,47 @@ def get_all_runtime_parameters(self) -> list[ToolParameter]: return parameters - def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage: + def create_image_message(self, image: str, save_as: str = "") -> ToolInvokeMessage: """ - create an image message + create an image message - :param image: the url of the image - :return: the image message + :param image: the url of the image + :return: the image message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, - message=image, - save_as=save_as) + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, message=image, save_as=save_as) - def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage: - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR, - message='', - meta={ - 'file_var': file_var - }, - save_as='') + def create_file_message(self, file: "File") -> ToolInvokeMessage: + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE, message="", meta={"file": file}, save_as="") - def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage: + def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage: """ - create a link message + create a link message - :param link: the url of the link - :return: the link message + :param link: the url of the link + :return: the link message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, - message=link, - save_as=save_as) + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, message=link, save_as=save_as) - def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage: + def create_text_message(self, text: str, save_as: str = "") -> ToolInvokeMessage: """ - create a text message + create a text message - :param text: the text - :return: the text message + :param text: the text + :return: the text message """ - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.TEXT, - message=text, - save_as=save_as - ) + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.TEXT, message=text, save_as=save_as) - def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage: + def create_blob_message(self, blob: bytes, meta: Optional[dict] = None, save_as: str = "") -> ToolInvokeMessage: """ - create a blob message + create a blob message - :param blob: the blob - :return: the blob message + :param blob: the blob + :return: the blob message """ - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.BLOB, - message=blob, meta=meta, - save_as=save_as - ) + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.BLOB, message=blob, meta=meta, save_as=save_as) def create_json_message(self, object: dict) -> ToolInvokeMessage: """ - create a json message + create a json message """ - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.JSON, - message=object - ) + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=object) diff --git a/api/core/tools/tool/workflow_tool.py b/api/core/tools/tool/workflow_tool.py index 12e498e76d8cd5..2ab72213ff90dc 100644 --- a/api/core/tools/tool/workflow_tool.py +++ b/api/core/tools/tool/workflow_tool.py @@ -1,9 +1,9 @@ import json import logging from copy import deepcopy -from typing import Any, Union +from typing import Any, Optional, Union -from core.file.file_obj import FileTransferMethod, FileVar +from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType from core.tools.tool.tool import Tool from extensions.ext_database import db @@ -13,22 +13,25 @@ logger = logging.getLogger(__name__) + class WorkflowTool(Tool): workflow_app_id: str version: str workflow_entities: dict[str, Any] workflow_call_depth: int + thread_pool_id: Optional[str] = None label: str """ Workflow tool. """ + def tool_provider_type(self) -> ToolProviderType: """ - get the tool provider type + get the tool provider type - :return: the tool provider type + :return: the tool provider type """ return ToolProviderType.WORKFLOW @@ -36,41 +39,45 @@ def _invoke( self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ - invoke the tool + invoke the tool """ app = self._get_app(app_id=self.workflow_app_id) workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version) # transform the tool parameters - tool_parameters, files = self._transform_args(tool_parameters) + tool_parameters, files = self._transform_args(tool_parameters=tool_parameters) from core.app.apps.workflow.app_generator import WorkflowAppGenerator + generator = WorkflowAppGenerator() + assert self.runtime is not None + assert self.runtime.invoke_from is not None result = generator.generate( - app_model=app, - workflow=workflow, - user=self._get_user(user_id), - args={ - 'inputs': tool_parameters, - 'files': files - }, + app_model=app, + workflow=workflow, + user=self._get_user(user_id), + args={"inputs": tool_parameters, "files": files}, invoke_from=self.runtime.invoke_from, stream=False, call_depth=self.workflow_call_depth + 1, + workflow_thread_pool_id=self.thread_pool_id, ) - data = result.get('data', {}) + data = result.get("data", {}) + + if data.get("error"): + raise Exception(data.get("error")) - if data.get('error'): - raise Exception(data.get('error')) - result = [] - outputs = data.get('outputs', {}) - outputs, files = self._extract_files(outputs) - for file in files: - result.append(self.create_file_var_message(file)) - + outputs = data.get("outputs") + if outputs == None: + outputs = {} + else: + outputs, files = self._extract_files(outputs) + for file in files: + result.append(self.create_file_message(file)) + result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False))) result.append(self.create_json_message(outputs)) @@ -78,7 +85,7 @@ def _invoke( def _get_user(self, user_id: str) -> Union[EndUser, Account]: """ - get the user by user id + get the user by user id """ user = db.session.query(EndUser).filter(EndUser.id == user_id).first() @@ -86,16 +93,16 @@ def _get_user(self, user_id: str) -> Union[EndUser, Account]: user = db.session.query(Account).filter(Account.id == user_id).first() if not user: - raise ValueError('user not found') + raise ValueError("user not found") return user - def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'WorkflowTool': + def fork_tool_runtime(self, runtime: dict[str, Any]) -> "WorkflowTool": """ - fork a new tool with meta data + fork a new tool with meta data - :param meta: the meta data of a tool call processing, tenant_id is required - :return: the new tool + :param meta: the meta data of a tool call processing, tenant_id is required + :return: the new tool """ return self.__class__( identity=deepcopy(self.identity), @@ -106,66 +113,65 @@ def fork_tool_runtime(self, runtime: dict[str, Any]) -> 'WorkflowTool': workflow_entities=self.workflow_entities, workflow_call_depth=self.workflow_call_depth, version=self.version, - label=self.label + label=self.label, ) - + def _get_workflow(self, app_id: str, version: str) -> Workflow: """ - get the workflow by app id and version + get the workflow by app id and version """ if not version: - workflow = db.session.query(Workflow).filter( - Workflow.app_id == app_id, - Workflow.version != 'draft' - ).order_by(Workflow.created_at.desc()).first() + workflow = ( + db.session.query(Workflow) + .filter(Workflow.app_id == app_id, Workflow.version != "draft") + .order_by(Workflow.created_at.desc()) + .first() + ) else: - workflow = db.session.query(Workflow).filter( - Workflow.app_id == app_id, - Workflow.version == version - ).first() + workflow = db.session.query(Workflow).filter(Workflow.app_id == app_id, Workflow.version == version).first() if not workflow: - raise ValueError('workflow not found or not published') + raise ValueError("workflow not found or not published") return workflow - + def _get_app(self, app_id: str) -> App: """ - get the app by app id + get the app by app id """ app = db.session.query(App).filter(App.id == app_id).first() if not app: - raise ValueError('app not found') + raise ValueError("app not found") return app - + def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]: """ - transform the tool parameters + transform the tool parameters - :param tool_parameters: the tool parameters - :return: tool_parameters, files + :param tool_parameters: the tool parameters + :return: tool_parameters, files """ parameter_rules = self.get_all_runtime_parameters() parameters_result = {} files = [] for parameter in parameter_rules: - if parameter.type == ToolParameter.ToolParameterType.FILE: + if parameter.type == ToolParameter.ToolParameterType.SYSTEM_FILES: file = tool_parameters.get(parameter.name) if file: try: - file_var_list = [FileVar(**f) for f in file] - for file_var in file_var_list: - file_dict = { - 'transfer_method': file_var.transfer_method.value, - 'type': file_var.type.value, + file_var_list = [File.model_validate(f) for f in file] + for file in file_var_list: + file_dict: dict[str, str | None] = { + "transfer_method": file.transfer_method.value, + "type": file.type.value, } - if file_var.transfer_method == FileTransferMethod.TOOL_FILE: - file_dict['tool_file_id'] = file_var.related_id - elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE: - file_dict['upload_file_id'] = file_var.related_id - elif file_var.transfer_method == FileTransferMethod.REMOTE_URL: - file_dict['url'] = file_var.preview_url + if file.transfer_method == FileTransferMethod.TOOL_FILE: + file_dict["tool_file_id"] = file.related_id + elif file.transfer_method == FileTransferMethod.LOCAL_FILE: + file_dict["upload_file_id"] = file.related_id + elif file.transfer_method == FileTransferMethod.REMOTE_URL: + file_dict["url"] = file.generate_url() files.append(file_dict) except Exception as e: @@ -174,29 +180,25 @@ def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]: parameters_result[parameter.name] = tool_parameters.get(parameter.name) return parameters_result, files - - def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]: + + def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]: """ - extract files from the result + extract files from the result - :param result: the result - :return: the result, files + :param result: the result + :return: the result, files """ files = [] result = {} for key, value in outputs.items(): if isinstance(value, list): - has_file = False for item in value: - if isinstance(item, dict) and item.get('__variant') == 'FileVar': - try: - files.append(FileVar(**item)) - has_file = True - except Exception as e: - pass - if has_file: - continue + if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY: + file = File.model_validate(item) + files.append(file) + elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: + file = File.model_validate(value) + files.append(file) result[key] = value - - return result, files \ No newline at end of file + return result, files diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 0e15151aa49cba..9e290c36515d5e 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -10,7 +10,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.file.file_obj import FileTransferMethod +from core.file import FileType +from core.file.models import FileTransferMethod from core.ops.ops_trace_manager import TraceQueueManager from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter from core.tools.errors import ( @@ -26,6 +27,7 @@ from core.tools.tool.workflow_tool import WorkflowTool from core.tools.utils.message_transformer import ToolFileMessageTransformer from extensions.ext_database import db +from models.enums import CreatedByRole from models.model import Message, MessageFile @@ -33,12 +35,17 @@ class ToolEngine: """ Tool runtime engine take care of the tool executions. """ + @staticmethod def agent_invoke( - tool: Tool, tool_parameters: Union[str, dict], - user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom, + tool: Tool, + tool_parameters: Union[str, dict], + user_id: str, + tenant_id: str, + message: Message, + invoke_from: InvokeFrom, agent_tool_callback: DifyAgentCallbackHandler, - trace_manager: Optional[TraceQueueManager] = None + trace_manager: Optional[TraceQueueManager] = None, ) -> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]: """ Agent invokes the tool with the given arguments. @@ -47,40 +54,30 @@ def agent_invoke( if isinstance(tool_parameters, str): # check if this tool has only one parameter parameters = [ - parameter for parameter in tool.get_runtime_parameters() or [] + parameter + for parameter in tool.get_runtime_parameters() or [] if parameter.form == ToolParameter.ToolParameterForm.LLM ] if parameters and len(parameters) == 1: - tool_parameters = { - parameters[0].name: tool_parameters - } + tool_parameters = {parameters[0].name: tool_parameters} else: raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") # invoke the tool try: # hit the callback handler - agent_tool_callback.on_tool_start( - tool_name=tool.identity.name, - tool_inputs=tool_parameters - ) + agent_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters) meta, response = ToolEngine._invoke(tool, tool_parameters, user_id) response = ToolFileMessageTransformer.transform_tool_invoke_messages( - messages=response, - user_id=user_id, - tenant_id=tenant_id, - conversation_id=message.conversation_id + messages=response, user_id=user_id, tenant_id=tenant_id, conversation_id=message.conversation_id ) # extract binary data from tool invoke message binary_files = ToolEngine._extract_tool_response_binary(response) # create message file message_files = ToolEngine._create_message_files( - tool_messages=binary_files, - agent_message=message, - invoke_from=invoke_from, - user_id=user_id + tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id ) plain_text = ToolEngine._convert_tool_response_to_str(response) @@ -91,7 +88,7 @@ def agent_invoke( tool_inputs=tool_parameters, tool_outputs=plain_text, message_id=message.id, - trace_manager=trace_manager + trace_manager=trace_manager, ) # transform tool invoke message to get LLM friendly message @@ -99,14 +96,10 @@ def agent_invoke( except ToolProviderCredentialValidationError as e: error_response = "Please check your tool provider credentials" agent_tool_callback.on_tool_error(e) - except ( - ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError - ) as e: + except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e: error_response = f"there is not a tool named {tool.identity.name}" agent_tool_callback.on_tool_error(e) - except ( - ToolParameterValidationError - ) as e: + except ToolParameterValidationError as e: error_response = f"tool parameters validation error: {e}, please check your tool parameters" agent_tool_callback.on_tool_error(e) except ToolInvokeError as e: @@ -124,23 +117,25 @@ def agent_invoke( return error_response, [], ToolInvokeMeta.error_instance(error_response) @staticmethod - def workflow_invoke(tool: Tool, tool_parameters: Mapping[str, Any], - user_id: str, - workflow_tool_callback: DifyWorkflowCallbackHandler, - workflow_call_depth: int, - ) -> list[ToolInvokeMessage]: + def workflow_invoke( + tool: Tool, + tool_parameters: Mapping[str, Any], + user_id: str, + workflow_tool_callback: DifyWorkflowCallbackHandler, + workflow_call_depth: int, + thread_pool_id: Optional[str] = None, + ) -> list[ToolInvokeMessage]: """ Workflow invokes the tool with the given arguments. """ try: # hit the callback handler - workflow_tool_callback.on_tool_start( - tool_name=tool.identity.name, - tool_inputs=tool_parameters - ) + assert tool.identity is not None + workflow_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters) if isinstance(tool, WorkflowTool): tool.workflow_call_depth = workflow_call_depth + 1 + tool.thread_pool_id = thread_pool_id if tool.runtime and tool.runtime.runtime_parameters: tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters} @@ -157,21 +152,24 @@ def workflow_invoke(tool: Tool, tool_parameters: Mapping[str, Any], except Exception as e: workflow_tool_callback.on_tool_error(e) raise e - + @staticmethod - def _invoke(tool: Tool, tool_parameters: dict, user_id: str) \ - -> tuple[ToolInvokeMeta, list[ToolInvokeMessage]]: + def _invoke(tool: Tool, tool_parameters: dict, user_id: str) -> tuple[ToolInvokeMeta, list[ToolInvokeMessage]]: """ Invoke the tool with the given arguments. """ started_at = datetime.now(timezone.utc) - meta = ToolInvokeMeta(time_cost=0.0, error=None, tool_config={ - 'tool_name': tool.identity.name, - 'tool_provider': tool.identity.provider, - 'tool_provider_type': tool.tool_provider_type().value, - 'tool_parameters': deepcopy(tool.runtime.runtime_parameters), - 'tool_icon': tool.identity.icon - }) + meta = ToolInvokeMeta( + time_cost=0.0, + error=None, + tool_config={ + "tool_name": tool.identity.name, + "tool_provider": tool.identity.provider, + "tool_provider_type": tool.tool_provider_type().value, + "tool_parameters": deepcopy(tool.runtime.runtime_parameters), + "tool_icon": tool.identity.icon, + }, + ) try: response = tool.invoke(user_id, tool_parameters) except Exception as e: @@ -182,28 +180,30 @@ def _invoke(tool: Tool, tool_parameters: dict, user_id: str) \ meta.time_cost = (ended_at - started_at).total_seconds() return meta, response - + @staticmethod def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str: """ Handle tool response """ - result = '' + result = "" for response in tool_response: if response.type == ToolInvokeMessage.MessageType.TEXT: result += response.message elif response.type == ToolInvokeMessage.MessageType.LINK: result += f"result link: {response.message}. please tell user to check it." - elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ - response.type == ToolInvokeMessage.MessageType.IMAGE: - result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now." + elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: + result += ( + "image has been created and sent to user already, you do not need to create it," + " just tell the user to check it now." + ) elif response.type == ToolInvokeMessage.MessageType.JSON: result += f"tool response: {json.dumps(response.message, ensure_ascii=False)}." else: result += f"tool response: {response.message}." return result - + @staticmethod def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]: """ @@ -212,52 +212,59 @@ def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> lis result = [] for response in tool_response: - if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ - response.type == ToolInvokeMessage.MessageType.IMAGE: + if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: mimetype = None - if response.meta.get('mime_type'): - mimetype = response.meta.get('mime_type') + if response.meta.get("mime_type"): + mimetype = response.meta.get("mime_type") else: try: url = URL(response.message) extension = url.suffix - guess_type_result, _ = guess_type(f'a{extension}') + guess_type_result, _ = guess_type(f"a{extension}") if guess_type_result: mimetype = guess_type_result except Exception: pass - + if not mimetype: - mimetype = 'image/jpeg' - - result.append(ToolInvokeMessageBinary( - mimetype=response.meta.get('mime_type', 'image/jpeg'), - url=response.message, - save_as=response.save_as, - )) + mimetype = "image/jpeg" + + result.append( + ToolInvokeMessageBinary( + mimetype=response.meta.get("mime_type", "image/jpeg"), + url=response.message, + save_as=response.save_as, + ) + ) elif response.type == ToolInvokeMessage.MessageType.BLOB: - result.append(ToolInvokeMessageBinary( - mimetype=response.meta.get('mime_type', 'octet/stream'), - url=response.message, - save_as=response.save_as, - )) - elif response.type == ToolInvokeMessage.MessageType.LINK: - # check if there is a mime type in meta - if response.meta and 'mime_type' in response.meta: - result.append(ToolInvokeMessageBinary( - mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream', + result.append( + ToolInvokeMessageBinary( + mimetype=response.meta.get("mime_type", "octet/stream"), url=response.message, save_as=response.save_as, - )) + ) + ) + elif response.type == ToolInvokeMessage.MessageType.LINK: + # check if there is a mime type in meta + if response.meta and "mime_type" in response.meta: + result.append( + ToolInvokeMessageBinary( + mimetype=response.meta.get("mime_type", "octet/stream") + if response.meta + else "octet/stream", + url=response.message, + save_as=response.save_as, + ) + ) return result - + @staticmethod def _create_message_files( tool_messages: list[ToolInvokeMessageBinary], agent_message: Message, invoke_from: InvokeFrom, - user_id: str + user_id: str, ) -> list[tuple[Any, str]]: """ Create message file @@ -268,29 +275,31 @@ def _create_message_files( result = [] for message in tool_messages: - file_type = 'bin' - if 'image' in message.mimetype: - file_type = 'image' - elif 'video' in message.mimetype: - file_type = 'video' - elif 'audio' in message.mimetype: - file_type = 'audio' - elif 'text' in message.mimetype: - file_type = 'text' - elif 'pdf' in message.mimetype: - file_type = 'pdf' - elif 'zip' in message.mimetype: - file_type = 'archive' - # ... + if "image" in message.mimetype: + file_type = FileType.IMAGE + elif "video" in message.mimetype: + file_type = FileType.VIDEO + elif "audio" in message.mimetype: + file_type = FileType.AUDIO + elif "text" in message.mimetype or "pdf" in message.mimetype: + file_type = FileType.DOCUMENT + else: + file_type = FileType.CUSTOM + # extract tool file id from url + tool_file_id = message.url.split("/")[-1].split(".")[0] message_file = MessageFile( message_id=agent_message.id, type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE.value, - belongs_to='assistant', + transfer_method=FileTransferMethod.TOOL_FILE, + belongs_to="assistant", url=message.url, - upload_file_id=None, - created_by_role=('account'if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'), + upload_file_id=tool_file_id, + created_by_role=( + CreatedByRole.ACCOUNT + if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else CreatedByRole.END_USER + ), created_by=user_id, ) @@ -298,11 +307,8 @@ def _create_message_files( db.session.commit() db.session.refresh(message_file) - result.append(( - message_file.id, - message.save_as - )) + result.append((message_file.id, message.save_as)) db.session.close() - return result \ No newline at end of file + return result diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index f9f7c7d78a7f28..ff56e20e8758ea 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -4,7 +4,6 @@ import logging import os import time -from collections.abc import Generator from mimetypes import guess_extension, guess_type from typing import Optional, Union from uuid import uuid4 @@ -27,24 +26,24 @@ def sign_file(tool_file_id: str, extension: str) -> str: sign file to get a temporary url """ base_url = dify_config.FILES_URL - file_preview_url = f'{base_url}/files/tools/{tool_file_id}{extension}' + file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}" timestamp = str(int(time.time())) nonce = os.urandom(16).hex() - data_to_sign = f'file-preview|{tool_file_id}|{timestamp}|{nonce}' - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b'' + data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() - return f'{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}' + return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" @staticmethod def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool: """ verify signature """ - data_to_sign = f'file-preview|{file_id}|{timestamp}|{nonce}' - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b'' + data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() @@ -57,22 +56,32 @@ def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool: @staticmethod def create_file_by_raw( - user_id: str, tenant_id: str, conversation_id: Optional[str], file_binary: bytes, mimetype: str + *, + user_id: str, + tenant_id: str, + conversation_id: Optional[str], + file_binary: bytes, + mimetype: str, ) -> ToolFile: - """ - create file - """ - extension = guess_extension(mimetype) or '.bin' + extension = guess_extension(mimetype) or ".bin" unique_name = uuid4().hex - filename = f'tools/{tenant_id}/{unique_name}{extension}' - storage.save(filename, file_binary) + filename = f"{unique_name}{extension}" + filepath = f"tools/{tenant_id}/{filename}" + storage.save(filepath, file_binary) tool_file = ToolFile( - user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=filename, mimetype=mimetype + user_id=user_id, + tenant_id=tenant_id, + conversation_id=conversation_id, + file_key=filepath, + mimetype=mimetype, + name=filename, + size=len(file_binary), ) db.session.add(tool_file) db.session.commit() + db.session.refresh(tool_file) return tool_file @@ -80,29 +89,34 @@ def create_file_by_raw( def create_file_by_url( user_id: str, tenant_id: str, - conversation_id: str, + conversation_id: str | None, file_url: str, ) -> ToolFile: - """ - create file - """ # try to download image - response = get(file_url) - response.raise_for_status() - blob = response.content - mimetype = guess_type(file_url)[0] or 'octet/stream' - extension = guess_extension(mimetype) or '.bin' + try: + response = get(file_url) + response.raise_for_status() + blob = response.content + except Exception as e: + logger.exception(f"Failed to download file from {file_url}: {e}") + raise + + mimetype = guess_type(file_url)[0] or "octet/stream" + extension = guess_extension(mimetype) or ".bin" unique_name = uuid4().hex - filename = f'tools/{tenant_id}/{unique_name}{extension}' - storage.save(filename, blob) + filename = f"{unique_name}{extension}" + filepath = f"tools/{tenant_id}/{filename}" + storage.save(filepath, blob) tool_file = ToolFile( user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, - file_key=filename, + file_key=filepath, mimetype=mimetype, original_url=file_url, + name=filename, + size=len(blob), ) db.session.add(tool_file) @@ -110,18 +124,6 @@ def create_file_by_url( return tool_file - @staticmethod - def create_file_by_key( - user_id: str, tenant_id: str, conversation_id: str, file_key: str, mimetype: str - ) -> ToolFile: - """ - create file - """ - tool_file = ToolFile( - user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=file_key, mimetype=mimetype - ) - return tool_file - @staticmethod def get_file_binary(id: str) -> Union[tuple[bytes, str], None]: """ @@ -131,7 +133,7 @@ def get_file_binary(id: str) -> Union[tuple[bytes, str], None]: :return: the binary of the file, mime type """ - tool_file: ToolFile = ( + tool_file = ( db.session.query(ToolFile) .filter( ToolFile.id == id, @@ -155,7 +157,7 @@ def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None :return: the binary of the file, mime type """ - message_file: MessageFile = ( + message_file = ( db.session.query(MessageFile) .filter( MessageFile.id == id, @@ -166,14 +168,16 @@ def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None # Check if message_file is not None if message_file is not None: # get tool file id - tool_file_id = message_file.url.split('/')[-1] - # trim extension - tool_file_id = tool_file_id.split('.')[0] + if message_file.url is not None: + tool_file_id = message_file.url.split("/")[-1] + # trim extension + tool_file_id = tool_file_id.split(".")[0] + else: + tool_file_id = None else: tool_file_id = None - - tool_file: ToolFile = ( + tool_file = ( db.session.query(ToolFile) .filter( ToolFile.id == tool_file_id, @@ -189,7 +193,7 @@ def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None return blob, tool_file.mimetype @staticmethod - def get_file_generator_by_tool_file_id(tool_file_id: str) -> Union[tuple[Generator, str], None]: + def get_file_generator_by_tool_file_id(tool_file_id: str): """ get file binary @@ -197,7 +201,7 @@ def get_file_generator_by_tool_file_id(tool_file_id: str) -> Union[tuple[Generat :return: the binary of the file, mime type """ - tool_file: ToolFile = ( + tool_file = ( db.session.query(ToolFile) .filter( ToolFile.id == tool_file_id, @@ -206,14 +210,14 @@ def get_file_generator_by_tool_file_id(tool_file_id: str) -> Union[tuple[Generat ) if not tool_file: - return None + return None, None - generator = storage.load_stream(tool_file.file_key) + stream = storage.load_stream(tool_file.file_key) - return generator, tool_file.mimetype + return stream, tool_file # init tool_file_parser from core.file.tool_file_parser import tool_file_manager -tool_file_manager['manager'] = ToolFileManager +tool_file_manager["manager"] = ToolFileManager diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 97788a7a07dfb0..2a5a2944ef8471 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -15,7 +15,7 @@ def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]: """ tool_labels = [label for label in tool_labels if label in default_tool_label_name_list] return list(set(tool_labels)) - + @classmethod def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]): """ @@ -26,20 +26,20 @@ def update_tool_labels(cls, controller: ToolProviderController, labels: list[str if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): provider_id = controller.provider_id else: - raise ValueError('Unsupported tool type') + raise ValueError("Unsupported tool type") # delete old labels - db.session.query(ToolLabelBinding).filter( - ToolLabelBinding.tool_id == provider_id - ).delete() + db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete() # insert new labels for label in labels: - db.session.add(ToolLabelBinding( - tool_id=provider_id, - tool_type=controller.provider_type.value, - label_name=label, - )) + db.session.add( + ToolLabelBinding( + tool_id=provider_id, + tool_type=controller.provider_type.value, + label_name=label, + ) + ) db.session.commit() @@ -53,12 +53,16 @@ def get_tool_labels(cls, controller: ToolProviderController) -> list[str]: elif isinstance(controller, BuiltinToolProviderController): return controller.tool_labels else: - raise ValueError('Unsupported tool type') + raise ValueError("Unsupported tool type") - labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding.label_name).filter( - ToolLabelBinding.tool_id == provider_id, - ToolLabelBinding.tool_type == controller.provider_type.value, - ).all() + labels: list[ToolLabelBinding] = ( + db.session.query(ToolLabelBinding.label_name) + .filter( + ToolLabelBinding.tool_id == provider_id, + ToolLabelBinding.tool_type == controller.provider_type.value, + ) + .all() + ) return [label.label_name for label in labels] @@ -75,22 +79,20 @@ def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[ """ if not tool_providers: return {} - + for controller in tool_providers: if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): - raise ValueError('Unsupported tool type') - + raise ValueError("Unsupported tool type") + provider_ids = [controller.provider_id for controller in tool_providers] - labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding).filter( - ToolLabelBinding.tool_id.in_(provider_ids) - ).all() + labels: list[ToolLabelBinding] = ( + db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all() + ) - tool_labels = { - label.tool_id: [] for label in labels - } + tool_labels = {label.tool_id: [] for label in labels} for label in labels: tool_labels[label.tool_id].append(label.label_name) - return tool_labels \ No newline at end of file + return tool_labels diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index d7ddb40e6be310..bf2ad13620b629 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -3,21 +3,18 @@ import mimetypes from collections.abc import Generator from os import listdir, path -from threading import Lock -from typing import Any, Union +from threading import Lock, Thread +from typing import Any, Optional, Union from configs import dify_config from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.module_import_helper import load_single_subclass_from_source +from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ( - ApiProviderAuthType, - ToolInvokeFrom, - ToolParameter, -) +from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter from core.tools.errors import ToolProviderNotFoundError from core.tools.provider.api_tool_provider import ApiToolProviderController from core.tools.provider.builtin._positions import BuiltinToolProviderSort @@ -26,18 +23,14 @@ from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.tool import Tool from core.tools.tool_label_manager import ToolLabelManager -from core.tools.utils.configuration import ( - ToolConfigurationManager, - ToolParameterConfigurationManager, -) -from core.tools.utils.tool_parameter_converter import ToolParameterConverter -from core.workflow.nodes.tool.entities import ToolEntity +from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) + class ToolManager: _builtin_provider_lock = Lock() _builtin_providers = {} @@ -47,29 +40,29 @@ class ToolManager: @classmethod def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController: """ - get the builtin provider + get the builtin provider - :param provider: the name of the provider - :return: the provider + :param provider: the name of the provider + :return: the provider """ if len(cls._builtin_providers) == 0: # init the builtin providers cls.load_builtin_providers_cache() if provider not in cls._builtin_providers: - raise ToolProviderNotFoundError(f'builtin provider {provider} not found') + raise ToolProviderNotFoundError(f"builtin provider {provider} not found") return cls._builtin_providers[provider] @classmethod def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool: """ - get the builtin tool + get the builtin tool - :param provider: the name of the provider - :param tool_name: the name of the tool + :param provider: the name of the provider + :param tool_name: the name of the tool - :return: the provider, the tool + :return: the provider, the tool """ provider_controller = cls.get_builtin_provider(provider) tool = provider_controller.get_tool(tool_name) @@ -77,67 +70,76 @@ def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool: return tool @classmethod - def get_tool(cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \ - -> Union[BuiltinTool, ApiTool]: + def get_tool( + cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: Optional[str] = None + ) -> Union[BuiltinTool, ApiTool]: """ - get the tool + get the tool - :param provider_type: the type of the provider - :param provider_name: the name of the provider - :param tool_name: the name of the tool + :param provider_type: the type of the provider + :param provider_name: the name of the provider + :param tool_name: the name of the tool - :return: the tool + :return: the tool """ - if provider_type == 'builtin': + if provider_type == "builtin": return cls.get_builtin_tool(provider_id, tool_name) - elif provider_type == 'api': + elif provider_type == "api": if tenant_id is None: - raise ValueError('tenant id is required for api provider') + raise ValueError("tenant id is required for api provider") api_provider, _ = cls.get_api_provider_controller(tenant_id, provider_id) return api_provider.get_tool(tool_name) - elif provider_type == 'app': - raise NotImplementedError('app provider not implemented') + elif provider_type == "app": + raise NotImplementedError("app provider not implemented") else: - raise ToolProviderNotFoundError(f'provider type {provider_type} not found') + raise ToolProviderNotFoundError(f"provider type {provider_type} not found") @classmethod - def get_tool_runtime(cls, provider_type: str, - provider_id: str, - tool_name: str, - tenant_id: str, - invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \ - -> Union[BuiltinTool, ApiTool]: + def get_tool_runtime( + cls, + provider_type: str, + provider_id: str, + tool_name: str, + tenant_id: str, + invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, + tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, + ) -> Union[BuiltinTool, ApiTool]: """ - get the tool runtime + get the tool runtime - :param provider_type: the type of the provider - :param provider_name: the name of the provider - :param tool_name: the name of the tool + :param provider_type: the type of the provider + :param provider_name: the name of the provider + :param tool_name: the name of the tool - :return: the tool + :return: the tool """ - if provider_type == 'builtin': + if provider_type == "builtin": builtin_tool = cls.get_builtin_tool(provider_id, tool_name) # check if the builtin tool need credentials provider_controller = cls.get_builtin_provider(provider_id) if not provider_controller.need_credentials: - return builtin_tool.fork_tool_runtime(runtime={ - 'tenant_id': tenant_id, - 'credentials': {}, - 'invoke_from': invoke_from, - 'tool_invoke_from': tool_invoke_from, - }) + return builtin_tool.fork_tool_runtime( + runtime={ + "tenant_id": tenant_id, + "credentials": {}, + "invoke_from": invoke_from, + "tool_invoke_from": tool_invoke_from, + } + ) # get credentials - builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_id, - ).first() + builtin_provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_id, + ) + .first() + ) if builtin_provider is None: - raise ToolProviderNotFoundError(f'builtin provider {provider_id} not found') + raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") # decrypt the credentials credentials = builtin_provider.credentials @@ -146,17 +148,19 @@ def get_tool_runtime(cls, provider_type: str, decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) - return builtin_tool.fork_tool_runtime(runtime={ - 'tenant_id': tenant_id, - 'credentials': decrypted_credentials, - 'runtime_parameters': {}, - 'invoke_from': invoke_from, - 'tool_invoke_from': tool_invoke_from, - }) + return builtin_tool.fork_tool_runtime( + runtime={ + "tenant_id": tenant_id, + "credentials": decrypted_credentials, + "runtime_parameters": {}, + "invoke_from": invoke_from, + "tool_invoke_from": tool_invoke_from, + } + ) - elif provider_type == 'api': + elif provider_type == "api": if tenant_id is None: - raise ValueError('tenant id is required for api provider') + raise ValueError("tenant id is required for api provider") api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) @@ -164,40 +168,43 @@ def get_tool_runtime(cls, provider_type: str, tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) - return api_provider.get_tool(tool_name).fork_tool_runtime(runtime={ - 'tenant_id': tenant_id, - 'credentials': decrypted_credentials, - 'invoke_from': invoke_from, - 'tool_invoke_from': tool_invoke_from, - }) - elif provider_type == 'workflow': - workflow_provider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == provider_id - ).first() + return api_provider.get_tool(tool_name).fork_tool_runtime( + runtime={ + "tenant_id": tenant_id, + "credentials": decrypted_credentials, + "invoke_from": invoke_from, + "tool_invoke_from": tool_invoke_from, + } + ) + elif provider_type == "workflow": + workflow_provider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) + .first() + ) if workflow_provider is None: - raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found') + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - controller = ToolTransformService.workflow_provider_to_controller( - db_provider=workflow_provider - ) + controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) - return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(runtime={ - 'tenant_id': tenant_id, - 'credentials': {}, - 'invoke_from': invoke_from, - 'tool_invoke_from': tool_invoke_from, - }) - elif provider_type == 'app': - raise NotImplementedError('app provider not implemented') + return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( + runtime={ + "tenant_id": tenant_id, + "credentials": {}, + "invoke_from": invoke_from, + "tool_invoke_from": tool_invoke_from, + } + ) + elif provider_type == "app": + raise NotImplementedError("app provider not implemented") else: - raise ToolProviderNotFoundError(f'provider type {provider_type} not found') + raise ToolProviderNotFoundError(f"provider type {provider_type} not found") @classmethod - def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]: + def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict): """ - init runtime parameter + init runtime parameter """ parameter_value = parameters.get(parameter_rule.name) if not parameter_value and parameter_value != 0: @@ -211,14 +218,17 @@ def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict options = [x.value for x in parameter_rule.options] if parameter_value is not None and parameter_value not in options: raise ValueError( - f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}") + f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}" + ) - return ToolParameterConverter.cast_parameter_by_type(parameter_value, parameter_rule.type) + return parameter_rule.type.cast_value(parameter_value) @classmethod - def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool: + def get_agent_tool_runtime( + cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER + ) -> Tool: """ - get the agent tool runtime + get the agent tool runtime """ tool_entity = cls.get_tool_runtime( provider_type=agent_tool.provider_type, @@ -226,13 +236,21 @@ def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentTo tool_name=agent_tool.tool_name, tenant_id=tenant_id, invoke_from=invoke_from, - tool_invoke_from=ToolInvokeFrom.AGENT + tool_invoke_from=ToolInvokeFrom.AGENT, ) runtime_parameters = {} parameters = tool_entity.get_all_runtime_parameters() for parameter in parameters: # check file types - if parameter.type == ToolParameter.ToolParameterType.FILE: + if ( + parameter.type + in { + ToolParameter.ToolParameterType.SYSTEM_FILES, + ToolParameter.ToolParameterType.FILE, + ToolParameter.ToolParameterType.FILES, + } + and parameter.required + ): raise ValueError(f"file type parameter {parameter.name} not supported in agent") if parameter.form == ToolParameter.ToolParameterForm.FORM: @@ -246,7 +264,7 @@ def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentTo tool_runtime=tool_entity, provider_name=agent_tool.provider_id, provider_type=agent_tool.provider_type, - identity_id=f'AGENT.{app_id}' + identity_id=f"AGENT.{app_id}", ) runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) @@ -254,9 +272,16 @@ def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentTo return tool_entity @classmethod - def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool: + def get_workflow_tool_runtime( + cls, + tenant_id: str, + app_id: str, + node_id: str, + workflow_tool: "ToolEntity", + invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, + ) -> Tool: """ - get the workflow tool runtime + get the workflow tool runtime """ tool_entity = cls.get_tool_runtime( provider_type=workflow_tool.provider_type, @@ -264,7 +289,7 @@ def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, wo tool_name=workflow_tool.tool_name, tenant_id=tenant_id, invoke_from=invoke_from, - tool_invoke_from=ToolInvokeFrom.WORKFLOW + tool_invoke_from=ToolInvokeFrom.WORKFLOW, ) runtime_parameters = {} parameters = tool_entity.get_all_runtime_parameters() @@ -281,7 +306,7 @@ def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, wo tool_runtime=tool_entity, provider_name=workflow_tool.provider_id, provider_type=workflow_tool.provider_type, - identity_id=f'WORKFLOW.{app_id}.{node_id}' + identity_id=f"WORKFLOW.{app_id}.{node_id}", ) if runtime_parameters: @@ -293,24 +318,30 @@ def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, wo @classmethod def get_builtin_provider_icon(cls, provider: str) -> tuple[str, str]: """ - get the absolute path of the icon of the builtin provider + get the absolute path of the icon of the builtin provider - :param provider: the name of the provider + :param provider: the name of the provider - :return: the absolute path of the icon, the mime type of the icon + :return: the absolute path of the icon, the mime type of the icon """ # get provider provider_controller = cls.get_builtin_provider(provider) - absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets', - provider_controller.identity.icon) + absolute_path = path.join( + path.dirname(path.realpath(__file__)), + "provider", + "builtin", + provider, + "_assets", + provider_controller.identity.icon, + ) # check if the icon exists if not path.exists(absolute_path): - raise ToolProviderNotFoundError(f'builtin provider {provider} icon not found') + raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found") # get the mime type mime_type, _ = mimetypes.guess_type(absolute_path) - mime_type = mime_type or 'application/octet-stream' + mime_type = mime_type or "application/octet-stream" return absolute_path, mime_type @@ -331,23 +362,25 @@ def list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None @classmethod def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]: """ - list all the builtin providers + list all the builtin providers """ - for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')): - if provider.startswith('__'): + for provider in listdir(path.join(path.dirname(path.realpath(__file__)), "provider", "builtin")): + if provider.startswith("__"): continue - if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider)): - if provider.startswith('__'): + if path.isdir(path.join(path.dirname(path.realpath(__file__)), "provider", "builtin", provider)): + if provider.startswith("__"): continue # init provider try: provider_class = load_single_subclass_from_source( - module_name=f'core.tools.provider.builtin.{provider}.{provider}', - script_path=path.join(path.dirname(path.realpath(__file__)), - 'provider', 'builtin', provider, f'{provider}.py'), - parent_type=BuiltinToolProviderController) + module_name=f"core.tools.provider.builtin.{provider}.{provider}", + script_path=path.join( + path.dirname(path.realpath(__file__)), "provider", "builtin", provider, f"{provider}.py" + ), + parent_type=BuiltinToolProviderController, + ) provider: BuiltinToolProviderController = provider_class() cls._builtin_providers[provider.identity.name] = provider for tool in provider.get_tools(): @@ -355,7 +388,7 @@ def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, Non yield provider except Exception as e: - logger.error(f'load builtin provider {provider} error: {e}') + logger.exception(f"load builtin provider {provider} error: {e}") continue # set builtin providers loaded cls._builtin_providers_loaded = True @@ -373,11 +406,11 @@ def clear_builtin_providers_cache(cls): @classmethod def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]: """ - get the tool label + get the tool label - :param tool_name: the name of the tool + :param tool_name: the name of the tool - :return: the label of the tool + :return: the label of the tool """ if len(cls._builtin_tools_labels) == 0: # init the builtin providers @@ -389,66 +422,78 @@ def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]: return cls._builtin_tools_labels[tool_name] @classmethod - def user_list_providers(cls, user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral) -> list[UserToolProvider]: + def user_list_providers( + cls, user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral + ) -> list[UserToolProvider]: result_providers: dict[str, UserToolProvider] = {} filters = [] if not typ: - filters.extend(['builtin', 'api', 'workflow']) + filters.extend(["builtin", "api", "workflow"]) else: filters.append(typ) - if 'builtin' in filters: - + if "builtin" in filters: # get builtin providers builtin_providers = cls.list_builtin_providers() # get db builtin providers - db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \ - filter(BuiltinToolProvider.tenant_id == tenant_id).all() + db_builtin_providers: list[BuiltinToolProvider] = ( + db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() + ) find_db_builtin_provider = lambda provider: next( - (x for x in db_builtin_providers if x.provider == provider), - None + (x for x in db_builtin_providers if x.provider == provider), None ) # append builtin providers for provider in builtin_providers: + # handle include, exclude + if is_filtered( + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, + data=provider, + name_func=lambda x: x.identity.name, + ): + continue + user_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider, db_provider=find_db_builtin_provider(provider.identity.name), - decrypt_credentials=False + decrypt_credentials=False, ) result_providers[provider.identity.name] = user_provider # get db api providers - if 'api' in filters: - db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \ - filter(ApiToolProvider.tenant_id == tenant_id).all() + if "api" in filters: + db_api_providers: list[ApiToolProvider] = ( + db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() + ) - api_provider_controllers = [{ - 'provider': provider, - 'controller': ToolTransformService.api_provider_to_controller(provider) - } for provider in db_api_providers] + api_provider_controllers = [ + {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} + for provider in db_api_providers + ] # get labels - labels = ToolLabelManager.get_tools_labels([x['controller'] for x in api_provider_controllers]) + labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers]) for api_provider_controller in api_provider_controllers: user_provider = ToolTransformService.api_provider_to_user_provider( - provider_controller=api_provider_controller['controller'], - db_provider=api_provider_controller['provider'], + provider_controller=api_provider_controller["controller"], + db_provider=api_provider_controller["provider"], decrypt_credentials=False, - labels=labels.get(api_provider_controller['controller'].provider_id, []) + labels=labels.get(api_provider_controller["controller"].provider_id, []), ) - result_providers[f'api_provider.{user_provider.name}'] = user_provider + result_providers[f"api_provider.{user_provider.name}"] = user_provider - if 'workflow' in filters: + if "workflow" in filters: # get workflow providers - workflow_providers: list[WorkflowToolProvider] = db.session.query(WorkflowToolProvider). \ - filter(WorkflowToolProvider.tenant_id == tenant_id).all() + workflow_providers: list[WorkflowToolProvider] = ( + db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() + ) workflow_provider_controllers = [] for provider in workflow_providers: @@ -467,32 +512,36 @@ def user_list_providers(cls, user_id: str, tenant_id: str, typ: UserToolProvider provider_controller=provider_controller, labels=labels.get(provider_controller.provider_id, []), ) - result_providers[f'workflow_provider.{user_provider.name}'] = user_provider + result_providers[f"workflow_provider.{user_provider.name}"] = user_provider return BuiltinToolProviderSort.sort(list(result_providers.values())) @classmethod - def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[ - ApiToolProviderController, dict[str, Any]]: + def get_api_provider_controller( + cls, tenant_id: str, provider_id: str + ) -> tuple[ApiToolProviderController, dict[str, Any]]: """ - get the api provider + get the api provider - :param provider_name: the name of the provider + :param provider_name: the name of the provider - :return: the provider controller, the credentials + :return: the provider controller, the credentials """ - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.id == provider_id, - ApiToolProvider.tenant_id == tenant_id, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.id == provider_id, + ApiToolProvider.tenant_id == tenant_id, + ) + .first() + ) if provider is None: - raise ToolProviderNotFoundError(f'api provider {provider_id} not found') + raise ToolProviderNotFoundError(f"api provider {provider_id} not found") controller = ApiToolProviderController.from_db( provider, - ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else - ApiProviderAuthType.NONE + ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, ) controller.load_bundled_tools(provider.tools) @@ -501,18 +550,22 @@ def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[ @classmethod def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: """ - get api provider + get api provider """ """ get tool provider """ - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider, + ) + .first() + ) if provider is None: - raise ValueError(f'you have not added provider {provider}') + raise ValueError(f"you have not added provider {provider}") try: credentials = json.loads(provider.credentials_str) or {} @@ -521,7 +574,7 @@ def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: # package tool provider controller controller = ApiToolProviderController.from_db( - provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE + provider, ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE ) # init tool configuration tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) @@ -532,65 +585,67 @@ def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: try: icon = json.loads(provider.icon) except: - icon = { - "background": "#252525", - "content": "\ud83d\ude01" - } + icon = {"background": "#252525", "content": "\ud83d\ude01"} # add tool labels labels = ToolLabelManager.get_tool_labels(controller) - return jsonable_encoder({ - 'schema_type': provider.schema_type, - 'schema': provider.schema, - 'tools': provider.tools, - 'icon': icon, - 'description': provider.description, - 'credentials': masked_credentials, - 'privacy_policy': provider.privacy_policy, - 'custom_disclaimer': provider.custom_disclaimer, - 'labels': labels, - }) + return jsonable_encoder( + { + "schema_type": provider.schema_type, + "schema": provider.schema, + "tools": provider.tools, + "icon": icon, + "description": provider.description, + "credentials": masked_credentials, + "privacy_policy": provider.privacy_policy, + "custom_disclaimer": provider.custom_disclaimer, + "labels": labels, + } + ) @classmethod def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]: """ - get the tool icon + get the tool icon - :param tenant_id: the id of the tenant - :param provider_type: the type of the provider - :param provider_id: the id of the provider - :return: + :param tenant_id: the id of the tenant + :param provider_type: the type of the provider + :param provider_id: the id of the provider + :return: """ provider_type = provider_type provider_id = provider_id - if provider_type == 'builtin': - return (dify_config.CONSOLE_API_URL - + "/console/api/workspaces/current/tool-provider/builtin/" - + provider_id - + "/icon") - elif provider_type == 'api': + if provider_type == "builtin": + return ( + dify_config.CONSOLE_API_URL + + "/console/api/workspaces/current/tool-provider/builtin/" + + provider_id + + "/icon" + ) + elif provider_type == "api": try: - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.id == provider_id - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) + .first() + ) return json.loads(provider.icon) except: - return { - "background": "#252525", - "content": "\ud83d\ude01" - } - elif provider_type == 'workflow': - provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == provider_id - ).first() + return {"background": "#252525", "content": "\ud83d\ude01"} + elif provider_type == "workflow": + provider: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) + .first() + ) if provider is None: - raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found') + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") return json.loads(provider.icon) else: raise ValueError(f"provider type {provider_type} not found") -ToolManager.load_builtin_providers_cache() + +# preload builtin tool providers +Thread(target=ToolManager.load_builtin_providers_cache, name="pre_load_builtin_providers_cache", daemon=True).start() diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index b213879e960b14..83600d21c13dc2 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -56,12 +56,13 @@ def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]: if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT: if field_name in credentials: if len(credentials[field_name]) > 6: - credentials[field_name] = \ - credentials[field_name][:2] + \ - '*' * (len(credentials[field_name]) - 4) + \ - credentials[field_name][-2:] + credentials[field_name] = ( + credentials[field_name][:2] + + "*" * (len(credentials[field_name]) - 4) + + credentials[field_name][-2:] + ) else: - credentials[field_name] = '*' * len(credentials[field_name]) + credentials[field_name] = "*" * len(credentials[field_name]) return credentials @@ -72,9 +73,9 @@ def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str return a deep copy of credentials with decrypted values """ cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}', - cache_type=ToolProviderCredentialsCacheType.PROVIDER + tenant_id=self.tenant_id, + identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}", + cache_type=ToolProviderCredentialsCacheType.PROVIDER, ) cached_credentials = cache.get() if cached_credentials: @@ -95,16 +96,18 @@ def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str def delete_tool_credentials_cache(self): cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}', - cache_type=ToolProviderCredentialsCacheType.PROVIDER + tenant_id=self.tenant_id, + identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}", + cache_type=ToolProviderCredentialsCacheType.PROVIDER, ) cache.delete() + class ToolParameterConfigurationManager(BaseModel): """ Tool parameter configuration manager """ + tenant_id: str tool_runtime: Tool provider_name: str @@ -152,15 +155,19 @@ def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: current_parameters = self._merge_parameters() for parameter in current_parameters: - if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): if parameter.name in parameters: if len(parameters[parameter.name]) > 6: - parameters[parameter.name] = \ - parameters[parameter.name][:2] + \ - '*' * (len(parameters[parameter.name]) - 4) + \ - parameters[parameter.name][-2:] + parameters[parameter.name] = ( + parameters[parameter.name][:2] + + "*" * (len(parameters[parameter.name]) - 4) + + parameters[parameter.name][-2:] + ) else: - parameters[parameter.name] = '*' * len(parameters[parameter.name]) + parameters[parameter.name] = "*" * len(parameters[parameter.name]) return parameters @@ -176,7 +183,10 @@ def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: parameters = self._deep_copy(parameters) for parameter in current_parameters: - if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): if parameter.name in parameters: encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name]) parameters[parameter.name] = encrypted @@ -191,10 +201,10 @@ def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: """ cache = ToolParameterCache( tenant_id=self.tenant_id, - provider=f'{self.provider_type}.{self.provider_name}', + provider=f"{self.provider_type}.{self.provider_name}", tool_name=self.tool_runtime.identity.name, cache_type=ToolParameterCacheType.PARAMETER, - identity_id=self.identity_id + identity_id=self.identity_id, ) cached_parameters = cache.get() if cached_parameters: @@ -205,7 +215,10 @@ def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: has_secret_input = False for parameter in current_parameters: - if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: + if ( + parameter.form == ToolParameter.ToolParameterForm.FORM + and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT + ): if parameter.name in parameters: try: has_secret_input = True @@ -221,9 +234,9 @@ def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: def delete_tool_parameters_cache(self): cache = ToolParameterCache( tenant_id=self.tenant_id, - provider=f'{self.provider_type}.{self.provider_name}', + provider=f"{self.provider_type}.{self.provider_name}", tool_name=self.tool_runtime.identity.name, cache_type=ToolParameterCacheType.PARAMETER, - identity_id=self.identity_id + identity_id=self.identity_id, ) cache.delete() diff --git a/api/core/tools/utils/feishu_api_utils.py b/api/core/tools/utils/feishu_api_utils.py new file mode 100644 index 00000000000000..ea28037df03720 --- /dev/null +++ b/api/core/tools/utils/feishu_api_utils.py @@ -0,0 +1,888 @@ +import json +from typing import Optional + +import httpx + +from core.tools.errors import ToolProviderCredentialValidationError +from extensions.ext_redis import redis_client + + +def auth(credentials): + app_id = credentials.get("app_id") + app_secret = credentials.get("app_secret") + if not app_id or not app_secret: + raise ToolProviderCredentialValidationError("app_id and app_secret is required") + try: + assert FeishuRequest(app_id, app_secret).tenant_access_token is not None + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) + + +def convert_add_records(json_str): + try: + data = json.loads(json_str) + if not isinstance(data, list): + raise ValueError("Parsed data must be a list") + converted_data = [{"fields": json.dumps(item, ensure_ascii=False)} for item in data] + return converted_data + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + except Exception as e: + raise ValueError(f"An error occurred while processing the data: {e}") + + +def convert_update_records(json_str): + try: + data = json.loads(json_str) + if not isinstance(data, list): + raise ValueError("Parsed data must be a list") + + converted_data = [ + {"fields": json.dumps(record["fields"], ensure_ascii=False), "record_id": record["record_id"]} + for record in data + if "fields" in record and "record_id" in record + ] + + if len(converted_data) != len(data): + raise ValueError("Each record must contain 'fields' and 'record_id'") + + return converted_data + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + except Exception as e: + raise ValueError(f"An error occurred while processing the data: {e}") + + +class FeishuRequest: + API_BASE_URL = "https://lark-plugin-api.solutionsuite.cn/lark-plugin" + + def __init__(self, app_id: str, app_secret: str): + self.app_id = app_id + self.app_secret = app_secret + + @property + def tenant_access_token(self): + feishu_tenant_access_token = f"tools:{self.app_id}:feishu_tenant_access_token" + if redis_client.exists(feishu_tenant_access_token): + return redis_client.get(feishu_tenant_access_token).decode() + res = self.get_tenant_access_token(self.app_id, self.app_secret) + redis_client.setex(feishu_tenant_access_token, res.get("expire"), res.get("tenant_access_token")) + return res.get("tenant_access_token") + + def _send_request( + self, + url: str, + method: str = "post", + require_token: bool = True, + payload: Optional[dict] = None, + params: Optional[dict] = None, + ): + headers = { + "Content-Type": "application/json", + "user-agent": "Dify", + } + if require_token: + headers["tenant-access-token"] = f"{self.tenant_access_token}" + res = httpx.request(method=method, url=url, headers=headers, json=payload, params=params, timeout=30).json() + if res.get("code") != 0: + raise Exception(res) + return res + + def get_tenant_access_token(self, app_id: str, app_secret: str) -> dict: + """ + API url: https://open.feishu.cn/document/server-docs/authentication-management/access-token/tenant_access_token_internal + Example Response: + { + "code": 0, + "msg": "ok", + "tenant_access_token": "t-caecc734c2e3328a62489fe0648c4b98779515d3", + "expire": 7200 + } + """ + url = f"{self.API_BASE_URL}/access_token/get_tenant_access_token" + payload = {"app_id": app_id, "app_secret": app_secret} + res = self._send_request(url, require_token=False, payload=payload) + return res + + def create_document(self, title: str, content: str, folder_token: str) -> dict: + """ + API url: https://open.larkoffice.com/document/server-docs/docs/docs/docx-v1/document/create + Example Response: + { + "data": { + "title": "title", + "url": "https://svi136aogf123.feishu.cn/docx/VWbvd4fEdoW0WSxaY1McQTz8n7d", + "type": "docx", + "token": "VWbvd4fEdoW0WSxaY1McQTz8n7d" + }, + "log_id": "021721281231575fdbddc0200ff00060a9258ec0000103df61b5d", + "code": 0, + "msg": "创建飞书文档成功,请查看" + } + """ + url = f"{self.API_BASE_URL}/document/create_document" + payload = { + "title": title, + "content": content, + "folder_token": folder_token, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def write_document(self, document_id: str, content: str, position: str = "end") -> dict: + url = f"{self.API_BASE_URL}/document/write_document" + payload = {"document_id": document_id, "content": content, "position": position} + res = self._send_request(url, payload=payload) + return res + + def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str: + """ + API url: https://open.larkoffice.com/document/server-docs/docs/docs/docx-v1/document/raw_content + Example Response: + { + "code": 0, + "msg": "success", + "data": { + "content": "云文档\n多人实时协同,插入一切元素。不仅是在线文档,更是强大的创作和互动工具\n云文档:专为协作而生\n" + } + } + """ # noqa: E501 + params = { + "document_id": document_id, + "mode": mode, + "lang": lang, + } + url = f"{self.API_BASE_URL}/document/get_document_content" + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data").get("content") + return "" + + def list_document_blocks( + self, document_id: str, page_token: str, user_id_type: str = "open_id", page_size: int = 500 + ) -> dict: + """ + API url: https://open.larkoffice.com/document/server-docs/docs/docs/docx-v1/document/list + """ + params = { + "user_id_type": user_id_type, + "document_id": document_id, + "page_size": page_size, + "page_token": page_token, + } + url = f"{self.API_BASE_URL}/document/list_document_blocks" + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict: + """ + API url: https://open.larkoffice.com/document/server-docs/im-v1/message/create + """ + url = f"{self.API_BASE_URL}/message/send_bot_message" + params = { + "receive_id_type": receive_id_type, + } + payload = { + "receive_id": receive_id, + "msg_type": msg_type, + "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), + } + res = self._send_request(url, params=params, payload=payload) + if "data" in res: + return res.get("data") + return res + + def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict: + url = f"{self.API_BASE_URL}/message/send_webhook_message" + payload = { + "webhook": webhook, + "msg_type": msg_type, + "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), + } + res = self._send_request(url, require_token=False, payload=payload) + return res + + def get_chat_messages( + self, + container_id: str, + start_time: str, + end_time: str, + page_token: str, + sort_type: str = "ByCreateTimeAsc", + page_size: int = 20, + ) -> dict: + """ + API url: https://open.larkoffice.com/document/server-docs/im-v1/message/list + """ + url = f"{self.API_BASE_URL}/message/get_chat_messages" + params = { + "container_id": container_id, + "start_time": start_time, + "end_time": end_time, + "sort_type": sort_type, + "page_token": page_token, + "page_size": page_size, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def get_thread_messages( + self, container_id: str, page_token: str, sort_type: str = "ByCreateTimeAsc", page_size: int = 20 + ) -> dict: + """ + API url: https://open.larkoffice.com/document/server-docs/im-v1/message/list + """ + url = f"{self.API_BASE_URL}/message/get_thread_messages" + params = { + "container_id": container_id, + "sort_type": sort_type, + "page_token": page_token, + "page_size": page_size, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict: + # 创建任务 + url = f"{self.API_BASE_URL}/task/create_task" + payload = { + "summary": summary, + "start_time": start_time, + "end_time": end_time, + "completed_at": completed_time, + "description": description, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def update_task( + self, task_guid: str, summary: str, start_time: str, end_time: str, completed_time: str, description: str + ) -> dict: + # 更新任务 + url = f"{self.API_BASE_URL}/task/update_task" + payload = { + "task_guid": task_guid, + "summary": summary, + "start_time": start_time, + "end_time": end_time, + "completed_time": completed_time, + "description": description, + } + res = self._send_request(url, method="PATCH", payload=payload) + if "data" in res: + return res.get("data") + return res + + def delete_task(self, task_guid: str) -> dict: + # 删除任务 + url = f"{self.API_BASE_URL}/task/delete_task" + payload = { + "task_guid": task_guid, + } + res = self._send_request(url, method="DELETE", payload=payload) + return res + + def add_members(self, task_guid: str, member_phone_or_email: str, member_role: str) -> dict: + # 删除任务 + url = f"{self.API_BASE_URL}/task/add_members" + payload = { + "task_guid": task_guid, + "member_phone_or_email": member_phone_or_email, + "member_role": member_role, + } + res = self._send_request(url, payload=payload) + return res + + def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, page_size: int = 20) -> dict: + # 获取知识库全部子节点列表 + url = f"{self.API_BASE_URL}/wiki/get_wiki_nodes" + payload = { + "space_id": space_id, + "parent_node_token": parent_node_token, + "page_token": page_token, + "page_size": page_size, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def get_primary_calendar(self, user_id_type: str = "open_id") -> dict: + url = f"{self.API_BASE_URL}/calendar/get_primary_calendar" + params = { + "user_id_type": user_id_type, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def create_event( + self, + summary: str, + description: str, + start_time: str, + end_time: str, + attendee_ability: str, + need_notification: bool = True, + auto_record: bool = False, + ) -> dict: + url = f"{self.API_BASE_URL}/calendar/create_event" + payload = { + "summary": summary, + "description": description, + "need_notification": need_notification, + "start_time": start_time, + "end_time": end_time, + "auto_record": auto_record, + "attendee_ability": attendee_ability, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def update_event( + self, + event_id: str, + summary: str, + description: str, + need_notification: bool, + start_time: str, + end_time: str, + auto_record: bool, + ) -> dict: + url = f"{self.API_BASE_URL}/calendar/update_event/{event_id}" + payload = {} + if summary: + payload["summary"] = summary + if description: + payload["description"] = description + if start_time: + payload["start_time"] = start_time + if end_time: + payload["end_time"] = end_time + if need_notification: + payload["need_notification"] = need_notification + if auto_record: + payload["auto_record"] = auto_record + res = self._send_request(url, method="PATCH", payload=payload) + return res + + def delete_event(self, event_id: str, need_notification: bool = True) -> dict: + url = f"{self.API_BASE_URL}/calendar/delete_event/{event_id}" + params = { + "need_notification": need_notification, + } + res = self._send_request(url, method="DELETE", params=params) + return res + + def list_events(self, start_time: str, end_time: str, page_token: str, page_size: int = 50) -> dict: + url = f"{self.API_BASE_URL}/calendar/list_events" + params = { + "start_time": start_time, + "end_time": end_time, + "page_token": page_token, + "page_size": page_size, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def search_events( + self, + query: str, + start_time: str, + end_time: str, + page_token: str, + user_id_type: str = "open_id", + page_size: int = 20, + ) -> dict: + url = f"{self.API_BASE_URL}/calendar/search_events" + payload = { + "query": query, + "start_time": start_time, + "end_time": end_time, + "page_token": page_token, + "user_id_type": user_id_type, + "page_size": page_size, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict: + # 参加日程参会人 + url = f"{self.API_BASE_URL}/calendar/add_event_attendees" + payload = { + "event_id": event_id, + "attendee_phone_or_email": attendee_phone_or_email, + "need_notification": need_notification, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def create_spreadsheet( + self, + title: str, + folder_token: str, + ) -> dict: + # 创建电子表格 + url = f"{self.API_BASE_URL}/spreadsheet/create_spreadsheet" + payload = { + "title": title, + "folder_token": folder_token, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def get_spreadsheet( + self, + spreadsheet_token: str, + user_id_type: str = "open_id", + ) -> dict: + # 获取电子表格信息 + url = f"{self.API_BASE_URL}/spreadsheet/get_spreadsheet" + params = { + "spreadsheet_token": spreadsheet_token, + "user_id_type": user_id_type, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def list_spreadsheet_sheets( + self, + spreadsheet_token: str, + ) -> dict: + # 列出电子表格的所有工作表 + url = f"{self.API_BASE_URL}/spreadsheet/list_spreadsheet_sheets" + params = { + "spreadsheet_token": spreadsheet_token, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def add_rows( + self, + spreadsheet_token: str, + sheet_id: str, + sheet_name: str, + length: int, + values: str, + ) -> dict: + # 增加行,在工作表最后添加 + url = f"{self.API_BASE_URL}/spreadsheet/add_rows" + payload = { + "spreadsheet_token": spreadsheet_token, + "sheet_id": sheet_id, + "sheet_name": sheet_name, + "length": length, + "values": values, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def add_cols( + self, + spreadsheet_token: str, + sheet_id: str, + sheet_name: str, + length: int, + values: str, + ) -> dict: + # 增加列,在工作表最后添加 + url = f"{self.API_BASE_URL}/spreadsheet/add_cols" + payload = { + "spreadsheet_token": spreadsheet_token, + "sheet_id": sheet_id, + "sheet_name": sheet_name, + "length": length, + "values": values, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def read_rows( + self, + spreadsheet_token: str, + sheet_id: str, + sheet_name: str, + start_row: int, + num_rows: int, + user_id_type: str = "open_id", + ) -> dict: + # 读取工作表行数据 + url = f"{self.API_BASE_URL}/spreadsheet/read_rows" + params = { + "spreadsheet_token": spreadsheet_token, + "sheet_id": sheet_id, + "sheet_name": sheet_name, + "start_row": start_row, + "num_rows": num_rows, + "user_id_type": user_id_type, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def read_cols( + self, + spreadsheet_token: str, + sheet_id: str, + sheet_name: str, + start_col: int, + num_cols: int, + user_id_type: str = "open_id", + ) -> dict: + # 读取工作表列数据 + url = f"{self.API_BASE_URL}/spreadsheet/read_cols" + params = { + "spreadsheet_token": spreadsheet_token, + "sheet_id": sheet_id, + "sheet_name": sheet_name, + "start_col": start_col, + "num_cols": num_cols, + "user_id_type": user_id_type, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def read_table( + self, + spreadsheet_token: str, + sheet_id: str, + sheet_name: str, + num_range: str, + query: str, + user_id_type: str = "open_id", + ) -> dict: + # 自定义读取行列数据 + url = f"{self.API_BASE_URL}/spreadsheet/read_table" + params = { + "spreadsheet_token": spreadsheet_token, + "sheet_id": sheet_id, + "sheet_name": sheet_name, + "range": num_range, + "query": query, + "user_id_type": user_id_type, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def create_base( + self, + name: str, + folder_token: str, + ) -> dict: + # 创建多维表格 + url = f"{self.API_BASE_URL}/base/create_base" + payload = { + "name": name, + "folder_token": folder_token, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def add_records( + self, + app_token: str, + table_id: str, + table_name: str, + records: str, + user_id_type: str = "open_id", + ) -> dict: + # 新增多条记录 + url = f"{self.API_BASE_URL}/base/add_records" + params = { + "app_token": app_token, + "table_id": table_id, + "table_name": table_name, + "user_id_type": user_id_type, + } + payload = { + "records": convert_add_records(records), + } + res = self._send_request(url, params=params, payload=payload) + if "data" in res: + return res.get("data") + return res + + def update_records( + self, + app_token: str, + table_id: str, + table_name: str, + records: str, + user_id_type: str, + ) -> dict: + # 更新多条记录 + url = f"{self.API_BASE_URL}/base/update_records" + params = { + "app_token": app_token, + "table_id": table_id, + "table_name": table_name, + "user_id_type": user_id_type, + } + payload = { + "records": convert_update_records(records), + } + res = self._send_request(url, params=params, payload=payload) + if "data" in res: + return res.get("data") + return res + + def delete_records( + self, + app_token: str, + table_id: str, + table_name: str, + record_ids: str, + ) -> dict: + # 删除多条记录 + url = f"{self.API_BASE_URL}/base/delete_records" + params = { + "app_token": app_token, + "table_id": table_id, + "table_name": table_name, + } + if not record_ids: + record_id_list = [] + else: + try: + record_id_list = json.loads(record_ids) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + payload = { + "records": record_id_list, + } + res = self._send_request(url, params=params, payload=payload) + if "data" in res: + return res.get("data") + return res + + def search_record( + self, + app_token: str, + table_id: str, + table_name: str, + view_id: str, + field_names: str, + sort: str, + filters: str, + page_token: str, + automatic_fields: bool = False, + user_id_type: str = "open_id", + page_size: int = 20, + ) -> dict: + # 查询记录,单次最多查询 500 行记录。 + url = f"{self.API_BASE_URL}/base/search_record" + params = { + "app_token": app_token, + "table_id": table_id, + "table_name": table_name, + "user_id_type": user_id_type, + "page_token": page_token, + "page_size": page_size, + } + + if not field_names: + field_name_list = [] + else: + try: + field_name_list = json.loads(field_names) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + + if not sort: + sort_list = [] + else: + try: + sort_list = json.loads(sort) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + + if not filters: + filter_dict = {} + else: + try: + filter_dict = json.loads(filters) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + + payload = {} + + if view_id: + payload["view_id"] = view_id + if field_names: + payload["field_names"] = field_name_list + if sort: + payload["sort"] = sort_list + if filters: + payload["filter"] = filter_dict + if automatic_fields: + payload["automatic_fields"] = automatic_fields + res = self._send_request(url, params=params, payload=payload) + + if "data" in res: + return res.get("data") + return res + + def get_base_info( + self, + app_token: str, + ) -> dict: + # 获取多维表格元数据 + url = f"{self.API_BASE_URL}/base/get_base_info" + params = { + "app_token": app_token, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def create_table( + self, + app_token: str, + table_name: str, + default_view_name: str, + fields: str, + ) -> dict: + # 新增一个数据表 + url = f"{self.API_BASE_URL}/base/create_table" + params = { + "app_token": app_token, + } + if not fields: + fields_list = [] + else: + try: + fields_list = json.loads(fields) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + payload = { + "name": table_name, + "fields": fields_list, + } + if default_view_name: + payload["default_view_name"] = default_view_name + res = self._send_request(url, params=params, payload=payload) + if "data" in res: + return res.get("data") + return res + + def delete_tables( + self, + app_token: str, + table_ids: str, + table_names: str, + ) -> dict: + # 删除多个数据表 + url = f"{self.API_BASE_URL}/base/delete_tables" + params = { + "app_token": app_token, + } + if not table_ids: + table_id_list = [] + else: + try: + table_id_list = json.loads(table_ids) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + + if not table_names: + table_name_list = [] + else: + try: + table_name_list = json.loads(table_names) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + + payload = { + "table_ids": table_id_list, + "table_names": table_name_list, + } + + res = self._send_request(url, params=params, payload=payload) + if "data" in res: + return res.get("data") + return res + + def list_tables( + self, + app_token: str, + page_token: str, + page_size: int = 20, + ) -> dict: + # 列出多维表格下的全部数据表 + url = f"{self.API_BASE_URL}/base/list_tables" + params = { + "app_token": app_token, + "page_token": page_token, + "page_size": page_size, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def read_records( + self, + app_token: str, + table_id: str, + table_name: str, + record_ids: str, + user_id_type: str = "open_id", + ) -> dict: + url = f"{self.API_BASE_URL}/base/read_records" + params = { + "app_token": app_token, + "table_id": table_id, + "table_name": table_name, + } + if not record_ids: + record_id_list = [] + else: + try: + record_id_list = json.loads(record_ids) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + payload = { + "record_ids": record_id_list, + "user_id_type": user_id_type, + } + res = self._send_request(url, method="GET", params=params, payload=payload) + if "data" in res: + return res.get("data") + return res diff --git a/api/core/tools/utils/lark_api_utils.py b/api/core/tools/utils/lark_api_utils.py new file mode 100644 index 00000000000000..30cb0cb141d9a6 --- /dev/null +++ b/api/core/tools/utils/lark_api_utils.py @@ -0,0 +1,820 @@ +import json +from typing import Optional + +import httpx + +from core.tools.errors import ToolProviderCredentialValidationError +from extensions.ext_redis import redis_client + + +def lark_auth(credentials): + app_id = credentials.get("app_id") + app_secret = credentials.get("app_secret") + if not app_id or not app_secret: + raise ToolProviderCredentialValidationError("app_id and app_secret is required") + try: + assert LarkRequest(app_id, app_secret).tenant_access_token is not None + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) + + +class LarkRequest: + API_BASE_URL = "https://lark-plugin-api.solutionsuite.ai/lark-plugin" + + def __init__(self, app_id: str, app_secret: str): + self.app_id = app_id + self.app_secret = app_secret + + def convert_add_records(self, json_str): + try: + data = json.loads(json_str) + if not isinstance(data, list): + raise ValueError("Parsed data must be a list") + converted_data = [{"fields": json.dumps(item, ensure_ascii=False)} for item in data] + return converted_data + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + except Exception as e: + raise ValueError(f"An error occurred while processing the data: {e}") + + def convert_update_records(self, json_str): + try: + data = json.loads(json_str) + if not isinstance(data, list): + raise ValueError("Parsed data must be a list") + + converted_data = [ + {"fields": json.dumps(record["fields"], ensure_ascii=False), "record_id": record["record_id"]} + for record in data + if "fields" in record and "record_id" in record + ] + + if len(converted_data) != len(data): + raise ValueError("Each record must contain 'fields' and 'record_id'") + + return converted_data + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + except Exception as e: + raise ValueError(f"An error occurred while processing the data: {e}") + + @property + def tenant_access_token(self) -> str: + feishu_tenant_access_token = f"tools:{self.app_id}:feishu_tenant_access_token" + if redis_client.exists(feishu_tenant_access_token): + return redis_client.get(feishu_tenant_access_token).decode() + res = self.get_tenant_access_token(self.app_id, self.app_secret) + redis_client.setex(feishu_tenant_access_token, res.get("expire"), res.get("tenant_access_token")) + if "tenant_access_token" in res: + return res.get("tenant_access_token") + return "" + + def _send_request( + self, + url: str, + method: str = "post", + require_token: bool = True, + payload: Optional[dict] = None, + params: Optional[dict] = None, + ): + headers = { + "Content-Type": "application/json", + "user-agent": "Dify", + } + if require_token: + headers["tenant-access-token"] = f"{self.tenant_access_token}" + res = httpx.request(method=method, url=url, headers=headers, json=payload, params=params, timeout=30).json() + if res.get("code") != 0: + raise Exception(res) + return res + + def get_tenant_access_token(self, app_id: str, app_secret: str) -> dict: + url = f"{self.API_BASE_URL}/access_token/get_tenant_access_token" + payload = {"app_id": app_id, "app_secret": app_secret} + res = self._send_request(url, require_token=False, payload=payload) + return res + + def create_document(self, title: str, content: str, folder_token: str) -> dict: + url = f"{self.API_BASE_URL}/document/create_document" + payload = { + "title": title, + "content": content, + "folder_token": folder_token, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def write_document(self, document_id: str, content: str, position: str = "end") -> dict: + url = f"{self.API_BASE_URL}/document/write_document" + payload = {"document_id": document_id, "content": content, "position": position} + res = self._send_request(url, payload=payload) + return res + + def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str | dict: + params = { + "document_id": document_id, + "mode": mode, + "lang": lang, + } + url = f"{self.API_BASE_URL}/document/get_document_content" + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data").get("content") + return "" + + def list_document_blocks( + self, document_id: str, page_token: str, user_id_type: str = "open_id", page_size: int = 500 + ) -> dict: + params = { + "user_id_type": user_id_type, + "document_id": document_id, + "page_size": page_size, + "page_token": page_token, + } + url = f"{self.API_BASE_URL}/document/list_document_blocks" + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict: + url = f"{self.API_BASE_URL}/message/send_bot_message" + params = { + "receive_id_type": receive_id_type, + } + payload = { + "receive_id": receive_id, + "msg_type": msg_type, + "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), + } + res = self._send_request(url, params=params, payload=payload) + if "data" in res: + return res.get("data") + return res + + def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict: + url = f"{self.API_BASE_URL}/message/send_webhook_message" + payload = { + "webhook": webhook, + "msg_type": msg_type, + "content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"), + } + res = self._send_request(url, require_token=False, payload=payload) + return res + + def get_chat_messages( + self, + container_id: str, + start_time: str, + end_time: str, + page_token: str, + sort_type: str = "ByCreateTimeAsc", + page_size: int = 20, + ) -> dict: + url = f"{self.API_BASE_URL}/message/get_chat_messages" + params = { + "container_id": container_id, + "start_time": start_time, + "end_time": end_time, + "sort_type": sort_type, + "page_token": page_token, + "page_size": page_size, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def get_thread_messages( + self, container_id: str, page_token: str, sort_type: str = "ByCreateTimeAsc", page_size: int = 20 + ) -> dict: + url = f"{self.API_BASE_URL}/message/get_thread_messages" + params = { + "container_id": container_id, + "sort_type": sort_type, + "page_token": page_token, + "page_size": page_size, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict: + url = f"{self.API_BASE_URL}/task/create_task" + payload = { + "summary": summary, + "start_time": start_time, + "end_time": end_time, + "completed_at": completed_time, + "description": description, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def update_task( + self, task_guid: str, summary: str, start_time: str, end_time: str, completed_time: str, description: str + ) -> dict: + url = f"{self.API_BASE_URL}/task/update_task" + payload = { + "task_guid": task_guid, + "summary": summary, + "start_time": start_time, + "end_time": end_time, + "completed_time": completed_time, + "description": description, + } + res = self._send_request(url, method="PATCH", payload=payload) + if "data" in res: + return res.get("data") + return res + + def delete_task(self, task_guid: str) -> dict: + url = f"{self.API_BASE_URL}/task/delete_task" + payload = { + "task_guid": task_guid, + } + res = self._send_request(url, method="DELETE", payload=payload) + if "data" in res: + return res.get("data") + return res + + def add_members(self, task_guid: str, member_phone_or_email: str, member_role: str) -> dict: + url = f"{self.API_BASE_URL}/task/add_members" + payload = { + "task_guid": task_guid, + "member_phone_or_email": member_phone_or_email, + "member_role": member_role, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, page_size: int = 20) -> dict: + url = f"{self.API_BASE_URL}/wiki/get_wiki_nodes" + payload = { + "space_id": space_id, + "parent_node_token": parent_node_token, + "page_token": page_token, + "page_size": page_size, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def get_primary_calendar(self, user_id_type: str = "open_id") -> dict: + url = f"{self.API_BASE_URL}/calendar/get_primary_calendar" + params = { + "user_id_type": user_id_type, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def create_event( + self, + summary: str, + description: str, + start_time: str, + end_time: str, + attendee_ability: str, + need_notification: bool = True, + auto_record: bool = False, + ) -> dict: + url = f"{self.API_BASE_URL}/calendar/create_event" + payload = { + "summary": summary, + "description": description, + "need_notification": need_notification, + "start_time": start_time, + "end_time": end_time, + "auto_record": auto_record, + "attendee_ability": attendee_ability, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def update_event( + self, + event_id: str, + summary: str, + description: str, + need_notification: bool, + start_time: str, + end_time: str, + auto_record: bool, + ) -> dict: + url = f"{self.API_BASE_URL}/calendar/update_event/{event_id}" + payload = {} + if summary: + payload["summary"] = summary + if description: + payload["description"] = description + if start_time: + payload["start_time"] = start_time + if end_time: + payload["end_time"] = end_time + if need_notification: + payload["need_notification"] = need_notification + if auto_record: + payload["auto_record"] = auto_record + res = self._send_request(url, method="PATCH", payload=payload) + return res + + def delete_event(self, event_id: str, need_notification: bool = True) -> dict: + url = f"{self.API_BASE_URL}/calendar/delete_event/{event_id}" + params = { + "need_notification": need_notification, + } + res = self._send_request(url, method="DELETE", params=params) + return res + + def list_events(self, start_time: str, end_time: str, page_token: str, page_size: int = 50) -> dict: + url = f"{self.API_BASE_URL}/calendar/list_events" + params = { + "start_time": start_time, + "end_time": end_time, + "page_token": page_token, + "page_size": page_size, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def search_events( + self, + query: str, + start_time: str, + end_time: str, + page_token: str, + user_id_type: str = "open_id", + page_size: int = 20, + ) -> dict: + url = f"{self.API_BASE_URL}/calendar/search_events" + payload = { + "query": query, + "start_time": start_time, + "end_time": end_time, + "page_token": page_token, + "user_id_type": user_id_type, + "page_size": page_size, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict: + url = f"{self.API_BASE_URL}/calendar/add_event_attendees" + payload = { + "event_id": event_id, + "attendee_phone_or_email": attendee_phone_or_email, + "need_notification": need_notification, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def create_spreadsheet( + self, + title: str, + folder_token: str, + ) -> dict: + url = f"{self.API_BASE_URL}/spreadsheet/create_spreadsheet" + payload = { + "title": title, + "folder_token": folder_token, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def get_spreadsheet( + self, + spreadsheet_token: str, + user_id_type: str = "open_id", + ) -> dict: + url = f"{self.API_BASE_URL}/spreadsheet/get_spreadsheet" + params = { + "spreadsheet_token": spreadsheet_token, + "user_id_type": user_id_type, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def list_spreadsheet_sheets( + self, + spreadsheet_token: str, + ) -> dict: + url = f"{self.API_BASE_URL}/spreadsheet/list_spreadsheet_sheets" + params = { + "spreadsheet_token": spreadsheet_token, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def add_rows( + self, + spreadsheet_token: str, + sheet_id: str, + sheet_name: str, + length: int, + values: str, + ) -> dict: + url = f"{self.API_BASE_URL}/spreadsheet/add_rows" + payload = { + "spreadsheet_token": spreadsheet_token, + "sheet_id": sheet_id, + "sheet_name": sheet_name, + "length": length, + "values": values, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def add_cols( + self, + spreadsheet_token: str, + sheet_id: str, + sheet_name: str, + length: int, + values: str, + ) -> dict: + url = f"{self.API_BASE_URL}/spreadsheet/add_cols" + payload = { + "spreadsheet_token": spreadsheet_token, + "sheet_id": sheet_id, + "sheet_name": sheet_name, + "length": length, + "values": values, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def read_rows( + self, + spreadsheet_token: str, + sheet_id: str, + sheet_name: str, + start_row: int, + num_rows: int, + user_id_type: str = "open_id", + ) -> dict: + url = f"{self.API_BASE_URL}/spreadsheet/read_rows" + params = { + "spreadsheet_token": spreadsheet_token, + "sheet_id": sheet_id, + "sheet_name": sheet_name, + "start_row": start_row, + "num_rows": num_rows, + "user_id_type": user_id_type, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def read_cols( + self, + spreadsheet_token: str, + sheet_id: str, + sheet_name: str, + start_col: int, + num_cols: int, + user_id_type: str = "open_id", + ) -> dict: + url = f"{self.API_BASE_URL}/spreadsheet/read_cols" + params = { + "spreadsheet_token": spreadsheet_token, + "sheet_id": sheet_id, + "sheet_name": sheet_name, + "start_col": start_col, + "num_cols": num_cols, + "user_id_type": user_id_type, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def read_table( + self, + spreadsheet_token: str, + sheet_id: str, + sheet_name: str, + num_range: str, + query: str, + user_id_type: str = "open_id", + ) -> dict: + url = f"{self.API_BASE_URL}/spreadsheet/read_table" + params = { + "spreadsheet_token": spreadsheet_token, + "sheet_id": sheet_id, + "sheet_name": sheet_name, + "range": num_range, + "query": query, + "user_id_type": user_id_type, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def create_base( + self, + name: str, + folder_token: str, + ) -> dict: + url = f"{self.API_BASE_URL}/base/create_base" + payload = { + "name": name, + "folder_token": folder_token, + } + res = self._send_request(url, payload=payload) + if "data" in res: + return res.get("data") + return res + + def add_records( + self, + app_token: str, + table_id: str, + table_name: str, + records: str, + user_id_type: str = "open_id", + ) -> dict: + url = f"{self.API_BASE_URL}/base/add_records" + params = { + "app_token": app_token, + "table_id": table_id, + "table_name": table_name, + "user_id_type": user_id_type, + } + payload = { + "records": self.convert_add_records(records), + } + res = self._send_request(url, params=params, payload=payload) + if "data" in res: + return res.get("data") + return res + + def update_records( + self, + app_token: str, + table_id: str, + table_name: str, + records: str, + user_id_type: str, + ) -> dict: + url = f"{self.API_BASE_URL}/base/update_records" + params = { + "app_token": app_token, + "table_id": table_id, + "table_name": table_name, + "user_id_type": user_id_type, + } + payload = { + "records": self.convert_update_records(records), + } + res = self._send_request(url, params=params, payload=payload) + if "data" in res: + return res.get("data") + return res + + def delete_records( + self, + app_token: str, + table_id: str, + table_name: str, + record_ids: str, + ) -> dict: + url = f"{self.API_BASE_URL}/base/delete_records" + params = { + "app_token": app_token, + "table_id": table_id, + "table_name": table_name, + } + if not record_ids: + record_id_list = [] + else: + try: + record_id_list = json.loads(record_ids) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + payload = { + "records": record_id_list, + } + res = self._send_request(url, params=params, payload=payload) + if "data" in res: + return res.get("data") + return res + + def search_record( + self, + app_token: str, + table_id: str, + table_name: str, + view_id: str, + field_names: str, + sort: str, + filters: str, + page_token: str, + automatic_fields: bool = False, + user_id_type: str = "open_id", + page_size: int = 20, + ) -> dict: + url = f"{self.API_BASE_URL}/base/search_record" + + params = { + "app_token": app_token, + "table_id": table_id, + "table_name": table_name, + "user_id_type": user_id_type, + "page_token": page_token, + "page_size": page_size, + } + + if not field_names: + field_name_list = [] + else: + try: + field_name_list = json.loads(field_names) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + + if not sort: + sort_list = [] + else: + try: + sort_list = json.loads(sort) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + + if not filters: + filter_dict = {} + else: + try: + filter_dict = json.loads(filters) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + + payload = {} + + if view_id: + payload["view_id"] = view_id + if field_names: + payload["field_names"] = field_name_list + if sort: + payload["sort"] = sort_list + if filters: + payload["filter"] = filter_dict + if automatic_fields: + payload["automatic_fields"] = automatic_fields + res = self._send_request(url, params=params, payload=payload) + if "data" in res: + return res.get("data") + return res + + def get_base_info( + self, + app_token: str, + ) -> dict: + url = f"{self.API_BASE_URL}/base/get_base_info" + params = { + "app_token": app_token, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def create_table( + self, + app_token: str, + table_name: str, + default_view_name: str, + fields: str, + ) -> dict: + url = f"{self.API_BASE_URL}/base/create_table" + params = { + "app_token": app_token, + } + if not fields: + fields_list = [] + else: + try: + fields_list = json.loads(fields) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + payload = { + "name": table_name, + "fields": fields_list, + } + if default_view_name: + payload["default_view_name"] = default_view_name + res = self._send_request(url, params=params, payload=payload) + if "data" in res: + return res.get("data") + return res + + def delete_tables( + self, + app_token: str, + table_ids: str, + table_names: str, + ) -> dict: + url = f"{self.API_BASE_URL}/base/delete_tables" + params = { + "app_token": app_token, + } + if not table_ids: + table_id_list = [] + else: + try: + table_id_list = json.loads(table_ids) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + + if not table_names: + table_name_list = [] + else: + try: + table_name_list = json.loads(table_names) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + + payload = { + "table_ids": table_id_list, + "table_names": table_name_list, + } + res = self._send_request(url, params=params, payload=payload) + if "data" in res: + return res.get("data") + return res + + def list_tables( + self, + app_token: str, + page_token: str, + page_size: int = 20, + ) -> dict: + url = f"{self.API_BASE_URL}/base/list_tables" + params = { + "app_token": app_token, + "page_token": page_token, + "page_size": page_size, + } + res = self._send_request(url, method="GET", params=params) + if "data" in res: + return res.get("data") + return res + + def read_records( + self, + app_token: str, + table_id: str, + table_name: str, + record_ids: str, + user_id_type: str = "open_id", + ) -> dict: + url = f"{self.API_BASE_URL}/base/read_records" + params = { + "app_token": app_token, + "table_id": table_id, + "table_name": table_name, + } + if not record_ids: + record_id_list = [] + else: + try: + record_id_list = json.loads(record_ids) + except json.JSONDecodeError: + raise ValueError("The input string is not valid JSON") + payload = { + "record_ids": record_id_list, + "user_id_type": user_id_type, + } + res = self._send_request(url, method="POST", params=params, payload=payload) + if "data" in res: + return res.get("data") + return res diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 564b9d3e14c15e..1812d245712189 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -1,109 +1,125 @@ import logging from mimetypes import guess_extension +from typing import Optional -from core.file.file_obj import FileTransferMethod, FileType +from core.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager logger = logging.getLogger(__name__) + class ToolFileMessageTransformer: @classmethod - def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage], - user_id: str, - tenant_id: str, - conversation_id: str) -> list[ToolInvokeMessage]: + def transform_tool_invoke_messages( + cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str | None + ) -> list[ToolInvokeMessage]: """ Transform tool message and handle file download """ result = [] for message in messages: - if message.type == ToolInvokeMessage.MessageType.TEXT: - result.append(message) - elif message.type == ToolInvokeMessage.MessageType.LINK: + if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}: result.append(message) - elif message.type == ToolInvokeMessage.MessageType.IMAGE: + elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance(message.message, str): # try to download image try: file = ToolFileManager.create_file_by_url( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=conversation_id, - file_url=message.message + user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_url=message.message ) url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}' - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + ) + ) except Exception as e: logger.exception(e) - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.TEXT, - message=f"Failed to download image: {message.message}, you can try to download it yourself.", - meta=message.meta.copy() if message.meta is not None else {}, - save_as=message.save_as, - )) + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.TEXT, + message=f"Failed to download image: {message.message}, please try to download it manually.", + meta=message.meta.copy() if message.meta is not None else {}, + save_as=message.save_as, + ) + ) elif message.type == ToolInvokeMessage.MessageType.BLOB: # get mime type and save blob to storage - mimetype = message.meta.get('mime_type', 'octet/stream') + assert message.meta is not None + mimetype = message.meta.get("mime_type", "octet/stream") # if message is str, encode it to bytes if isinstance(message.message, str): - message.message = message.message.encode('utf-8') + message.message = message.message.encode("utf-8") + # FIXME: should do a type check here. + assert isinstance(message.message, bytes) file = ToolFileManager.create_file_by_raw( - user_id=user_id, tenant_id=tenant_id, + user_id=user_id, + tenant_id=tenant_id, conversation_id=conversation_id, file_binary=message.message, - mimetype=mimetype + mimetype=mimetype, ) - url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype)) + url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype)) # check if file is image - if 'image' in mimetype: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) + if "image" in mimetype: + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + ) + ) else: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) - elif message.type == ToolInvokeMessage.MessageType.FILE_VAR: - file_var = message.meta.get('file_var') - if file_var: - if file_var.transfer_method == FileTransferMethod.TOOL_FILE: - url = cls.get_tool_file_url(file_var.related_id, file_var.extension) - if file_var.type == FileType.IMAGE: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE_LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + ) + ) + elif message.type == ToolInvokeMessage.MessageType.FILE: + assert message.meta is not None + file = message.meta.get("file") + if isinstance(file, File): + if file.transfer_method == FileTransferMethod.TOOL_FILE: + assert file.related_id is not None + url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) + if file.type == FileType.IMAGE: + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + ) + ) else: - result.append(ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=url, - save_as=message.save_as, - meta=message.meta.copy() if message.meta is not None else {}, - )) + result.append( + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=url, + save_as=message.save_as, + meta=message.meta.copy() if message.meta is not None else {}, + ) + ) + else: + result.append(message) else: result.append(message) return result @classmethod - def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str: + def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str: return f'/files/tools/{tool_file_id}{extension or ".bin"}' diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 9e8ef478237d6e..4e226810d6ac90 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -1,7 +1,7 @@ """ - For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc. +For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc. - Therefore, a model manager is needed to list/invoke/validate models. +Therefore, a model manager is needed to list/invoke/validate models. """ import json @@ -27,52 +27,49 @@ class InvokeModelError(Exception): pass + class ModelInvocationUtils: @staticmethod def get_max_llm_context_tokens( tenant_id: str, ) -> int: """ - get max llm context tokens of the model + get max llm context tokens of the model """ model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, model_type=ModelType.LLM, + tenant_id=tenant_id, + model_type=ModelType.LLM, ) if not model_instance: - raise InvokeModelError('Model not found') - + raise InvokeModelError("Model not found") + llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) if not schema: - raise InvokeModelError('No model schema found') + raise InvokeModelError("No model schema found") max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None) if max_tokens is None: return 2048 - + return max_tokens @staticmethod - def calculate_tokens( - tenant_id: str, - prompt_messages: list[PromptMessage] - ) -> int: + def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int: """ - calculate tokens from prompt messages and model parameters + calculate tokens from prompt messages and model parameters """ # get model instance model_manager = ModelManager() - model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, model_type=ModelType.LLM - ) + model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM) if not model_instance: - raise InvokeModelError('Model not found') - + raise InvokeModelError("Model not found") + # get tokens tokens = model_instance.get_llm_num_tokens(prompt_messages) @@ -80,9 +77,7 @@ def calculate_tokens( @staticmethod def invoke( - user_id: str, tenant_id: str, - tool_type: str, tool_name: str, - prompt_messages: list[PromptMessage] + user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage] ) -> LLMResult: """ invoke model with parameters in user's own context @@ -103,15 +98,16 @@ def invoke( model_manager = ModelManager() # get model instance model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, model_type=ModelType.LLM, + tenant_id=tenant_id, + model_type=ModelType.LLM, ) # get prompt tokens prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) model_parameters = { - 'temperature': 0.8, - 'top_p': 0.8, + "temperature": 0.8, + "top_p": 0.8, } # create tool model invoke @@ -123,14 +119,14 @@ def invoke( tool_name=tool_name, model_parameters=json.dumps(model_parameters), prompt_messages=json.dumps(jsonable_encoder(prompt_messages)), - model_response='', + model_response="", prompt_tokens=prompt_tokens, answer_tokens=0, answer_unit_price=0, answer_price_unit=0, provider_response_latency=0, total_price=0, - currency='USD', + currency="USD", ) db.session.add(tool_model_invoke) @@ -140,20 +136,24 @@ def invoke( response: LLMResult = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=model_parameters, - tools=[], stop=[], stream=False, user=user_id, callbacks=[] + tools=[], + stop=[], + stream=False, + user=user_id, + callbacks=[], ) except InvokeRateLimitError as e: - raise InvokeModelError(f'Invoke rate limit error: {e}') + raise InvokeModelError(f"Invoke rate limit error: {e}") except InvokeBadRequestError as e: - raise InvokeModelError(f'Invoke bad request error: {e}') + raise InvokeModelError(f"Invoke bad request error: {e}") except InvokeConnectionError as e: - raise InvokeModelError(f'Invoke connection error: {e}') + raise InvokeModelError(f"Invoke connection error: {e}") except InvokeAuthorizationError as e: - raise InvokeModelError('Invoke authorization error') + raise InvokeModelError("Invoke authorization error") except InvokeServerUnavailableError as e: - raise InvokeModelError(f'Invoke server unavailable error: {e}') + raise InvokeModelError(f"Invoke server unavailable error: {e}") except Exception as e: - raise InvokeModelError(f'Invoke error: {e}') + raise InvokeModelError(f"Invoke error: {e}") # update tool model invoke tool_model_invoke.model_response = response.message.content diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index f711f7c9f3c2e8..5867a11bb39da1 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -1,9 +1,9 @@ - import re import uuid from json import dumps as json_dumps from json import loads as json_loads from json.decoder import JSONDecodeError +from typing import Optional from requests import get from yaml import YAMLError, safe_load @@ -16,54 +16,56 @@ class ApiBasedToolSchemaParser: @staticmethod - def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: + def parse_openapi_to_tool_bundle( + openapi: dict, extra_info: Optional[dict], warning: Optional[dict] + ) -> list[ApiToolBundle]: warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} # set description to extra_info - extra_info['description'] = openapi['info'].get('description', '') + extra_info["description"] = openapi["info"].get("description", "") - if len(openapi['servers']) == 0: - raise ToolProviderNotFoundError('No server found in the openapi yaml.') + if len(openapi["servers"]) == 0: + raise ToolProviderNotFoundError("No server found in the openapi yaml.") - server_url = openapi['servers'][0]['url'] + server_url = openapi["servers"][0]["url"] # list all interfaces interfaces = [] - for path, path_item in openapi['paths'].items(): - methods = ['get', 'post', 'put', 'delete', 'patch', 'head', 'options', 'trace'] + for path, path_item in openapi["paths"].items(): + methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"] for method in methods: if method in path_item: - interfaces.append({ - 'path': path, - 'method': method, - 'operation': path_item[method], - }) + interfaces.append( + { + "path": path, + "method": method, + "operation": path_item[method], + } + ) # get all parameters bundles = [] for interface in interfaces: # convert parameters parameters = [] - if 'parameters' in interface['operation']: - for parameter in interface['operation']['parameters']: + if "parameters" in interface["operation"]: + for parameter in interface["operation"]["parameters"]: tool_parameter = ToolParameter( - name=parameter['name'], - label=I18nObject( - en_US=parameter['name'], - zh_Hans=parameter['name'] - ), + name=parameter["name"], + label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]), human_description=I18nObject( - en_US=parameter.get('description', ''), - zh_Hans=parameter.get('description', '') + en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "") ), type=ToolParameter.ToolParameterType.STRING, - required=parameter.get('required', False), + required=parameter.get("required", False), form=ToolParameter.ToolParameterForm.LLM, - llm_description=parameter.get('description'), - default=parameter['schema']['default'] if 'schema' in parameter and 'default' in parameter['schema'] else None, + llm_description=parameter.get("description"), + default=parameter["schema"]["default"] + if "schema" in parameter and "default" in parameter["schema"] + else None, ) - + # check if there is a type typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter) if typ: @@ -72,44 +74,40 @@ def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning parameters.append(tool_parameter) # create tool bundle # check if there is a request body - if 'requestBody' in interface['operation']: - request_body = interface['operation']['requestBody'] - if 'content' in request_body: - for content_type, content in request_body['content'].items(): + if "requestBody" in interface["operation"]: + request_body = interface["operation"]["requestBody"] + if "content" in request_body: + for content_type, content in request_body["content"].items(): # if there is a reference, get the reference and overwrite the content - if 'schema' not in content: + if "schema" not in content: continue - if '$ref' in content['schema']: + if "$ref" in content["schema"]: # get the reference root = openapi - reference = content['schema']['$ref'].split('/')[1:] + reference = content["schema"]["$ref"].split("/")[1:] for ref in reference: root = root[ref] # overwrite the content - interface['operation']['requestBody']['content'][content_type]['schema'] = root + interface["operation"]["requestBody"]["content"][content_type]["schema"] = root # parse body parameters - if 'schema' in interface['operation']['requestBody']['content'][content_type]: - body_schema = interface['operation']['requestBody']['content'][content_type]['schema'] - required = body_schema.get('required', []) - properties = body_schema.get('properties', {}) + if "schema" in interface["operation"]["requestBody"]["content"][content_type]: + body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"] + required = body_schema.get("required", []) + properties = body_schema.get("properties", {}) for name, property in properties.items(): tool = ToolParameter( name=name, - label=I18nObject( - en_US=name, - zh_Hans=name - ), + label=I18nObject(en_US=name, zh_Hans=name), human_description=I18nObject( - en_US=property.get('description', ''), - zh_Hans=property.get('description', '') + en_US=property.get("description", ""), zh_Hans=property.get("description", "") ), type=ToolParameter.ToolParameterType.STRING, required=name in required, form=ToolParameter.ToolParameterForm.LLM, - llm_description=property.get('description', ''), - default=property.get('default', None), + llm_description=property.get("description", ""), + default=property.get("default", None), ) # check if there is a type @@ -127,172 +125,176 @@ def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning parameters_count[parameter.name] += 1 for name, count in parameters_count.items(): if count > 1: - warning['duplicated_parameter'] = f'Parameter {name} is duplicated.' + warning["duplicated_parameter"] = f"Parameter {name} is duplicated." # check if there is a operation id, use $path_$method as operation id if not - if 'operationId' not in interface['operation']: + if "operationId" not in interface["operation"]: # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ - path = interface['path'] - if interface['path'].startswith('/'): - path = interface['path'][1:] + path = interface["path"] + if interface["path"].startswith("/"): + path = interface["path"][1:] # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ - path = re.sub(r'[^a-zA-Z0-9_-]', '', path) + path = re.sub(r"[^a-zA-Z0-9_-]", "", path) if not path: path = str(uuid.uuid4()) - - interface['operation']['operationId'] = f'{path}_{interface["method"]}' - - bundles.append(ApiToolBundle( - server_url=server_url + interface['path'], - method=interface['method'], - summary=interface['operation']['description'] if 'description' in interface['operation'] else - interface['operation'].get('summary', None), - operation_id=interface['operation']['operationId'], - parameters=parameters, - author='', - icon=None, - openapi=interface['operation'], - )) + + interface["operation"]["operationId"] = f'{path}_{interface["method"]}' + + bundles.append( + ApiToolBundle( + server_url=server_url + interface["path"], + method=interface["method"], + summary=interface["operation"]["description"] + if "description" in interface["operation"] + else interface["operation"].get("summary", None), + operation_id=interface["operation"]["operationId"], + parameters=parameters, + author="", + icon=None, + openapi=interface["operation"], + ) + ) return bundles - + @staticmethod def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType: parameter = parameter or {} typ = None - if 'type' in parameter: - typ = parameter['type'] - elif 'schema' in parameter and 'type' in parameter['schema']: - typ = parameter['schema']['type'] - - if typ == 'integer' or typ == 'number': + if "type" in parameter: + typ = parameter["type"] + elif "schema" in parameter and "type" in parameter["schema"]: + typ = parameter["schema"]["type"] + + if typ in {"integer", "number"}: return ToolParameter.ToolParameterType.NUMBER - elif typ == 'boolean': + elif typ == "boolean": return ToolParameter.ToolParameterType.BOOLEAN - elif typ == 'string': + elif typ == "string": return ToolParameter.ToolParameterType.STRING @staticmethod - def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: + def parse_openapi_yaml_to_tool_bundle( + yaml: str, extra_info: Optional[dict], warning: Optional[dict] + ) -> list[ApiToolBundle]: """ - parse openapi yaml to tool bundle + parse openapi yaml to tool bundle - :param yaml: the yaml string - :return: the tool bundle + :param yaml: the yaml string + :return: the tool bundle """ warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} openapi: dict = safe_load(yaml) if openapi is None: - raise ToolApiSchemaError('Invalid openapi yaml.') + raise ToolApiSchemaError("Invalid openapi yaml.") return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) - + @staticmethod - def parse_swagger_to_openapi(swagger: dict, extra_info: dict = None, warning: dict = None) -> dict: + def parse_swagger_to_openapi(swagger: dict, extra_info: Optional[dict], warning: Optional[dict]) -> dict: """ - parse swagger to openapi + parse swagger to openapi - :param swagger: the swagger dict - :return: the openapi dict + :param swagger: the swagger dict + :return: the openapi dict """ # convert swagger to openapi - info = swagger.get('info', { - 'title': 'Swagger', - 'description': 'Swagger', - 'version': '1.0.0' - }) + info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"}) - servers = swagger.get('servers', []) + servers = swagger.get("servers", []) if len(servers) == 0: - raise ToolApiSchemaError('No server found in the swagger yaml.') + raise ToolApiSchemaError("No server found in the swagger yaml.") openapi = { - 'openapi': '3.0.0', - 'info': { - 'title': info.get('title', 'Swagger'), - 'description': info.get('description', 'Swagger'), - 'version': info.get('version', '1.0.0') + "openapi": "3.0.0", + "info": { + "title": info.get("title", "Swagger"), + "description": info.get("description", "Swagger"), + "version": info.get("version", "1.0.0"), }, - 'servers': swagger['servers'], - 'paths': {}, - 'components': { - 'schemas': {} - } + "servers": swagger["servers"], + "paths": {}, + "components": {"schemas": {}}, } # check paths - if 'paths' not in swagger or len(swagger['paths']) == 0: - raise ToolApiSchemaError('No paths found in the swagger yaml.') + if "paths" not in swagger or len(swagger["paths"]) == 0: + raise ToolApiSchemaError("No paths found in the swagger yaml.") # convert paths - for path, path_item in swagger['paths'].items(): - openapi['paths'][path] = {} + for path, path_item in swagger["paths"].items(): + openapi["paths"][path] = {} for method, operation in path_item.items(): - if 'operationId' not in operation: - raise ToolApiSchemaError(f'No operationId found in operation {method} {path}.') - - if ('summary' not in operation or len(operation['summary']) == 0) and \ - ('description' not in operation or len(operation['description']) == 0): - warning['missing_summary'] = f'No summary or description found in operation {method} {path}.' - - openapi['paths'][path][method] = { - 'operationId': operation['operationId'], - 'summary': operation.get('summary', ''), - 'description': operation.get('description', ''), - 'parameters': operation.get('parameters', []), - 'responses': operation.get('responses', {}), + if "operationId" not in operation: + raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.") + + if ("summary" not in operation or len(operation["summary"]) == 0) and ( + "description" not in operation or len(operation["description"]) == 0 + ): + warning["missing_summary"] = f"No summary or description found in operation {method} {path}." + + openapi["paths"][path][method] = { + "operationId": operation["operationId"], + "summary": operation.get("summary", ""), + "description": operation.get("description", ""), + "parameters": operation.get("parameters", []), + "responses": operation.get("responses", {}), } - if 'requestBody' in operation: - openapi['paths'][path][method]['requestBody'] = operation['requestBody'] + if "requestBody" in operation: + openapi["paths"][path][method]["requestBody"] = operation["requestBody"] # convert definitions - for name, definition in swagger['definitions'].items(): - openapi['components']['schemas'][name] = definition + for name, definition in swagger["definitions"].items(): + openapi["components"]["schemas"][name] = definition return openapi @staticmethod - def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]: + def parse_openai_plugin_json_to_tool_bundle( + json: str, extra_info: Optional[dict], warning: Optional[dict] + ) -> list[ApiToolBundle]: """ - parse openapi plugin yaml to tool bundle + parse openapi plugin yaml to tool bundle - :param json: the json string - :return: the tool bundle + :param json: the json string + :return: the tool bundle """ warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} try: openai_plugin = json_loads(json) - api = openai_plugin['api'] - api_url = api['url'] - api_type = api['type'] + api = openai_plugin["api"] + api_url = api["url"] + api_type = api["type"] except: - raise ToolProviderNotFoundError('Invalid openai plugin json.') - - if api_type != 'openapi': - raise ToolNotSupportedError('Only openapi is supported now.') - + raise ToolProviderNotFoundError("Invalid openai plugin json.") + + if api_type != "openapi": + raise ToolNotSupportedError("Only openapi is supported now.") + # get openapi yaml - response = get(api_url, headers={ - 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ' - }, timeout=5) + response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5) if response.status_code != 200: - raise ToolProviderNotFoundError('cannot get openapi yaml from url.') - - return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning) - + raise ToolProviderNotFoundError("cannot get openapi yaml from url.") + + return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle( + response.text, extra_info=extra_info, warning=warning + ) + @staticmethod - def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiToolBundle], str]: + def auto_parse_to_tool_bundle( + content: str, extra_info: Optional[dict] = None, warning: Optional[dict] = None + ) -> tuple[list[ApiToolBundle], str]: """ - auto parse to tool bundle + auto parse to tool bundle - :param content: the content - :return: tools bundle, schema_type + :param content: the content + :return: tools bundle, schema_type """ warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} @@ -301,7 +303,7 @@ def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: di loaded_content = None json_error = None yaml_error = None - + try: loaded_content = json_loads(content) except JSONDecodeError as e: @@ -313,34 +315,48 @@ def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: di except YAMLError as e: yaml_error = e if loaded_content is None: - raise ToolApiSchemaError(f'Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)}, yaml error: {str(yaml_error)}') + raise ToolApiSchemaError( + f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)}," + f" yaml error: {str(yaml_error)}" + ) swagger_error = None openapi_error = None openapi_plugin_error = None schema_type = None - + try: - openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(loaded_content, extra_info=extra_info, warning=warning) + openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle( + loaded_content, extra_info=extra_info, warning=warning + ) schema_type = ApiProviderSchemaType.OPENAPI.value return openapi, schema_type except ToolApiSchemaError as e: openapi_error = e - + # openai parse error, fallback to swagger try: - converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(loaded_content, extra_info=extra_info, warning=warning) + converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi( + loaded_content, extra_info=extra_info, warning=warning + ) schema_type = ApiProviderSchemaType.SWAGGER.value - return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(converted_swagger, extra_info=extra_info, warning=warning), schema_type + return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle( + converted_swagger, extra_info=extra_info, warning=warning + ), schema_type except ToolApiSchemaError as e: swagger_error = e - + # swagger parse error, fallback to openai plugin try: - openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(json_dumps(loaded_content), extra_info=extra_info, warning=warning) + openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( + json_dumps(loaded_content), extra_info=extra_info, warning=warning + ) return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value except ToolNotSupportedError as e: # maybe it's not plugin at all openapi_plugin_error = e - raise ToolApiSchemaError(f'Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)}, openapi plugin error: {str(openapi_plugin_error)}') + raise ToolApiSchemaError( + f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)}," + f" openapi plugin error: {str(openapi_plugin_error)}" + ) diff --git a/api/core/tools/utils/tool_parameter_converter.py b/api/core/tools/utils/tool_parameter_converter.py deleted file mode 100644 index 6f88eeaa0a8a98..00000000000000 --- a/api/core/tools/utils/tool_parameter_converter.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import Any - -from core.tools.entities.tool_entities import ToolParameter - - -class ToolParameterConverter: - @staticmethod - def get_parameter_type(parameter_type: str | ToolParameter.ToolParameterType) -> str: - match parameter_type: - case ToolParameter.ToolParameterType.STRING \ - | ToolParameter.ToolParameterType.SECRET_INPUT \ - | ToolParameter.ToolParameterType.SELECT: - return 'string' - - case ToolParameter.ToolParameterType.BOOLEAN: - return 'boolean' - - case ToolParameter.ToolParameterType.NUMBER: - return 'number' - - case _: - raise ValueError(f"Unsupported parameter type {parameter_type}") - - @staticmethod - def cast_parameter_by_type(value: Any, parameter_type: str) -> Any: - # convert tool parameter config to correct type - try: - match parameter_type: - case ToolParameter.ToolParameterType.STRING \ - | ToolParameter.ToolParameterType.SECRET_INPUT \ - | ToolParameter.ToolParameterType.SELECT: - if value is None: - return '' - else: - return value if isinstance(value, str) else str(value) - - case ToolParameter.ToolParameterType.BOOLEAN: - if value is None: - return False - elif isinstance(value, str): - # Allowed YAML boolean value strings: https://yaml.org/type/bool.html - # and also '0' for False and '1' for True - match value.lower(): - case 'true' | 'yes' | 'y' | '1': - return True - case 'false' | 'no' | 'n' | '0': - return False - case _: - return bool(value) - else: - return value if isinstance(value, bool) else bool(value) - - case ToolParameter.ToolParameterType.NUMBER: - if isinstance(value, int) | isinstance(value, float): - return value - elif isinstance(value, str) and value != '': - if '.' in value: - return float(value) - else: - return int(value) - case ToolParameter.ToolParameterType.FILE: - return value - case _: - return str(value) - - except Exception: - raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.") diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index a461328ae6fad8..5807d61b9409a6 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -8,6 +8,8 @@ import tempfile import unicodedata from contextlib import contextmanager +from pathlib import Path +from typing import Optional from urllib.parse import unquote import chardet @@ -32,13 +34,14 @@ def page_result(text: str, cursor: int, max_length: int) -> str: """Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" - return text[cursor: cursor + max_length] + return text[cursor : cursor + max_length] -def get_url(url: str, user_agent: str = None) -> str: +def get_url(url: str, user_agent: Optional[str] = None) -> str: """Fetch URL and return the contents as a string.""" headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)" + " Chrome/91.0.4472.124 Safari/537.36" } if user_agent: headers["User-Agent"] = user_agent @@ -49,15 +52,15 @@ def get_url(url: str, user_agent: str = None) -> str: if response.status_code == 200: # check content-type - content_type = response.headers.get('Content-Type') + content_type = response.headers.get("Content-Type") if content_type: - main_content_type = response.headers.get('Content-Type').split(';')[0].strip() + main_content_type = response.headers.get("Content-Type").split(";")[0].strip() else: - content_disposition = response.headers.get('Content-Disposition', '') + content_disposition = response.headers.get("Content-Disposition", "") filename_match = re.search(r'filename="([^"]+)"', content_disposition) if filename_match: filename = unquote(filename_match.group(1)) - extension = re.search(r'\.(\w+)$', filename) + extension = re.search(r"\.(\w+)$", filename) if extension: main_content_type = mimetypes.guess_type(filename)[0] @@ -78,7 +81,7 @@ def get_url(url: str, user_agent: str = None) -> str: # Detect encoding using chardet detected_encoding = chardet.detect(response.content) - encoding = detected_encoding['encoding'] + encoding = detected_encoding["encoding"] if encoding: try: content = response.content.decode(encoding) @@ -89,35 +92,34 @@ def get_url(url: str, user_agent: str = None) -> str: a = extract_using_readabilipy(content) - if not a['plain_text'] or not a['plain_text'].strip(): - return '' + if not a["plain_text"] or not a["plain_text"].strip(): + return "" res = FULL_TEMPLATE.format( - title=a['title'], - authors=a['byline'], - publish_date=a['date'], + title=a["title"], + authors=a["byline"], + publish_date=a["date"], top_image="", - text=a['plain_text'] if a['plain_text'] else "", + text=a["plain_text"] or "", ) return res def extract_using_readabilipy(html): - with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html: + with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html: f_html.write(html) f_html.close() html_path = f_html.name # Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file article_json_path = html_path + ".json" - jsdir = os.path.join(find_module_path('readabilipy'), 'javascript') + jsdir = os.path.join(find_module_path("readabilipy"), "javascript") with chdir(jsdir): subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path]) # Read output of call to Readability.parse() from JSON file and return as Python dictionary - with open(article_json_path, encoding="utf-8") as json_file: - input_json = json.loads(json_file.read()) + input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8")) # Deleting files after processing os.unlink(article_json_path) @@ -129,7 +131,7 @@ def extract_using_readabilipy(html): "date": None, "content": None, "plain_content": None, - "plain_text": None + "plain_text": None, } # Populate article fields from readability fields where present if input_json: @@ -145,7 +147,7 @@ def extract_using_readabilipy(html): article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"]) if input_json.get("textContent"): article_json["plain_text"] = input_json["textContent"] - article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"]) + article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"]) return article_json @@ -158,6 +160,7 @@ def find_module_path(module_name): return None + @contextmanager def chdir(path): """Change directory in context and return to original on exit""" @@ -172,12 +175,14 @@ def chdir(path): def extract_text_blocks_as_plain_text(paragraph_html): # Load article as DOM - soup = BeautifulSoup(paragraph_html, 'html.parser') + soup = BeautifulSoup(paragraph_html, "html.parser") # Select all lists - list_elements = soup.find_all(['ul', 'ol']) + list_elements = soup.find_all(["ul", "ol"]) # Prefix text in all list items with "* " and make lists paragraphs for list_element in list_elements: - plain_items = "".join(list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all('li')]))) + plain_items = "".join( + list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")])) + ) list_element.string = plain_items list_element.name = "p" # Select all text blocks @@ -189,8 +194,8 @@ def extract_text_blocks_as_plain_text(paragraph_html): def plain_text_leaf_node(element): - # Extract all text, stripped of any child HTML elements and normalise it - plain_text = normalise_text(element.get_text()) + # Extract all text, stripped of any child HTML elements and normalize it + plain_text = normalize_text(element.get_text()) if plain_text != "" and element.name == "li": plain_text = "* {}, ".format(plain_text) if plain_text == "": @@ -204,7 +209,7 @@ def plain_text_leaf_node(element): def plain_content(readability_content, content_digests, node_indexes): # Load article as DOM - soup = BeautifulSoup(readability_content, 'html.parser') + soup = BeautifulSoup(readability_content, "html.parser") # Make all elements plain elements = plain_elements(soup.contents, content_digests, node_indexes) if node_indexes: @@ -217,8 +222,7 @@ def plain_content(readability_content, content_digests, node_indexes): def plain_elements(elements, content_digests, node_indexes): # Get plain content versions of all elements - elements = [plain_element(element, content_digests, node_indexes) - for element in elements] + elements = [plain_element(element, content_digests, node_indexes) for element in elements] if content_digests: # Add content digest attribute to nodes elements = [add_content_digest(element) for element in elements] @@ -231,8 +235,8 @@ def plain_element(element, content_digests, node_indexes): # For leaf node elements, extract the text content, discarding any HTML tags # 1. Get element contents as text plain_text = element.get_text() - # 2. Normalise the extracted text string to a canonical representation - plain_text = normalise_text(plain_text) + # 2. Normalize the extracted text string to a canonical representation + plain_text = normalize_text(plain_text) # 3. Update element content to be plain text element.string = plain_text elif is_text(element): @@ -243,7 +247,7 @@ def plain_element(element, content_digests, node_indexes): element = type(element)("") else: plain_text = element.string - plain_text = normalise_text(plain_text) + plain_text = normalize_text(plain_text) element = type(element)(plain_text) else: # If not a leaf node or leaf type call recursively on child nodes, replacing @@ -258,21 +262,19 @@ def add_node_indexes(element, node_index="0"): # Add index to current element element["data-node-index"] = node_index # Add index to child elements - for local_idx, child in enumerate( - [c for c in element.contents if not is_text(c)], start=1): + for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1): # Can't add attributes to leaf string types - child_index = "{stem}.{local}".format( - stem=node_index, local=local_idx) + child_index = "{stem}.{local}".format(stem=node_index, local=local_idx) add_node_indexes(child, node_index=child_index) return element -def normalise_text(text): - """Normalise unicode and whitespace.""" - # Normalise unicode first to try and standardise whitespace characters as much as possible before normalising them +def normalize_text(text): + """Normalize unicode and whitespace.""" + # Normalize unicode first to try and standardize whitespace characters as much as possible before normalizing them text = strip_control_characters(text) - text = normalise_unicode(text) - text = normalise_whitespace(text) + text = normalize_unicode(text) + text = normalize_whitespace(text) return text @@ -284,29 +286,35 @@ def strip_control_characters(text): # [Cn]: Other, Not Assigned # [Co]: Other, Private Use # [Cs]: Other, Surrogate - control_chars = {'Cc', 'Cf', 'Cn', 'Co', 'Cs'} - retained_chars = ['\t', '\n', '\r', '\f'] + control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"} + retained_chars = ["\t", "\n", "\r", "\f"] # Remove non-printing control characters - return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text]) + return "".join( + [ + "" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char + for char in text + ] + ) -def normalise_unicode(text): - """Normalise unicode such that things that are visually equivalent map to the same unicode string where possible.""" +def normalize_unicode(text): + """Normalize unicode such that things that are visually equivalent map to the same unicode string where possible.""" normal_form = "NFKC" text = unicodedata.normalize(normal_form, text) return text -def normalise_whitespace(text): +def normalize_whitespace(text): """Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed.""" text = regex.sub(r"\s+", " ", text) # Remove leading and trailing whitespace text = text.strip() return text + def is_leaf(element): - return (element.name in ['p', 'li']) + return element.name in {"p", "li"} def is_text(element): @@ -330,7 +338,7 @@ def content_digest(element): if trimmed_string == "": digest = "" else: - digest = hashlib.sha256(trimmed_string.encode('utf-8')).hexdigest() + digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest() else: contents = element.contents num_contents = len(contents) @@ -343,9 +351,8 @@ def content_digest(element): else: # Build content digest from the "non-empty" digests of child nodes digest = hashlib.sha256() - child_digests = list( - filter(lambda x: x != "", [content_digest(content) for content in contents])) + child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents])) for child in child_digests: - digest.update(child.encode('utf-8')) + digest.update(child.encode("utf-8")) digest = digest.hexdigest() return digest diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index ff5505bbbfb9f5..d92bfb9b90a9aa 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -1,36 +1,33 @@ +from collections.abc import Mapping, Sequence +from typing import Any + from core.app.app_config.entities import VariableEntity from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration class WorkflowToolConfigurationUtils: @classmethod - def check_parameter_configurations(cls, configurations: list[dict]): - """ - check parameter configurations - """ + def check_parameter_configurations(cls, configurations: Mapping[str, Any]): for configuration in configurations: - if not WorkflowToolParameterConfiguration(**configuration): - raise ValueError('invalid parameter configuration') + WorkflowToolParameterConfiguration.model_validate(configuration) @classmethod - def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]: + def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]: """ get workflow graph variables """ - nodes = graph.get('nodes', []) - start_node = next(filter(lambda x: x.get('data', {}).get('type') == 'start', nodes), None) + nodes = graph.get("nodes", []) + start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None) if not start_node: return [] - return [ - VariableEntity(**variable) for variable in start_node.get('data', {}).get('variables', []) - ] - + return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])] + @classmethod - def check_is_synced(cls, - variables: list[VariableEntity], - tool_configurations: list[WorkflowToolParameterConfiguration]) -> None: + def check_is_synced( + cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration] + ) -> None: """ check is synced @@ -39,10 +36,10 @@ def check_is_synced(cls, variable_names = [variable.variable for variable in variables] if len(tool_configurations) != len(variables): - raise ValueError('parameter configuration mismatch, please republish the tool to update') - + raise ValueError("parameter configuration mismatch, please republish the tool to update") + for parameter in tool_configurations: if parameter.name not in variable_names: - raise ValueError('parameter configuration mismatch, please republish the tool to update') + raise ValueError("parameter configuration mismatch, please republish the tool to update") - return True \ No newline at end of file + return True diff --git a/api/core/tools/utils/yaml_utils.py b/api/core/tools/utils/yaml_utils.py index 21155a696031f6..42c7f85bc6daeb 100644 --- a/api/core/tools/utils/yaml_utils.py +++ b/api/core/tools/utils/yaml_utils.py @@ -1,4 +1,5 @@ import logging +from pathlib import Path from typing import Any import yaml @@ -17,16 +18,18 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any :param default_value: the value returned when errors ignored :return: an object of the YAML content """ - try: - with open(file_path, encoding='utf-8') as yaml_file: - try: - yaml_content = yaml.safe_load(yaml_file) - return yaml_content if yaml_content else default_value - except Exception as e: - raise YAMLError(f'Failed to load YAML file {file_path}: {e}') - except Exception as e: + if not file_path or not Path(file_path).exists(): if ignore_error: - logger.debug(f'Failed to load YAML file {file_path}: {e}') return default_value else: - raise e + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, encoding="utf-8") as yaml_file: + try: + yaml_content = yaml.safe_load(yaml_file) + return yaml_content or default_value + except Exception as e: + if ignore_error: + return default_value + else: + raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e diff --git a/api/core/variables/__init__.py b/api/core/variables/__init__.py new file mode 100644 index 00000000000000..87f9e3ed45c7cb --- /dev/null +++ b/api/core/variables/__init__.py @@ -0,0 +1,61 @@ +from .segment_group import SegmentGroup +from .segments import ( + ArrayAnySegment, + ArrayFileSegment, + ArrayNumberSegment, + ArrayObjectSegment, + ArraySegment, + ArrayStringSegment, + FileSegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, + Segment, + StringSegment, +) +from .types import SegmentType +from .variables import ( + ArrayAnyVariable, + ArrayNumberVariable, + ArrayObjectVariable, + ArrayStringVariable, + FileVariable, + FloatVariable, + IntegerVariable, + NoneVariable, + ObjectVariable, + SecretVariable, + StringVariable, + Variable, +) + +__all__ = [ + "IntegerVariable", + "FloatVariable", + "ObjectVariable", + "SecretVariable", + "StringVariable", + "ArrayAnyVariable", + "Variable", + "SegmentType", + "SegmentGroup", + "Segment", + "NoneSegment", + "NoneVariable", + "IntegerSegment", + "FloatSegment", + "ObjectSegment", + "ArrayAnySegment", + "StringSegment", + "ArrayStringVariable", + "ArrayNumberVariable", + "ArrayObjectVariable", + "ArraySegment", + "ArrayFileSegment", + "ArrayNumberSegment", + "ArrayObjectSegment", + "ArrayStringSegment", + "FileSegment", + "FileVariable", +] diff --git a/api/core/variables/exc.py b/api/core/variables/exc.py new file mode 100644 index 00000000000000..5cf67c3baccacc --- /dev/null +++ b/api/core/variables/exc.py @@ -0,0 +1,2 @@ +class VariableError(ValueError): + pass diff --git a/api/core/variables/segment_group.py b/api/core/variables/segment_group.py new file mode 100644 index 00000000000000..b363255b2cae9e --- /dev/null +++ b/api/core/variables/segment_group.py @@ -0,0 +1,22 @@ +from .segments import Segment +from .types import SegmentType + + +class SegmentGroup(Segment): + value_type: SegmentType = SegmentType.GROUP + value: list[Segment] + + @property + def text(self): + return "".join([segment.text for segment in self.value]) + + @property + def log(self): + return "".join([segment.log for segment in self.value]) + + @property + def markdown(self): + return "".join([segment.markdown for segment in self.value]) + + def to_object(self): + return [segment.to_object() for segment in self.value] diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py new file mode 100644 index 00000000000000..b71882b043ecdf --- /dev/null +++ b/api/core/variables/segments.py @@ -0,0 +1,157 @@ +import json +import sys +from collections.abc import Mapping, Sequence +from typing import Any + +from pydantic import BaseModel, ConfigDict, field_validator + +from core.file import File + +from .types import SegmentType + + +class Segment(BaseModel): + model_config = ConfigDict(frozen=True) + + value_type: SegmentType + value: Any + + @field_validator("value_type") + @classmethod + def validate_value_type(cls, value): + """ + This validator checks if the provided value is equal to the default value of the 'value_type' field. + If the value is different, a ValueError is raised. + """ + if value != cls.model_fields["value_type"].default: + raise ValueError("Cannot modify 'value_type'") + return value + + @property + def text(self) -> str: + return str(self.value) + + @property + def log(self) -> str: + return str(self.value) + + @property + def markdown(self) -> str: + return str(self.value) + + @property + def size(self) -> int: + """ + Return the size of the value in bytes. + """ + return sys.getsizeof(self.value) + + def to_object(self) -> Any: + return self.value + + +class NoneSegment(Segment): + value_type: SegmentType = SegmentType.NONE + value: None = None + + @property + def text(self) -> str: + return "" + + @property + def log(self) -> str: + return "" + + @property + def markdown(self) -> str: + return "" + + +class StringSegment(Segment): + value_type: SegmentType = SegmentType.STRING + value: str + + +class FloatSegment(Segment): + value_type: SegmentType = SegmentType.NUMBER + value: float + + +class IntegerSegment(Segment): + value_type: SegmentType = SegmentType.NUMBER + value: int + + +class ObjectSegment(Segment): + value_type: SegmentType = SegmentType.OBJECT + value: Mapping[str, Any] + + @property + def text(self) -> str: + return json.dumps(self.model_dump()["value"], ensure_ascii=False) + + @property + def log(self) -> str: + return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2) + + @property + def markdown(self) -> str: + return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2) + + +class ArraySegment(Segment): + @property + def markdown(self) -> str: + items = [] + for item in self.value: + items.append(str(item)) + return "\n".join(items) + + +class FileSegment(Segment): + value_type: SegmentType = SegmentType.FILE + value: File + + @property + def markdown(self) -> str: + return self.value.markdown + + @property + def log(self) -> str: + return str(self.value) + + @property + def text(self) -> str: + return str(self.value) + + +class ArrayAnySegment(ArraySegment): + value_type: SegmentType = SegmentType.ARRAY_ANY + value: Sequence[Any] + + +class ArrayStringSegment(ArraySegment): + value_type: SegmentType = SegmentType.ARRAY_STRING + value: Sequence[str] + + +class ArrayNumberSegment(ArraySegment): + value_type: SegmentType = SegmentType.ARRAY_NUMBER + value: Sequence[float | int] + + +class ArrayObjectSegment(ArraySegment): + value_type: SegmentType = SegmentType.ARRAY_OBJECT + value: Sequence[Mapping[str, Any]] + + +class ArrayFileSegment(ArraySegment): + value_type: SegmentType = SegmentType.ARRAY_FILE + value: Sequence[File] + + @property + def markdown(self) -> str: + items = [] + for item in self.value: + items.append(item.markdown) + return "\n".join(items) diff --git a/api/core/variables/types.py b/api/core/variables/types.py new file mode 100644 index 00000000000000..53c2e8a3aa6ddc --- /dev/null +++ b/api/core/variables/types.py @@ -0,0 +1,17 @@ +from enum import Enum + + +class SegmentType(str, Enum): + NONE = "none" + NUMBER = "number" + STRING = "string" + SECRET = "secret" + ARRAY_ANY = "array[any]" + ARRAY_STRING = "array[string]" + ARRAY_NUMBER = "array[number]" + ARRAY_OBJECT = "array[object]" + OBJECT = "object" + FILE = "file" + ARRAY_FILE = "array[file]" + + GROUP = "group" diff --git a/api/core/app/segments/variables.py b/api/core/variables/variables.py similarity index 88% rename from api/core/app/segments/variables.py rename to api/core/variables/variables.py index 8fef707fcf298b..ddc69141928c83 100644 --- a/api/core/app/segments/variables.py +++ b/api/core/variables/variables.py @@ -7,6 +7,7 @@ ArrayNumberSegment, ArrayObjectSegment, ArrayStringSegment, + FileSegment, FloatSegment, IntegerSegment, NoneSegment, @@ -23,11 +24,11 @@ class Variable(Segment): """ id: str = Field( - default='', + default="", description="Unique identity for variable. It's only used by environment variables now.", ) name: str - description: str = Field(default='', description='Description of the variable.') + description: str = Field(default="", description="Description of the variable.") class StringVariable(StringSegment, Variable): @@ -62,7 +63,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable): pass - class SecretVariable(StringVariable): value_type: SegmentType = SegmentType.SECRET @@ -74,3 +74,7 @@ def log(self) -> str: class NoneVariable(NoneSegment, Variable): value_type: SegmentType = SegmentType.NONE value: None = None + + +class FileVariable(FileSegment, Variable): + pass diff --git a/api/core/workflow/callbacks/__init__.py b/api/core/workflow/callbacks/__init__.py index e69de29bb2d1d6..403fbbaa2fa616 100644 --- a/api/core/workflow/callbacks/__init__.py +++ b/api/core/workflow/callbacks/__init__.py @@ -0,0 +1,7 @@ +from .base_workflow_callback import WorkflowCallback +from .workflow_logging_callback import WorkflowLoggingCallback + +__all__ = [ + "WorkflowLoggingCallback", + "WorkflowCallback", +] diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index 6db8adf4c21d72..83086d1afc9018 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -1,116 +1,12 @@ from abc import ABC, abstractmethod -from typing import Any, Optional -from core.app.entities.queue_entities import AppQueueEvent -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeType +from core.workflow.graph_engine.entities.event import GraphEngineEvent class WorkflowCallback(ABC): @abstractmethod - def on_workflow_run_started(self) -> None: + def on_event(self, event: GraphEngineEvent) -> None: """ - Workflow run started - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_run_succeeded(self) -> None: - """ - Workflow run succeeded - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_run_failed(self, error: str) -> None: - """ - Workflow run failed - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_node_execute_started(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - node_run_index: int = 1, - predecessor_node_id: Optional[str] = None) -> None: - """ - Workflow node execute started - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_node_execute_succeeded(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - inputs: Optional[dict] = None, - process_data: Optional[dict] = None, - outputs: Optional[dict] = None, - execution_metadata: Optional[dict] = None) -> None: - """ - Workflow node execute succeeded - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_node_execute_failed(self, node_id: str, - node_type: NodeType, - node_data: BaseNodeData, - error: str, - inputs: Optional[dict] = None, - outputs: Optional[dict] = None, - process_data: Optional[dict] = None) -> None: - """ - Workflow node execute failed - """ - raise NotImplementedError - - @abstractmethod - def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: - """ - Publish text chunk - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_iteration_started(self, - node_id: str, - node_type: NodeType, - node_run_index: int = 1, - node_data: Optional[BaseNodeData] = None, - inputs: Optional[dict] = None, - predecessor_node_id: Optional[str] = None, - metadata: Optional[dict] = None) -> None: - """ - Publish iteration started - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_iteration_next(self, node_id: str, - node_type: NodeType, - index: int, - node_run_index: int, - output: Optional[Any], - ) -> None: - """ - Publish iteration next - """ - raise NotImplementedError - - @abstractmethod - def on_workflow_iteration_completed(self, node_id: str, - node_type: NodeType, - node_run_index: int, - outputs: dict) -> None: - """ - Publish iteration completed - """ - raise NotImplementedError - - @abstractmethod - def on_event(self, event: AppQueueEvent) -> None: - """ - Publish event + Published event """ raise NotImplementedError diff --git a/api/core/workflow/callbacks/workflow_logging_callback.py b/api/core/workflow/callbacks/workflow_logging_callback.py new file mode 100644 index 00000000000000..17913de7b0d2ce --- /dev/null +++ b/api/core/workflow/callbacks/workflow_logging_callback.py @@ -0,0 +1,221 @@ +from typing import Optional + +from core.model_runtime.utils.encoders import jsonable_encoder +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + GraphRunFailedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + IterationRunFailedEvent, + IterationRunNextEvent, + IterationRunStartedEvent, + IterationRunSucceededEvent, + NodeRunFailedEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, + ParallelBranchRunFailedEvent, + ParallelBranchRunStartedEvent, + ParallelBranchRunSucceededEvent, +) + +from .base_workflow_callback import WorkflowCallback + +_TEXT_COLOR_MAPPING = { + "blue": "36;1", + "yellow": "33;1", + "pink": "38;5;200", + "green": "32;1", + "red": "31;1", +} + + +class WorkflowLoggingCallback(WorkflowCallback): + def __init__(self) -> None: + self.current_node_id = None + + def on_event(self, event: GraphEngineEvent) -> None: + if isinstance(event, GraphRunStartedEvent): + self.print_text("\n[GraphRunStartedEvent]", color="pink") + elif isinstance(event, GraphRunSucceededEvent): + self.print_text("\n[GraphRunSucceededEvent]", color="green") + elif isinstance(event, GraphRunFailedEvent): + self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red") + elif isinstance(event, NodeRunStartedEvent): + self.on_workflow_node_execute_started(event=event) + elif isinstance(event, NodeRunSucceededEvent): + self.on_workflow_node_execute_succeeded(event=event) + elif isinstance(event, NodeRunFailedEvent): + self.on_workflow_node_execute_failed(event=event) + elif isinstance(event, NodeRunStreamChunkEvent): + self.on_node_text_chunk(event=event) + elif isinstance(event, ParallelBranchRunStartedEvent): + self.on_workflow_parallel_started(event=event) + elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent): + self.on_workflow_parallel_completed(event=event) + elif isinstance(event, IterationRunStartedEvent): + self.on_workflow_iteration_started(event=event) + elif isinstance(event, IterationRunNextEvent): + self.on_workflow_iteration_next(event=event) + elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent): + self.on_workflow_iteration_completed(event=event) + else: + self.print_text(f"\n[{event.__class__.__name__}]", color="blue") + + def on_workflow_node_execute_started(self, event: NodeRunStartedEvent) -> None: + """ + Workflow node execute started + """ + self.print_text("\n[NodeRunStartedEvent]", color="yellow") + self.print_text(f"Node ID: {event.node_id}", color="yellow") + self.print_text(f"Node Title: {event.node_data.title}", color="yellow") + self.print_text(f"Type: {event.node_type.value}", color="yellow") + + def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent) -> None: + """ + Workflow node execute succeeded + """ + route_node_state = event.route_node_state + + self.print_text("\n[NodeRunSucceededEvent]", color="green") + self.print_text(f"Node ID: {event.node_id}", color="green") + self.print_text(f"Node Title: {event.node_data.title}", color="green") + self.print_text(f"Type: {event.node_type.value}", color="green") + + if route_node_state.node_run_result: + node_run_result = route_node_state.node_run_result + self.print_text( + f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", + color="green", + ) + self.print_text( + f"Process Data: " + f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", + color="green", + ) + self.print_text( + f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", + color="green", + ) + self.print_text( + f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}", + color="green", + ) + + def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent) -> None: + """ + Workflow node execute failed + """ + route_node_state = event.route_node_state + + self.print_text("\n[NodeRunFailedEvent]", color="red") + self.print_text(f"Node ID: {event.node_id}", color="red") + self.print_text(f"Node Title: {event.node_data.title}", color="red") + self.print_text(f"Type: {event.node_type.value}", color="red") + + if route_node_state.node_run_result: + node_run_result = route_node_state.node_run_result + self.print_text(f"Error: {node_run_result.error}", color="red") + self.print_text( + f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", + color="red", + ) + self.print_text( + f"Process Data: " + f"{jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}", + color="red", + ) + self.print_text( + f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", + color="red", + ) + + def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None: + """ + Publish text chunk + """ + route_node_state = event.route_node_state + if not self.current_node_id or self.current_node_id != route_node_state.node_id: + self.current_node_id = route_node_state.node_id + self.print_text("\n[NodeRunStreamChunkEvent]") + self.print_text(f"Node ID: {route_node_state.node_id}") + + node_run_result = route_node_state.node_run_result + if node_run_result: + self.print_text( + f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}" + ) + + self.print_text(event.chunk_content, color="pink", end="") + + def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent) -> None: + """ + Publish parallel started + """ + self.print_text("\n[ParallelBranchRunStartedEvent]", color="blue") + self.print_text(f"Parallel ID: {event.parallel_id}", color="blue") + self.print_text(f"Branch ID: {event.parallel_start_node_id}", color="blue") + if event.in_iteration_id: + self.print_text(f"Iteration ID: {event.in_iteration_id}", color="blue") + + def on_workflow_parallel_completed( + self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent + ) -> None: + """ + Publish parallel completed + """ + if isinstance(event, ParallelBranchRunSucceededEvent): + color = "blue" + elif isinstance(event, ParallelBranchRunFailedEvent): + color = "red" + + self.print_text( + "\n[ParallelBranchRunSucceededEvent]" + if isinstance(event, ParallelBranchRunSucceededEvent) + else "\n[ParallelBranchRunFailedEvent]", + color=color, + ) + self.print_text(f"Parallel ID: {event.parallel_id}", color=color) + self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color) + if event.in_iteration_id: + self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color) + + if isinstance(event, ParallelBranchRunFailedEvent): + self.print_text(f"Error: {event.error}", color=color) + + def on_workflow_iteration_started(self, event: IterationRunStartedEvent) -> None: + """ + Publish iteration started + """ + self.print_text("\n[IterationRunStartedEvent]", color="blue") + self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue") + + def on_workflow_iteration_next(self, event: IterationRunNextEvent) -> None: + """ + Publish iteration next + """ + self.print_text("\n[IterationRunNextEvent]", color="blue") + self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue") + self.print_text(f"Iteration Index: {event.index}", color="blue") + + def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent) -> None: + """ + Publish iteration completed + """ + self.print_text( + "\n[IterationRunSucceededEvent]" + if isinstance(event, IterationRunSucceededEvent) + else "\n[IterationRunFailedEvent]", + color="blue", + ) + self.print_text(f"Node ID: {event.iteration_id}", color="blue") + + def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None: + """Print text with highlighting and no end characters.""" + text_to_print = self._get_colored_text(text, color) if color else text + print(f"{text_to_print}", end=end) + + def _get_colored_text(self, text: str, color: str) -> str: + """Get colored text.""" + color_str = _TEXT_COLOR_MAPPING[color] + return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" diff --git a/api/core/workflow/constants.py b/api/core/workflow/constants.py new file mode 100644 index 00000000000000..e3fe17c2845837 --- /dev/null +++ b/api/core/workflow/constants.py @@ -0,0 +1,3 @@ +SYSTEM_VARIABLE_NODE_ID = "sys" +ENVIRONMENT_VARIABLE_NODE_ID = "env" +CONVERSATION_VARIABLE_NODE_ID = "conversation" diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/entities/base_node_data_entities.py deleted file mode 100644 index 6bf0c11c7d723f..00000000000000 --- a/api/core/workflow/entities/base_node_data_entities.py +++ /dev/null @@ -1,22 +0,0 @@ -from abc import ABC -from typing import Optional - -from pydantic import BaseModel - - -class BaseNodeData(ABC, BaseModel): - title: str - desc: Optional[str] = None - -class BaseIterationNodeData(BaseNodeData): - start_node_id: str - -class BaseIterationState(BaseModel): - iteration_node_id: str - index: int - inputs: dict - - class MetaData(BaseModel): - pass - - metadata: MetaData \ No newline at end of file diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 025453567bfc1b..7e10cddc712baa 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -4,58 +4,26 @@ from pydantic import BaseModel -from models import WorkflowNodeExecutionStatus +from core.model_runtime.entities.llm_entities import LLMUsage +from models.workflow import WorkflowNodeExecutionStatus -class NodeType(Enum): - """ - Node Types. - """ - - START = 'start' - END = 'end' - ANSWER = 'answer' - LLM = 'llm' - KNOWLEDGE_RETRIEVAL = 'knowledge-retrieval' - IF_ELSE = 'if-else' - CODE = 'code' - TEMPLATE_TRANSFORM = 'template-transform' - QUESTION_CLASSIFIER = 'question-classifier' - HTTP_REQUEST = 'http-request' - TOOL = 'tool' - VARIABLE_AGGREGATOR = 'variable-aggregator' - # TODO: merge this into VARIABLE_AGGREGATOR - VARIABLE_ASSIGNER = 'variable-assigner' - LOOP = 'loop' - ITERATION = 'iteration' - PARAMETER_EXTRACTOR = 'parameter-extractor' - CONVERSATION_VARIABLE_ASSIGNER = 'assigner' - - @classmethod - def value_of(cls, value: str) -> 'NodeType': - """ - Get value of given node type. - - :param value: node type value - :return: node type - """ - for node_type in cls: - if node_type.value == value: - return node_type - raise ValueError(f'invalid node type value {value}') - - -class NodeRunMetadataKey(Enum): +class NodeRunMetadataKey(str, Enum): """ Node Run Metadata Key. """ - TOTAL_TOKENS = 'total_tokens' - TOTAL_PRICE = 'total_price' - CURRENCY = 'currency' - TOOL_INFO = 'tool_info' - ITERATION_ID = 'iteration_id' - ITERATION_INDEX = 'iteration_index' + TOTAL_TOKENS = "total_tokens" + TOTAL_PRICE = "total_price" + CURRENCY = "currency" + TOOL_INFO = "tool_info" + ITERATION_ID = "iteration_id" + ITERATION_INDEX = "iteration_index" + PARALLEL_ID = "parallel_id" + PARALLEL_START_NODE_ID = "parallel_start_node_id" + PARENT_PARALLEL_ID = "parent_parallel_id" + PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" + PARALLEL_MODE_RUN_ID = "parallel_mode_run_id" class NodeRunResult(BaseModel): @@ -66,9 +34,10 @@ class NodeRunResult(BaseModel): status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING inputs: Optional[Mapping[str, Any]] = None # node inputs - process_data: Optional[dict] = None # process data - outputs: Optional[Mapping[str, Any]] = None # node outputs + process_data: Optional[dict[str, Any]] = None # process data + outputs: Optional[dict[str, Any]] = None # node outputs metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata + llm_usage: Optional[LLMUsage] = None # llm usage edge_source_handle: Optional[str] = None # source handle id of node with multiple branches diff --git a/api/core/workflow/entities/variable_entities.py b/api/core/workflow/entities/variable_entities.py index 19d9af2a6171a4..8f4c2d797552ca 100644 --- a/api/core/workflow/entities/variable_entities.py +++ b/api/core/workflow/entities/variable_entities.py @@ -1,3 +1,5 @@ +from collections.abc import Sequence + from pydantic import BaseModel @@ -5,5 +7,6 @@ class VariableSelector(BaseModel): """ Variable Selector. """ + variable: str - value_selector: list[str] + value_selector: Sequence[str] diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 9fe3356faa2ef5..3dc3395da1e3af 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -1,62 +1,87 @@ +import re from collections import defaultdict from collections.abc import Mapping, Sequence from typing import Any, Union -from typing_extensions import deprecated +from pydantic import BaseModel, Field + +from core.file import File, FileAttribute, file_manager +from core.variables import Segment, SegmentGroup, Variable +from core.variables.segments import FileSegment +from factories import variable_factory + +from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from ..enums import SystemVariableKey + +VariableValue = Union[str, int, float, dict, list, File] + + +VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") + + +class VariablePool(BaseModel): + # Variable dictionary is a dictionary for looking up variables by their selector. + # The first element of the selector is the node id, it's the first-level key in the dictionary. + # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the + # elements of the selector except the first one. + variable_dictionary: dict[str, dict[int, Segment]] = Field( + description="Variables mapping", + default=defaultdict(dict), + ) + # TODO: This user inputs is not used for pool. + user_inputs: Mapping[str, Any] = Field( + description="User inputs", + ) + system_variables: Mapping[SystemVariableKey, Any] = Field( + description="System variables", + ) + environment_variables: Sequence[Variable] = Field( + description="Environment variables.", + default_factory=list, + ) + conversation_variables: Sequence[Variable] = Field( + description="Conversation variables.", + default_factory=list, + ) -from core.app.segments import Segment, Variable, factory -from core.file.file_obj import FileVar -from core.workflow.enums import SystemVariable - -VariableValue = Union[str, int, float, dict, list, FileVar] - - -SYSTEM_VARIABLE_NODE_ID = 'sys' -ENVIRONMENT_VARIABLE_NODE_ID = 'env' -CONVERSATION_VARIABLE_NODE_ID = 'conversation' - - -class VariablePool: def __init__( self, - system_variables: Mapping[SystemVariable, Any], - user_inputs: Mapping[str, Any], - environment_variables: Sequence[Variable], + *, + system_variables: Mapping[SystemVariableKey, Any] | None = None, + user_inputs: Mapping[str, Any] | None = None, + environment_variables: Sequence[Variable] | None = None, conversation_variables: Sequence[Variable] | None = None, - ) -> None: - # system variables - # for example: - # { - # 'query': 'abc', - # 'files': [] - # } - - # Varaible dictionary is a dictionary for looking up variables by their selector. - # The first element of the selector is the node id, it's the first-level key in the dictionary. - # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the - # elements of the selector except the first one. - self._variable_dictionary: dict[str, dict[int, Segment]] = defaultdict(dict) - - # TODO: This user inputs is not used for pool. - self.user_inputs = user_inputs - - # Add system variables to the variable pool - self.system_variables = system_variables - for key, value in system_variables.items(): + **kwargs, + ): + environment_variables = environment_variables or [] + conversation_variables = conversation_variables or [] + user_inputs = user_inputs or {} + system_variables = system_variables or {} + + super().__init__( + system_variables=system_variables, + user_inputs=user_inputs, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + **kwargs, + ) + + for key, value in self.system_variables.items(): self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value) - # Add environment variables to the variable pool - for var in environment_variables: + for var in self.environment_variables: self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) - # Add conversation variables to the variable pool - for var in conversation_variables or []: + for var in self.conversation_variables: self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var) def add(self, selector: Sequence[str], value: Any, /) -> None: """ Adds a variable to the variable pool. + NOTE: You should not add a non-Segment value to the variable pool + even if it is allowed now. + Args: selector (Sequence[str]): The selector for the variable. value (VariableValue): The value of the variable. @@ -68,18 +93,15 @@ def add(self, selector: Sequence[str], value: Any, /) -> None: None """ if len(selector) < 2: - raise ValueError('Invalid selector') - - if value is None: - return + raise ValueError("Invalid selector") if isinstance(value, Segment): v = value else: - v = factory.build_segment(value) + v = variable_factory.build_segment(value) hash_key = hash(tuple(selector[1:])) - self._variable_dictionary[selector[0]][hash_key] = v + self.variable_dictionary[selector[0]][hash_key] = v def get(self, selector: Sequence[str], /) -> Segment | None: """ @@ -95,31 +117,24 @@ def get(self, selector: Sequence[str], /) -> Segment | None: ValueError: If the selector is invalid. """ if len(selector) < 2: - raise ValueError('Invalid selector') - hash_key = hash(tuple(selector[1:])) - value = self._variable_dictionary[selector[0]].get(hash_key) + return None - return value - - @deprecated('This method is deprecated, use `get` instead.') - def get_any(self, selector: Sequence[str], /) -> Any | None: - """ - Retrieves the value from the variable pool based on the given selector. - - Args: - selector (Sequence[str]): The selector used to identify the variable. + hash_key = hash(tuple(selector[1:])) + value = self.variable_dictionary[selector[0]].get(hash_key) - Returns: - Any: The value associated with the given selector. + if value is None: + selector, attr = selector[:-1], selector[-1] + # Python support `attr in FileAttribute` after 3.12 + if attr not in {item.value for item in FileAttribute}: + return None + value = self.get(selector) + if not isinstance(value, FileSegment): + return None + attr = FileAttribute(attr) + attr_value = file_manager.get_attr(file=value.value, attr=attr) + return variable_factory.build_segment(attr_value) - Raises: - ValueError: If the selector is invalid. - """ - if len(selector) < 2: - raise ValueError('Invalid selector') - hash_key = hash(tuple(selector[1:])) - value = self._variable_dictionary[selector[0]].get(hash_key) - return value.to_object() if value else None + return value def remove(self, selector: Sequence[str], /): """ @@ -134,7 +149,23 @@ def remove(self, selector: Sequence[str], /): if not selector: return if len(selector) == 1: - self._variable_dictionary[selector[0]] = {} + self.variable_dictionary[selector[0]] = {} return hash_key = hash(tuple(selector[1:])) - self._variable_dictionary[selector[0]].pop(hash_key, None) + self.variable_dictionary[selector[0]].pop(hash_key, None) + + def convert_template(self, template: str, /): + parts = VARIABLE_PATTERN.split(template) + segments = [] + for part in filter(lambda x: x, parts): + if "." in part and (variable := self.get(part.split("."))): + segments.append(variable) + else: + segments.append(variable_factory.build_segment(part)) + return SegmentGroup(value=segments) + + def get_file(self, selector: Sequence[str], /) -> FileSegment | None: + segment = self.get(selector) + if isinstance(segment, FileSegment): + return segment + return None diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 9b35b8df8aa8e9..da56af1407d94f 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -3,12 +3,13 @@ from pydantic import BaseModel from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.base_node_data_entities import BaseIterationState -from core.workflow.entities.node_entities import NodeRunResult -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseNode, UserFrom +from core.workflow.nodes.base import BaseIterationState, BaseNode +from models.enums import UserFrom from models.workflow import Workflow, WorkflowType +from .node_entities import NodeRunResult +from .variable_pool import VariablePool + class WorkflowNodeAndResult: node: BaseNode @@ -46,13 +47,16 @@ class NodeRun(BaseModel): current_iteration_state: Optional[BaseIterationState] - def __init__(self, workflow: Workflow, - start_at: float, - variable_pool: VariablePool, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - workflow_call_depth: int): + def __init__( + self, + workflow: Workflow, + start_at: float, + variable_pool: VariablePool, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + workflow_call_depth: int, + ): self.workflow_id = workflow.id self.tenant_id = workflow.tenant_id self.app_id = workflow.app_id @@ -66,8 +70,7 @@ def __init__(self, workflow: Workflow, self.variable_pool = variable_pool self.total_tokens = 0 - self.workflow_nodes_and_results = [] - self.current_iteration_state = None self.workflow_node_steps = 1 - self.workflow_node_runs = [] \ No newline at end of file + self.workflow_node_runs = [] + self.current_iteration_state = None diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index 4757cf32f88988..213ed57f570968 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -1,25 +1,16 @@ from enum import Enum -class SystemVariable(str, Enum): +class SystemVariableKey(str, Enum): """ System Variables. """ - QUERY = 'query' - FILES = 'files' - CONVERSATION_ID = 'conversation_id' - USER_ID = 'user_id' - DIALOGUE_COUNT = 'dialogue_count' - @classmethod - def value_of(cls, value: str): - """ - Get value of given system variable. - - :param value: system variable value - :return: system variable - """ - for system_variable in cls: - if system_variable.value == value: - return system_variable - raise ValueError(f'invalid system variable value {value}') + QUERY = "query" + FILES = "files" + CONVERSATION_ID = "conversation_id" + USER_ID = "user_id" + DIALOGUE_COUNT = "dialogue_count" + APP_ID = "app_id" + WORKFLOW_ID = "workflow_id" + WORKFLOW_RUN_ID = "workflow_run_id" diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py index fe79fadf66b876..bd4ccc1072a2a3 100644 --- a/api/core/workflow/errors.py +++ b/api/core/workflow/errors.py @@ -1,10 +1,8 @@ -from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.base import BaseNode class WorkflowNodeRunFailedError(Exception): - def __init__(self, node_id: str, node_type: NodeType, node_title: str, error: str): - self.node_id = node_id - self.node_type = node_type - self.node_title = node_title + def __init__(self, node_instance: BaseNode, error: str): + self.node_instance = node_instance self.error = error - super().__init__(f"Node {node_title} run failed: {error}") + super().__init__(f"Node {node_instance.node_data.title} run failed: {error}") diff --git a/api/core/workflow/graph_engine/__init__.py b/api/core/workflow/graph_engine/__init__.py new file mode 100644 index 00000000000000..2fee3d7fad8644 --- /dev/null +++ b/api/core/workflow/graph_engine/__init__.py @@ -0,0 +1,3 @@ +from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState + +__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"] diff --git a/api/core/workflow/graph_engine/condition_handlers/__init__.py b/api/core/workflow/graph_engine/condition_handlers/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/graph_engine/condition_handlers/base_handler.py b/api/core/workflow/graph_engine/condition_handlers/base_handler.py new file mode 100644 index 00000000000000..697392b2a3c23f --- /dev/null +++ b/api/core/workflow/graph_engine/condition_handlers/base_handler.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod + +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.entities.run_condition import RunCondition +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState + + +class RunConditionHandler(ABC): + def __init__(self, init_params: GraphInitParams, graph: Graph, condition: RunCondition): + self.init_params = init_params + self.graph = graph + self.condition = condition + + @abstractmethod + def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool: + """ + Check if the condition can be executed + + :param graph_runtime_state: graph runtime state + :param previous_route_node_state: previous route node state + :return: bool + """ + raise NotImplementedError diff --git a/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py new file mode 100644 index 00000000000000..af695df7d84607 --- /dev/null +++ b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py @@ -0,0 +1,25 @@ +from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState + + +class BranchIdentifyRunConditionHandler(RunConditionHandler): + def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool: + """ + Check if the condition can be executed + + :param graph_runtime_state: graph runtime state + :param previous_route_node_state: previous route node state + :return: bool + """ + if not self.condition.branch_identify: + raise Exception("Branch identify is required") + + run_result = previous_route_node_state.node_run_result + if not run_result: + return False + + if not run_result.edge_source_handle: + return False + + return self.condition.branch_identify == run_result.edge_source_handle diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py new file mode 100644 index 00000000000000..bc3a15bd004ace --- /dev/null +++ b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py @@ -0,0 +1,27 @@ +from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.utils.condition.processor import ConditionProcessor + + +class ConditionRunConditionHandlerHandler(RunConditionHandler): + def check(self, graph_runtime_state: GraphRuntimeState, previous_route_node_state: RouteNodeState) -> bool: + """ + Check if the condition can be executed + + :param graph_runtime_state: graph runtime state + :param previous_route_node_state: previous route node state + :return: bool + """ + if not self.condition.conditions: + return True + + # process condition + condition_processor = ConditionProcessor() + _, _, final_result = condition_processor.process_conditions( + variable_pool=graph_runtime_state.variable_pool, + conditions=self.condition.conditions, + operator="and", + ) + + return final_result diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_manager.py b/api/core/workflow/graph_engine/condition_handlers/condition_manager.py new file mode 100644 index 00000000000000..1c9237d82fbe68 --- /dev/null +++ b/api/core/workflow/graph_engine/condition_handlers/condition_manager.py @@ -0,0 +1,25 @@ +from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler +from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler +from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.run_condition import RunCondition + + +class ConditionManager: + @staticmethod + def get_condition_handler( + init_params: GraphInitParams, graph: Graph, run_condition: RunCondition + ) -> RunConditionHandler: + """ + Get condition handler + + :param init_params: init params + :param graph: graph + :param run_condition: run condition + :return: condition handler + """ + if run_condition.type == "branch_identify": + return BranchIdentifyRunConditionHandler(init_params=init_params, graph=graph, condition=run_condition) + else: + return ConditionRunConditionHandlerHandler(init_params=init_params, graph=graph, condition=run_condition) diff --git a/api/core/workflow/graph_engine/entities/__init__.py b/api/core/workflow/graph_engine/entities/__init__.py new file mode 100644 index 00000000000000..6331a0b723fd50 --- /dev/null +++ b/api/core/workflow/graph_engine/entities/__init__.py @@ -0,0 +1,6 @@ +from .graph import Graph +from .graph_init_params import GraphInitParams +from .graph_runtime_state import GraphRuntimeState +from .runtime_route_state import RuntimeRouteState + +__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"] diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py new file mode 100644 index 00000000000000..bacea191dd866c --- /dev/null +++ b/api/core/workflow/graph_engine/entities/event.py @@ -0,0 +1,170 @@ +from datetime import datetime +from typing import Any, Optional + +from pydantic import BaseModel, Field + +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.nodes import NodeType +from core.workflow.nodes.base import BaseNodeData + + +class GraphEngineEvent(BaseModel): + pass + + +########################################### +# Graph Events +########################################### + + +class BaseGraphEvent(GraphEngineEvent): + pass + + +class GraphRunStartedEvent(BaseGraphEvent): + pass + + +class GraphRunSucceededEvent(BaseGraphEvent): + outputs: Optional[dict[str, Any]] = None + """outputs""" + + +class GraphRunFailedEvent(BaseGraphEvent): + error: str = Field(..., description="failed reason") + + +########################################### +# Node Events +########################################### + + +class BaseNodeEvent(GraphEngineEvent): + id: str = Field(..., description="node execution id") + node_id: str = Field(..., description="node id") + node_type: NodeType = Field(..., description="node type") + node_data: BaseNodeData = Field(..., description="node data") + route_node_state: RouteNodeState = Field(..., description="route node state") + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + + +class NodeRunStartedEvent(BaseNodeEvent): + predecessor_node_id: Optional[str] = None + parallel_mode_run_id: Optional[str] = None + """predecessor node id""" + + +class NodeRunStreamChunkEvent(BaseNodeEvent): + chunk_content: str = Field(..., description="chunk content") + from_variable_selector: Optional[list[str]] = None + """from variable selector""" + + +class NodeRunRetrieverResourceEvent(BaseNodeEvent): + retriever_resources: list[dict] = Field(..., description="retriever resources") + context: str = Field(..., description="context") + + +class NodeRunSucceededEvent(BaseNodeEvent): + pass + + +class NodeRunFailedEvent(BaseNodeEvent): + error: str = Field(..., description="error") + + +class NodeInIterationFailedEvent(BaseNodeEvent): + error: str = Field(..., description="error") + + +########################################### +# Parallel Branch Events +########################################### + + +class BaseParallelBranchEvent(GraphEngineEvent): + parallel_id: str = Field(..., description="parallel id") + """parallel id""" + parallel_start_node_id: str = Field(..., description="parallel start node id") + """parallel start node id""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + + +class ParallelBranchRunStartedEvent(BaseParallelBranchEvent): + pass + + +class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent): + pass + + +class ParallelBranchRunFailedEvent(BaseParallelBranchEvent): + error: str = Field(..., description="failed reason") + + +########################################### +# Iteration Events +########################################### + + +class BaseIterationEvent(GraphEngineEvent): + iteration_id: str = Field(..., description="iteration node execution id") + iteration_node_id: str = Field(..., description="iteration node id") + iteration_node_type: NodeType = Field(..., description="node type, iteration or loop") + iteration_node_data: BaseNodeData = Field(..., description="node data") + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + parallel_mode_run_id: Optional[str] = None + """iteratoin run in parallel mode run id""" + + +class IterationRunStartedEvent(BaseIterationEvent): + start_at: datetime = Field(..., description="start at") + inputs: Optional[dict[str, Any]] = None + metadata: Optional[dict[str, Any]] = None + predecessor_node_id: Optional[str] = None + + +class IterationRunNextEvent(BaseIterationEvent): + index: int = Field(..., description="index") + pre_iteration_output: Optional[Any] = Field(None, description="pre iteration output") + + +class IterationRunSucceededEvent(BaseIterationEvent): + start_at: datetime = Field(..., description="start at") + inputs: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + metadata: Optional[dict[str, Any]] = None + steps: int = 0 + + +class IterationRunFailedEvent(BaseIterationEvent): + start_at: datetime = Field(..., description="start at") + inputs: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + metadata: Optional[dict[str, Any]] = None + steps: int = 0 + error: str = Field(..., description="failed reason") + + +InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py new file mode 100644 index 00000000000000..d87c039409d62e --- /dev/null +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -0,0 +1,715 @@ +import uuid +from collections.abc import Mapping +from typing import Any, Optional, cast + +from pydantic import BaseModel, Field + +from core.workflow.graph_engine.entities.run_condition import RunCondition +from core.workflow.nodes import NodeType +from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter +from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute +from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter +from core.workflow.nodes.end.entities import EndStreamParam + + +class GraphEdge(BaseModel): + source_node_id: str = Field(..., description="source node id") + target_node_id: str = Field(..., description="target node id") + run_condition: Optional[RunCondition] = None + """run condition""" + + +class GraphParallel(BaseModel): + id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id") + start_from_node_id: str = Field(..., description="start from node id") + parent_parallel_id: Optional[str] = None + """parent parallel id""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id""" + end_to_node_id: Optional[str] = None + """end to node id""" + + +class Graph(BaseModel): + root_node_id: str = Field(..., description="root node id of the graph") + node_ids: list[str] = Field(default_factory=list, description="graph node ids") + node_id_config_mapping: dict[str, dict] = Field( + default_factory=list, description="node configs mapping (node id: node config)" + ) + edge_mapping: dict[str, list[GraphEdge]] = Field( + default_factory=dict, description="graph edge mapping (source node id: edges)" + ) + reverse_edge_mapping: dict[str, list[GraphEdge]] = Field( + default_factory=dict, description="reverse graph edge mapping (target node id: edges)" + ) + parallel_mapping: dict[str, GraphParallel] = Field( + default_factory=dict, description="graph parallel mapping (parallel id: parallel)" + ) + node_parallel_mapping: dict[str, str] = Field( + default_factory=dict, description="graph node parallel mapping (node id: parallel id)" + ) + answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(..., description="answer stream generate routes") + end_stream_param: EndStreamParam = Field(..., description="end stream param") + + @classmethod + def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = None) -> "Graph": + """ + Init graph + + :param graph_config: graph config + :param root_node_id: root node id + :return: graph + """ + # edge configs + edge_configs = graph_config.get("edges") + if edge_configs is None: + edge_configs = [] + + edge_configs = cast(list, edge_configs) + + # reorganize edges mapping + edge_mapping: dict[str, list[GraphEdge]] = {} + reverse_edge_mapping: dict[str, list[GraphEdge]] = {} + target_edge_ids = set() + for edge_config in edge_configs: + source_node_id = edge_config.get("source") + if not source_node_id: + continue + + if source_node_id not in edge_mapping: + edge_mapping[source_node_id] = [] + + target_node_id = edge_config.get("target") + if not target_node_id: + continue + + if target_node_id not in reverse_edge_mapping: + reverse_edge_mapping[target_node_id] = [] + + target_edge_ids.add(target_node_id) + + # parse run condition + run_condition = None + if edge_config.get("sourceHandle") and edge_config.get("sourceHandle") != "source": + run_condition = RunCondition(type="branch_identify", branch_identify=edge_config.get("sourceHandle")) + + graph_edge = GraphEdge( + source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition + ) + + edge_mapping[source_node_id].append(graph_edge) + reverse_edge_mapping[target_node_id].append(graph_edge) + + # node configs + node_configs = graph_config.get("nodes") + if not node_configs: + raise ValueError("Graph must have at least one node") + + node_configs = cast(list, node_configs) + + # fetch nodes that have no predecessor node + root_node_configs = [] + all_node_id_config_mapping: dict[str, dict] = {} + for node_config in node_configs: + node_id = node_config.get("id") + if not node_id: + continue + + if node_id not in target_edge_ids: + root_node_configs.append(node_config) + + all_node_id_config_mapping[node_id] = node_config + + root_node_ids = [node_config.get("id") for node_config in root_node_configs] + + # fetch root node + if not root_node_id: + # if no root node id, use the START type node as root node + root_node_id = next( + ( + node_config.get("id") + for node_config in root_node_configs + if node_config.get("data", {}).get("type", "") == NodeType.START.value + ), + None, + ) + + if not root_node_id or root_node_id not in root_node_ids: + raise ValueError(f"Root node id {root_node_id} not found in the graph") + + # Check whether it is connected to the previous node + cls._check_connected_to_previous_node(route=[root_node_id], edge_mapping=edge_mapping) + + # fetch all node ids from root node + node_ids = [root_node_id] + cls._recursively_add_node_ids(node_ids=node_ids, edge_mapping=edge_mapping, node_id=root_node_id) + + node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids} + + # init parallel mapping + parallel_mapping: dict[str, GraphParallel] = {} + node_parallel_mapping: dict[str, str] = {} + cls._recursively_add_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=root_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + ) + + # Check if it exceeds N layers of parallel + for parallel in parallel_mapping.values(): + if parallel.parent_parallel_id: + cls._check_exceed_parallel_limit( + parallel_mapping=parallel_mapping, level_limit=3, parent_parallel_id=parallel.parent_parallel_id + ) + + # init answer stream generate routes + answer_stream_generate_routes = AnswerStreamGeneratorRouter.init( + node_id_config_mapping=node_id_config_mapping, reverse_edge_mapping=reverse_edge_mapping + ) + + # init end stream param + end_stream_param = EndStreamGeneratorRouter.init( + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + node_parallel_mapping=node_parallel_mapping, + ) + + # init graph + graph = cls( + root_node_id=root_node_id, + node_ids=node_ids, + node_id_config_mapping=node_id_config_mapping, + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + answer_stream_generate_routes=answer_stream_generate_routes, + end_stream_param=end_stream_param, + ) + + return graph + + def add_extra_edge( + self, source_node_id: str, target_node_id: str, run_condition: Optional[RunCondition] = None + ) -> None: + """ + Add extra edge to the graph + + :param source_node_id: source node id + :param target_node_id: target node id + :param run_condition: run condition + """ + if source_node_id not in self.node_ids or target_node_id not in self.node_ids: + return + + if source_node_id not in self.edge_mapping: + self.edge_mapping[source_node_id] = [] + + if target_node_id in [graph_edge.target_node_id for graph_edge in self.edge_mapping[source_node_id]]: + return + + graph_edge = GraphEdge( + source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition + ) + + self.edge_mapping[source_node_id].append(graph_edge) + + def get_leaf_node_ids(self) -> list[str]: + """ + Get leaf node ids of the graph + + :return: leaf node ids + """ + leaf_node_ids = [] + for node_id in self.node_ids: + if node_id not in self.edge_mapping or ( + len(self.edge_mapping[node_id]) == 1 + and self.edge_mapping[node_id][0].target_node_id == self.root_node_id + ): + leaf_node_ids.append(node_id) + + return leaf_node_ids + + @classmethod + def _recursively_add_node_ids( + cls, node_ids: list[str], edge_mapping: dict[str, list[GraphEdge]], node_id: str + ) -> None: + """ + Recursively add node ids + + :param node_ids: node ids + :param edge_mapping: edge mapping + :param node_id: node id + """ + for graph_edge in edge_mapping.get(node_id, []): + if graph_edge.target_node_id in node_ids: + continue + + node_ids.append(graph_edge.target_node_id) + cls._recursively_add_node_ids( + node_ids=node_ids, edge_mapping=edge_mapping, node_id=graph_edge.target_node_id + ) + + @classmethod + def _check_connected_to_previous_node(cls, route: list[str], edge_mapping: dict[str, list[GraphEdge]]) -> None: + """ + Check whether it is connected to the previous node + """ + last_node_id = route[-1] + + for graph_edge in edge_mapping.get(last_node_id, []): + if not graph_edge.target_node_id: + continue + + if graph_edge.target_node_id in route: + raise ValueError( + f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph." + ) + + new_route = route.copy() + new_route.append(graph_edge.target_node_id) + cls._check_connected_to_previous_node( + route=new_route, + edge_mapping=edge_mapping, + ) + + @classmethod + def _recursively_add_parallels( + cls, + edge_mapping: dict[str, list[GraphEdge]], + reverse_edge_mapping: dict[str, list[GraphEdge]], + start_node_id: str, + parallel_mapping: dict[str, GraphParallel], + node_parallel_mapping: dict[str, str], + parent_parallel: Optional[GraphParallel] = None, + ) -> None: + """ + Recursively add parallel ids + + :param edge_mapping: edge mapping + :param start_node_id: start from node id + :param parallel_mapping: parallel mapping + :param node_parallel_mapping: node parallel mapping + :param parent_parallel: parent parallel + """ + target_node_edges = edge_mapping.get(start_node_id, []) + parallel = None + if len(target_node_edges) > 1: + # fetch all node ids in current parallels + parallel_branch_node_ids = {} + condition_edge_mappings = {} + for graph_edge in target_node_edges: + if graph_edge.run_condition is None: + if "default" not in parallel_branch_node_ids: + parallel_branch_node_ids["default"] = [] + + parallel_branch_node_ids["default"].append(graph_edge.target_node_id) + else: + condition_hash = graph_edge.run_condition.hash + if condition_hash not in condition_edge_mappings: + condition_edge_mappings[condition_hash] = [] + + condition_edge_mappings[condition_hash].append(graph_edge) + + for condition_hash, graph_edges in condition_edge_mappings.items(): + if len(graph_edges) > 1: + if condition_hash not in parallel_branch_node_ids: + parallel_branch_node_ids[condition_hash] = [] + + for graph_edge in graph_edges: + parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id) + + condition_parallels = {} + for condition_hash, condition_parallel_branch_node_ids in parallel_branch_node_ids.items(): + # any target node id in node_parallel_mapping + parallel = None + if condition_parallel_branch_node_ids: + parent_parallel_id = parent_parallel.id if parent_parallel else None + + parallel = GraphParallel( + start_from_node_id=start_node_id, + parent_parallel_id=parent_parallel.id if parent_parallel else None, + parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None, + ) + parallel_mapping[parallel.id] = parallel + condition_parallels[condition_hash] = parallel + + in_branch_node_ids = cls._fetch_all_node_ids_in_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + parallel_branch_node_ids=condition_parallel_branch_node_ids, + ) + + # collect all branches node ids + parallel_node_ids = [] + for _, node_ids in in_branch_node_ids.items(): + for node_id in node_ids: + in_parent_parallel = True + if parent_parallel_id: + in_parent_parallel = False + for parallel_node_id, parallel_id in node_parallel_mapping.items(): + if parallel_id == parent_parallel_id and parallel_node_id == node_id: + in_parent_parallel = True + break + + if in_parent_parallel: + parallel_node_ids.append(node_id) + node_parallel_mapping[node_id] = parallel.id + + outside_parallel_target_node_ids = set() + for node_id in parallel_node_ids: + if node_id == parallel.start_from_node_id: + continue + + node_edges = edge_mapping.get(node_id) + if not node_edges: + continue + + if len(node_edges) > 1: + continue + + target_node_id = node_edges[0].target_node_id + if target_node_id in parallel_node_ids: + continue + + if parent_parallel_id: + parent_parallel = parallel_mapping.get(parent_parallel_id) + if not parent_parallel: + continue + + if ( + ( + node_parallel_mapping.get(target_node_id) + and node_parallel_mapping.get(target_node_id) == parent_parallel_id + ) + or ( + parent_parallel + and parent_parallel.end_to_node_id + and target_node_id == parent_parallel.end_to_node_id + ) + or (not node_parallel_mapping.get(target_node_id) and not parent_parallel) + ): + outside_parallel_target_node_ids.add(target_node_id) + + if len(outside_parallel_target_node_ids) == 1: + if ( + parent_parallel + and parent_parallel.end_to_node_id + and parallel.end_to_node_id == parent_parallel.end_to_node_id + ): + parallel.end_to_node_id = None + else: + parallel.end_to_node_id = outside_parallel_target_node_ids.pop() + + if condition_edge_mappings: + for condition_hash, graph_edges in condition_edge_mappings.items(): + for graph_edge in graph_edges: + current_parallel: GraphParallel | None = cls._get_current_parallel( + parallel_mapping=parallel_mapping, + graph_edge=graph_edge, + parallel=condition_parallels.get(condition_hash), + parent_parallel=parent_parallel, + ) + + cls._recursively_add_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=graph_edge.target_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + parent_parallel=current_parallel, + ) + else: + for graph_edge in target_node_edges: + current_parallel = cls._get_current_parallel( + parallel_mapping=parallel_mapping, + graph_edge=graph_edge, + parallel=parallel, + parent_parallel=parent_parallel, + ) + + cls._recursively_add_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=graph_edge.target_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + parent_parallel=current_parallel, + ) + else: + for graph_edge in target_node_edges: + current_parallel = cls._get_current_parallel( + parallel_mapping=parallel_mapping, + graph_edge=graph_edge, + parallel=parallel, + parent_parallel=parent_parallel, + ) + + cls._recursively_add_parallels( + edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=graph_edge.target_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping, + parent_parallel=current_parallel, + ) + + @classmethod + def _get_current_parallel( + cls, + parallel_mapping: dict[str, GraphParallel], + graph_edge: GraphEdge, + parallel: Optional[GraphParallel] = None, + parent_parallel: Optional[GraphParallel] = None, + ) -> Optional[GraphParallel]: + """ + Get current parallel + """ + current_parallel = None + if parallel: + current_parallel = parallel + elif parent_parallel: + if not parent_parallel.end_to_node_id or ( + parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id + ): + current_parallel = parent_parallel + else: + # fetch parent parallel's parent parallel + parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id + if parent_parallel_parent_parallel_id: + parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id) + if parent_parallel_parent_parallel and ( + not parent_parallel_parent_parallel.end_to_node_id + or ( + parent_parallel_parent_parallel.end_to_node_id + and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id + ) + ): + current_parallel = parent_parallel_parent_parallel + + return current_parallel + + @classmethod + def _check_exceed_parallel_limit( + cls, + parallel_mapping: dict[str, GraphParallel], + level_limit: int, + parent_parallel_id: str, + current_level: int = 1, + ) -> None: + """ + Check if it exceeds N layers of parallel + """ + parent_parallel = parallel_mapping.get(parent_parallel_id) + if not parent_parallel: + return + + current_level += 1 + if current_level > level_limit: + raise ValueError(f"Exceeds {level_limit} layers of parallel") + + if parent_parallel.parent_parallel_id: + cls._check_exceed_parallel_limit( + parallel_mapping=parallel_mapping, + level_limit=level_limit, + parent_parallel_id=parent_parallel.parent_parallel_id, + current_level=current_level, + ) + + @classmethod + def _recursively_add_parallel_node_ids( + cls, + branch_node_ids: list[str], + edge_mapping: dict[str, list[GraphEdge]], + merge_node_id: str, + start_node_id: str, + ) -> None: + """ + Recursively add node ids + + :param branch_node_ids: in branch node ids + :param edge_mapping: edge mapping + :param merge_node_id: merge node id + :param start_node_id: start node id + """ + for graph_edge in edge_mapping.get(start_node_id, []): + if graph_edge.target_node_id != merge_node_id and graph_edge.target_node_id not in branch_node_ids: + branch_node_ids.append(graph_edge.target_node_id) + cls._recursively_add_parallel_node_ids( + branch_node_ids=branch_node_ids, + edge_mapping=edge_mapping, + merge_node_id=merge_node_id, + start_node_id=graph_edge.target_node_id, + ) + + @classmethod + def _fetch_all_node_ids_in_parallels( + cls, + edge_mapping: dict[str, list[GraphEdge]], + reverse_edge_mapping: dict[str, list[GraphEdge]], + parallel_branch_node_ids: list[str], + ) -> dict[str, list[str]]: + """ + Fetch all node ids in parallels + """ + routes_node_ids: dict[str, list[str]] = {} + for parallel_branch_node_id in parallel_branch_node_ids: + routes_node_ids[parallel_branch_node_id] = [parallel_branch_node_id] + + # fetch routes node ids + cls._recursively_fetch_routes( + edge_mapping=edge_mapping, + start_node_id=parallel_branch_node_id, + routes_node_ids=routes_node_ids[parallel_branch_node_id], + ) + + # fetch leaf node ids from routes node ids + leaf_node_ids: dict[str, list[str]] = {} + merge_branch_node_ids: dict[str, list[str]] = {} + for branch_node_id, node_ids in routes_node_ids.items(): + for node_id in node_ids: + if node_id not in edge_mapping or len(edge_mapping[node_id]) == 0: + if branch_node_id not in leaf_node_ids: + leaf_node_ids[branch_node_id] = [] + + leaf_node_ids[branch_node_id].append(node_id) + + for branch_node_id2, inner_route2 in routes_node_ids.items(): + if ( + branch_node_id != branch_node_id2 + and node_id in inner_route2 + and len(reverse_edge_mapping.get(node_id, [])) > 1 + and cls._is_node_in_routes( + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=node_id, + routes_node_ids=routes_node_ids, + ) + ): + if node_id not in merge_branch_node_ids: + merge_branch_node_ids[node_id] = [] + + if branch_node_id2 not in merge_branch_node_ids[node_id]: + merge_branch_node_ids[node_id].append(branch_node_id2) + + # sorted merge_branch_node_ids by branch_node_ids length desc + merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True)) + + duplicate_end_node_ids = {} + for node_id, branch_node_ids in merge_branch_node_ids.items(): + for node_id2, branch_node_ids2 in merge_branch_node_ids.items(): + if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2): + if (node_id, node_id2) not in duplicate_end_node_ids and ( + node_id2, + node_id, + ) not in duplicate_end_node_ids: + duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids + + for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items(): + # check which node is after + if cls._is_node2_after_node1(node1_id=node_id, node2_id=node_id2, edge_mapping=edge_mapping): + if node_id in merge_branch_node_ids: + del merge_branch_node_ids[node_id2] + elif cls._is_node2_after_node1(node1_id=node_id2, node2_id=node_id, edge_mapping=edge_mapping): + if node_id2 in merge_branch_node_ids: + del merge_branch_node_ids[node_id] + + branches_merge_node_ids: dict[str, str] = {} + for node_id, branch_node_ids in merge_branch_node_ids.items(): + if len(branch_node_ids) <= 1: + continue + + for branch_node_id in branch_node_ids: + if branch_node_id in branches_merge_node_ids: + continue + + branches_merge_node_ids[branch_node_id] = node_id + + in_branch_node_ids: dict[str, list[str]] = {} + for branch_node_id, node_ids in routes_node_ids.items(): + in_branch_node_ids[branch_node_id] = [] + if branch_node_id not in branches_merge_node_ids: + # all node ids in current branch is in this thread + in_branch_node_ids[branch_node_id].append(branch_node_id) + in_branch_node_ids[branch_node_id].extend(node_ids) + else: + merge_node_id = branches_merge_node_ids[branch_node_id] + if merge_node_id != branch_node_id: + in_branch_node_ids[branch_node_id].append(branch_node_id) + + # fetch all node ids from branch_node_id and merge_node_id + cls._recursively_add_parallel_node_ids( + branch_node_ids=in_branch_node_ids[branch_node_id], + edge_mapping=edge_mapping, + merge_node_id=merge_node_id, + start_node_id=branch_node_id, + ) + + return in_branch_node_ids + + @classmethod + def _recursively_fetch_routes( + cls, edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: list[str] + ) -> None: + """ + Recursively fetch route + """ + if start_node_id not in edge_mapping: + return + + for graph_edge in edge_mapping[start_node_id]: + # find next node ids + if graph_edge.target_node_id not in routes_node_ids: + routes_node_ids.append(graph_edge.target_node_id) + + cls._recursively_fetch_routes( + edge_mapping=edge_mapping, start_node_id=graph_edge.target_node_id, routes_node_ids=routes_node_ids + ) + + @classmethod + def _is_node_in_routes( + cls, reverse_edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, routes_node_ids: dict[str, list[str]] + ) -> bool: + """ + Recursively check if the node is in the routes + """ + if start_node_id not in reverse_edge_mapping: + return False + + all_routes_node_ids = set() + parallel_start_node_ids: dict[str, list[str]] = {} + for branch_node_id, node_ids in routes_node_ids.items(): + all_routes_node_ids.update(node_ids) + + if branch_node_id in reverse_edge_mapping: + for graph_edge in reverse_edge_mapping[branch_node_id]: + if graph_edge.source_node_id not in parallel_start_node_ids: + parallel_start_node_ids[graph_edge.source_node_id] = [] + + parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id) + + for _, branch_node_ids in parallel_start_node_ids.items(): + if set(branch_node_ids) == set(routes_node_ids.keys()): + return True + + return False + + @classmethod + def _is_node2_after_node1(cls, node1_id: str, node2_id: str, edge_mapping: dict[str, list[GraphEdge]]) -> bool: + """ + is node2 after node1 + """ + if node1_id not in edge_mapping: + return False + + for graph_edge in edge_mapping[node1_id]: + if graph_edge.target_node_id == node2_id: + return True + + if cls._is_node2_after_node1( + node1_id=graph_edge.target_node_id, node2_id=node2_id, edge_mapping=edge_mapping + ): + return True + + return False diff --git a/api/core/workflow/graph_engine/entities/graph_init_params.py b/api/core/workflow/graph_engine/entities/graph_init_params.py new file mode 100644 index 00000000000000..a0ecd824f427b9 --- /dev/null +++ b/api/core/workflow/graph_engine/entities/graph_init_params.py @@ -0,0 +1,21 @@ +from collections.abc import Mapping +from typing import Any + +from pydantic import BaseModel, Field + +from core.app.entities.app_invoke_entities import InvokeFrom +from models.enums import UserFrom +from models.workflow import WorkflowType + + +class GraphInitParams(BaseModel): + # init params + tenant_id: str = Field(..., description="tenant / workspace id") + app_id: str = Field(..., description="app id") + workflow_type: WorkflowType = Field(..., description="workflow type") + workflow_id: str = Field(..., description="workflow id") + graph_config: Mapping[str, Any] = Field(..., description="graph config") + user_id: str = Field(..., description="user id") + user_from: UserFrom = Field(..., description="user from, account or end-user") + invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger") + call_depth: int = Field(..., description="call depth") diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py new file mode 100644 index 00000000000000..afc09bfac5b0c1 --- /dev/null +++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py @@ -0,0 +1,27 @@ +from typing import Any + +from pydantic import BaseModel, Field + +from core.model_runtime.entities.llm_entities import LLMUsage +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState + + +class GraphRuntimeState(BaseModel): + variable_pool: VariablePool = Field(..., description="variable pool") + """variable pool""" + + start_at: float = Field(..., description="start time") + """start time""" + total_tokens: int = 0 + """total tokens""" + llm_usage: LLMUsage = LLMUsage.empty_usage() + """llm usage info""" + outputs: dict[str, Any] = {} + """outputs""" + + node_run_steps: int = 0 + """node run steps""" + + node_run_state: RuntimeRouteState = RuntimeRouteState() + """node run state""" diff --git a/api/core/workflow/graph_engine/entities/next_graph_node.py b/api/core/workflow/graph_engine/entities/next_graph_node.py new file mode 100644 index 00000000000000..6aa4341ddfe171 --- /dev/null +++ b/api/core/workflow/graph_engine/entities/next_graph_node.py @@ -0,0 +1,13 @@ +from typing import Optional + +from pydantic import BaseModel + +from core.workflow.graph_engine.entities.graph import GraphParallel + + +class NextGraphNode(BaseModel): + node_id: str + """next node id""" + + parallel: Optional[GraphParallel] = None + """parallel""" diff --git a/api/core/workflow/graph_engine/entities/run_condition.py b/api/core/workflow/graph_engine/entities/run_condition.py new file mode 100644 index 00000000000000..eedce8842b411e --- /dev/null +++ b/api/core/workflow/graph_engine/entities/run_condition.py @@ -0,0 +1,21 @@ +import hashlib +from typing import Literal, Optional + +from pydantic import BaseModel + +from core.workflow.utils.condition.entities import Condition + + +class RunCondition(BaseModel): + type: Literal["branch_identify", "condition"] + """condition type""" + + branch_identify: Optional[str] = None + """branch identify like: sourceHandle, required when type is branch_identify""" + + conditions: Optional[list[Condition]] = None + """conditions to run the node, required when type is condition""" + + @property + def hash(self) -> str: + return hashlib.sha256(self.model_dump_json().encode()).hexdigest() diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py new file mode 100644 index 00000000000000..bb24b511127395 --- /dev/null +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -0,0 +1,109 @@ +import uuid +from datetime import datetime, timezone +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, Field + +from core.workflow.entities.node_entities import NodeRunResult +from models.workflow import WorkflowNodeExecutionStatus + + +class RouteNodeState(BaseModel): + class Status(Enum): + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + PAUSED = "paused" + + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + """node state id""" + + node_id: str + """node id""" + + node_run_result: Optional[NodeRunResult] = None + """node run result""" + + status: Status = Status.RUNNING + """node status""" + + start_at: datetime + """start time""" + + paused_at: Optional[datetime] = None + """paused time""" + + finished_at: Optional[datetime] = None + """finished time""" + + failed_reason: Optional[str] = None + """failed reason""" + + paused_by: Optional[str] = None + """paused by""" + + index: int = 1 + + def set_finished(self, run_result: NodeRunResult) -> None: + """ + Node finished + + :param run_result: run result + """ + if self.status in {RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED}: + raise Exception(f"Route state {self.id} already finished") + + if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + self.status = RouteNodeState.Status.SUCCESS + elif run_result.status == WorkflowNodeExecutionStatus.FAILED: + self.status = RouteNodeState.Status.FAILED + self.failed_reason = run_result.error + else: + raise Exception(f"Invalid route status {run_result.status}") + + self.node_run_result = run_result + self.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + + +class RuntimeRouteState(BaseModel): + routes: dict[str, list[str]] = Field( + default_factory=dict, description="graph state routes (source_node_state_id: target_node_state_id)" + ) + + node_state_mapping: dict[str, RouteNodeState] = Field( + default_factory=dict, description="node state mapping (route_node_state_id: route_node_state)" + ) + + def create_node_state(self, node_id: str) -> RouteNodeState: + """ + Create node state + + :param node_id: node id + """ + state = RouteNodeState(node_id=node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None)) + self.node_state_mapping[state.id] = state + return state + + def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None: + """ + Add route to the graph state + + :param source_node_state_id: source node state id + :param target_node_state_id: target node state id + """ + if source_node_state_id not in self.routes: + self.routes[source_node_state_id] = [] + + self.routes[source_node_state_id].append(target_node_state_id) + + def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) -> list[RouteNodeState]: + """ + Get routes with node state by source node id + + :param source_node_state_id: source node state id + :return: routes with node state + """ + return [ + self.node_state_mapping[target_state_id] for target_state_id in self.routes.get(source_node_state_id, []) + ] diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py new file mode 100644 index 00000000000000..f07ad4de11bdfe --- /dev/null +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -0,0 +1,741 @@ +import logging +import queue +import time +import uuid +from collections.abc import Generator, Mapping +from concurrent.futures import ThreadPoolExecutor, wait +from copy import copy, deepcopy +from typing import Any, Optional + +from flask import Flask, current_app + +from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.entities.variable_pool import VariablePool, VariableValue +from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager +from core.workflow.graph_engine.entities.event import ( + BaseIterationEvent, + GraphEngineEvent, + GraphRunFailedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunFailedEvent, + NodeRunRetrieverResourceEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, + ParallelBranchRunFailedEvent, + ParallelBranchRunStartedEvent, + ParallelBranchRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph, GraphEdge +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.nodes import NodeType +from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor +from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent +from core.workflow.nodes.node_mapping import node_type_classes_mapping +from extensions.ext_database import db +from models.enums import UserFrom +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType + +logger = logging.getLogger(__name__) + + +class GraphEngineThreadPool(ThreadPoolExecutor): + def __init__( + self, max_workers=None, thread_name_prefix="", initializer=None, initargs=(), max_submit_count=100 + ) -> None: + super().__init__(max_workers, thread_name_prefix, initializer, initargs) + self.max_submit_count = max_submit_count + self.submit_count = 0 + + def submit(self, fn, *args, **kwargs): + self.submit_count += 1 + self.check_is_full() + + return super().submit(fn, *args, **kwargs) + + def task_done_callback(self, future): + self.submit_count -= 1 + + def check_is_full(self) -> None: + print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}") + if self.submit_count > self.max_submit_count: + raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.") + + +class GraphEngine: + workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {} + + def __init__( + self, + tenant_id: str, + app_id: str, + workflow_type: WorkflowType, + workflow_id: str, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + call_depth: int, + graph: Graph, + graph_config: Mapping[str, Any], + variable_pool: VariablePool, + max_execution_steps: int, + max_execution_time: int, + thread_pool_id: Optional[str] = None, + ) -> None: + thread_pool_max_submit_count = 100 + thread_pool_max_workers = 10 + + # init thread pool + if thread_pool_id: + if thread_pool_id not in GraphEngine.workflow_thread_pool_mapping: + raise ValueError(f"Max submit count {thread_pool_max_submit_count} of workflow thread pool reached.") + + self.thread_pool_id = thread_pool_id + self.thread_pool = GraphEngine.workflow_thread_pool_mapping[thread_pool_id] + self.is_main_thread_pool = False + else: + self.thread_pool = GraphEngineThreadPool( + max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count + ) + self.thread_pool_id = str(uuid.uuid4()) + self.is_main_thread_pool = True + GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool + + self.graph = graph + self.init_params = GraphInitParams( + tenant_id=tenant_id, + app_id=app_id, + workflow_type=workflow_type, + workflow_id=workflow_id, + graph_config=graph_config, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + call_depth=call_depth, + ) + + self.graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + self.max_execution_steps = max_execution_steps + self.max_execution_time = max_execution_time + + def run(self) -> Generator[GraphEngineEvent, None, None]: + # trigger graph run start event + yield GraphRunStartedEvent() + + try: + if self.init_params.workflow_type == WorkflowType.CHAT: + stream_processor = AnswerStreamProcessor( + graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool + ) + else: + stream_processor = EndStreamProcessor( + graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool + ) + + # run graph + generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id)) + + for item in generator: + try: + yield item + if isinstance(item, NodeRunFailedEvent): + yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or "Unknown error.") + return + elif isinstance(item, NodeRunSucceededEvent): + if item.node_type == NodeType.END: + self.graph_runtime_state.outputs = ( + item.route_node_state.node_run_result.outputs + if item.route_node_state.node_run_result + and item.route_node_state.node_run_result.outputs + else {} + ) + elif item.node_type == NodeType.ANSWER: + if "answer" not in self.graph_runtime_state.outputs: + self.graph_runtime_state.outputs["answer"] = "" + + self.graph_runtime_state.outputs["answer"] += "\n" + ( + item.route_node_state.node_run_result.outputs.get("answer", "") + if item.route_node_state.node_run_result + and item.route_node_state.node_run_result.outputs + else "" + ) + + self.graph_runtime_state.outputs["answer"] = self.graph_runtime_state.outputs[ + "answer" + ].strip() + except Exception as e: + logger.exception(f"Graph run failed: {str(e)}") + yield GraphRunFailedEvent(error=str(e)) + return + + # trigger graph run success event + yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs) + self._release_thread() + except GraphRunFailedError as e: + yield GraphRunFailedEvent(error=e.error) + self._release_thread() + return + except Exception as e: + logger.exception("Unknown Error when graph running") + yield GraphRunFailedEvent(error=str(e)) + self._release_thread() + raise e + + def _release_thread(self): + if self.is_main_thread_pool and self.thread_pool_id in GraphEngine.workflow_thread_pool_mapping: + del GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] + + def _run( + self, + start_node_id: str, + in_parallel_id: Optional[str] = None, + parent_parallel_id: Optional[str] = None, + parent_parallel_start_node_id: Optional[str] = None, + ) -> Generator[GraphEngineEvent, None, None]: + parallel_start_node_id = None + if in_parallel_id: + parallel_start_node_id = start_node_id + + next_node_id = start_node_id + previous_route_node_state: Optional[RouteNodeState] = None + while True: + # max steps reached + if self.graph_runtime_state.node_run_steps > self.max_execution_steps: + raise GraphRunFailedError("Max steps {} reached.".format(self.max_execution_steps)) + + # or max execution time reached + if self._is_timed_out( + start_at=self.graph_runtime_state.start_at, max_execution_time=self.max_execution_time + ): + raise GraphRunFailedError("Max execution time {}s reached.".format(self.max_execution_time)) + + # init route node state + route_node_state = self.graph_runtime_state.node_run_state.create_node_state(node_id=next_node_id) + + # get node config + node_id = route_node_state.node_id + node_config = self.graph.node_id_config_mapping.get(node_id) + if not node_config: + raise GraphRunFailedError(f"Node {node_id} config not found.") + + # convert to specific node + node_type = NodeType(node_config.get("data", {}).get("type")) + node_cls = node_type_classes_mapping[node_type] + + previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None + + # init workflow run state + node_instance = node_cls( # type: ignore + id=route_node_state.id, + config=node_config, + graph_init_params=self.init_params, + graph=self.graph, + graph_runtime_state=self.graph_runtime_state, + previous_node_id=previous_node_id, + thread_pool_id=self.thread_pool_id, + ) + + try: + # run node + generator = self._run_node( + node_instance=node_instance, + route_node_state=route_node_state, + parallel_id=in_parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + + for item in generator: + if isinstance(item, NodeRunStartedEvent): + self.graph_runtime_state.node_run_steps += 1 + item.route_node_state.index = self.graph_runtime_state.node_run_steps + + yield item + + self.graph_runtime_state.node_run_state.node_state_mapping[route_node_state.id] = route_node_state + + # append route + if previous_route_node_state: + self.graph_runtime_state.node_run_state.add_route( + source_node_state_id=previous_route_node_state.id, target_node_state_id=route_node_state.id + ) + except Exception as e: + route_node_state.status = RouteNodeState.Status.FAILED + route_node_state.failed_reason = str(e) + yield NodeRunFailedEvent( + error=str(e), + id=node_instance.id, + node_id=next_node_id, + node_type=node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=in_parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + raise e + + # It may not be necessary, but it is necessary. :) + if ( + self.graph.node_id_config_mapping[next_node_id].get("data", {}).get("type", "").lower() + == NodeType.END.value + ): + break + + previous_route_node_state = route_node_state + + # get next node ids + edge_mappings = self.graph.edge_mapping.get(next_node_id) + if not edge_mappings: + break + + if len(edge_mappings) == 1: + edge = edge_mappings[0] + + if edge.run_condition: + result = ConditionManager.get_condition_handler( + init_params=self.init_params, + graph=self.graph, + run_condition=edge.run_condition, + ).check( + graph_runtime_state=self.graph_runtime_state, + previous_route_node_state=previous_route_node_state, + ) + + if not result: + break + + next_node_id = edge.target_node_id + else: + final_node_id = None + + if any(edge.run_condition for edge in edge_mappings): + # if nodes has run conditions, get node id which branch to take based on the run condition results + condition_edge_mappings = {} + for edge in edge_mappings: + if edge.run_condition: + run_condition_hash = edge.run_condition.hash + if run_condition_hash not in condition_edge_mappings: + condition_edge_mappings[run_condition_hash] = [] + + condition_edge_mappings[run_condition_hash].append(edge) + + for _, sub_edge_mappings in condition_edge_mappings.items(): + if len(sub_edge_mappings) == 0: + continue + + edge = sub_edge_mappings[0] + + result = ConditionManager.get_condition_handler( + init_params=self.init_params, + graph=self.graph, + run_condition=edge.run_condition, + ).check( + graph_runtime_state=self.graph_runtime_state, + previous_route_node_state=previous_route_node_state, + ) + + if not result: + continue + + if len(sub_edge_mappings) == 1: + final_node_id = edge.target_node_id + else: + parallel_generator = self._run_parallel_branches( + edge_mappings=sub_edge_mappings, + in_parallel_id=in_parallel_id, + parallel_start_node_id=parallel_start_node_id, + ) + + for item in parallel_generator: + if isinstance(item, str): + final_node_id = item + else: + yield item + + break + + if not final_node_id: + break + + next_node_id = final_node_id + else: + parallel_generator = self._run_parallel_branches( + edge_mappings=edge_mappings, + in_parallel_id=in_parallel_id, + parallel_start_node_id=parallel_start_node_id, + ) + + for item in parallel_generator: + if isinstance(item, str): + final_node_id = item + else: + yield item + + if not final_node_id: + break + + next_node_id = final_node_id + + if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, "") != in_parallel_id: + break + + def _run_parallel_branches( + self, + edge_mappings: list[GraphEdge], + in_parallel_id: Optional[str] = None, + parallel_start_node_id: Optional[str] = None, + ) -> Generator[GraphEngineEvent | str, None, None]: + # if nodes has no run conditions, parallel run all nodes + parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id) + if not parallel_id: + node_id = edge_mappings[0].target_node_id + node_config = self.graph.node_id_config_mapping.get(node_id) + if not node_config: + raise GraphRunFailedError( + f"Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches." + ) + + node_title = node_config.get("data", {}).get("title") + raise GraphRunFailedError( + f"Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches." + ) + + parallel = self.graph.parallel_mapping.get(parallel_id) + if not parallel: + raise GraphRunFailedError(f"Parallel {parallel_id} not found.") + + # run parallel nodes, run in new thread and use queue to get results + q: queue.Queue = queue.Queue() + + # Create a list to store the threads + futures = [] + + # new thread + for edge in edge_mappings: + if ( + edge.target_node_id not in self.graph.node_parallel_mapping + or self.graph.node_parallel_mapping.get(edge.target_node_id, "") != parallel_id + ): + continue + + future = self.thread_pool.submit( + self._run_parallel_node, + **{ + "flask_app": current_app._get_current_object(), # type: ignore[attr-defined] + "q": q, + "parallel_id": parallel_id, + "parallel_start_node_id": edge.target_node_id, + "parent_parallel_id": in_parallel_id, + "parent_parallel_start_node_id": parallel_start_node_id, + }, + ) + + future.add_done_callback(self.thread_pool.task_done_callback) + + futures.append(future) + + succeeded_count = 0 + while True: + try: + event = q.get(timeout=1) + if event is None: + break + + yield event + if event.parallel_id == parallel_id: + if isinstance(event, ParallelBranchRunSucceededEvent): + succeeded_count += 1 + if succeeded_count == len(futures): + q.put(None) + + continue + elif isinstance(event, ParallelBranchRunFailedEvent): + raise GraphRunFailedError(event.error) + except queue.Empty: + continue + + # wait all threads + wait(futures) + + # get final node id + final_node_id = parallel.end_to_node_id + if final_node_id: + yield final_node_id + + def _run_parallel_node( + self, + flask_app: Flask, + q: queue.Queue, + parallel_id: str, + parallel_start_node_id: str, + parent_parallel_id: Optional[str] = None, + parent_parallel_start_node_id: Optional[str] = None, + ) -> None: + """ + Run parallel nodes + """ + with flask_app.app_context(): + try: + q.put( + ParallelBranchRunStartedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + ) + + # run node + generator = self._run( + start_node_id=parallel_start_node_id, + in_parallel_id=parallel_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + + for item in generator: + q.put(item) + + # trigger graph run success event + q.put( + ParallelBranchRunSucceededEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + ) + except GraphRunFailedError as e: + q.put( + ParallelBranchRunFailedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + error=e.error, + ) + ) + except Exception as e: + logger.exception("Unknown Error when generating in parallel") + q.put( + ParallelBranchRunFailedEvent( + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + error=str(e), + ) + ) + finally: + db.session.remove() + + def _run_node( + self, + node_instance: BaseNode, + route_node_state: RouteNodeState, + parallel_id: Optional[str] = None, + parallel_start_node_id: Optional[str] = None, + parent_parallel_id: Optional[str] = None, + parent_parallel_start_node_id: Optional[str] = None, + ) -> Generator[GraphEngineEvent, None, None]: + """ + Run node + """ + # trigger node run start event + yield NodeRunStartedEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + predecessor_node_id=node_instance.previous_node_id, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + + db.session.close() + + try: + # run node + generator = node_instance.run() + for item in generator: + if isinstance(item, GraphEngineEvent): + if isinstance(item, BaseIterationEvent): + # add parallel info to iteration event + item.parallel_id = parallel_id + item.parallel_start_node_id = parallel_start_node_id + item.parent_parallel_id = parent_parallel_id + item.parent_parallel_start_node_id = parent_parallel_start_node_id + + yield item + else: + if isinstance(item, RunCompletedEvent): + run_result = item.run_result + route_node_state.set_finished(run_result=run_result) + + if run_result.status == WorkflowNodeExecutionStatus.FAILED: + yield NodeRunFailedEvent( + error=route_node_state.failed_reason or "Unknown error.", + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + # plus state total_tokens + self.graph_runtime_state.total_tokens += int( + run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type] + ) + + if run_result.llm_usage: + # use the latest usage + self.graph_runtime_state.llm_usage += run_result.llm_usage + + # append node output variables to variable pool + if run_result.outputs: + for variable_key, variable_value in run_result.outputs.items(): + # append variables to variable pool recursively + self._append_variables_recursively( + node_id=node_instance.node_id, + variable_key_list=[variable_key], + variable_value=variable_value, + ) + + # add parallel info to run result metadata + if parallel_id and parallel_start_node_id: + if not run_result.metadata: + run_result.metadata = {} + + run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id + run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id + if parent_parallel_id and parent_parallel_start_node_id: + run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id + run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( + parent_parallel_start_node_id + ) + + yield NodeRunSucceededEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + + break + elif isinstance(item, RunStreamChunkEvent): + yield NodeRunStreamChunkEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + chunk_content=item.chunk_content, + from_variable_selector=item.from_variable_selector, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + elif isinstance(item, RunRetrieverResourceEvent): + yield NodeRunRetrieverResourceEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + retriever_resources=item.retriever_resources, + context=item.context, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + except GenerateTaskStoppedError: + # trigger node run failed event + route_node_state.status = RouteNodeState.Status.FAILED + route_node_state.failed_reason = "Workflow stopped." + yield NodeRunFailedEvent( + error="Workflow stopped.", + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + return + except Exception as e: + logger.exception(f"Node {node_instance.node_data.title} run failed: {str(e)}") + raise e + finally: + db.session.close() + + def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): + """ + Append variables recursively + :param node_id: node id + :param variable_key_list: variable key list + :param variable_value: variable value + :return: + """ + self.graph_runtime_state.variable_pool.add([node_id] + variable_key_list, variable_value) + + # if variable_value is a dict, then recursively append variables + if isinstance(variable_value, dict): + for key, value in variable_value.items(): + # construct new key list + new_key_list = variable_key_list + [key] + self._append_variables_recursively( + node_id=node_id, variable_key_list=new_key_list, variable_value=value + ) + + def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: + """ + Check timeout + :param start_at: start time + :param max_execution_time: max execution time + :return: + """ + return time.perf_counter() - start_at > max_execution_time + + def create_copy(self): + """ + create a graph engine copy + :return: with a new variable pool instance of graph engine + """ + new_instance = copy(self) + new_instance.graph_runtime_state = copy(self.graph_runtime_state) + new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool) + return new_instance + + +class GraphRunFailedError(Exception): + def __init__(self, error: str): + self.error = error diff --git a/api/core/workflow/nodes/__init__.py b/api/core/workflow/nodes/__init__.py index e69de29bb2d1d6..6101fcf9afd982 100644 --- a/api/core/workflow/nodes/__init__.py +++ b/api/core/workflow/nodes/__init__.py @@ -0,0 +1,3 @@ +from .enums import NodeType + +__all__ = ["NodeType"] diff --git a/api/core/workflow/nodes/answer/__init__.py b/api/core/workflow/nodes/answer/__init__.py index e69de29bb2d1d6..7a10f47eed2d9c 100644 --- a/api/core/workflow/nodes/answer/__init__.py +++ b/api/core/workflow/nodes/answer/__init__.py @@ -0,0 +1,4 @@ +from .answer_node import AnswerNode +from .entities import AnswerStreamGenerateRoute + +__all__ = ["AnswerStreamGenerateRoute", "AnswerNode"] diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 5bae27092f920d..520cbdbb605115 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -1,131 +1,72 @@ -from typing import cast +from collections.abc import Mapping, Sequence +from typing import Any, cast -from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool +from core.variables import ArrayFileSegment, FileSegment +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter from core.workflow.nodes.answer.entities import ( AnswerNodeData, GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk, ) -from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser from models.workflow import WorkflowNodeExecutionStatus -class AnswerNode(BaseNode): +class AnswerNode(BaseNode[AnswerNodeData]): _node_data_cls = AnswerNodeData _node_type: NodeType = NodeType.ANSWER - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run node - :param variable_pool: variable pool :return: """ - node_data = self.node_data - node_data = cast(AnswerNodeData, node_data) - # generate routes - generate_routes = self.extract_generate_route_from_node_data(node_data) + generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self.node_data) - answer = '' + answer = "" + files = [] for part in generate_routes: - if part.type == "var": + if part.type == GenerateRouteChunk.ChunkType.VAR: part = cast(VarGenerateRouteChunk, part) value_selector = part.value_selector - value = variable_pool.get(value_selector) - if value: - answer += value.markdown + variable = self.graph_runtime_state.variable_pool.get(value_selector) + if variable: + if isinstance(variable, FileSegment): + files.append(variable.value) + elif isinstance(variable, ArrayFileSegment): + files.extend(variable.value) + answer += variable.markdown else: part = cast(TextGenerateRouteChunk, part) answer += part.text - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - "answer": answer - } - ) - - @classmethod - def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]: - """ - Extract generate route selectors - :param config: node config - :return: - """ - node_data = cls._node_data_cls(**config.get("data", {})) - node_data = cast(AnswerNodeData, node_data) - - return cls.extract_generate_route_from_node_data(node_data) - - @classmethod - def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]: - """ - Extract generate route from node data - :param node_data: node data object - :return: - """ - variable_template_parser = VariableTemplateParser(template=node_data.answer) - variable_selectors = variable_template_parser.extract_variable_selectors() - - value_selector_mapping = { - variable_selector.variable: variable_selector.value_selector - for variable_selector in variable_selectors - } - - variable_keys = list(value_selector_mapping.keys()) - - # format answer template - template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True) - template_variable_keys = template_parser.variable_keys - - # Take the intersection of variable_keys and template_variable_keys - variable_keys = list(set(variable_keys) & set(template_variable_keys)) - - template = node_data.answer - for var in variable_keys: - template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω') - - generate_routes = [] - for part in template.split('Ω'): - if part: - if cls._is_variable(part, variable_keys): - var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '') - value_selector = value_selector_mapping[var_key] - generate_routes.append(VarGenerateRouteChunk( - value_selector=value_selector - )) - else: - generate_routes.append(TextGenerateRouteChunk( - text=part - )) - - return generate_routes + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer, "files": files}) @classmethod - def _is_variable(cls, part, variable_keys): - cleaned_part = part.replace('{{', '').replace('}}', '') - return part.startswith('{{') and cleaned_part in variable_keys - - @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: AnswerNodeData, + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ - node_data = node_data - node_data = cast(AnswerNodeData, node_data) - variable_template_parser = VariableTemplateParser(template=node_data.answer) variable_selectors = variable_template_parser.extract_variable_selectors() variable_mapping = {} for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector + variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector return variable_mapping diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py new file mode 100644 index 00000000000000..96e24a7db3725e --- /dev/null +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -0,0 +1,166 @@ +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.nodes.answer.entities import ( + AnswerNodeData, + AnswerStreamGenerateRoute, + GenerateRouteChunk, + TextGenerateRouteChunk, + VarGenerateRouteChunk, +) +from core.workflow.nodes.enums import NodeType +from core.workflow.utils.variable_template_parser import VariableTemplateParser + + +class AnswerStreamGeneratorRouter: + @classmethod + def init( + cls, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + ) -> AnswerStreamGenerateRoute: + """ + Get stream generate routes. + :return: + """ + # parse stream output node value selectors of answer nodes + answer_generate_route: dict[str, list[GenerateRouteChunk]] = {} + for answer_node_id, node_config in node_id_config_mapping.items(): + if node_config.get("data", {}).get("type") != NodeType.ANSWER.value: + continue + + # get generate route for stream output + generate_route = cls._extract_generate_route_selectors(node_config) + answer_generate_route[answer_node_id] = generate_route + + # fetch answer dependencies + answer_node_ids = list(answer_generate_route.keys()) + answer_dependencies = cls._fetch_answers_dependencies( + answer_node_ids=answer_node_ids, + reverse_edge_mapping=reverse_edge_mapping, + node_id_config_mapping=node_id_config_mapping, + ) + + return AnswerStreamGenerateRoute( + answer_generate_route=answer_generate_route, answer_dependencies=answer_dependencies + ) + + @classmethod + def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]: + """ + Extract generate route from node data + :param node_data: node data object + :return: + """ + variable_template_parser = VariableTemplateParser(template=node_data.answer) + variable_selectors = variable_template_parser.extract_variable_selectors() + + value_selector_mapping = { + variable_selector.variable: variable_selector.value_selector for variable_selector in variable_selectors + } + + variable_keys = list(value_selector_mapping.keys()) + + # format answer template + template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True) + template_variable_keys = template_parser.variable_keys + + # Take the intersection of variable_keys and template_variable_keys + variable_keys = list(set(variable_keys) & set(template_variable_keys)) + + template = node_data.answer + for var in variable_keys: + template = template.replace(f"{{{{{var}}}}}", f"Ω{{{{{var}}}}}Ω") + + generate_routes: list[GenerateRouteChunk] = [] + for part in template.split("Ω"): + if part: + if cls._is_variable(part, variable_keys): + var_key = part.replace("Ω", "").replace("{{", "").replace("}}", "") + value_selector = value_selector_mapping[var_key] + generate_routes.append(VarGenerateRouteChunk(value_selector=value_selector)) + else: + generate_routes.append(TextGenerateRouteChunk(text=part)) + + return generate_routes + + @classmethod + def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]: + """ + Extract generate route selectors + :param config: node config + :return: + """ + node_data = AnswerNodeData(**config.get("data", {})) + return cls.extract_generate_route_from_node_data(node_data) + + @classmethod + def _is_variable(cls, part, variable_keys): + cleaned_part = part.replace("{{", "").replace("}}", "") + return part.startswith("{{") and cleaned_part in variable_keys + + @classmethod + def _fetch_answers_dependencies( + cls, + answer_node_ids: list[str], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_id_config_mapping: dict[str, dict], + ) -> dict[str, list[str]]: + """ + Fetch answer dependencies + :param answer_node_ids: answer node ids + :param reverse_edge_mapping: reverse edge mapping + :param node_id_config_mapping: node id config mapping + :return: + """ + answer_dependencies: dict[str, list[str]] = {} + for answer_node_id in answer_node_ids: + if answer_dependencies.get(answer_node_id) is None: + answer_dependencies[answer_node_id] = [] + + cls._recursive_fetch_answer_dependencies( + current_node_id=answer_node_id, + answer_node_id=answer_node_id, + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + answer_dependencies=answer_dependencies, + ) + + return answer_dependencies + + @classmethod + def _recursive_fetch_answer_dependencies( + cls, + current_node_id: str, + answer_node_id: str, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + answer_dependencies: dict[str, list[str]], + ) -> None: + """ + Recursive fetch answer dependencies + :param current_node_id: current node id + :param answer_node_id: answer node id + :param node_id_config_mapping: node id config mapping + :param reverse_edge_mapping: reverse edge mapping + :param answer_dependencies: answer dependencies + :return: + """ + reverse_edges = reverse_edge_mapping.get(current_node_id, []) + for edge in reverse_edges: + source_node_id = edge.source_node_id + source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") + if source_node_type in { + NodeType.ANSWER, + NodeType.IF_ELSE, + NodeType.QUESTION_CLASSIFIER, + NodeType.ITERATION, + NodeType.CONVERSATION_VARIABLE_ASSIGNER, + }: + answer_dependencies[answer_node_id].append(source_node_id) + else: + cls._recursive_fetch_answer_dependencies( + current_node_id=source_node_id, + answer_node_id=answer_node_id, + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + answer_dependencies=answer_dependencies, + ) diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py new file mode 100644 index 00000000000000..8a768088da660e --- /dev/null +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -0,0 +1,221 @@ +import logging +from collections.abc import Generator +from typing import cast + +from core.file import FILE_MODEL_IDENTITY, File +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.nodes.answer.base_stream_processor import StreamProcessor +from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk + +logger = logging.getLogger(__name__) + + +class AnswerStreamProcessor(StreamProcessor): + def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: + super().__init__(graph, variable_pool) + self.generate_routes = graph.answer_stream_generate_routes + self.route_position = {} + for answer_node_id in self.generate_routes.answer_generate_route: + self.route_position[answer_node_id] = 0 + self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} + + def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: + for event in generator: + if isinstance(event, NodeRunStartedEvent): + if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids: + self.reset() + + yield event + elif isinstance(event, NodeRunStreamChunkEvent): + if event.in_iteration_id: + yield event + continue + + if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: + stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[ + event.route_node_state.node_id + ] + else: + stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event) + self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = ( + stream_out_answer_node_ids + ) + + for _ in stream_out_answer_node_ids: + yield event + elif isinstance(event, NodeRunSucceededEvent): + yield event + if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: + # update self.route_position after all stream event finished + for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]: + self.route_position[answer_node_id] += 1 + + del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] + + # remove unreachable nodes + self._remove_unreachable_nodes(event) + + # generate stream outputs + yield from self._generate_stream_outputs_when_node_finished(event) + else: + yield event + + def reset(self) -> None: + self.route_position = {} + for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items(): + self.route_position[answer_node_id] = 0 + self.rest_node_ids = self.graph.node_ids.copy() + self.current_stream_chunk_generating_node_ids = {} + + def _generate_stream_outputs_when_node_finished( + self, event: NodeRunSucceededEvent + ) -> Generator[GraphEngineEvent, None, None]: + """ + Generate stream outputs. + :param event: node run succeeded event + :return: + """ + for answer_node_id, position in self.route_position.items(): + # all depends on answer node id not in rest node ids + if event.route_node_state.node_id != answer_node_id and ( + answer_node_id not in self.rest_node_ids + or not all( + dep_id not in self.rest_node_ids + for dep_id in self.generate_routes.answer_dependencies[answer_node_id] + ) + ): + continue + + route_position = self.route_position[answer_node_id] + route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:] + + for route_chunk in route_chunks: + if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT: + route_chunk = cast(TextGenerateRouteChunk, route_chunk) + yield NodeRunStreamChunkEvent( + id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + chunk_content=route_chunk.text, + route_node_state=event.route_node_state, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + from_variable_selector=[answer_node_id, "answer"], + ) + else: + route_chunk = cast(VarGenerateRouteChunk, route_chunk) + value_selector = route_chunk.value_selector + if not value_selector: + break + + value = self.variable_pool.get(value_selector) + + if value is None: + break + + text = value.markdown + + if text: + yield NodeRunStreamChunkEvent( + id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + chunk_content=text, + from_variable_selector=value_selector, + route_node_state=event.route_node_state, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ) + + self.route_position[answer_node_id] += 1 + + def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]: + """ + Is stream out support + :param event: queue text chunk event + :return: + """ + if not event.from_variable_selector: + return [] + + stream_output_value_selector = event.from_variable_selector + if not stream_output_value_selector: + return [] + + stream_out_answer_node_ids = [] + for answer_node_id, route_position in self.route_position.items(): + if answer_node_id not in self.rest_node_ids: + continue + + # all depends on answer node id not in rest node ids + if all( + dep_id not in self.rest_node_ids for dep_id in self.generate_routes.answer_dependencies[answer_node_id] + ): + if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]): + continue + + route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position] + + if route_chunk.type != GenerateRouteChunk.ChunkType.VAR: + continue + + route_chunk = cast(VarGenerateRouteChunk, route_chunk) + value_selector = route_chunk.value_selector + + # check chunk node id is before current node id or equal to current node id + if value_selector != stream_output_value_selector: + continue + + stream_out_answer_node_ids.append(answer_node_id) + + return stream_out_answer_node_ids + + @classmethod + def _fetch_files_from_variable_value(cls, value: dict | list) -> list[dict]: + """ + Fetch files from variable value + :param value: variable value + :return: + """ + if not value: + return [] + + files = [] + if isinstance(value, list): + for item in value: + file_var = cls._get_file_var_from_value(item) + if file_var: + files.append(file_var) + elif isinstance(value, dict): + file_var = cls._get_file_var_from_value(value) + if file_var: + files.append(file_var) + + return files + + @classmethod + def _get_file_var_from_value(cls, value: dict | list): + """ + Get file var from value + :param value: variable value + :return: + """ + if not value: + return None + + if isinstance(value, dict): + if "dify_model_identity" in value and value["dify_model_identity"] == FILE_MODEL_IDENTITY: + return value + elif isinstance(value, File): + return value.to_dict() + + return None diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py new file mode 100644 index 00000000000000..36c3fe180a9cb2 --- /dev/null +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -0,0 +1,70 @@ +from abc import ABC, abstractmethod +from collections.abc import Generator + +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent +from core.workflow.graph_engine.entities.graph import Graph + + +class StreamProcessor(ABC): + def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: + self.graph = graph + self.variable_pool = variable_pool + self.rest_node_ids = graph.node_ids.copy() + + @abstractmethod + def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: + raise NotImplementedError + + def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None: + finished_node_id = event.route_node_state.node_id + if finished_node_id not in self.rest_node_ids: + return + + # remove finished node id + self.rest_node_ids.remove(finished_node_id) + + run_result = event.route_node_state.node_run_result + if not run_result: + return + + if run_result.edge_source_handle: + reachable_node_ids = [] + unreachable_first_node_ids = [] + for edge in self.graph.edge_mapping[finished_node_id]: + if ( + edge.run_condition + and edge.run_condition.branch_identify + and run_result.edge_source_handle == edge.run_condition.branch_identify + ): + reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) + continue + else: + unreachable_first_node_ids.append(edge.target_node_id) + + for node_id in unreachable_first_node_ids: + self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) + + def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]: + node_ids = [] + for edge in self.graph.edge_mapping.get(node_id, []): + if edge.target_node_id == self.graph.root_node_id: + continue + + node_ids.append(edge.target_node_id) + node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) + return node_ids + + def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None: + """ + remove target node ids until merge + """ + if node_id not in self.rest_node_ids: + return + + self.rest_node_ids.remove(node_id) + for edge in self.graph.edge_mapping.get(node_id, []): + if edge.target_node_id in reachable_node_ids: + continue + + self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids) diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index 9effbbbe671420..a05cc44c99428e 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -1,34 +1,65 @@ +from collections.abc import Sequence +from enum import Enum -from pydantic import BaseModel +from pydantic import BaseModel, Field -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.nodes.base import BaseNodeData class AnswerNodeData(BaseNodeData): """ Answer Node Data. """ - answer: str + + answer: str = Field(..., description="answer template string") class GenerateRouteChunk(BaseModel): """ Generate Route Chunk. """ - type: str + + class ChunkType(Enum): + VAR = "var" + TEXT = "text" + + type: ChunkType = Field(..., description="generate route chunk type") class VarGenerateRouteChunk(GenerateRouteChunk): """ Var Generate Route Chunk. """ - type: str = "var" - value_selector: list[str] + + type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR + """generate route chunk type""" + value_selector: Sequence[str] = Field(..., description="value selector") class TextGenerateRouteChunk(GenerateRouteChunk): """ Text Generate Route Chunk. """ - type: str = "text" - text: str + + type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT + """generate route chunk type""" + text: str = Field(..., description="text") + + +class AnswerNodeDoubleLink(BaseModel): + node_id: str = Field(..., description="node id") + source_node_ids: list[str] = Field(..., description="source node ids") + target_node_ids: list[str] = Field(..., description="target node ids") + + +class AnswerStreamGenerateRoute(BaseModel): + """ + AnswerStreamGenerateRoute entity + """ + + answer_dependencies: dict[str, list[str]] = Field( + ..., description="answer dependencies (answer node id -> dependent answer node ids)" + ) + answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field( + ..., description="answer generate route (answer node id -> generate route chunks)" + ) diff --git a/api/core/workflow/nodes/base/__init__.py b/api/core/workflow/nodes/base/__init__.py new file mode 100644 index 00000000000000..61f727740c87db --- /dev/null +++ b/api/core/workflow/nodes/base/__init__.py @@ -0,0 +1,4 @@ +from .entities import BaseIterationNodeData, BaseIterationState, BaseNodeData +from .node import BaseNode + +__all__ = ["BaseNode", "BaseNodeData", "BaseIterationNodeData", "BaseIterationState"] diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py new file mode 100644 index 00000000000000..2a864dd7a84c8b --- /dev/null +++ b/api/core/workflow/nodes/base/entities.py @@ -0,0 +1,24 @@ +from abc import ABC +from typing import Optional + +from pydantic import BaseModel + + +class BaseNodeData(ABC, BaseModel): + title: str + desc: Optional[str] = None + + +class BaseIterationNodeData(BaseNodeData): + start_node_id: Optional[str] = None + + +class BaseIterationState(BaseModel): + iteration_node_id: str + index: int + inputs: dict + + class MetaData(BaseModel): + pass + + metadata: MetaData diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py new file mode 100644 index 00000000000000..1433c8eaed6d4d --- /dev/null +++ b/api/core/workflow/nodes/base/node.py @@ -0,0 +1,137 @@ +import logging +from abc import abstractmethod +from collections.abc import Generator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast + +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event import NodeEvent, RunCompletedEvent +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import BaseNodeData + +if TYPE_CHECKING: + from core.workflow.graph_engine.entities.event import InNodeEvent + from core.workflow.graph_engine.entities.graph import Graph + from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams + from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState + +logger = logging.getLogger(__name__) + +GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData) + + +class BaseNode(Generic[GenericNodeData]): + _node_data_cls: type[BaseNodeData] + _node_type: NodeType + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph: "Graph", + graph_runtime_state: "GraphRuntimeState", + previous_node_id: Optional[str] = None, + thread_pool_id: Optional[str] = None, + ) -> None: + self.id = id + self.tenant_id = graph_init_params.tenant_id + self.app_id = graph_init_params.app_id + self.workflow_type = graph_init_params.workflow_type + self.workflow_id = graph_init_params.workflow_id + self.graph_config = graph_init_params.graph_config + self.user_id = graph_init_params.user_id + self.user_from = graph_init_params.user_from + self.invoke_from = graph_init_params.invoke_from + self.workflow_call_depth = graph_init_params.call_depth + self.graph = graph + self.graph_runtime_state = graph_runtime_state + self.previous_node_id = previous_node_id + self.thread_pool_id = thread_pool_id + + node_id = config.get("id") + if not node_id: + raise ValueError("Node ID is required.") + + self.node_id = node_id + self.node_data: GenericNodeData = cast(GenericNodeData, self._node_data_cls(**config.get("data", {}))) + + @abstractmethod + def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]: + """ + Run node + :return: + """ + raise NotImplementedError + + def run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]: + try: + result = self._run() + except Exception as e: + logger.exception(f"Node {self.node_id} failed to run: {e}") + result = NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + ) + + if isinstance(result, NodeRunResult): + yield RunCompletedEvent(run_result=result) + else: + yield from result + + @classmethod + def extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + config: Mapping[str, Any], + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param config: node config + :return: + """ + node_id = config.get("id") + if not node_id: + raise ValueError("Node ID is required when extracting variable selector to variable mapping.") + + node_data = cls._node_data_cls(**config.get("data", {})) + return cls._extract_variable_selector_to_variable_mapping( + graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data) + ) + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: GenericNodeData, + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ + return {} + + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + """ + Get default config of node. + :param filters: filter by node config parameters. + :return: + """ + return {} + + @property + def node_type(self) -> NodeType: + """ + Get node type + :return: + """ + return self._node_type diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py deleted file mode 100644 index 3d9cf52771e146..00000000000000 --- a/api/core/workflow/nodes/base_node.py +++ /dev/null @@ -1,195 +0,0 @@ -from abc import ABC, abstractmethod -from collections.abc import Mapping, Sequence -from enum import Enum -from typing import Any, Optional - -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from models import WorkflowNodeExecutionStatus - - -class UserFrom(Enum): - """ - User from - """ - ACCOUNT = "account" - END_USER = "end-user" - - @classmethod - def value_of(cls, value: str) -> "UserFrom": - """ - Value of - :param value: value - :return: - """ - for item in cls: - if item.value == value: - return item - raise ValueError(f"Invalid value: {value}") - - -class BaseNode(ABC): - _node_data_cls: type[BaseNodeData] - _node_type: NodeType - - tenant_id: str - app_id: str - workflow_id: str - user_id: str - user_from: UserFrom - invoke_from: InvokeFrom - - workflow_call_depth: int - - node_id: str - node_data: BaseNodeData - node_run_result: Optional[NodeRunResult] = None - - callbacks: Sequence[WorkflowCallback] - - is_answer_previous_node: bool = False - - def __init__(self, tenant_id: str, - app_id: str, - workflow_id: str, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - config: Mapping[str, Any], - callbacks: Sequence[WorkflowCallback] | None = None, - workflow_call_depth: int = 0) -> None: - self.tenant_id = tenant_id - self.app_id = app_id - self.workflow_id = workflow_id - self.user_id = user_id - self.user_from = user_from - self.invoke_from = invoke_from - self.workflow_call_depth = workflow_call_depth - - # TODO: May need to check if key exists. - self.node_id = config["id"] - if not self.node_id: - raise ValueError("Node ID is required.") - - self.node_data = self._node_data_cls(**config.get("data", {})) - self.callbacks = callbacks or [] - - @abstractmethod - def _run(self, variable_pool: VariablePool) -> NodeRunResult: - """ - Run node - :param variable_pool: variable pool - :return: - """ - raise NotImplementedError - - def run(self, variable_pool: VariablePool) -> NodeRunResult: - """ - Run node entry - :param variable_pool: variable pool - :return: - """ - try: - result = self._run( - variable_pool=variable_pool - ) - self.node_run_result = result - return result - except Exception as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - ) - - def publish_text_chunk(self, text: str, value_selector: list[str] | None = None) -> None: - """ - Publish text chunk - :param text: chunk text - :param value_selector: value selector - :return: - """ - if self.callbacks: - for callback in self.callbacks: - callback.on_node_text_chunk( - node_id=self.node_id, - text=text, - metadata={ - "node_type": self.node_type, - "is_answer_previous_node": self.is_answer_previous_node, - "value_selector": value_selector - } - ) - - @classmethod - def extract_variable_selector_to_variable_mapping(cls, config: dict): - """ - Extract variable selector to variable mapping - :param config: node config - :return: - """ - node_data = cls._node_data_cls(**config.get("data", {})) - return cls._extract_variable_selector_to_variable_mapping(node_data) - - @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param node_data: node data - :return: - """ - return {} - - @classmethod - def get_default_config(cls, filters: Optional[dict] = None) -> dict: - """ - Get default config of node. - :param filters: filter by node config parameters. - :return: - """ - return {} - - @property - def node_type(self) -> NodeType: - """ - Get node type - :return: - """ - return self._node_type - -class BaseIterationNode(BaseNode): - @abstractmethod - def _run(self, variable_pool: VariablePool) -> BaseIterationState: - """ - Run node - :param variable_pool: variable pool - :return: - """ - raise NotImplementedError - - def run(self, variable_pool: VariablePool) -> BaseIterationState: - """ - Run node entry - :param variable_pool: variable pool - :return: - """ - return self._run(variable_pool=variable_pool) - - def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str: - """ - Get next iteration start node id based on the graph. - :param graph: graph - :return: next node id - """ - return self._get_next_iteration(variable_pool, state) - - @abstractmethod - def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str: - """ - Get next iteration start node id based on the graph. - :param graph: graph - :return: next node id - """ - raise NotImplementedError diff --git a/api/core/workflow/nodes/code/__init__.py b/api/core/workflow/nodes/code/__init__.py index e69de29bb2d1d6..8c6dcc7fccbf2e 100644 --- a/api/core/workflow/nodes/code/__init__.py +++ b/api/core/workflow/nodes/code/__init__.py @@ -0,0 +1,3 @@ +from .code_node import CodeNode + +__all__ = ["CodeNode"] diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 60678bc2ba8c16..ce283e38ec9b12 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,29 +1,27 @@ -from typing import Optional, Union, cast +from collections.abc import Mapping, Sequence +from typing import Any, Optional, Union from configs import dify_config -from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseNode +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode from core.workflow.nodes.code.entities import CodeNodeData +from core.workflow.nodes.enums import NodeType from models.workflow import WorkflowNodeExecutionStatus -MAX_NUMBER = dify_config.CODE_MAX_NUMBER -MIN_NUMBER = dify_config.CODE_MIN_NUMBER -MAX_PRECISION = 20 -MAX_DEPTH = 5 -MAX_STRING_LENGTH = dify_config.CODE_MAX_STRING_LENGTH -MAX_STRING_ARRAY_LENGTH = dify_config.CODE_MAX_STRING_ARRAY_LENGTH -MAX_OBJECT_ARRAY_LENGTH = dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH -MAX_NUMBER_ARRAY_LENGTH = dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH +from .exc import ( + CodeNodeError, + DepthLimitError, + OutputValidationError, +) -class CodeNode(BaseNode): +class CodeNode(BaseNode[CodeNodeData]): _node_data_cls = CodeNodeData - node_type = NodeType.CODE + _node_type = NodeType.CODE @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: @@ -34,57 +32,38 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ code_language = CodeLanguage.PYTHON3 if filters: - code_language = (filters.get("code_language", CodeLanguage.PYTHON3)) + code_language = filters.get("code_language", CodeLanguage.PYTHON3) providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider] - code_provider: type[CodeNodeProvider] = next(p for p in providers - if p.is_accept_language(code_language)) + code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language)) return code_provider.get_default_config() - def _run(self, variable_pool: VariablePool) -> NodeRunResult: - """ - Run code - :param variable_pool: variable pool - :return: - """ - node_data = self.node_data - node_data: CodeNodeData = cast(self._node_data_cls, node_data) - + def _run(self) -> NodeRunResult: # Get code language - code_language = node_data.code_language - code = node_data.code + code_language = self.node_data.code_language + code = self.node_data.code # Get variables variables = {} - for variable_selector in node_data.variables: - variable = variable_selector.variable - value = variable_pool.get_any(variable_selector.value_selector) - - variables[variable] = value + for variable_selector in self.node_data.variables: + variable_name = variable_selector.variable + variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) + variables[variable_name] = variable.to_object() if variable else None # Run code try: result = CodeExecutor.execute_workflow_code_template( language=code_language, code=code, inputs=variables, - dependencies=node_data.dependencies ) # Transform result - result = self._transform_result(result, node_data.outputs) - except (CodeExecutionException, ValueError) as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error=str(e) - ) + result = self._transform_result(result, self.node_data.outputs) + except (CodeExecutionError, CodeNodeError) as e: + return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, - outputs=result - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) def _check_string(self, value: str, variable: str) -> str: """ @@ -94,15 +73,18 @@ def _check_string(self, value: str, variable: str) -> str: :return: """ if not isinstance(value, str): - if isinstance(value, type(None)): + if value is None: return None else: - raise ValueError(f"Output variable `{variable}` must be a string") - - if len(value) > MAX_STRING_LENGTH: - raise ValueError(f'The length of output variable `{variable}` must be less than {MAX_STRING_LENGTH} characters') + raise OutputValidationError(f"Output variable `{variable}` must be a string") - return value.replace('\x00', '') + if len(value) > dify_config.CODE_MAX_STRING_LENGTH: + raise OutputValidationError( + f"The length of output variable `{variable}` must be" + f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters" + ) + + return value.replace("\x00", "") def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]: """ @@ -112,32 +94,38 @@ def _check_number(self, value: Union[int, float], variable: str) -> Union[int, f :return: """ if not isinstance(value, int | float): - if isinstance(value, type(None)): + if value is None: return None else: - raise ValueError(f"Output variable `{variable}` must be a number") + raise OutputValidationError(f"Output variable `{variable}` must be a number") - if value > MAX_NUMBER or value < MIN_NUMBER: - raise ValueError(f'Output variable `{variable}` is out of range, it must be between {MIN_NUMBER} and {MAX_NUMBER}.') + if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER: + raise OutputValidationError( + f"Output variable `{variable}` is out of range," + f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}." + ) if isinstance(value, float): # raise error if precision is too high - if len(str(value).split('.')[1]) > MAX_PRECISION: - raise ValueError(f'Output variable `{variable}` has too high precision, it must be less than {MAX_PRECISION} digits.') + if len(str(value).split(".")[1]) > dify_config.CODE_MAX_PRECISION: + raise OutputValidationError( + f"Output variable `{variable}` has too high precision," + f" it must be less than {dify_config.CODE_MAX_PRECISION} digits." + ) return value - def _transform_result(self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], - prefix: str = '', - depth: int = 1) -> dict: + def _transform_result( + self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = "", depth: int = 1 + ) -> dict: """ Transform result :param result: result :param output_schema: output schema :return: """ - if depth > MAX_DEPTH: - raise ValueError("Depth limit reached, object too deep.") + if depth > dify_config.CODE_MAX_DEPTH: + raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.") transformed_result = {} if output_schema is None: @@ -147,182 +135,200 @@ def _transform_result(self, result: dict, output_schema: Optional[dict[str, Code self._transform_result( result=output_value, output_schema=None, - prefix=f'{prefix}.{output_name}' if prefix else output_name, - depth=depth + 1 + prefix=f"{prefix}.{output_name}" if prefix else output_name, + depth=depth + 1, ) elif isinstance(output_value, int | float): self._check_number( - value=output_value, - variable=f'{prefix}.{output_name}' if prefix else output_name + value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name ) elif isinstance(output_value, str): self._check_string( - value=output_value, - variable=f'{prefix}.{output_name}' if prefix else output_name + value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name ) elif isinstance(output_value, list): first_element = output_value[0] if len(output_value) > 0 else None if first_element is not None: - if isinstance(first_element, int | float) and all(value is None or isinstance(value, int | float) for value in output_value): + if isinstance(first_element, int | float) and all( + value is None or isinstance(value, int | float) for value in output_value + ): for i, value in enumerate(output_value): self._check_number( value=value, - variable=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]' + variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", ) - elif isinstance(first_element, str) and all(value is None or isinstance(value, str) for value in output_value): + elif isinstance(first_element, str) and all( + value is None or isinstance(value, str) for value in output_value + ): for i, value in enumerate(output_value): self._check_string( value=value, - variable=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]' + variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", ) - elif isinstance(first_element, dict) and all(value is None or isinstance(value, dict) for value in output_value): + elif isinstance(first_element, dict) and all( + value is None or isinstance(value, dict) for value in output_value + ): for i, value in enumerate(output_value): if value is not None: self._transform_result( result=value, output_schema=None, - prefix=f'{prefix}.{output_name}[{i}]' if prefix else f'{output_name}[{i}]', - depth=depth + 1 + prefix=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", + depth=depth + 1, ) else: - raise ValueError(f'Output {prefix}.{output_name} is not a valid array. make sure all elements are of the same type.') - elif isinstance(output_value, type(None)): + raise OutputValidationError( + f"Output {prefix}.{output_name} is not a valid array." + f" make sure all elements are of the same type." + ) + elif output_value is None: pass else: - raise ValueError(f'Output {prefix}.{output_name} is not a valid type.') - + raise OutputValidationError(f"Output {prefix}.{output_name} is not a valid type.") + return result parameters_validated = {} for output_name, output_config in output_schema.items(): - dot = '.' if prefix else '' + dot = "." if prefix else "" if output_name not in result: - raise ValueError(f'Output {prefix}{dot}{output_name} is missing.') - - if output_config.type == 'object': + raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.") + + if output_config.type == "object": # check if output is object if not isinstance(result.get(output_name), dict): if isinstance(result.get(output_name), type(None)): transformed_result[output_name] = None else: - raise ValueError( - f'Output {prefix}{dot}{output_name} is not an object, got {type(result.get(output_name))} instead.' + raise OutputValidationError( + f"Output {prefix}{dot}{output_name} is not an object," + f" got {type(result.get(output_name))} instead." ) else: transformed_result[output_name] = self._transform_result( result=result[output_name], output_schema=output_config.children, - prefix=f'{prefix}.{output_name}', - depth=depth + 1 + prefix=f"{prefix}.{output_name}", + depth=depth + 1, ) - elif output_config.type == 'number': + elif output_config.type == "number": # check if number available transformed_result[output_name] = self._check_number( - value=result[output_name], - variable=f'{prefix}{dot}{output_name}' + value=result[output_name], variable=f"{prefix}{dot}{output_name}" ) - elif output_config.type == 'string': + elif output_config.type == "string": # check if string available transformed_result[output_name] = self._check_string( value=result[output_name], - variable=f'{prefix}{dot}{output_name}', + variable=f"{prefix}{dot}{output_name}", ) - elif output_config.type == 'array[number]': + elif output_config.type == "array[number]": # check if array of number available if not isinstance(result[output_name], list): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: - raise ValueError( - f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' + raise OutputValidationError( + f"Output {prefix}{dot}{output_name} is not an array," + f" got {type(result.get(output_name))} instead." ) else: - if len(result[output_name]) > MAX_NUMBER_ARRAY_LENGTH: - raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_NUMBER_ARRAY_LENGTH} elements.' + if len(result[output_name]) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH: + raise OutputValidationError( + f"The length of output variable `{prefix}{dot}{output_name}` must be" + f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements." ) transformed_result[output_name] = [ - self._check_number( - value=value, - variable=f'{prefix}{dot}{output_name}[{i}]' - ) + self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]") for i, value in enumerate(result[output_name]) ] - elif output_config.type == 'array[string]': + elif output_config.type == "array[string]": # check if array of string available if not isinstance(result[output_name], list): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: - raise ValueError( - f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' + raise OutputValidationError( + f"Output {prefix}{dot}{output_name} is not an array," + f" got {type(result.get(output_name))} instead." ) else: - if len(result[output_name]) > MAX_STRING_ARRAY_LENGTH: - raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_STRING_ARRAY_LENGTH} elements.' + if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH: + raise OutputValidationError( + f"The length of output variable `{prefix}{dot}{output_name}` must be" + f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements." ) transformed_result[output_name] = [ - self._check_string( - value=value, - variable=f'{prefix}{dot}{output_name}[{i}]' - ) + self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]") for i, value in enumerate(result[output_name]) ] - elif output_config.type == 'array[object]': + elif output_config.type == "array[object]": # check if array of object available if not isinstance(result[output_name], list): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: - raise ValueError( - f'Output {prefix}{dot}{output_name} is not an array, got {type(result.get(output_name))} instead.' + raise OutputValidationError( + f"Output {prefix}{dot}{output_name} is not an array," + f" got {type(result.get(output_name))} instead." ) else: - if len(result[output_name]) > MAX_OBJECT_ARRAY_LENGTH: - raise ValueError( - f'The length of output variable `{prefix}{dot}{output_name}` must be less than {MAX_OBJECT_ARRAY_LENGTH} elements.' + if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH: + raise OutputValidationError( + f"The length of output variable `{prefix}{dot}{output_name}` must be" + f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements." ) - + for i, value in enumerate(result[output_name]): if not isinstance(value, dict): - if isinstance(value, type(None)): + if value is None: pass else: - raise ValueError( - f'Output {prefix}{dot}{output_name}[{i}] is not an object, got {type(value)} instead at index {i}.' + raise OutputValidationError( + f"Output {prefix}{dot}{output_name}[{i}] is not an object," + f" got {type(value)} instead at index {i}." ) transformed_result[output_name] = [ - None if value is None else self._transform_result( + None + if value is None + else self._transform_result( result=value, output_schema=output_config.children, - prefix=f'{prefix}{dot}{output_name}[{i}]', - depth=depth + 1 + prefix=f"{prefix}{dot}{output_name}[{i}]", + depth=depth + 1, ) for i, value in enumerate(result[output_name]) ] else: - raise ValueError(f'Output type {output_config.type} is not supported.') - + raise OutputValidationError(f"Output type {output_config.type} is not supported.") + parameters_validated[output_name] = True # check if all output parameters are validated if len(parameters_validated) != len(result): - raise ValueError('Not all output parameters are validated.') + raise CodeNodeError("Not all output parameters are validated.") return transformed_result @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: CodeNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: CodeNodeData, + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ - return { - variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables + node_id + "." + variable_selector.variable: variable_selector.value_selector + for variable_selector in node_data.variables } diff --git a/api/core/workflow/nodes/code/entities.py b/api/core/workflow/nodes/code/entities.py index 83a5416d57ce6a..e78183baf12389 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/core/workflow/nodes/code/entities.py @@ -3,21 +3,25 @@ from pydantic import BaseModel from core.helper.code_executor.code_executor import CodeLanguage -from core.helper.code_executor.entities import CodeDependency -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.nodes.base import BaseNodeData class CodeNodeData(BaseNodeData): """ Code Node Data. """ + class Output(BaseModel): - type: Literal['string', 'number', 'object', 'array[string]', 'array[number]', 'array[object]'] - children: Optional[dict[str, 'Output']] = None + type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] + children: Optional[dict[str, "Output"]] = None + + class Dependency(BaseModel): + name: str + version: str variables: list[VariableSelector] code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT] code: str outputs: dict[str, Output] - dependencies: Optional[list[CodeDependency]] = None \ No newline at end of file + dependencies: Optional[list[Dependency]] = None diff --git a/api/core/workflow/nodes/code/exc.py b/api/core/workflow/nodes/code/exc.py new file mode 100644 index 00000000000000..d6334fd554cde5 --- /dev/null +++ b/api/core/workflow/nodes/code/exc.py @@ -0,0 +1,16 @@ +class CodeNodeError(ValueError): + """Base class for code node errors.""" + + pass + + +class OutputValidationError(CodeNodeError): + """Raised when there is an output validation error.""" + + pass + + +class DepthLimitError(CodeNodeError): + """Raised when the depth limit is reached.""" + + pass diff --git a/api/core/workflow/nodes/document_extractor/__init__.py b/api/core/workflow/nodes/document_extractor/__init__.py new file mode 100644 index 00000000000000..3cc5fae18745f9 --- /dev/null +++ b/api/core/workflow/nodes/document_extractor/__init__.py @@ -0,0 +1,4 @@ +from .entities import DocumentExtractorNodeData +from .node import DocumentExtractorNode + +__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData"] diff --git a/api/core/workflow/nodes/document_extractor/entities.py b/api/core/workflow/nodes/document_extractor/entities.py new file mode 100644 index 00000000000000..7e9ffaa889b988 --- /dev/null +++ b/api/core/workflow/nodes/document_extractor/entities.py @@ -0,0 +1,7 @@ +from collections.abc import Sequence + +from core.workflow.nodes.base import BaseNodeData + + +class DocumentExtractorNodeData(BaseNodeData): + variable_selector: Sequence[str] diff --git a/api/core/workflow/nodes/document_extractor/exc.py b/api/core/workflow/nodes/document_extractor/exc.py new file mode 100644 index 00000000000000..5caf00ebc5f1c6 --- /dev/null +++ b/api/core/workflow/nodes/document_extractor/exc.py @@ -0,0 +1,14 @@ +class DocumentExtractorError(ValueError): + """Base exception for errors related to the DocumentExtractorNode.""" + + +class FileDownloadError(DocumentExtractorError): + """Exception raised when there's an error downloading a file.""" + + +class UnsupportedFileTypeError(DocumentExtractorError): + """Exception raised when trying to extract text from an unsupported file type.""" + + +class TextExtractionError(DocumentExtractorError): + """Exception raised when there's an error during text extraction from a file.""" diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py new file mode 100644 index 00000000000000..c90017d5e15cec --- /dev/null +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -0,0 +1,303 @@ +import csv +import io +import json + +import docx +import pandas as pd +import pypdfium2 +import yaml +from unstructured.partition.api import partition_via_api +from unstructured.partition.email import partition_email +from unstructured.partition.epub import partition_epub +from unstructured.partition.msg import partition_msg +from unstructured.partition.ppt import partition_ppt +from unstructured.partition.pptx import partition_pptx + +from configs import dify_config +from core.file import File, FileTransferMethod, file_manager +from core.helper import ssrf_proxy +from core.variables import ArrayFileSegment +from core.variables.segments import FileSegment +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import DocumentExtractorNodeData +from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError + + +class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): + """ + Extracts text content from various file types. + Supports plain text, PDF, and DOC/DOCX files. + """ + + _node_data_cls = DocumentExtractorNodeData + _node_type = NodeType.DOCUMENT_EXTRACTOR + + def _run(self): + variable_selector = self.node_data.variable_selector + variable = self.graph_runtime_state.variable_pool.get(variable_selector) + + if variable is None: + error_message = f"File variable not found for selector: {variable_selector}" + return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message) + if variable.value and not isinstance(variable, ArrayFileSegment | FileSegment): + error_message = f"Variable {variable_selector} is not an ArrayFileSegment" + return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message) + + value = variable.value + inputs = {"variable_selector": variable_selector} + process_data = {"documents": value if isinstance(value, list) else [value]} + + try: + if isinstance(value, list): + extracted_text_list = list(map(_extract_text_from_file, value)) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs={"text": extracted_text_list}, + ) + elif isinstance(value, File): + extracted_text = _extract_text_from_file(value) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs={"text": extracted_text}, + ) + else: + raise DocumentExtractorError(f"Unsupported variable type: {type(value)}") + except DocumentExtractorError as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + inputs=inputs, + process_data=process_data, + ) + + +def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: + """Extract text from a file based on its MIME type.""" + match mime_type: + case "text/plain" | "text/html" | "text/htm" | "text/markdown" | "text/xml": + return _extract_text_from_plain_text(file_content) + case "application/pdf": + return _extract_text_from_pdf(file_content) + case "application/vnd.openxmlformats-officedocument.wordprocessingml.document" | "application/msword": + return _extract_text_from_doc(file_content) + case "text/csv": + return _extract_text_from_csv(file_content) + case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | "application/vnd.ms-excel": + return _extract_text_from_excel(file_content) + case "application/vnd.ms-powerpoint": + return _extract_text_from_ppt(file_content) + case "application/vnd.openxmlformats-officedocument.presentationml.presentation": + return _extract_text_from_pptx(file_content) + case "application/epub+zip": + return _extract_text_from_epub(file_content) + case "message/rfc822": + return _extract_text_from_eml(file_content) + case "application/vnd.ms-outlook": + return _extract_text_from_msg(file_content) + case "application/json": + return _extract_text_from_json(file_content) + case "application/x-yaml" | "text/yaml": + return _extract_text_from_yaml(file_content) + case _: + raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}") + + +def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str: + """Extract text from a file based on its file extension.""" + match file_extension: + case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml": + return _extract_text_from_plain_text(file_content) + case ".json": + return _extract_text_from_json(file_content) + case ".yaml" | ".yml": + return _extract_text_from_yaml(file_content) + case ".pdf": + return _extract_text_from_pdf(file_content) + case ".doc" | ".docx": + return _extract_text_from_doc(file_content) + case ".csv": + return _extract_text_from_csv(file_content) + case ".xls" | ".xlsx": + return _extract_text_from_excel(file_content) + case ".ppt": + return _extract_text_from_ppt(file_content) + case ".pptx": + return _extract_text_from_pptx(file_content) + case ".epub": + return _extract_text_from_epub(file_content) + case ".eml": + return _extract_text_from_eml(file_content) + case ".msg": + return _extract_text_from_msg(file_content) + case _: + raise UnsupportedFileTypeError(f"Unsupported Extension Type: {file_extension}") + + +def _extract_text_from_plain_text(file_content: bytes) -> str: + try: + return file_content.decode("utf-8") + except UnicodeDecodeError as e: + raise TextExtractionError("Failed to decode plain text file") from e + + +def _extract_text_from_json(file_content: bytes) -> str: + try: + json_data = json.loads(file_content.decode("utf-8")) + return json.dumps(json_data, indent=2, ensure_ascii=False) + except (UnicodeDecodeError, json.JSONDecodeError) as e: + raise TextExtractionError(f"Failed to decode or parse JSON file: {e}") from e + + +def _extract_text_from_yaml(file_content: bytes) -> str: + """Extract the content from yaml file""" + try: + yaml_data = yaml.safe_load_all(file_content.decode("utf-8")) + return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) + except (UnicodeDecodeError, yaml.YAMLError) as e: + raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e + + +def _extract_text_from_pdf(file_content: bytes) -> str: + try: + pdf_file = io.BytesIO(file_content) + pdf_document = pypdfium2.PdfDocument(pdf_file, autoclose=True) + text = "" + for page in pdf_document: + text_page = page.get_textpage() + text += text_page.get_text_range() + text_page.close() + page.close() + return text + except Exception as e: + raise TextExtractionError(f"Failed to extract text from PDF: {str(e)}") from e + + +def _extract_text_from_doc(file_content: bytes) -> str: + try: + doc_file = io.BytesIO(file_content) + doc = docx.Document(doc_file) + return "\n".join([paragraph.text for paragraph in doc.paragraphs]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e + + +def _download_file_content(file: File) -> bytes: + """Download the content of a file based on its transfer method.""" + try: + if file.transfer_method == FileTransferMethod.REMOTE_URL: + if file.remote_url is None: + raise FileDownloadError("Missing URL for remote file") + response = ssrf_proxy.get(file.remote_url) + response.raise_for_status() + return response.content + else: + return file_manager.download(file) + except Exception as e: + raise FileDownloadError(f"Error downloading file: {str(e)}") from e + + +def _extract_text_from_file(file: File): + file_content = _download_file_content(file) + if file.extension: + extracted_text = _extract_text_by_file_extension(file_content=file_content, file_extension=file.extension) + elif file.mime_type: + extracted_text = _extract_text_by_mime_type(file_content=file_content, mime_type=file.mime_type) + else: + raise UnsupportedFileTypeError("Unable to determine file type: MIME type or file extension is missing") + return extracted_text + + +def _extract_text_from_csv(file_content: bytes) -> str: + try: + csv_file = io.StringIO(file_content.decode("utf-8")) + csv_reader = csv.reader(csv_file) + rows = list(csv_reader) + + if not rows: + return "" + + # Create Markdown table + markdown_table = "| " + " | ".join(rows[0]) + " |\n" + markdown_table += "| " + " | ".join(["---"] * len(rows[0])) + " |\n" + for row in rows[1:]: + markdown_table += "| " + " | ".join(row) + " |\n" + + return markdown_table.strip() + except Exception as e: + raise TextExtractionError(f"Failed to extract text from CSV: {str(e)}") from e + + +def _extract_text_from_excel(file_content: bytes) -> str: + """Extract text from an Excel file using pandas.""" + + try: + df = pd.read_excel(io.BytesIO(file_content)) + + # Drop rows where all elements are NaN + df.dropna(how="all", inplace=True) + + # Convert DataFrame to Markdown table + markdown_table = df.to_markdown(index=False) + return markdown_table + except Exception as e: + raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e + + +def _extract_text_from_ppt(file_content: bytes) -> str: + try: + with io.BytesIO(file_content) as file: + elements = partition_ppt(file=file) + return "\n".join([getattr(element, "text", "") for element in elements]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from PPT: {str(e)}") from e + + +def _extract_text_from_pptx(file_content: bytes) -> str: + try: + with io.BytesIO(file_content) as file: + if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY: + elements = partition_via_api( + file=file, + api_url=dify_config.UNSTRUCTURED_API_URL, + api_key=dify_config.UNSTRUCTURED_API_KEY, + ) + else: + elements = partition_pptx(file=file) + return "\n".join([getattr(element, "text", "") for element in elements]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e + + +def _extract_text_from_epub(file_content: bytes) -> str: + try: + with io.BytesIO(file_content) as file: + elements = partition_epub(file=file) + return "\n".join([str(element) for element in elements]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from EPUB: {str(e)}") from e + + +def _extract_text_from_eml(file_content: bytes) -> str: + try: + with io.BytesIO(file_content) as file: + elements = partition_email(file=file) + return "\n".join([str(element) for element in elements]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from EML: {str(e)}") from e + + +def _extract_text_from_msg(file_content: bytes) -> str: + try: + with io.BytesIO(file_content) as file: + elements = partition_msg(file=file) + return "\n".join([str(element) for element in elements]) + except Exception as e: + raise TextExtractionError(f"Failed to extract text from MSG: {str(e)}") from e diff --git a/api/core/workflow/nodes/end/__init__.py b/api/core/workflow/nodes/end/__init__.py index e69de29bb2d1d6..adb381701cecb1 100644 --- a/api/core/workflow/nodes/end/__init__.py +++ b/api/core/workflow/nodes/end/__init__.py @@ -0,0 +1,4 @@ +from .end_node import EndNode +from .entities import EndStreamParam + +__all__ = ["EndStreamParam", "EndNode"] diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 440dfa2f2710f3..2398e4e89d59fa 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,85 +1,48 @@ -from typing import cast +from collections.abc import Mapping, Sequence +from typing import Any -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseNode +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode from core.workflow.nodes.end.entities import EndNodeData +from core.workflow.nodes.enums import NodeType from models.workflow import WorkflowNodeExecutionStatus -class EndNode(BaseNode): +class EndNode(BaseNode[EndNodeData]): _node_data_cls = EndNodeData _node_type = NodeType.END - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run node - :param variable_pool: variable pool :return: """ - node_data = self.node_data - node_data = cast(EndNodeData, node_data) - output_variables = node_data.outputs + output_variables = self.node_data.outputs outputs = {} for variable_selector in output_variables: - value = variable_pool.get_any(variable_selector.value_selector) + variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) + value = variable.to_object() if variable is not None else None outputs[variable_selector.variable] = value return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=outputs, - outputs=outputs + outputs=outputs, ) @classmethod - def extract_generate_nodes(cls, graph: dict, config: dict) -> list[str]: - """ - Extract generate nodes - :param graph: graph - :param config: node config - :return: - """ - node_data = cls._node_data_cls(**config.get("data", {})) - node_data = cast(EndNodeData, node_data) - - return cls.extract_generate_nodes_from_node_data(graph, node_data) - - @classmethod - def extract_generate_nodes_from_node_data(cls, graph: dict, node_data: EndNodeData) -> list[str]: - """ - Extract generate nodes from node data - :param graph: graph - :param node_data: node data object - :return: - """ - nodes = graph.get('nodes', []) - node_mapping = {node.get('id'): node for node in nodes} - - variable_selectors = node_data.outputs - - generate_nodes = [] - for variable_selector in variable_selectors: - if not variable_selector.value_selector: - continue - - node_id = variable_selector.value_selector[0] - if node_id != 'sys' and node_id in node_mapping: - node = node_mapping[node_id] - node_type = node.get('data', {}).get('type') - if node_type == NodeType.LLM.value and variable_selector.value_selector[1] == 'text': - generate_nodes.append(node_id) - - # remove duplicates - generate_nodes = list(set(generate_nodes)) - - return generate_nodes - - @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: EndNodeData, + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py new file mode 100644 index 00000000000000..ea8b6b50420c99 --- /dev/null +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -0,0 +1,151 @@ +from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam +from core.workflow.nodes.enums import NodeType + + +class EndStreamGeneratorRouter: + @classmethod + def init( + cls, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_parallel_mapping: dict[str, str], + ) -> EndStreamParam: + """ + Get stream generate routes. + :return: + """ + # parse stream output node value selector of end nodes + end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {} + for end_node_id, node_config in node_id_config_mapping.items(): + if node_config.get("data", {}).get("type") != NodeType.END.value: + continue + + # skip end node in parallel + if end_node_id in node_parallel_mapping: + continue + + # get generate route for stream output + stream_variable_selectors = cls._extract_stream_variable_selector(node_id_config_mapping, node_config) + end_stream_variable_selectors_mapping[end_node_id] = stream_variable_selectors + + # fetch end dependencies + end_node_ids = list(end_stream_variable_selectors_mapping.keys()) + end_dependencies = cls._fetch_ends_dependencies( + end_node_ids=end_node_ids, + reverse_edge_mapping=reverse_edge_mapping, + node_id_config_mapping=node_id_config_mapping, + ) + + return EndStreamParam( + end_stream_variable_selector_mapping=end_stream_variable_selectors_mapping, + end_dependencies=end_dependencies, + ) + + @classmethod + def extract_stream_variable_selector_from_node_data( + cls, node_id_config_mapping: dict[str, dict], node_data: EndNodeData + ) -> list[list[str]]: + """ + Extract stream variable selector from node data + :param node_id_config_mapping: node id config mapping + :param node_data: node data object + :return: + """ + variable_selectors = node_data.outputs + + value_selectors = [] + for variable_selector in variable_selectors: + if not variable_selector.value_selector: + continue + + node_id = variable_selector.value_selector[0] + if node_id != "sys" and node_id in node_id_config_mapping: + node = node_id_config_mapping[node_id] + node_type = node.get("data", {}).get("type") + if ( + variable_selector.value_selector not in value_selectors + and node_type == NodeType.LLM.value + and variable_selector.value_selector[1] == "text" + ): + value_selectors.append(variable_selector.value_selector) + + return value_selectors + + @classmethod + def _extract_stream_variable_selector( + cls, node_id_config_mapping: dict[str, dict], config: dict + ) -> list[list[str]]: + """ + Extract stream variable selector from node config + :param node_id_config_mapping: node id config mapping + :param config: node config + :return: + """ + node_data = EndNodeData(**config.get("data", {})) + return cls.extract_stream_variable_selector_from_node_data(node_id_config_mapping, node_data) + + @classmethod + def _fetch_ends_dependencies( + cls, + end_node_ids: list[str], + reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] + node_id_config_mapping: dict[str, dict], + ) -> dict[str, list[str]]: + """ + Fetch end dependencies + :param end_node_ids: end node ids + :param reverse_edge_mapping: reverse edge mapping + :param node_id_config_mapping: node id config mapping + :return: + """ + end_dependencies: dict[str, list[str]] = {} + for end_node_id in end_node_ids: + if end_dependencies.get(end_node_id) is None: + end_dependencies[end_node_id] = [] + + cls._recursive_fetch_end_dependencies( + current_node_id=end_node_id, + end_node_id=end_node_id, + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + end_dependencies=end_dependencies, + ) + + return end_dependencies + + @classmethod + def _recursive_fetch_end_dependencies( + cls, + current_node_id: str, + end_node_id: str, + node_id_config_mapping: dict[str, dict], + reverse_edge_mapping: dict[str, list["GraphEdge"]], + # type: ignore[name-defined] + end_dependencies: dict[str, list[str]], + ) -> None: + """ + Recursive fetch end dependencies + :param current_node_id: current node id + :param end_node_id: end node id + :param node_id_config_mapping: node id config mapping + :param reverse_edge_mapping: reverse edge mapping + :param end_dependencies: end dependencies + :return: + """ + reverse_edges = reverse_edge_mapping.get(current_node_id, []) + for edge in reverse_edges: + source_node_id = edge.source_node_id + source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") + if source_node_type in { + NodeType.IF_ELSE.value, + NodeType.QUESTION_CLASSIFIER, + }: + end_dependencies[end_node_id].append(source_node_id) + else: + cls._recursive_fetch_end_dependencies( + current_node_id=source_node_id, + end_node_id=end_node_id, + node_id_config_mapping=node_id_config_mapping, + reverse_edge_mapping=reverse_edge_mapping, + end_dependencies=end_dependencies, + ) diff --git a/api/core/workflow/nodes/end/end_stream_processor.py b/api/core/workflow/nodes/end/end_stream_processor.py new file mode 100644 index 00000000000000..1aecf863ac5fb9 --- /dev/null +++ b/api/core/workflow/nodes/end/end_stream_processor.py @@ -0,0 +1,187 @@ +import logging +from collections.abc import Generator + +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.event import ( + GraphEngineEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.nodes.answer.base_stream_processor import StreamProcessor + +logger = logging.getLogger(__name__) + + +class EndStreamProcessor(StreamProcessor): + def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: + super().__init__(graph, variable_pool) + self.end_stream_param = graph.end_stream_param + self.route_position = {} + for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items(): + self.route_position[end_node_id] = 0 + self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} + self.has_output = False + self.output_node_ids = set() + + def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generator[GraphEngineEvent, None, None]: + for event in generator: + if isinstance(event, NodeRunStartedEvent): + if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids: + self.reset() + + yield event + elif isinstance(event, NodeRunStreamChunkEvent): + if event.in_iteration_id: + if self.has_output and event.node_id not in self.output_node_ids: + event.chunk_content = "\n" + event.chunk_content + + self.output_node_ids.add(event.node_id) + self.has_output = True + yield event + continue + + if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: + stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[ + event.route_node_state.node_id + ] + else: + stream_out_end_node_ids = self._get_stream_out_end_node_ids(event) + self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] = ( + stream_out_end_node_ids + ) + + if stream_out_end_node_ids: + if self.has_output and event.node_id not in self.output_node_ids: + event.chunk_content = "\n" + event.chunk_content + + self.output_node_ids.add(event.node_id) + self.has_output = True + yield event + elif isinstance(event, NodeRunSucceededEvent): + yield event + if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: + # update self.route_position after all stream event finished + for end_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]: + self.route_position[end_node_id] += 1 + + del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] + + # remove unreachable nodes + self._remove_unreachable_nodes(event) + + # generate stream outputs + yield from self._generate_stream_outputs_when_node_finished(event) + else: + yield event + + def reset(self) -> None: + self.route_position = {} + for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items(): + self.route_position[end_node_id] = 0 + self.rest_node_ids = self.graph.node_ids.copy() + self.current_stream_chunk_generating_node_ids = {} + + def _generate_stream_outputs_when_node_finished( + self, event: NodeRunSucceededEvent + ) -> Generator[GraphEngineEvent, None, None]: + """ + Generate stream outputs. + :param event: node run succeeded event + :return: + """ + for end_node_id, position in self.route_position.items(): + # all depends on end node id not in rest node ids + if event.route_node_state.node_id != end_node_id and ( + end_node_id not in self.rest_node_ids + or not all( + dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id] + ) + ): + continue + + route_position = self.route_position[end_node_id] + + position = 0 + value_selectors = [] + for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]: + if position >= route_position: + value_selectors.append(current_value_selectors) + + position += 1 + + for value_selector in value_selectors: + if not value_selector: + continue + + value = self.variable_pool.get(value_selector) + + if value is None: + break + + text = value.markdown + + if text: + current_node_id = value_selector[0] + if self.has_output and current_node_id not in self.output_node_ids: + text = "\n" + text + + self.output_node_ids.add(current_node_id) + self.has_output = True + yield NodeRunStreamChunkEvent( + id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + chunk_content=text, + from_variable_selector=value_selector, + route_node_state=event.route_node_state, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ) + + self.route_position[end_node_id] += 1 + + def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]: + """ + Is stream out support + :param event: queue text chunk event + :return: + """ + if not event.from_variable_selector: + return [] + + stream_output_value_selector = event.from_variable_selector + if not stream_output_value_selector: + return [] + + stream_out_end_node_ids = [] + for end_node_id, route_position in self.route_position.items(): + if end_node_id not in self.rest_node_ids: + continue + + # all depends on end node id not in rest node ids + if all(dep_id not in self.rest_node_ids for dep_id in self.end_stream_param.end_dependencies[end_node_id]): + if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]): + continue + + position = 0 + value_selector = None + for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]: + if position == route_position: + value_selector = current_value_selectors + break + + position += 1 + + if not value_selector: + continue + + # check chunk node id is before current node id or equal to current node id + if value_selector != stream_output_value_selector: + continue + + stream_out_end_node_ids.append(end_node_id) + + return stream_out_end_node_ids diff --git a/api/core/workflow/nodes/end/entities.py b/api/core/workflow/nodes/end/entities.py index ad4fc8f04fd43c..c16e85b0eb2a86 100644 --- a/api/core/workflow/nodes/end/entities.py +++ b/api/core/workflow/nodes/end/entities.py @@ -1,9 +1,25 @@ -from core.workflow.entities.base_node_data_entities import BaseNodeData +from pydantic import BaseModel, Field + from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.nodes.base import BaseNodeData class EndNodeData(BaseNodeData): """ END Node Data. """ + outputs: list[VariableSelector] + + +class EndStreamParam(BaseModel): + """ + EndStreamParam entity + """ + + end_dependencies: dict[str, list[str]] = Field( + ..., description="end dependencies (end node id -> dependent node ids)" + ) + end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field( + ..., description="end stream variable selector mapping (end node id -> stream variable selectors)" + ) diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py new file mode 100644 index 00000000000000..208144655b5a59 --- /dev/null +++ b/api/core/workflow/nodes/enums.py @@ -0,0 +1,24 @@ +from enum import Enum + + +class NodeType(str, Enum): + START = "start" + END = "end" + ANSWER = "answer" + LLM = "llm" + KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" + IF_ELSE = "if-else" + CODE = "code" + TEMPLATE_TRANSFORM = "template-transform" + QUESTION_CLASSIFIER = "question-classifier" + HTTP_REQUEST = "http-request" + TOOL = "tool" + VARIABLE_AGGREGATOR = "variable-aggregator" + VARIABLE_ASSIGNER = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database. + LOOP = "loop" + ITERATION = "iteration" + ITERATION_START = "iteration-start" # Fake start node for iteration. + PARAMETER_EXTRACTOR = "parameter-extractor" + CONVERSATION_VARIABLE_ASSIGNER = "assigner" + DOCUMENT_EXTRACTOR = "document-extractor" + LIST_OPERATOR = "list-operator" diff --git a/api/core/workflow/nodes/event/__init__.py b/api/core/workflow/nodes/event/__init__.py new file mode 100644 index 00000000000000..581def95533544 --- /dev/null +++ b/api/core/workflow/nodes/event/__init__.py @@ -0,0 +1,10 @@ +from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent +from .types import NodeEvent + +__all__ = [ + "RunCompletedEvent", + "RunRetrieverResourceEvent", + "RunStreamChunkEvent", + "NodeEvent", + "ModelInvokeCompletedEvent", +] diff --git a/api/core/workflow/nodes/event/event.py b/api/core/workflow/nodes/event/event.py new file mode 100644 index 00000000000000..b7034561bf6713 --- /dev/null +++ b/api/core/workflow/nodes/event/event.py @@ -0,0 +1,28 @@ +from pydantic import BaseModel, Field + +from core.model_runtime.entities.llm_entities import LLMUsage +from core.workflow.entities.node_entities import NodeRunResult + + +class RunCompletedEvent(BaseModel): + run_result: NodeRunResult = Field(..., description="run result") + + +class RunStreamChunkEvent(BaseModel): + chunk_content: str = Field(..., description="chunk content") + from_variable_selector: list[str] = Field(..., description="from variable selector") + + +class RunRetrieverResourceEvent(BaseModel): + retriever_resources: list[dict] = Field(..., description="retriever resources") + context: str = Field(..., description="context") + + +class ModelInvokeCompletedEvent(BaseModel): + """ + Model invoke completed + """ + + text: str + usage: LLMUsage + finish_reason: str | None = None diff --git a/api/core/workflow/nodes/event/types.py b/api/core/workflow/nodes/event/types.py new file mode 100644 index 00000000000000..b19a91022df2e1 --- /dev/null +++ b/api/core/workflow/nodes/event/types.py @@ -0,0 +1,3 @@ +from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent + +NodeEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent | ModelInvokeCompletedEvent diff --git a/api/core/workflow/nodes/http_request/__init__.py b/api/core/workflow/nodes/http_request/__init__.py index e69de29bb2d1d6..9408c2dde0c0e9 100644 --- a/api/core/workflow/nodes/http_request/__init__.py +++ b/api/core/workflow/nodes/http_request/__init__.py @@ -0,0 +1,4 @@ +from .entities import BodyData, HttpRequestNodeAuthorization, HttpRequestNodeBody, HttpRequestNodeData +from .node import HttpRequestNode + +__all__ = ["HttpRequestNodeData", "HttpRequestNodeAuthorization", "HttpRequestNodeBody", "BodyData", "HttpRequestNode"] diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index 90d644e0e22d3c..36ded104c16a66 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -1,49 +1,72 @@ -from typing import Literal, Optional, Union +from collections.abc import Sequence +from typing import Any, Literal, Optional -from pydantic import BaseModel, ValidationInfo, field_validator +import httpx +from pydantic import BaseModel, Field, ValidationInfo, field_validator from configs import dify_config -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.nodes.base import BaseNodeData -MAX_CONNECT_TIMEOUT = dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT -MAX_READ_TIMEOUT = dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT -MAX_WRITE_TIMEOUT = dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT +NON_FILE_CONTENT_TYPES = ( + "application/json", + "application/xml", + "text/html", + "text/plain", + "application/x-www-form-urlencoded", +) class HttpRequestNodeAuthorizationConfig(BaseModel): - type: Literal[None, 'basic', 'bearer', 'custom'] - api_key: Union[None, str] = None - header: Union[None, str] = None + type: Literal["basic", "bearer", "custom"] + api_key: str + header: str = "" class HttpRequestNodeAuthorization(BaseModel): - type: Literal['no-auth', 'api-key'] + type: Literal["no-auth", "api-key"] config: Optional[HttpRequestNodeAuthorizationConfig] = None - @field_validator('config', mode='before') + @field_validator("config", mode="before") @classmethod def check_config(cls, v: HttpRequestNodeAuthorizationConfig, values: ValidationInfo): """ Check config, if type is no-auth, config should be None, otherwise it should be a dict. """ - if values.data['type'] == 'no-auth': + if values.data["type"] == "no-auth": return None else: if not v or not isinstance(v, dict): - raise ValueError('config should be a dict') + raise ValueError("config should be a dict") return v +class BodyData(BaseModel): + key: str = "" + type: Literal["file", "text"] + value: str = "" + file: Sequence[str] = Field(default_factory=list) + + class HttpRequestNodeBody(BaseModel): - type: Literal['none', 'form-data', 'x-www-form-urlencoded', 'raw-text', 'json'] - data: Union[None, str] = None + type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json", "binary"] + data: Sequence[BodyData] = Field(default_factory=list) + + @field_validator("data", mode="before") + @classmethod + def check_data(cls, v: Any): + """For compatibility, if body is not set, return empty list.""" + if not v: + return [] + if isinstance(v, str): + return [BodyData(key="", type="text", value=v)] + return v class HttpRequestNodeTimeout(BaseModel): - connect: int = MAX_CONNECT_TIMEOUT - read: int = MAX_READ_TIMEOUT - write: int = MAX_WRITE_TIMEOUT + connect: int = dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT + read: int = dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT + write: int = dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT class HttpRequestNodeData(BaseNodeData): @@ -51,10 +74,58 @@ class HttpRequestNodeData(BaseNodeData): Code Node Data. """ - method: Literal['get', 'post', 'put', 'patch', 'delete', 'head'] + method: Literal["get", "post", "put", "patch", "delete", "head"] url: str authorization: HttpRequestNodeAuthorization headers: str params: str body: Optional[HttpRequestNodeBody] = None timeout: Optional[HttpRequestNodeTimeout] = None + + +class Response: + headers: dict[str, str] + response: httpx.Response + + def __init__(self, response: httpx.Response): + self.response = response + self.headers = dict(response.headers) + + @property + def is_file(self): + content_type = self.content_type + content_disposition = self.response.headers.get("content-disposition", "") + + return "attachment" in content_disposition or ( + not any(non_file in content_type for non_file in NON_FILE_CONTENT_TYPES) + and any(file_type in content_type for file_type in ("application/", "image/", "audio/", "video/")) + ) + + @property + def content_type(self) -> str: + return self.headers.get("content-type", "") + + @property + def text(self) -> str: + return self.response.text + + @property + def content(self) -> bytes: + return self.response.content + + @property + def status_code(self) -> int: + return self.response.status_code + + @property + def size(self) -> int: + return len(self.content) + + @property + def readable_size(self) -> str: + if self.size < 1024: + return f"{self.size} bytes" + elif self.size < 1024 * 1024: + return f"{(self.size / 1024):.2f} KB" + else: + return f"{(self.size / 1024 / 1024):.2f} MB" diff --git a/api/core/workflow/nodes/http_request/exc.py b/api/core/workflow/nodes/http_request/exc.py new file mode 100644 index 00000000000000..7a5ab7dbc1c1fa --- /dev/null +++ b/api/core/workflow/nodes/http_request/exc.py @@ -0,0 +1,18 @@ +class HttpRequestNodeError(ValueError): + """Custom error for HTTP request node.""" + + +class AuthorizationConfigError(HttpRequestNodeError): + """Raised when authorization config is missing or invalid.""" + + +class FileFetchError(HttpRequestNodeError): + """Raised when a file cannot be fetched.""" + + +class InvalidHttpMethodError(HttpRequestNodeError): + """Raised when an invalid HTTP method is used.""" + + +class ResponseSizeError(HttpRequestNodeError): + """Raised when the response size exceeds the allowed threshold.""" diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py new file mode 100644 index 00000000000000..80b322b068ec50 --- /dev/null +++ b/api/core/workflow/nodes/http_request/executor.py @@ -0,0 +1,329 @@ +import json +from collections.abc import Mapping +from copy import deepcopy +from random import randint +from typing import Any, Literal +from urllib.parse import urlencode, urlparse + +import httpx + +from configs import dify_config +from core.file import file_manager +from core.helper import ssrf_proxy +from core.workflow.entities.variable_pool import VariablePool + +from .entities import ( + HttpRequestNodeAuthorization, + HttpRequestNodeData, + HttpRequestNodeTimeout, + Response, +) +from .exc import ( + AuthorizationConfigError, + FileFetchError, + InvalidHttpMethodError, + ResponseSizeError, +) + +BODY_TYPE_TO_CONTENT_TYPE = { + "json": "application/json", + "x-www-form-urlencoded": "application/x-www-form-urlencoded", + "form-data": "multipart/form-data", + "raw-text": "text/plain", +} + + +class Executor: + method: Literal["get", "head", "post", "put", "delete", "patch"] + url: str + params: Mapping[str, str] | None + content: str | bytes | None + data: Mapping[str, Any] | None + files: Mapping[str, tuple[str | None, bytes, str]] | None + json: Any + headers: dict[str, str] + auth: HttpRequestNodeAuthorization + timeout: HttpRequestNodeTimeout + + boundary: str + + def __init__( + self, + *, + node_data: HttpRequestNodeData, + timeout: HttpRequestNodeTimeout, + variable_pool: VariablePool, + ): + # If authorization API key is present, convert the API key using the variable pool + if node_data.authorization.type == "api-key": + if node_data.authorization.config is None: + raise AuthorizationConfigError("authorization config is required") + node_data.authorization.config.api_key = variable_pool.convert_template( + node_data.authorization.config.api_key + ).text + + self.url: str = node_data.url + self.method = node_data.method + self.auth = node_data.authorization + self.timeout = timeout + self.params = {} + self.headers = {} + self.content = None + self.files = None + self.data = None + self.json = None + + # init template + self.variable_pool = variable_pool + self.node_data = node_data + self._initialize() + + def _initialize(self): + self._init_url() + self._init_params() + self._init_headers() + self._init_body() + + def _init_url(self): + self.url = self.variable_pool.convert_template(self.node_data.url).text + + def _init_params(self): + params = _plain_text_to_dict(self.node_data.params) + for key in params: + params[key] = self.variable_pool.convert_template(params[key]).text + self.params = params + + def _init_headers(self): + headers = self.variable_pool.convert_template(self.node_data.headers).text + self.headers = _plain_text_to_dict(headers) + + def _init_body(self): + body = self.node_data.body + if body is not None: + data = body.data + match body.type: + case "none": + self.content = "" + case "raw-text": + self.content = self.variable_pool.convert_template(data[0].value).text + case "json": + json_string = self.variable_pool.convert_template(data[0].value).text + json_object = json.loads(json_string) + self.json = json_object + # self.json = self._parse_object_contains_variables(json_object) + case "binary": + file_selector = data[0].file + file_variable = self.variable_pool.get_file(file_selector) + if file_variable is None: + raise FileFetchError(f"cannot fetch file with selector {file_selector}") + file = file_variable.value + self.content = file_manager.download(file) + case "x-www-form-urlencoded": + form_data = { + self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template( + item.value + ).text + for item in data + } + self.data = form_data + case "form-data": + form_data = { + self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template( + item.value + ).text + for item in filter(lambda item: item.type == "text", data) + } + file_selectors = { + self.variable_pool.convert_template(item.key).text: item.file + for item in filter(lambda item: item.type == "file", data) + } + files = {k: self.variable_pool.get_file(selector) for k, selector in file_selectors.items()} + files = {k: v for k, v in files.items() if v is not None} + files = {k: variable.value for k, variable in files.items()} + files = { + k: (v.filename, file_manager.download(v), v.mime_type or "application/octet-stream") + for k, v in files.items() + if v.related_id is not None + } + self.data = form_data + self.files = files or None + + def _assembling_headers(self) -> dict[str, Any]: + authorization = deepcopy(self.auth) + headers = deepcopy(self.headers) or {} + if self.auth.type == "api-key": + if self.auth.config is None: + raise AuthorizationConfigError("self.authorization config is required") + if authorization.config is None: + raise AuthorizationConfigError("authorization config is required") + + if self.auth.config.api_key is None: + raise AuthorizationConfigError("api_key is required") + + if not authorization.config.header: + authorization.config.header = "Authorization" + + if self.auth.config.type == "bearer": + headers[authorization.config.header] = f"Bearer {authorization.config.api_key}" + elif self.auth.config.type == "basic": + headers[authorization.config.header] = f"Basic {authorization.config.api_key}" + elif self.auth.config.type == "custom": + headers[authorization.config.header] = authorization.config.api_key or "" + + return headers + + def _validate_and_parse_response(self, response: httpx.Response) -> Response: + executor_response = Response(response) + + threshold_size = ( + dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE + if executor_response.is_file + else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE + ) + if executor_response.size > threshold_size: + raise ResponseSizeError( + f'{"File" if executor_response.is_file else "Text"} size is too large,' + f' max size is {threshold_size / 1024 / 1024:.2f} MB,' + f' but current size is {executor_response.readable_size}.' + ) + + return executor_response + + def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: + """ + do http request depending on api bundle + """ + if self.method not in {"get", "head", "post", "put", "delete", "patch"}: + raise InvalidHttpMethodError(f"Invalid http method {self.method}") + + request_args = { + "url": self.url, + "data": self.data, + "files": self.files, + "json": self.json, + "content": self.content, + "headers": headers, + "params": self.params, + "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), + "follow_redirects": True, + } + # request_args = {k: v for k, v in request_args.items() if v is not None} + + response = getattr(ssrf_proxy, self.method)(**request_args) + return response + + def invoke(self) -> Response: + # assemble headers + headers = self._assembling_headers() + # do http request + response = self._do_http_request(headers) + # validate response + return self._validate_and_parse_response(response) + + def to_log(self): + url_parts = urlparse(self.url) + path = url_parts.path or "/" + + # Add query parameters + if self.params: + query_string = urlencode(self.params) + path += f"?{query_string}" + elif url_parts.query: + path += f"?{url_parts.query}" + + raw = f"{self.method.upper()} {path} HTTP/1.1\r\n" + raw += f"Host: {url_parts.netloc}\r\n" + + headers = self._assembling_headers() + body = self.node_data.body + boundary = f"----WebKitFormBoundary{_generate_random_string(16)}" + if body: + if "content-type" not in (k.lower() for k in self.headers) and body.type in BODY_TYPE_TO_CONTENT_TYPE: + headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type] + if body.type == "form-data": + headers["Content-Type"] = f"multipart/form-data; boundary={boundary}" + for k, v in headers.items(): + if self.auth.type == "api-key": + authorization_header = "Authorization" + if self.auth.config and self.auth.config.header: + authorization_header = self.auth.config.header + if k.lower() == authorization_header.lower(): + raw += f'{k}: {"*" * len(v)}\r\n' + continue + raw += f"{k}: {v}\r\n" + + body = "" + if self.files: + for k, v in self.files.items(): + body += f"--{boundary}\r\n" + body += f'Content-Disposition: form-data; name="{k}"\r\n\r\n' + body += f"{v[1]}\r\n" + body += f"--{boundary}--\r\n" + elif self.node_data.body: + if self.content: + if isinstance(self.content, str): + body = self.content + elif isinstance(self.content, bytes): + body = self.content.decode("utf-8", errors="replace") + elif self.data and self.node_data.body.type == "x-www-form-urlencoded": + body = urlencode(self.data) + elif self.data and self.node_data.body.type == "form-data": + for key, value in self.data.items(): + body += f"--{boundary}\r\n" + body += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' + body += f"{value}\r\n" + body += f"--{boundary}--\r\n" + elif self.json: + body = json.dumps(self.json) + elif self.node_data.body.type == "raw-text": + body = self.node_data.body.data[0].value + if body: + raw += f"Content-Length: {len(body)}\r\n" + raw += "\r\n" # Empty line between headers and body + raw += body + + return raw + + +def _plain_text_to_dict(text: str, /) -> dict[str, str]: + """ + Convert a string of key-value pairs to a dictionary. + + Each line in the input string represents a key-value pair. + Keys and values are separated by ':'. + Empty values are allowed. + + Examples: + 'aa:bb\n cc:dd' -> {'aa': 'bb', 'cc': 'dd'} + 'aa:\n cc:dd\n' -> {'aa': '', 'cc': 'dd'} + 'aa\n cc : dd' -> {'aa': '', 'cc': 'dd'} + + Args: + convert_text (str): The input string to convert. + + Returns: + dict[str, str]: A dictionary of key-value pairs. + """ + return { + key.strip(): (value[0].strip() if value else "") + for line in text.splitlines() + if line.strip() + for key, *value in [line.split(":", 1)] + } + + +def _generate_random_string(n: int) -> str: + """ + Generate a random string of lowercase ASCII letters. + + Args: + n (int): The length of the random string to generate. + + Returns: + str: A random string of lowercase ASCII letters with length n. + + Example: + >>> _generate_random_string(5) + 'abcde' + """ + return "".join([chr(randint(97, 122)) for _ in range(n)]) diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py deleted file mode 100644 index db18bd00b2d9e1..00000000000000 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ /dev/null @@ -1,347 +0,0 @@ -import json -from copy import deepcopy -from random import randint -from typing import Any, Optional, Union -from urllib.parse import urlencode - -import httpx - -import core.helper.ssrf_proxy as ssrf_proxy -from configs import dify_config -from core.workflow.entities.variable_entities import VariableSelector -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.http_request.entities import ( - HttpRequestNodeAuthorization, - HttpRequestNodeBody, - HttpRequestNodeData, - HttpRequestNodeTimeout, -) -from core.workflow.utils.variable_template_parser import VariableTemplateParser - -MAX_BINARY_SIZE = dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE -READABLE_MAX_BINARY_SIZE = dify_config.HTTP_REQUEST_NODE_READABLE_MAX_BINARY_SIZE -MAX_TEXT_SIZE = dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE -READABLE_MAX_TEXT_SIZE = dify_config.HTTP_REQUEST_NODE_READABLE_MAX_TEXT_SIZE - - -class HttpExecutorResponse: - headers: dict[str, str] - response: httpx.Response - - def __init__(self, response: httpx.Response): - self.response = response - self.headers = dict(response.headers) if isinstance(self.response, httpx.Response) else {} - - @property - def is_file(self) -> bool: - """ - check if response is file - """ - content_type = self.get_content_type() - file_content_types = ['image', 'audio', 'video'] - - return any(v in content_type for v in file_content_types) - - def get_content_type(self) -> str: - return self.headers.get('content-type', '') - - def extract_file(self) -> tuple[str, bytes]: - """ - extract file from response if content type is file related - """ - if self.is_file: - return self.get_content_type(), self.body - - return '', b'' - - @property - def content(self) -> str: - if isinstance(self.response, httpx.Response): - return self.response.text - else: - raise ValueError(f'Invalid response type {type(self.response)}') - - @property - def body(self) -> bytes: - if isinstance(self.response, httpx.Response): - return self.response.content - else: - raise ValueError(f'Invalid response type {type(self.response)}') - - @property - def status_code(self) -> int: - if isinstance(self.response, httpx.Response): - return self.response.status_code - else: - raise ValueError(f'Invalid response type {type(self.response)}') - - @property - def size(self) -> int: - return len(self.body) - - @property - def readable_size(self) -> str: - if self.size < 1024: - return f'{self.size} bytes' - elif self.size < 1024 * 1024: - return f'{(self.size / 1024):.2f} KB' - else: - return f'{(self.size / 1024 / 1024):.2f} MB' - - -class HttpExecutor: - server_url: str - method: str - authorization: HttpRequestNodeAuthorization - params: dict[str, Any] - headers: dict[str, Any] - body: Union[None, str] - files: Union[None, dict[str, Any]] - boundary: str - variable_selectors: list[VariableSelector] - timeout: HttpRequestNodeTimeout - - def __init__( - self, - node_data: HttpRequestNodeData, - timeout: HttpRequestNodeTimeout, - variable_pool: Optional[VariablePool] = None, - ): - self.server_url = node_data.url - self.method = node_data.method - self.authorization = node_data.authorization - self.timeout = timeout - self.params = {} - self.headers = {} - self.body = None - self.files = None - - # init template - self.variable_selectors = [] - self._init_template(node_data, variable_pool) - - @staticmethod - def _is_json_body(body: HttpRequestNodeBody): - """ - check if body is json - """ - if body and body.type == 'json' and body.data: - try: - json.loads(body.data) - return True - except: - return False - - return False - - @staticmethod - def _to_dict(convert_text: str): - """ - Convert the string like `aa:bb\n cc:dd` to dict `{aa:bb, cc:dd}` - """ - kv_paris = convert_text.split('\n') - result = {} - for kv in kv_paris: - if not kv.strip(): - continue - - kv = kv.split(':', maxsplit=1) - if len(kv) == 1: - k, v = kv[0], '' - else: - k, v = kv - result[k.strip()] = v - return result - - def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional[VariablePool] = None): - # extract all template in url - self.server_url, server_url_variable_selectors = self._format_template(node_data.url, variable_pool) - - # extract all template in params - params, params_variable_selectors = self._format_template(node_data.params, variable_pool) - self.params = self._to_dict(params) - - # extract all template in headers - headers, headers_variable_selectors = self._format_template(node_data.headers, variable_pool) - self.headers = self._to_dict(headers) - - # extract all template in body - body_data_variable_selectors = [] - if node_data.body: - # check if it's a valid JSON - is_valid_json = self._is_json_body(node_data.body) - - body_data = node_data.body.data or '' - if body_data: - body_data, body_data_variable_selectors = self._format_template(body_data, variable_pool, is_valid_json) - - content_type_is_set = any(key.lower() == 'content-type' for key in self.headers) - if node_data.body.type == 'json' and not content_type_is_set: - self.headers['Content-Type'] = 'application/json' - elif node_data.body.type == 'x-www-form-urlencoded' and not content_type_is_set: - self.headers['Content-Type'] = 'application/x-www-form-urlencoded' - - if node_data.body.type in ['form-data', 'x-www-form-urlencoded']: - body = self._to_dict(body_data) - - if node_data.body.type == 'form-data': - self.files = {k: ('', v) for k, v in body.items()} - random_str = lambda n: ''.join([chr(randint(97, 122)) for _ in range(n)]) - self.boundary = f'----WebKitFormBoundary{random_str(16)}' - - self.headers['Content-Type'] = f'multipart/form-data; boundary={self.boundary}' - else: - self.body = urlencode(body) - elif node_data.body.type in ['json', 'raw-text']: - self.body = body_data - elif node_data.body.type == 'none': - self.body = '' - - self.variable_selectors = ( - server_url_variable_selectors - + params_variable_selectors - + headers_variable_selectors - + body_data_variable_selectors - ) - - def _assembling_headers(self) -> dict[str, Any]: - authorization = deepcopy(self.authorization) - headers = deepcopy(self.headers) or {} - if self.authorization.type == 'api-key': - if self.authorization.config is None: - raise ValueError('self.authorization config is required') - if authorization.config is None: - raise ValueError('authorization config is required') - - if self.authorization.config.api_key is None: - raise ValueError('api_key is required') - - if not authorization.config.header: - authorization.config.header = 'Authorization' - - if self.authorization.config.type == 'bearer': - headers[authorization.config.header] = f'Bearer {authorization.config.api_key}' - elif self.authorization.config.type == 'basic': - headers[authorization.config.header] = f'Basic {authorization.config.api_key}' - elif self.authorization.config.type == 'custom': - headers[authorization.config.header] = authorization.config.api_key - - return headers - - def _validate_and_parse_response(self, response: httpx.Response) -> HttpExecutorResponse: - """ - validate the response - """ - if isinstance(response, httpx.Response): - executor_response = HttpExecutorResponse(response) - else: - raise ValueError(f'Invalid response type {type(response)}') - - if executor_response.is_file: - if executor_response.size > MAX_BINARY_SIZE: - raise ValueError( - f'File size is too large, max size is {READABLE_MAX_BINARY_SIZE}, but current size is {executor_response.readable_size}.' - ) - else: - if executor_response.size > MAX_TEXT_SIZE: - raise ValueError( - f'Text size is too large, max size is {READABLE_MAX_TEXT_SIZE}, but current size is {executor_response.readable_size}.' - ) - - return executor_response - - def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: - """ - do http request depending on api bundle - """ - kwargs = { - 'url': self.server_url, - 'headers': headers, - 'params': self.params, - 'timeout': (self.timeout.connect, self.timeout.read, self.timeout.write), - 'follow_redirects': True, - } - - if self.method in ('get', 'head', 'post', 'put', 'delete', 'patch'): - response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs) - else: - raise ValueError(f'Invalid http method {self.method}') - return response - - def invoke(self) -> HttpExecutorResponse: - """ - invoke http request - """ - # assemble headers - headers = self._assembling_headers() - - # do http request - response = self._do_http_request(headers) - - # validate response - return self._validate_and_parse_response(response) - - def to_raw_request(self) -> str: - """ - convert to raw request - """ - server_url = self.server_url - if self.params: - server_url += f'?{urlencode(self.params)}' - - raw_request = f'{self.method.upper()} {server_url} HTTP/1.1\n' - - headers = self._assembling_headers() - for k, v in headers.items(): - # get authorization header - if self.authorization.type == 'api-key': - authorization_header = 'Authorization' - if self.authorization.config and self.authorization.config.header: - authorization_header = self.authorization.config.header - - if k.lower() == authorization_header.lower(): - raw_request += f'{k}: {"*" * len(v)}\n' - continue - - raw_request += f'{k}: {v}\n' - - raw_request += '\n' - - # if files, use multipart/form-data with boundary - if self.files: - boundary = self.boundary - raw_request += f'--{boundary}' - for k, v in self.files.items(): - raw_request += f'\nContent-Disposition: form-data; name="{k}"\n\n' - raw_request += f'{v[1]}\n' - raw_request += f'--{boundary}' - raw_request += '--' - else: - raw_request += self.body or '' - - return raw_request - - def _format_template( - self, template: str, variable_pool: Optional[VariablePool], escape_quotes: bool = False - ) -> tuple[str, list[VariableSelector]]: - """ - format template - """ - variable_template_parser = VariableTemplateParser(template=template) - variable_selectors = variable_template_parser.extract_variable_selectors() - - if variable_pool: - variable_value_mapping = {} - for variable_selector in variable_selectors: - variable = variable_pool.get_any(variable_selector.value_selector) - if variable is None: - raise ValueError(f'Variable {variable_selector.variable} not found') - if escape_quotes and isinstance(variable, str): - value = variable.replace('"', '\\"').replace('\n', '\\n') - else: - value = variable - variable_value_mapping[variable_selector.variable] = value - - return variable_template_parser.format(variable_value_mapping), variable_selectors - else: - return template, variable_selectors diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py deleted file mode 100644 index 1facf8a4f4a4b5..00000000000000 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ /dev/null @@ -1,163 +0,0 @@ -import logging -from mimetypes import guess_extension -from os import path -from typing import cast - -from core.app.segments import parser -from core.file.file_obj import FileTransferMethod, FileType, FileVar -from core.tools.tool_file_manager import ToolFileManager -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.http_request.entities import ( - MAX_CONNECT_TIMEOUT, - MAX_READ_TIMEOUT, - MAX_WRITE_TIMEOUT, - HttpRequestNodeData, - HttpRequestNodeTimeout, -) -from core.workflow.nodes.http_request.http_executor import HttpExecutor, HttpExecutorResponse -from models.workflow import WorkflowNodeExecutionStatus - -HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( - connect=min(10, MAX_CONNECT_TIMEOUT), - read=min(60, MAX_READ_TIMEOUT), - write=min(20, MAX_WRITE_TIMEOUT), -) - - -class HttpRequestNode(BaseNode): - _node_data_cls = HttpRequestNodeData - _node_type = NodeType.HTTP_REQUEST - - @classmethod - def get_default_config(cls, filters: dict | None = None) -> dict: - return { - 'type': 'http-request', - 'config': { - 'method': 'get', - 'authorization': { - 'type': 'no-auth', - }, - 'body': {'type': 'none'}, - 'timeout': { - **HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(), - 'max_connect_timeout': MAX_CONNECT_TIMEOUT, - 'max_read_timeout': MAX_READ_TIMEOUT, - 'max_write_timeout': MAX_WRITE_TIMEOUT, - }, - }, - } - - def _run(self, variable_pool: VariablePool) -> NodeRunResult: - node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data) - # TODO: Switch to use segment directly - if node_data.authorization.config and node_data.authorization.config.api_key: - node_data.authorization.config.api_key = parser.convert_template(template=node_data.authorization.config.api_key, variable_pool=variable_pool).text - - # init http executor - http_executor = None - try: - http_executor = HttpExecutor( - node_data=node_data, timeout=self._get_request_timeout(node_data), variable_pool=variable_pool - ) - - # invoke http executor - response = http_executor.invoke() - except Exception as e: - process_data = {} - if http_executor: - process_data = { - 'request': http_executor.to_raw_request(), - } - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - process_data=process_data, - ) - - files = self.extract_files(http_executor.server_url, response) - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - 'status_code': response.status_code, - 'body': response.content if not files else '', - 'headers': response.headers, - 'files': files, - }, - process_data={ - 'request': http_executor.to_raw_request(), - }, - ) - - def _get_request_timeout(self, node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout: - timeout = node_data.timeout - if timeout is None: - return HTTP_REQUEST_DEFAULT_TIMEOUT - - timeout.connect = timeout.connect or HTTP_REQUEST_DEFAULT_TIMEOUT.connect - timeout.connect = min(timeout.connect, MAX_CONNECT_TIMEOUT) - timeout.read = timeout.read or HTTP_REQUEST_DEFAULT_TIMEOUT.read - timeout.read = min(timeout.read, MAX_READ_TIMEOUT) - timeout.write = timeout.write or HTTP_REQUEST_DEFAULT_TIMEOUT.write - timeout.write = min(timeout.write, MAX_WRITE_TIMEOUT) - return timeout - - @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: - """ - Extract variable selector to variable mapping - :param node_data: node data - :return: - """ - node_data = cast(HttpRequestNodeData, node_data) - try: - http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT) - - variable_selectors = http_executor.variable_selectors - - variable_mapping = {} - for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector - - return variable_mapping - except Exception as e: - logging.exception(f'Failed to extract variable selector to variable mapping: {e}') - return {} - - def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVar]: - """ - Extract files from response - """ - files = [] - mimetype, file_binary = response.extract_file() - - if mimetype: - # extract filename from url - filename = path.basename(url) - # extract extension if possible - extension = guess_extension(mimetype) or '.bin' - - tool_file = ToolFileManager.create_file_by_raw( - user_id=self.user_id, - tenant_id=self.tenant_id, - conversation_id=None, - file_binary=file_binary, - mimetype=mimetype, - ) - - files.append( - FileVar( - tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=FileTransferMethod.TOOL_FILE, - related_id=tool_file.id, - filename=filename, - extension=extension, - mime_type=mimetype, - ) - ) - - return files diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py new file mode 100644 index 00000000000000..5b399bed63df97 --- /dev/null +++ b/api/core/workflow/nodes/http_request/node.py @@ -0,0 +1,176 @@ +import logging +from collections.abc import Mapping, Sequence +from mimetypes import guess_extension +from os import path +from typing import Any + +from configs import dify_config +from core.file import File, FileTransferMethod, FileType +from core.tools.tool_file_manager import ToolFileManager +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.http_request.executor import Executor +from core.workflow.utils import variable_template_parser +from factories import file_factory +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import ( + HttpRequestNodeData, + HttpRequestNodeTimeout, + Response, +) +from .exc import HttpRequestNodeError + +HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( + connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, + read=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, + write=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, +) + +logger = logging.getLogger(__name__) + + +class HttpRequestNode(BaseNode[HttpRequestNodeData]): + _node_data_cls = HttpRequestNodeData + _node_type = NodeType.HTTP_REQUEST + + @classmethod + def get_default_config(cls, filters: dict | None = None) -> dict: + return { + "type": "http-request", + "config": { + "method": "get", + "authorization": { + "type": "no-auth", + }, + "body": {"type": "none"}, + "timeout": { + **HTTP_REQUEST_DEFAULT_TIMEOUT.model_dump(), + "max_connect_timeout": dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, + "max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, + "max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, + }, + }, + } + + def _run(self) -> NodeRunResult: + process_data = {} + try: + http_executor = Executor( + node_data=self.node_data, + timeout=self._get_request_timeout(self.node_data), + variable_pool=self.graph_runtime_state.variable_pool, + ) + process_data["request"] = http_executor.to_log() + + response = http_executor.invoke() + files = self.extract_files(url=http_executor.url, response=response) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={ + "status_code": response.status_code, + "body": response.text if not files else "", + "headers": response.headers, + "files": files, + }, + process_data={ + "request": http_executor.to_log(), + }, + ) + except HttpRequestNodeError as e: + logger.warning(f"http request node {self.node_id} failed to run: {e}") + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + process_data=process_data, + ) + + @staticmethod + def _get_request_timeout(node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout: + timeout = node_data.timeout + if timeout is None: + return HTTP_REQUEST_DEFAULT_TIMEOUT + + timeout.connect = timeout.connect or HTTP_REQUEST_DEFAULT_TIMEOUT.connect + timeout.read = timeout.read or HTTP_REQUEST_DEFAULT_TIMEOUT.read + timeout.write = timeout.write or HTTP_REQUEST_DEFAULT_TIMEOUT.write + return timeout + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: HttpRequestNodeData, + ) -> Mapping[str, Sequence[str]]: + selectors: list[VariableSelector] = [] + selectors += variable_template_parser.extract_selectors_from_template(node_data.headers) + selectors += variable_template_parser.extract_selectors_from_template(node_data.params) + if node_data.body: + body_type = node_data.body.type + data = node_data.body.data + match body_type: + case "binary": + selector = data[0].file + selectors.append(VariableSelector(variable="#" + ".".join(selector) + "#", value_selector=selector)) + case "json" | "raw-text": + selectors += variable_template_parser.extract_selectors_from_template(data[0].key) + selectors += variable_template_parser.extract_selectors_from_template(data[0].value) + case "x-www-form-urlencoded": + for item in data: + selectors += variable_template_parser.extract_selectors_from_template(item.key) + selectors += variable_template_parser.extract_selectors_from_template(item.value) + case "form-data": + for item in data: + selectors += variable_template_parser.extract_selectors_from_template(item.key) + if item.type == "text": + selectors += variable_template_parser.extract_selectors_from_template(item.value) + elif item.type == "file": + selectors.append( + VariableSelector(variable="#" + ".".join(item.file) + "#", value_selector=item.file) + ) + + mapping = {} + for selector in selectors: + mapping[node_id + "." + selector.variable] = selector.value_selector + + return mapping + + def extract_files(self, url: str, response: Response) -> list[File]: + """ + Extract files from response + """ + files = [] + is_file = response.is_file + content_type = response.content_type + content = response.content + + if is_file and content_type: + # extract filename from url + filename = path.basename(url) + # extract extension if possible + extension = guess_extension(content_type) or ".bin" + + tool_file = ToolFileManager.create_file_by_raw( + user_id=self.user_id, + tenant_id=self.tenant_id, + conversation_id=None, + file_binary=content, + mimetype=content_type, + ) + + mapping = { + "tool_file_id": tool_file.id, + "type": FileType.IMAGE.value, + "transfer_method": FileTransferMethod.TOOL_FILE.value, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, + ) + files.append(file) + + return files diff --git a/api/core/workflow/nodes/if_else/__init__.py b/api/core/workflow/nodes/if_else/__init__.py index e69de29bb2d1d6..afa0e8112c5b17 100644 --- a/api/core/workflow/nodes/if_else/__init__.py +++ b/api/core/workflow/nodes/if_else/__init__.py @@ -0,0 +1,3 @@ +from .if_else_node import IfElseNode + +__all__ = ["IfElseNode"] diff --git a/api/core/workflow/nodes/if_else/entities.py b/api/core/workflow/nodes/if_else/entities.py index bc6dce0d3bd37a..23f5d2cc317f78 100644 --- a/api/core/workflow/nodes/if_else/entities.py +++ b/api/core/workflow/nodes/if_else/entities.py @@ -1,22 +1,9 @@ from typing import Literal, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field -from core.workflow.entities.base_node_data_entities import BaseNodeData - - -class Condition(BaseModel): - """ - Condition entity - """ - variable_selector: list[str] - comparison_operator: Literal[ - # for string or array - "contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", - # for number - "=", "≠", ">", "<", "≥", "≤", "null", "not null" - ] - value: Optional[str] = None +from core.workflow.nodes.base import BaseNodeData +from core.workflow.utils.condition.entities import Condition class IfElseNodeData(BaseNodeData): @@ -28,11 +15,12 @@ class Case(BaseModel): """ Case entity representing a single logical condition group """ + case_id: str logical_operator: Literal["and", "or"] conditions: list[Condition] logical_operator: Optional[Literal["and", "or"]] = "and" - conditions: Optional[list[Condition]] = None + conditions: Optional[list[Condition]] = Field(default=None, deprecated=True) cases: Optional[list[Case]] = None diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index c6d235627f04b0..6960fc045a5efc 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -1,46 +1,44 @@ -from collections.abc import Sequence -from typing import Optional, cast +from collections.abc import Mapping, Sequence +from typing import Any, Literal -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType +from typing_extensions import deprecated + +from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.if_else.entities import Condition, IfElseNodeData -from core.workflow.utils.variable_template_parser import VariableTemplateParser +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.if_else.entities import IfElseNodeData +from core.workflow.utils.condition.entities import Condition +from core.workflow.utils.condition.processor import ConditionProcessor from models.workflow import WorkflowNodeExecutionStatus -class IfElseNode(BaseNode): +class IfElseNode(BaseNode[IfElseNodeData]): _node_data_cls = IfElseNodeData _node_type = NodeType.IF_ELSE - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run node - :param variable_pool: variable pool :return: """ - node_data = self.node_data - node_data = cast(IfElseNodeData, node_data) - - node_inputs = { - "conditions": [] - } + node_inputs: dict[str, list] = {"conditions": []} - process_datas = { - "condition_results": [] - } + process_datas: dict[str, list] = {"condition_results": []} input_conditions = [] final_result = False selected_case_id = None + condition_processor = ConditionProcessor() try: # Check if the new cases structure is used - if node_data.cases: - for case in node_data.cases: - input_conditions, group_result = self.process_conditions(variable_pool, case.conditions) - # Apply the logical operator for the current case - final_result = all(group_result) if case.logical_operator == "and" else any(group_result) + if self.node_data.cases: + for case in self.node_data.cases: + input_conditions, group_result, final_result = condition_processor.process_conditions( + variable_pool=self.graph_runtime_state.variable_pool, + conditions=case.conditions, + operator=case.logical_operator, + ) process_datas["condition_results"].append( { @@ -56,29 +54,26 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: break else: + # TODO: Update database then remove this # Fallback to old structure if cases are not defined - input_conditions, group_result = self.process_conditions(variable_pool, node_data.conditions) - - final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result) + input_conditions, group_result, final_result = _should_not_use_old_function( + condition_processor=condition_processor, + variable_pool=self.graph_runtime_state.variable_pool, + conditions=self.node_data.conditions or [], + operator=self.node_data.logical_operator or "and", + ) selected_case_id = "true" if final_result else "false" process_datas["condition_results"].append( - { - "group": "default", - "results": group_result, - "final_result": final_result - } + {"group": "default", "results": group_result, "final_result": final_result} ) node_inputs["conditions"] = input_conditions except Exception as e: return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=node_inputs, - process_data=process_datas, - error=str(e) + status=WorkflowNodeExecutionStatus.FAILED, inputs=node_inputs, process_data=process_datas, error=str(e) ) outputs = {"result": final_result, "selected_case_id": selected_case_id} @@ -87,371 +82,40 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, process_data=process_datas, - edge_source_handle=selected_case_id if selected_case_id else "false", # Use case ID or 'default' - outputs=outputs + edge_source_handle=selected_case_id or "false", # Use case ID or 'default' + outputs=outputs, ) return data - def evaluate_condition( - self, actual_value: Optional[str | list], expected_value: str, comparison_operator: str - ) -> bool: - """ - Evaluate condition - :param actual_value: actual value - :param expected_value: expected value - :param comparison_operator: comparison operator - - :return: bool - """ - if comparison_operator == "contains": - return self._assert_contains(actual_value, expected_value) - elif comparison_operator == "not contains": - return self._assert_not_contains(actual_value, expected_value) - elif comparison_operator == "start with": - return self._assert_start_with(actual_value, expected_value) - elif comparison_operator == "end with": - return self._assert_end_with(actual_value, expected_value) - elif comparison_operator == "is": - return self._assert_is(actual_value, expected_value) - elif comparison_operator == "is not": - return self._assert_is_not(actual_value, expected_value) - elif comparison_operator == "empty": - return self._assert_empty(actual_value) - elif comparison_operator == "not empty": - return self._assert_not_empty(actual_value) - elif comparison_operator == "=": - return self._assert_equal(actual_value, expected_value) - elif comparison_operator == "≠": - return self._assert_not_equal(actual_value, expected_value) - elif comparison_operator == ">": - return self._assert_greater_than(actual_value, expected_value) - elif comparison_operator == "<": - return self._assert_less_than(actual_value, expected_value) - elif comparison_operator == "≥": - return self._assert_greater_than_or_equal(actual_value, expected_value) - elif comparison_operator == "≤": - return self._assert_less_than_or_equal(actual_value, expected_value) - elif comparison_operator == "null": - return self._assert_null(actual_value) - elif comparison_operator == "not null": - return self._assert_not_null(actual_value) - else: - raise ValueError(f"Invalid comparison operator: {comparison_operator}") - - def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]): - input_conditions = [] - group_result = [] - - for condition in conditions: - actual_variable = variable_pool.get_any(condition.variable_selector) - - if condition.value is not None: - variable_template_parser = VariableTemplateParser(template=condition.value) - expected_value = variable_template_parser.extract_variable_selectors() - variable_selectors = variable_template_parser.extract_variable_selectors() - if variable_selectors: - for variable_selector in variable_selectors: - value = variable_pool.get_any(variable_selector.value_selector) - expected_value = variable_template_parser.format({variable_selector.variable: value}) - else: - expected_value = condition.value - else: - expected_value = None - - comparison_operator = condition.comparison_operator - input_conditions.append( - { - "actual_value": actual_variable, - "expected_value": expected_value, - "comparison_operator": comparison_operator - } - ) - - result = self.evaluate_condition(actual_variable, expected_value, comparison_operator) - group_result.append(result) - - return input_conditions, group_result - - def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: - """ - Assert contains - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return False - - if not isinstance(actual_value, str | list): - raise ValueError('Invalid actual value type: string or array') - - if expected_value not in actual_value: - return False - return True - - def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: - """ - Assert not contains - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return True - - if not isinstance(actual_value, str | list): - raise ValueError('Invalid actual value type: string or array') - - if expected_value in actual_value: - return False - return True - - def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert start with - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return False - - if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') - - if not actual_value.startswith(expected_value): - return False - return True - - def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert end with - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if not actual_value: - return False - - if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') - - if not actual_value.endswith(expected_value): - return False - return True - - def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert is - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') - - if actual_value != expected_value: - return False - return True - - def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool: - """ - Assert is not - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, str): - raise ValueError('Invalid actual value type: string') - - if actual_value == expected_value: - return False - return True - - def _assert_empty(self, actual_value: Optional[str]) -> bool: - """ - Assert empty - :param actual_value: actual value - :return: - """ - if not actual_value: - return True - return False - - def _assert_not_empty(self, actual_value: Optional[str]) -> bool: - """ - Assert not empty - :param actual_value: actual value - :return: - """ - if actual_value: - return True - return False - - def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value != expected_value: - return False - return True - - def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert not equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value == expected_value: - return False - return True - - def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert greater than - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value <= expected_value: - return False - return True - - def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert less than - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value >= expected_value: - return False - return True - - def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert greater than or equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value < expected_value: - return False - return True - - def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: - """ - Assert less than or equal - :param actual_value: actual value - :param expected_value: expected value - :return: - """ - if actual_value is None: - return False - - if not isinstance(actual_value, int | float): - raise ValueError('Invalid actual value type: number') - - if isinstance(actual_value, int): - expected_value = int(expected_value) - else: - expected_value = float(expected_value) - - if actual_value > expected_value: - return False - return True - - def _assert_null(self, actual_value: Optional[int | float]) -> bool: - """ - Assert null - :param actual_value: actual value - :return: - """ - if actual_value is None: - return True - return False - - def _assert_not_null(self, actual_value: Optional[int | float]) -> bool: - """ - Assert not null - :param actual_value: actual value - :return: - """ - if actual_value is not None: - return True - return False - @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: IfElseNodeData, + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ return {} + + +@deprecated("This function is deprecated. You should use the new cases structure.") +def _should_not_use_old_function( + *, + condition_processor: ConditionProcessor, + variable_pool: VariablePool, + conditions: list[Condition], + operator: Literal["and", "or"], +): + return condition_processor.process_conditions( + variable_pool=variable_pool, + conditions=conditions, + operator=operator, + ) diff --git a/api/core/workflow/nodes/iteration/__init__.py b/api/core/workflow/nodes/iteration/__init__.py index e69de29bb2d1d6..5bb87aaffa92b4 100644 --- a/api/core/workflow/nodes/iteration/__init__.py +++ b/api/core/workflow/nodes/iteration/__init__.py @@ -0,0 +1,5 @@ +from .entities import IterationNodeData +from .iteration_node import IterationNode +from .iteration_start_node import IterationStartNode + +__all__ = ["IterationNode", "IterationNodeData", "IterationStartNode"] diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py index 177b47b9518e00..ebcb6f82fbc397 100644 --- a/api/core/workflow/nodes/iteration/entities.py +++ b/api/core/workflow/nodes/iteration/entities.py @@ -1,27 +1,51 @@ +from enum import Enum from typing import Any, Optional -from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState +from pydantic import Field + +from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData + + +class ErrorHandleMode(str, Enum): + TERMINATED = "terminated" + CONTINUE_ON_ERROR = "continue-on-error" + REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output" class IterationNodeData(BaseIterationNodeData): """ Iteration Node Data. """ - parent_loop_id: Optional[str] = None # redundant field, not used currently - iterator_selector: list[str] # variable selector - output_selector: list[str] # output selector + + parent_loop_id: Optional[str] = None # redundant field, not used currently + iterator_selector: list[str] # variable selector + output_selector: list[str] # output selector + is_parallel: bool = False # open the parallel mode or not + parallel_nums: int = 10 # the numbers of parallel + error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED # how to handle the error + + +class IterationStartNodeData(BaseNodeData): + """ + Iteration Start Node Data. + """ + + pass + class IterationState(BaseIterationState): """ Iteration State. """ - outputs: list[Any] = None + + outputs: list[Any] = Field(default_factory=list) current_output: Optional[Any] = None class MetaData(BaseIterationState.MetaData): """ Data. """ + iterator_length: int def get_last_output(self) -> Optional[Any]: @@ -31,9 +55,9 @@ def get_last_output(self) -> Optional[Any]: if self.outputs: return self.outputs[-1] return None - + def get_current_output(self) -> Optional[Any]: """ Get current output. """ - return self.current_output \ No newline at end of file + return self.current_output diff --git a/api/core/workflow/nodes/iteration/exc.py b/api/core/workflow/nodes/iteration/exc.py new file mode 100644 index 00000000000000..d9947e09bc10c8 --- /dev/null +++ b/api/core/workflow/nodes/iteration/exc.py @@ -0,0 +1,22 @@ +class IterationNodeError(ValueError): + """Base class for iteration node errors.""" + + +class IteratorVariableNotFoundError(IterationNodeError): + """Raised when the iterator variable is not found.""" + + +class InvalidIteratorValueError(IterationNodeError): + """Raised when the iterator value is invalid.""" + + +class StartNodeIdNotFoundError(IterationNodeError): + """Raised when the start node ID is not found.""" + + +class IterationGraphNotFoundError(IterationNodeError): + """Raised when the iteration graph is not found.""" + + +class IterationIndexNotFoundError(IterationNodeError): + """Raised when the iteration index is not found.""" diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 54dfe8b7f4e40e..e5863d771b0431 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,124 +1,551 @@ -from typing import cast +import logging +import uuid +from collections.abc import Generator, Mapping, Sequence +from concurrent.futures import Future, wait +from datetime import datetime, timezone +from queue import Empty, Queue +from typing import TYPE_CHECKING, Any, Optional, cast +from flask import Flask, current_app + +from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities.base_node_data_entities import BaseIterationState -from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.node_entities import ( + NodeRunMetadataKey, + NodeRunResult, +) from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseIterationNode -from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState +from core.workflow.graph_engine.entities.event import ( + BaseGraphEvent, + BaseNodeEvent, + BaseParallelBranchEvent, + GraphRunFailedEvent, + InNodeEvent, + IterationRunFailedEvent, + IterationRunNextEvent, + IterationRunStartedEvent, + IterationRunSucceededEvent, + NodeInIterationFailedEvent, + NodeRunFailedEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event import NodeEvent, RunCompletedEvent +from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from models.workflow import WorkflowNodeExecutionStatus +from .exc import ( + InvalidIteratorValueError, + IterationGraphNotFoundError, + IterationIndexNotFoundError, + IterationNodeError, + IteratorVariableNotFoundError, + StartNodeIdNotFoundError, +) + +if TYPE_CHECKING: + from core.workflow.graph_engine.graph_engine import GraphEngine +logger = logging.getLogger(__name__) + -class IterationNode(BaseIterationNode): +class IterationNode(BaseNode[IterationNodeData]): """ Iteration Node. """ + _node_data_cls = IterationNodeData _node_type = NodeType.ITERATION - def _run(self, variable_pool: VariablePool) -> BaseIterationState: + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + return { + "type": "iteration", + "config": { + "is_parallel": False, + "parallel_nums": 10, + "error_handle_mode": ErrorHandleMode.TERMINATED.value, + }, + } + + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: """ Run the node. """ - self.node_data = cast(IterationNodeData, self.node_data) - iterator = variable_pool.get_any(self.node_data.iterator_selector) - - if not isinstance(iterator, list): - raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.") - - state = IterationState(iteration_node_id=self.node_id, index=-1, inputs={ - 'iterator_selector': iterator - }, outputs=[], metadata=IterationState.MetaData( - iterator_length=len(iterator) if iterator is not None else 0 - )) - - self._set_current_iteration_variable(variable_pool, state) - return state - - def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str: - """ - Get next iteration start node id based on the graph. - :param graph: graph - :return: next node id - """ - # resolve current output - self._resolve_current_output(variable_pool, state) - # move to next iteration - self._next_iteration(variable_pool, state) - - node_data = cast(IterationNodeData, self.node_data) - if self._reached_iteration_limit(variable_pool, state): - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - 'output': jsonable_encoder(state.outputs) - } - ) - - return node_data.start_node_id - - def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState): - """ - Set current iteration variable. - :variable_pool: variable pool - """ - node_data = cast(IterationNodeData, self.node_data) + iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) - variable_pool.add((self.node_id, 'index'), state.index) - # get the iterator value - iterator = variable_pool.get_any(node_data.iterator_selector) + if not iterator_list_segment: + raise IteratorVariableNotFoundError(f"Iterator variable {self.node_data.iterator_selector} not found") - if iterator is None or not isinstance(iterator, list): + if len(iterator_list_segment.value) == 0: + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={"output": []}, + ) + ) return - - if state.index < len(iterator): - variable_pool.add((self.node_id, 'item'), iterator[state.index]) - def _next_iteration(self, variable_pool: VariablePool, state: IterationState): + iterator_list_value = iterator_list_segment.to_object() + + if not isinstance(iterator_list_value, list): + raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") + + inputs = {"iterator_selector": iterator_list_value} + + graph_config = self.graph_config + + if not self.node_data.start_node_id: + raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found") + + root_node_id = self.node_data.start_node_id + + # init graph + iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id) + + if not iteration_graph: + raise IterationGraphNotFoundError("iteration graph not found") + + variable_pool = self.graph_runtime_state.variable_pool + + # append iteration variable (item, index) to variable pool + variable_pool.add([self.node_id, "index"], 0) + variable_pool.add([self.node_id, "item"], iterator_list_value[0]) + + # init graph engine + from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool + + graph_engine = GraphEngine( + tenant_id=self.tenant_id, + app_id=self.app_id, + workflow_type=self.workflow_type, + workflow_id=self.workflow_id, + user_id=self.user_id, + user_from=self.user_from, + invoke_from=self.invoke_from, + call_depth=self.workflow_call_depth, + graph=iteration_graph, + graph_config=graph_config, + variable_pool=variable_pool, + max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, + max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, + thread_pool_id=self.thread_pool_id, + ) + + start_at = datetime.now(timezone.utc).replace(tzinfo=None) + + yield IterationRunStartedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + metadata={"iterator_length": len(iterator_list_value)}, + predecessor_node_id=self.previous_node_id, + ) + + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=0, + pre_iteration_output=None, + ) + outputs: list[Any] = [None] * len(iterator_list_value) + try: + if self.node_data.is_parallel: + futures: list[Future] = [] + q = Queue() + thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100) + for index, item in enumerate(iterator_list_value): + future: Future = thread_pool.submit( + self._run_single_iter_parallel, + current_app._get_current_object(), + q, + iterator_list_value, + inputs, + outputs, + start_at, + graph_engine, + iteration_graph, + index, + item, + ) + future.add_done_callback(thread_pool.task_done_callback) + futures.append(future) + succeeded_count = 0 + while True: + try: + event = q.get(timeout=1) + if event is None: + break + if isinstance(event, IterationRunNextEvent): + succeeded_count += 1 + if succeeded_count == len(futures): + q.put(None) + yield event + if isinstance(event, RunCompletedEvent): + q.put(None) + for f in futures: + if not f.done(): + f.cancel() + yield event + if isinstance(event, IterationRunFailedEvent): + q.put(None) + yield event + except Empty: + continue + + # wait all threads + wait(futures) + else: + for _ in range(len(iterator_list_value)): + yield from self._run_single_iter( + iterator_list_value, + variable_pool, + inputs, + outputs, + start_at, + graph_engine, + iteration_graph, + ) + if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: + outputs = [output for output in outputs if output is not None] + yield IterationRunSucceededEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": jsonable_encoder(outputs)}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + ) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": jsonable_encoder(outputs)} + ) + ) + except IterationNodeError as e: + # iteration run failed + logger.warning("Iteration run failed") + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": jsonable_encoder(outputs)}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=str(e), + ) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + ) + ) + finally: + # remove iteration variable (item, index) from variable pool after iteration run completed + variable_pool.remove([self.node_id, "index"]) + variable_pool.remove([self.node_id, "item"]) + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: IterationNodeData, + ) -> Mapping[str, Sequence[str]]: """ - Move to next iteration. - :param variable_pool: variable pool + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: """ - state.index += 1 - self._set_current_iteration_variable(variable_pool, state) + variable_mapping = { + f"{node_id}.input_selector": node_data.iterator_selector, + } + + # init graph + iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id) + + if not iteration_graph: + raise IterationGraphNotFoundError("iteration graph not found") + + for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items(): + if sub_node_config.get("data", {}).get("iteration_id") != node_id: + continue - def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState): + # variable selector to variable mapping + try: + # Get node class + from core.workflow.nodes.node_mapping import node_type_classes_mapping + + node_type = NodeType(sub_node_config.get("data", {}).get("type")) + node_cls = node_type_classes_mapping.get(node_type) + if not node_cls: + continue + + sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=graph_config, config=sub_node_config + ) + sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping) + except NotImplementedError: + sub_node_variable_mapping = {} + + # remove iteration variables + sub_node_variable_mapping = { + sub_node_id + "." + key: value + for key, value in sub_node_variable_mapping.items() + if value[0] != node_id + } + + variable_mapping.update(sub_node_variable_mapping) + + # remove variable out from iteration + variable_mapping = { + key: value for key, value in variable_mapping.items() if value[0] not in iteration_graph.node_ids + } + + return variable_mapping + + def _handle_event_metadata( + self, event: BaseNodeEvent, iter_run_index: str, parallel_mode_run_id: str + ) -> NodeRunStartedEvent | BaseNodeEvent: """ - Check if iteration limit is reached. - :return: True if iteration limit is reached, False otherwise + add iteration metadata to event. """ - node_data = cast(IterationNodeData, self.node_data) - iterator = variable_pool.get_any(node_data.iterator_selector) + if not isinstance(event, BaseNodeEvent): + return event + if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent): + event.parallel_mode_run_id = parallel_mode_run_id + return event + if event.route_node_state.node_run_result: + metadata = event.route_node_state.node_run_result.metadata + if not metadata: + metadata = {} - if iterator is None or not isinstance(iterator, list): - return True + if NodeRunMetadataKey.ITERATION_ID not in metadata: + metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id + if self.node_data.is_parallel: + metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id + else: + metadata[NodeRunMetadataKey.ITERATION_INDEX] = iter_run_index + event.route_node_state.node_run_result.metadata = metadata + return event - return state.index >= len(iterator) - - def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState): + def _run_single_iter( + self, + iterator_list_value: list[str], + variable_pool: VariablePool, + inputs: dict[str, list], + outputs: list, + start_at: datetime, + graph_engine: "GraphEngine", + iteration_graph: Graph, + parallel_mode_run_id: Optional[str] = None, + ) -> Generator[NodeEvent | InNodeEvent, None, None]: """ - Resolve current output. - :param variable_pool: variable pool + run single iteration """ - output_selector = cast(IterationNodeData, self.node_data).output_selector - output = variable_pool.get_any(output_selector) - # clear the output for this iteration - variable_pool.remove([self.node_id] + output_selector[1:]) - state.current_output = output - if output is not None: - # NOTE: This is a temporary patch to process double nested list (for example, DALL-E output in iteration). - if isinstance(output, list): - state.outputs.extend(output) - else: - state.outputs.append(output) + try: + rst = graph_engine.run() + # get current iteration index + current_index = variable_pool.get([self.node_id, "index"]).value + next_index = int(current_index) + 1 - @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]: + if current_index is None: + raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found") + for event in rst: + if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: + event.in_iteration_id = self.node_id + + if ( + isinstance(event, BaseNodeEvent) + and event.node_type == NodeType.ITERATION_START + and not isinstance(event, NodeRunStreamChunkEvent) + ): + continue + + if isinstance(event, NodeRunSucceededEvent): + yield self._handle_event_metadata(event, current_index, parallel_mode_run_id) + elif isinstance(event, BaseGraphEvent): + if isinstance(event, GraphRunFailedEvent): + # iteration run failed + if self.node_data.is_parallel: + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + parallel_mode_run_id=parallel_mode_run_id, + start_at=start_at, + inputs=inputs, + outputs={"output": jsonable_encoder(outputs)}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + else: + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": jsonable_encoder(outputs)}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=event.error, + ) + ) + return + else: + event = cast(InNodeEvent, event) + metadata_event = self._handle_event_metadata(event, current_index, parallel_mode_run_id) + if isinstance(event, NodeRunFailedEvent): + if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR: + yield NodeInIterationFailedEvent( + **metadata_event.model_dump(), + ) + outputs[current_index] = None + variable_pool.add([self.node_id, "index"], next_index) + if next_index < len(iterator_list_value): + variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=next_index, + parallel_mode_run_id=parallel_mode_run_id, + pre_iteration_output=None, + ) + return + elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: + yield NodeInIterationFailedEvent( + **metadata_event.model_dump(), + ) + variable_pool.add([self.node_id, "index"], next_index) + + if next_index < len(iterator_list_value): + variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=next_index, + parallel_mode_run_id=parallel_mode_run_id, + pre_iteration_output=None, + ) + return + elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED: + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": None}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + yield metadata_event + + current_iteration_output = variable_pool.get(self.node_data.output_selector).value + outputs[current_index] = current_iteration_output + # remove all nodes outputs from variable pool + for node_id in iteration_graph.node_ids: + variable_pool.remove([node_id]) + + # move to next iteration + variable_pool.add([self.node_id, "index"], next_index) + + if next_index < len(iterator_list_value): + variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=next_index, + parallel_mode_run_id=parallel_mode_run_id, + pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None, + ) + + except IterationNodeError as e: + logger.warning(f"Iteration run failed:{str(e)}") + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": None}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=str(e), + ) + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + ) + ) + + def _run_single_iter_parallel( + self, + flask_app: Flask, + q: Queue, + iterator_list_value: list[str], + inputs: dict[str, list], + outputs: list, + start_at: datetime, + graph_engine: "GraphEngine", + iteration_graph: Graph, + index: int, + item: Any, + ) -> Generator[NodeEvent | InNodeEvent, None, None]: """ - Extract variable selector to variable mapping - :param node_data: node data - :return: + run single iteration in parallel mode """ - return { - 'input_selector': node_data.iterator_selector, - } \ No newline at end of file + with flask_app.app_context(): + parallel_mode_run_id = uuid.uuid4().hex + graph_engine_copy = graph_engine.create_copy() + variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool + variable_pool_copy.add([self.node_id, "index"], index) + variable_pool_copy.add([self.node_id, "item"], item) + for event in self._run_single_iter( + iterator_list_value=iterator_list_value, + variable_pool=variable_pool_copy, + inputs=inputs, + outputs=outputs, + start_at=start_at, + graph_engine=graph_engine_copy, + iteration_graph=iteration_graph, + parallel_mode_run_id=parallel_mode_run_id, + ): + q.put(event) diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py new file mode 100644 index 00000000000000..6ab7c301066d93 --- /dev/null +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -0,0 +1,36 @@ +from collections.abc import Mapping, Sequence +from typing import Any + +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData +from models.workflow import WorkflowNodeExecutionStatus + + +class IterationStartNode(BaseNode): + """ + Iteration Start Node. + """ + + _node_data_cls = IterationStartNodeData + _node_type = NodeType.ITERATION_START + + def _run(self) -> NodeRunResult: + """ + Run the node. + """ + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED) + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ + return {} diff --git a/api/core/workflow/nodes/knowledge_retrieval/__init__.py b/api/core/workflow/nodes/knowledge_retrieval/__init__.py index e69de29bb2d1d6..4d4a4cbd9f1342 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/__init__.py +++ b/api/core/workflow/nodes/knowledge_retrieval/__init__.py @@ -0,0 +1,3 @@ +from .knowledge_retrieval_node import KnowledgeRetrievalNode + +__all__ = ["KnowledgeRetrievalNode"] diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 7cf392277caad6..e8972d1381d3ce 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -2,13 +2,14 @@ from pydantic import BaseModel -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.nodes.base import BaseNodeData class RerankingModelConfig(BaseModel): """ Reranking Model Config. """ + provider: str model: str @@ -17,6 +18,7 @@ class VectorSetting(BaseModel): """ Vector Setting. """ + vector_weight: float embedding_provider_name: str embedding_model_name: str @@ -26,6 +28,7 @@ class KeywordSetting(BaseModel): """ Keyword Setting. """ + keyword_weight: float @@ -33,6 +36,7 @@ class WeightedScoreConfig(BaseModel): """ Weighted score Config. """ + vector_setting: VectorSetting keyword_setting: KeywordSetting @@ -41,17 +45,20 @@ class MultipleRetrievalConfig(BaseModel): """ Multiple Retrieval Config. """ + top_k: int score_threshold: Optional[float] = None - reranking_mode: str = 'reranking_model' + reranking_mode: str = "reranking_model" reranking_enable: bool = True reranking_model: Optional[RerankingModelConfig] = None weights: Optional[WeightedScoreConfig] = None + class ModelConfig(BaseModel): """ - Model Config. + Model Config. """ + provider: str name: str mode: str @@ -62,6 +69,7 @@ class SingleRetrievalConfig(BaseModel): """ Single Retrieval Config. """ + model: ModelConfig @@ -69,9 +77,10 @@ class KnowledgeRetrievalNodeData(BaseNodeData): """ Knowledge retrieval Node Data. """ - type: str = 'knowledge-retrieval' + + type: str = "knowledge-retrieval" query_variable_selector: list[str] dataset_ids: list[str] - retrieval_mode: Literal['single', 'multiple'] + retrieval_mode: Literal["single", "multiple"] multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None single_retrieval_config: Optional[SingleRetrievalConfig] = None diff --git a/api/core/workflow/nodes/knowledge_retrieval/exc.py b/api/core/workflow/nodes/knowledge_retrieval/exc.py new file mode 100644 index 00000000000000..0c3b6e86fa37be --- /dev/null +++ b/api/core/workflow/nodes/knowledge_retrieval/exc.py @@ -0,0 +1,18 @@ +class KnowledgeRetrievalNodeError(ValueError): + """Base class for KnowledgeRetrievalNode errors.""" + + +class ModelNotExistError(KnowledgeRetrievalNodeError): + """Raised when the model does not exist.""" + + +class ModelCredentialsNotInitializedError(KnowledgeRetrievalNodeError): + """Raised when the model credentials are not initialized.""" + + +class ModelNotSupportedError(KnowledgeRetrievalNodeError): + """Raised when the model is not supported.""" + + +class ModelQuotaExceededError(KnowledgeRetrievalNodeError): + """Raised when the model provider quota is exceeded.""" diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 01bf6e16e6ebcd..8c5a9b5ecb8708 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -1,3 +1,5 @@ +import logging +from collections.abc import Mapping, Sequence from typing import Any, cast from sqlalchemy import func @@ -6,99 +8,95 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from core.rag.retrieval.retrival_methods import RetrievalMethod -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.variables import StringSegment +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment from models.workflow import WorkflowNodeExecutionStatus +from .entities import KnowledgeRetrievalNodeData +from .exc import ( + KnowledgeRetrievalNodeError, + ModelCredentialsNotInitializedError, + ModelNotExistError, + ModelNotSupportedError, + ModelQuotaExceededError, +) + +logger = logging.getLogger(__name__) + default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } -class KnowledgeRetrievalNode(BaseNode): +class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): _node_data_cls = KnowledgeRetrievalNodeData - node_type = NodeType.KNOWLEDGE_RETRIEVAL - - def _run(self, variable_pool: VariablePool) -> NodeRunResult: - node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data) + _node_type = NodeType.KNOWLEDGE_RETRIEVAL + def _run(self) -> NodeRunResult: # extract variables - variable = variable_pool.get_any(node_data.query_variable_selector) - query = variable - variables = { - 'query': query - } - if not query: + variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector) + if not isinstance(variable, StringSegment): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error="Query is required." + inputs={}, + error="Query variable is not string type.", + ) + query = variable.value + variables = {"query": query} + if not query: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required." ) # retrieve knowledge try: - results = self._fetch_dataset_retriever( - node_data=node_data, query=query - ) - outputs = { - 'result': results - } + results = self._fetch_dataset_retriever(node_data=self.node_data, query=query) + outputs = {"result": results} return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, - process_data=None, - outputs=outputs + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs ) - except Exception as e: + except KnowledgeRetrievalNodeError as e: + logger.warning("Error when running knowledge retrieval node") + return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error=str(e) - ) - - def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[ - dict[str, Any]]: + def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]: available_datasets = [] dataset_ids = node_data.dataset_ids # Subquery: Count the number of available documents for each dataset - subquery = db.session.query( - Document.dataset_id, - func.count(Document.id).label('available_document_count') - ).filter( - Document.indexing_status == 'completed', - Document.enabled == True, - Document.archived == False, - Document.dataset_id.in_(dataset_ids) - ).group_by(Document.dataset_id).having( - func.count(Document.id) > 0 - ).subquery() + subquery = ( + db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count")) + .filter( + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + Document.dataset_id.in_(dataset_ids), + ) + .group_by(Document.dataset_id) + .having(func.count(Document.id) > 0) + .subquery() + ) - results = db.session.query(Dataset).join( - subquery, Dataset.id == subquery.c.dataset_id - ).filter( - Dataset.tenant_id == self.tenant_id, - Dataset.id.in_(dataset_ids) - ).all() + results = ( + db.session.query(Dataset) + .outerjoin(subquery, Dataset.id == subquery.c.dataset_id) + .filter(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids)) + .filter((subquery.c.available_document_count > 0) | (Dataset.provider == "external")) + .all() + ) for dataset in results: # pass if dataset is not available @@ -115,16 +113,14 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: model_type_instance = cast(LargeLanguageModel, model_type_instance) # get model schema model_schema = model_type_instance.get_model_schema( - model=model_config.model, - credentials=model_config.credentials + model=model_config.model, credentials=model_config.credentials ) if model_schema: planning_strategy = PlanningStrategy.REACT_ROUTER features = model_schema.features if features: - if ModelFeature.TOOL_CALL in features \ - or ModelFeature.MULTI_TOOL_CALL in features: + if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features: planning_strategy = PlanningStrategy.ROUTER all_documents = dataset_retrieval.single_retrieve( available_datasets=available_datasets, @@ -135,110 +131,152 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: query=query, model_config=model_config, model_instance=model_instance, - planning_strategy=planning_strategy + planning_strategy=planning_strategy, ) elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: - if node_data.multiple_retrieval_config.reranking_mode == 'reranking_model': - reranking_model = { - 'reranking_provider_name': node_data.multiple_retrieval_config.reranking_model.provider, - 'reranking_model_name': node_data.multiple_retrieval_config.reranking_model.model - } + if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": + if node_data.multiple_retrieval_config.reranking_model: + reranking_model = { + "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider, + "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model, + } + else: + reranking_model = None weights = None - elif node_data.multiple_retrieval_config.reranking_mode == 'weighted_score': + elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score": reranking_model = None + vector_setting = node_data.multiple_retrieval_config.weights.vector_setting weights = { - 'vector_setting': { - "vector_weight": node_data.multiple_retrieval_config.weights.vector_setting.vector_weight, - "embedding_provider_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_provider_name, - "embedding_model_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_model_name, + "vector_setting": { + "vector_weight": vector_setting.vector_weight, + "embedding_provider_name": vector_setting.embedding_provider_name, + "embedding_model_name": vector_setting.embedding_model_name, }, - 'keyword_setting': { + "keyword_setting": { "keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight - } + }, } else: reranking_model = None weights = None - all_documents = dataset_retrieval.multiple_retrieve(self.app_id, self.tenant_id, self.user_id, - self.user_from.value, - available_datasets, query, - node_data.multiple_retrieval_config.top_k, - node_data.multiple_retrieval_config.score_threshold, - node_data.multiple_retrieval_config.reranking_mode, - reranking_model, - weights, - node_data.multiple_retrieval_config.reranking_enable, - ) - - context_list = [] - if all_documents: + all_documents = dataset_retrieval.multiple_retrieve( + self.app_id, + self.tenant_id, + self.user_id, + self.user_from.value, + available_datasets, + query, + node_data.multiple_retrieval_config.top_k, + node_data.multiple_retrieval_config.score_threshold, + node_data.multiple_retrieval_config.reranking_mode, + reranking_model, + weights, + node_data.multiple_retrieval_config.reranking_enable, + ) + dify_documents = [item for item in all_documents if item.provider == "dify"] + external_documents = [item for item in all_documents if item.provider == "external"] + retrieval_resource_list = [] + # deal with external documents + for item in external_documents: + source = { + "metadata": { + "_source": "knowledge", + "dataset_id": item.metadata.get("dataset_id"), + "dataset_name": item.metadata.get("dataset_name"), + "document_name": item.metadata.get("title"), + "data_source_type": "external", + "retriever_from": "workflow", + "score": item.metadata.get("score"), + }, + "title": item.metadata.get("title"), + "content": item.page_content, + } + retrieval_resource_list.append(source) + document_score_list = {} + # deal with dify documents + if dify_documents: document_score_list = {} - for item in all_documents: - if item.metadata.get('score'): - document_score_list[item.metadata['doc_id']] = item.metadata['score'] + for item in dify_documents: + if item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - index_node_ids = [document.metadata['doc_id'] for document in all_documents] + index_node_ids = [document.metadata["doc_id"] for document in dify_documents] segments = DocumentSegment.query.filter( DocumentSegment.dataset_id.in_(dataset_ids), DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == 'completed', + DocumentSegment.status == "completed", DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids) + DocumentSegment.index_node_id.in_(index_node_ids), ).all() if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted(segments, - key=lambda segment: index_node_id_to_position.get(segment.index_node_id, - float('inf'))) + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) for segment in sorted_segments: - dataset = Dataset.query.filter_by( - id=segment.dataset_id + dataset = Dataset.query.filter_by(id=segment.dataset_id).first() + document = Document.query.filter( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, ).first() - document = Document.query.filter(Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ).first() - resource_number = 1 if dataset and document: - source = { - 'metadata': { - '_source': 'knowledge', - 'position': resource_number, - 'dataset_id': dataset.id, - 'dataset_name': dataset.name, - 'document_id': document.id, - 'document_name': document.name, - 'document_data_source_type': document.data_source_type, - 'segment_id': segment.id, - 'retriever_from': 'workflow', - 'score': document_score_list.get(segment.index_node_id, None), - 'segment_hit_count': segment.hit_count, - 'segment_word_count': segment.word_count, - 'segment_position': segment.position, - 'segment_index_node_hash': segment.index_node_hash, + "metadata": { + "_source": "knowledge", + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "document_data_source_type": document.data_source_type, + "segment_id": segment.id, + "retriever_from": "workflow", + "score": document_score_list.get(segment.index_node_id, None), + "segment_hit_count": segment.hit_count, + "segment_word_count": segment.word_count, + "segment_position": segment.position, + "segment_index_node_hash": segment.index_node_hash, }, - 'title': document.name + "title": document.name, } if segment.answer: - source['content'] = f'question:{segment.get_sign_content()} \nanswer:{segment.answer}' + source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}" else: - source['content'] = segment.get_sign_content() - context_list.append(source) - resource_number += 1 - return context_list + source["content"] = segment.get_sign_content() + retrieval_resource_list.append(source) + if retrieval_resource_list: + retrieval_resource_list = sorted( + retrieval_resource_list, key=lambda x: x.get("metadata").get("score") or 0.0, reverse=True + ) + position = 1 + for item in retrieval_resource_list: + item["metadata"]["position"] = position + position += 1 + return retrieval_resource_list @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: - node_data = node_data - node_data = cast(cls._node_data_cls, node_data) + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: KnowledgeRetrievalNodeData, + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ variable_mapping = {} - variable_mapping['query'] = node_data.query_variable_selector + variable_mapping[node_id + ".query"] = node_data.query_variable_selector return variable_mapping - def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ - ModelInstance, ModelConfigWithCredentialsEntity]: + def _fetch_model_config( + self, node_data: KnowledgeRetrievalNodeData + ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config :param node_data: node data @@ -249,10 +287,7 @@ def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ model_manager = ModelManager() model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, - model_type=ModelType.LLM, - provider=provider_name, - model=model_name + tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name ) provider_model_bundle = model_instance.provider_model_bundle @@ -263,39 +298,35 @@ def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[ # check model provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_name, - model_type=ModelType.LLM + model=model_name, model_type=ModelType.LLM ) if provider_model is None: - raise ValueError(f"Model {model_name} not exist.") + raise ModelNotExistError(f"Model {model_name} not exist.") if provider_model.status == ModelStatus.NO_CONFIGURE: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + raise ModelCredentialsNotInitializedError(f"Model {model_name} credentials is not initialized.") elif provider_model.status == ModelStatus.NO_PERMISSION: - raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + raise ModelNotSupportedError(f"Dify Hosted OpenAI {model_name} currently not support.") elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.") # model config completion_params = node_data.single_retrieval_config.model.completion_params stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] # get model mode model_mode = node_data.single_retrieval_config.model.mode if not model_mode: - raise ValueError("LLM mode is required.") + raise ModelNotExistError("LLM mode is required.") - model_schema = model_type_instance.get_model_schema( - model_name, - model_credentials - ) + model_schema = model_type_instance.get_model_schema(model_name, model_credentials) if not model_schema: - raise ValueError(f"Model {model_name} not exist.") + raise ModelNotExistError(f"Model {model_name} not exist.") return model_instance, ModelConfigWithCredentialsEntity( provider=provider_name, diff --git a/api/core/workflow/nodes/list_operator/__init__.py b/api/core/workflow/nodes/list_operator/__init__.py new file mode 100644 index 00000000000000..1877586ef41145 --- /dev/null +++ b/api/core/workflow/nodes/list_operator/__init__.py @@ -0,0 +1,3 @@ +from .node import ListOperatorNode + +__all__ = ["ListOperatorNode"] diff --git a/api/core/workflow/nodes/list_operator/entities.py b/api/core/workflow/nodes/list_operator/entities.py new file mode 100644 index 00000000000000..79cef1c27ab718 --- /dev/null +++ b/api/core/workflow/nodes/list_operator/entities.py @@ -0,0 +1,56 @@ +from collections.abc import Sequence +from typing import Literal + +from pydantic import BaseModel, Field + +from core.workflow.nodes.base import BaseNodeData + +_Condition = Literal[ + # string conditions + "contains", + "start with", + "end with", + "is", + "in", + "empty", + "not contains", + "is not", + "not in", + "not empty", + # number conditions + "=", + "≠", + "<", + ">", + "≥", + "≤", +] + + +class FilterCondition(BaseModel): + key: str = "" + comparison_operator: _Condition = "contains" + value: str | Sequence[str] = "" + + +class FilterBy(BaseModel): + enabled: bool = False + conditions: Sequence[FilterCondition] = Field(default_factory=list) + + +class OrderBy(BaseModel): + enabled: bool = False + key: str = "" + value: Literal["asc", "desc"] = "asc" + + +class Limit(BaseModel): + enabled: bool = False + size: int = -1 + + +class ListOperatorNodeData(BaseNodeData): + variable: Sequence[str] = Field(default_factory=list) + filter_by: FilterBy + order_by: OrderBy + limit: Limit diff --git a/api/core/workflow/nodes/list_operator/exc.py b/api/core/workflow/nodes/list_operator/exc.py new file mode 100644 index 00000000000000..f88aa0be29c92a --- /dev/null +++ b/api/core/workflow/nodes/list_operator/exc.py @@ -0,0 +1,16 @@ +class ListOperatorError(ValueError): + """Base class for all ListOperator errors.""" + + pass + + +class InvalidFilterValueError(ListOperatorError): + pass + + +class InvalidKeyError(ListOperatorError): + pass + + +class InvalidConditionError(ListOperatorError): + pass diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py new file mode 100644 index 00000000000000..49e7ca85fd5fc8 --- /dev/null +++ b/api/core/workflow/nodes/list_operator/node.py @@ -0,0 +1,298 @@ +from collections.abc import Callable, Sequence +from typing import Literal, Union + +from core.file import File +from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import ListOperatorNodeData +from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError + + +class ListOperatorNode(BaseNode[ListOperatorNodeData]): + _node_data_cls = ListOperatorNodeData + _node_type = NodeType.LIST_OPERATOR + + def _run(self): + inputs = {} + process_data = {} + outputs = {} + + variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable) + if variable is None: + error_message = f"Variable not found for selector: {self.node_data.variable}" + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs + ) + if not variable.value: + inputs = {"variable": []} + process_data = {"variable": []} + outputs = {"result": [], "first_record": None, "last_record": None} + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs=outputs, + ) + if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment): + error_message = ( + f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment " + "or ArrayStringSegment" + ) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs + ) + + if isinstance(variable, ArrayFileSegment): + inputs = {"variable": [item.to_dict() for item in variable.value]} + process_data["variable"] = [item.to_dict() for item in variable.value] + else: + inputs = {"variable": variable.value} + process_data["variable"] = variable.value + + try: + # Filter + if self.node_data.filter_by.enabled: + variable = self._apply_filter(variable) + + # Order + if self.node_data.order_by.enabled: + variable = self._apply_order(variable) + + # Slice + if self.node_data.limit.enabled: + variable = self._apply_slice(variable) + + outputs = { + "result": variable.value, + "first_record": variable.value[0] if variable.value else None, + "last_record": variable.value[-1] if variable.value else None, + } + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs=outputs, + ) + except ListOperatorError as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + inputs=inputs, + process_data=process_data, + outputs=outputs, + ) + + def _apply_filter( + self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] + ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: + for condition in self.node_data.filter_by.conditions: + if isinstance(variable, ArrayStringSegment): + if not isinstance(condition.value, str): + raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") + value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text + filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value) + result = list(filter(filter_func, variable.value)) + variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayNumberSegment): + if not isinstance(condition.value, str): + raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") + value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text + filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value)) + result = list(filter(filter_func, variable.value)) + variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayFileSegment): + if isinstance(condition.value, str): + value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text + else: + value = condition.value + filter_func = _get_file_filter_func( + key=condition.key, + condition=condition.comparison_operator, + value=value, + ) + result = list(filter(filter_func, variable.value)) + variable = variable.model_copy(update={"value": result}) + return variable + + def _apply_order( + self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] + ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: + if isinstance(variable, ArrayStringSegment): + result = _order_string(order=self.node_data.order_by.value, array=variable.value) + variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayNumberSegment): + result = _order_number(order=self.node_data.order_by.value, array=variable.value) + variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayFileSegment): + result = _order_file( + order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value + ) + variable = variable.model_copy(update={"value": result}) + return variable + + def _apply_slice( + self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] + ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: + result = variable.value[: self.node_data.limit.size] + return variable.model_copy(update={"value": result}) + + +def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]: + match key: + case "size": + return lambda x: x.size + case _: + raise InvalidKeyError(f"Invalid key: {key}") + + +def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]: + match key: + case "name": + return lambda x: x.filename or "" + case "type": + return lambda x: x.type + case "extension": + return lambda x: x.extension or "" + case "mime_type": + return lambda x: x.mime_type or "" + case "transfer_method": + return lambda x: x.transfer_method + case "url": + return lambda x: x.remote_url or "" + case _: + raise InvalidKeyError(f"Invalid key: {key}") + + +def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]: + match condition: + case "contains": + return _contains(value) + case "start with": + return _startswith(value) + case "end with": + return _endswith(value) + case "is": + return _is(value) + case "in": + return _in(value) + case "empty": + return lambda x: x == "" + case "not contains": + return lambda x: not _contains(value)(x) + case "is not": + return lambda x: not _is(value)(x) + case "not in": + return lambda x: not _in(value)(x) + case "not empty": + return lambda x: x != "" + case _: + raise InvalidConditionError(f"Invalid condition: {condition}") + + +def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]: + match condition: + case "in": + return _in(value) + case "not in": + return lambda x: not _in(value)(x) + case _: + raise InvalidConditionError(f"Invalid condition: {condition}") + + +def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]: + match condition: + case "=": + return _eq(value) + case "≠": + return _ne(value) + case "<": + return _lt(value) + case "≤": + return _le(value) + case ">": + return _gt(value) + case "≥": + return _ge(value) + case _: + raise InvalidConditionError(f"Invalid condition: {condition}") + + +def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: + if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str): + extract_func = _get_file_extract_string_func(key=key) + return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x)) + if key in {"type", "transfer_method"} and isinstance(value, Sequence): + extract_func = _get_file_extract_string_func(key=key) + return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x)) + elif key == "size" and isinstance(value, str): + extract_func = _get_file_extract_number_func(key=key) + return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x)) + else: + raise InvalidKeyError(f"Invalid key: {key}") + + +def _contains(value: str): + return lambda x: value in x + + +def _startswith(value: str): + return lambda x: x.startswith(value) + + +def _endswith(value: str): + return lambda x: x.endswith(value) + + +def _is(value: str): + return lambda x: x is value + + +def _in(value: str | Sequence[str]): + return lambda x: x in value + + +def _eq(value: int | float): + return lambda x: x == value + + +def _ne(value: int | float): + return lambda x: x != value + + +def _lt(value: int | float): + return lambda x: x < value + + +def _le(value: int | float): + return lambda x: x <= value + + +def _gt(value: int | float): + return lambda x: x > value + + +def _ge(value: int | float): + return lambda x: x >= value + + +def _order_number(*, order: Literal["asc", "desc"], array: Sequence[int | float]): + return sorted(array, key=lambda x: x, reverse=order == "desc") + + +def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]): + return sorted(array, key=lambda x: x, reverse=order == "desc") + + +def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]): + if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}: + extract_func = _get_file_extract_string_func(key=order_by) + return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc") + elif order_by == "size": + extract_func = _get_file_extract_number_func(key=order_by) + return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc") + else: + raise InvalidKeyError(f"Invalid order key: {order_by}") diff --git a/api/core/workflow/nodes/llm/__init__.py b/api/core/workflow/nodes/llm/__init__.py index e69de29bb2d1d6..f7bc713f63174e 100644 --- a/api/core/workflow/nodes/llm/__init__.py +++ b/api/core/workflow/nodes/llm/__init__.py @@ -0,0 +1,17 @@ +from .entities import ( + LLMNodeChatModelMessage, + LLMNodeCompletionModelPromptTemplate, + LLMNodeData, + ModelConfig, + VisionConfig, +) +from .node import LLMNode + +__all__ = [ + "LLMNode", + "LLMNodeChatModelMessage", + "LLMNodeCompletionModelPromptTemplate", + "LLMNodeData", + "ModelConfig", + "VisionConfig", +] diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 1e48a10bc77012..a25d563fe0b809 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -1,16 +1,15 @@ -from typing import Any, Literal, Optional, Union +from collections.abc import Sequence +from typing import Any, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field, field_validator +from core.model_runtime.entities import ImagePromptMessageContent from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.nodes.base import BaseNodeData class ModelConfig(BaseModel): - """ - Model Config. - """ provider: str name: str mode: str @@ -18,51 +17,43 @@ class ModelConfig(BaseModel): class ContextConfig(BaseModel): - """ - Context Config. - """ enabled: bool variable_selector: Optional[list[str]] = None +class VisionConfigOptions(BaseModel): + variable_selector: Sequence[str] = Field(default_factory=lambda: ["sys", "files"]) + detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH + + class VisionConfig(BaseModel): - """ - Vision Config. - """ - class Configs(BaseModel): - """ - Configs. - """ - detail: Literal['low', 'high'] + enabled: bool = False + configs: VisionConfigOptions = Field(default_factory=VisionConfigOptions) + + @field_validator("configs", mode="before") + @classmethod + def convert_none_configs(cls, v: Any): + if v is None: + return VisionConfigOptions() + return v - enabled: bool - configs: Optional[Configs] = None class PromptConfig(BaseModel): - """ - Prompt Config. - """ jinja2_variables: Optional[list[VariableSelector]] = None + class LLMNodeChatModelMessage(ChatModelMessage): - """ - LLM Node Chat Model Message. - """ jinja2_text: Optional[str] = None + class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): - """ - LLM Node Chat Model Prompt Template. - """ jinja2_text: Optional[str] = None + class LLMNodeData(BaseNodeData): - """ - LLM Node Data. - """ model: ModelConfig - prompt_template: Union[list[LLMNodeChatModelMessage], LLMNodeCompletionModelPromptTemplate] + prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate prompt_config: Optional[PromptConfig] = None memory: Optional[MemoryConfig] = None context: ContextConfig - vision: VisionConfig + vision: VisionConfig = Field(default_factory=VisionConfig) diff --git a/api/core/workflow/nodes/llm/exc.py b/api/core/workflow/nodes/llm/exc.py new file mode 100644 index 00000000000000..f858be25156951 --- /dev/null +++ b/api/core/workflow/nodes/llm/exc.py @@ -0,0 +1,26 @@ +class LLMNodeError(ValueError): + """Base class for LLM Node errors.""" + + +class VariableNotFoundError(LLMNodeError): + """Raised when a required variable is not found.""" + + +class InvalidContextStructureError(LLMNodeError): + """Raised when the context structure is invalid.""" + + +class InvalidVariableTypeError(LLMNodeError): + """Raised when the variable type is invalid.""" + + +class ModelNotExistError(LLMNodeError): + """Raised when the specified model does not exist.""" + + +class LLMModeRequiredError(LLMNodeError): + """Raised when LLM mode is required but not provided.""" + + +class NoPromptFoundError(LLMNodeError): + """Raised when no prompt is found in the LLM configuration.""" diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py deleted file mode 100644 index c20e0d45062f51..00000000000000 --- a/api/core/workflow/nodes/llm/llm_node.py +++ /dev/null @@ -1,733 +0,0 @@ -import json -from collections.abc import Generator -from copy import deepcopy -from typing import TYPE_CHECKING, Optional, cast - -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.app.entities.queue_entities import QueueRetrieverResourcesEvent -from core.entities.model_entities import ModelStatus -from core.entities.provider_entities import QuotaUnit -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.entities.message_entities import ( - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentType, -) -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.utils.encoders import jsonable_encoder -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable -from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.llm.entities import ( - LLMNodeChatModelMessage, - LLMNodeCompletionModelPromptTemplate, - LLMNodeData, - ModelConfig, -) -from core.workflow.utils.variable_template_parser import VariableTemplateParser -from extensions.ext_database import db -from models.model import Conversation -from models.provider import Provider, ProviderType -from models.workflow import WorkflowNodeExecutionStatus - -if TYPE_CHECKING: - from core.file.file_obj import FileVar - - - -class LLMNode(BaseNode): - _node_data_cls = LLMNodeData - _node_type = NodeType.LLM - - def _run(self, variable_pool: VariablePool) -> NodeRunResult: - """ - Run node - :param variable_pool: variable pool - :return: - """ - node_data = cast(LLMNodeData, deepcopy(self.node_data)) - - node_inputs = None - process_data = None - - try: - # init messages template - node_data.prompt_template = self._transform_chat_messages(node_data.prompt_template) - - # fetch variables and fetch values from variable pool - inputs = self._fetch_inputs(node_data, variable_pool) - - # fetch jinja2 inputs - jinja_inputs = self._fetch_jinja_inputs(node_data, variable_pool) - - # merge inputs - inputs.update(jinja_inputs) - - node_inputs = {} - - # fetch files - files = self._fetch_files(node_data, variable_pool) - - if files: - node_inputs['#files#'] = [file.to_dict() for file in files] - - # fetch context value - context = self._fetch_context(node_data, variable_pool) - - if context: - node_inputs['#context#'] = context - - # fetch model config - model_instance, model_config = self._fetch_model_config(node_data.model) - - # fetch memory - memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) - - # fetch prompt messages - prompt_messages, stop = self._fetch_prompt_messages( - node_data=node_data, - query=variable_pool.get_any(['sys', SystemVariable.QUERY.value]) - if node_data.memory else None, - query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None, - inputs=inputs, - files=files, - context=context, - memory=memory, - model_config=model_config - ) - - process_data = { - 'model_mode': model_config.mode, - 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_config.mode, - prompt_messages=prompt_messages - ) - } - - # handle invoke result - result_text, usage = self._invoke_llm( - node_data_model=node_data.model, - model_instance=model_instance, - prompt_messages=prompt_messages, - stop=stop - ) - except Exception as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=node_inputs, - process_data=process_data - ) - - outputs = { - 'text': result_text, - 'usage': jsonable_encoder(usage) - } - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=node_inputs, - process_data=process_data, - outputs=outputs, - metadata={ - NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, - NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency - } - ) - - def _invoke_llm(self, node_data_model: ModelConfig, - model_instance: ModelInstance, - prompt_messages: list[PromptMessage], - stop: list[str]) -> tuple[str, LLMUsage]: - """ - Invoke large language model - :param node_data_model: node data model - :param model_instance: model instance - :param prompt_messages: prompt messages - :param stop: stop - :return: - """ - db.session.close() - - invoke_result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=node_data_model.completion_params, - stop=stop, - stream=True, - user=self.user_id, - ) - - # handle invoke result - text, usage = self._handle_invoke_result( - invoke_result=invoke_result - ) - - # deduct quota - self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) - - return text, usage - - def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]: - """ - Handle invoke result - :param invoke_result: invoke result - :return: - """ - model = None - prompt_messages = [] - full_text = '' - usage = None - for result in invoke_result: - text = result.delta.message.content - full_text += text - - self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text']) - - if not model: - model = result.model - - if not prompt_messages: - prompt_messages = result.prompt_messages - - if not usage and result.delta.usage: - usage = result.delta.usage - - if not usage: - usage = LLMUsage.empty_usage() - - return full_text, usage - - def _transform_chat_messages(self, - messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate - ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: - """ - Transform chat messages - - :param messages: chat messages - :return: - """ - - if isinstance(messages, LLMNodeCompletionModelPromptTemplate): - if messages.edition_type == 'jinja2': - messages.text = messages.jinja2_text - - return messages - - for message in messages: - if message.edition_type == 'jinja2': - message.text = message.jinja2_text - - return messages - - def _fetch_jinja_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]: - """ - Fetch jinja inputs - :param node_data: node data - :param variable_pool: variable pool - :return: - """ - variables = {} - - if not node_data.prompt_config: - return variables - - for variable_selector in node_data.prompt_config.jinja2_variables or []: - variable = variable_selector.variable - value = variable_pool.get_any( - variable_selector.value_selector - ) - - def parse_dict(d: dict) -> str: - """ - Parse dict into string - """ - # check if it's a context structure - if 'metadata' in d and '_source' in d['metadata'] and 'content' in d: - return d['content'] - - # else, parse the dict - try: - return json.dumps(d, ensure_ascii=False) - except Exception: - return str(d) - - if isinstance(value, str): - value = value - elif isinstance(value, list): - result = '' - for item in value: - if isinstance(item, dict): - result += parse_dict(item) - elif isinstance(item, str): - result += item - elif isinstance(item, int | float): - result += str(item) - else: - result += str(item) - result += '\n' - value = result.strip() - elif isinstance(value, dict): - value = parse_dict(value) - elif isinstance(value, int | float): - value = str(value) - else: - value = str(value) - - variables[variable] = value - - return variables - - def _fetch_inputs(self, node_data: LLMNodeData, variable_pool: VariablePool) -> dict[str, str]: - """ - Fetch inputs - :param node_data: node data - :param variable_pool: variable pool - :return: - """ - inputs = {} - prompt_template = node_data.prompt_template - - variable_selectors = [] - if isinstance(prompt_template, list): - for prompt in prompt_template: - variable_template_parser = VariableTemplateParser(template=prompt.text) - variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - elif isinstance(prompt_template, CompletionModelPromptTemplate): - variable_template_parser = VariableTemplateParser(template=prompt_template.text) - variable_selectors = variable_template_parser.extract_variable_selectors() - - for variable_selector in variable_selectors: - variable_value = variable_pool.get_any(variable_selector.value_selector) - if variable_value is None: - raise ValueError(f'Variable {variable_selector.variable} not found') - - inputs[variable_selector.variable] = variable_value - - memory = node_data.memory - if memory and memory.query_prompt_template: - query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template) - .extract_variable_selectors()) - for variable_selector in query_variable_selectors: - variable_value = variable_pool.get_any(variable_selector.value_selector) - if variable_value is None: - raise ValueError(f'Variable {variable_selector.variable} not found') - - inputs[variable_selector.variable] = variable_value - - return inputs - - def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]: - """ - Fetch files - :param node_data: node data - :param variable_pool: variable pool - :return: - """ - if not node_data.vision.enabled: - return [] - - files = variable_pool.get_any(['sys', SystemVariable.FILES.value]) - if not files: - return [] - - return files - - def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]: - """ - Fetch context - :param node_data: node data - :param variable_pool: variable pool - :return: - """ - if not node_data.context.enabled: - return None - - if not node_data.context.variable_selector: - return None - - context_value = variable_pool.get_any(node_data.context.variable_selector) - if context_value: - if isinstance(context_value, str): - return context_value - elif isinstance(context_value, list): - context_str = '' - original_retriever_resource = [] - for item in context_value: - if isinstance(item, str): - context_str += item + '\n' - else: - if 'content' not in item: - raise ValueError(f'Invalid context structure: {item}') - - context_str += item['content'] + '\n' - - retriever_resource = self._convert_to_original_retriever_resource(item) - if retriever_resource: - original_retriever_resource.append(retriever_resource) - - if self.callbacks and original_retriever_resource: - for callback in self.callbacks: - callback.on_event( - event=QueueRetrieverResourcesEvent( - retriever_resources=original_retriever_resource - ) - ) - - return context_str.strip() - - return None - - def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: - """ - Convert to original retriever resource, temp. - :param context_dict: context dict - :return: - """ - if ('metadata' in context_dict and '_source' in context_dict['metadata'] - and context_dict['metadata']['_source'] == 'knowledge'): - metadata = context_dict.get('metadata', {}) - source = { - 'position': metadata.get('position'), - 'dataset_id': metadata.get('dataset_id'), - 'dataset_name': metadata.get('dataset_name'), - 'document_id': metadata.get('document_id'), - 'document_name': metadata.get('document_name'), - 'data_source_type': metadata.get('document_data_source_type'), - 'segment_id': metadata.get('segment_id'), - 'retriever_from': metadata.get('retriever_from'), - 'score': metadata.get('score'), - 'hit_count': metadata.get('segment_hit_count'), - 'word_count': metadata.get('segment_word_count'), - 'segment_position': metadata.get('segment_position'), - 'index_node_hash': metadata.get('segment_index_node_hash'), - 'content': context_dict.get('content'), - } - - return source - - return None - - def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ - ModelInstance, ModelConfigWithCredentialsEntity]: - """ - Fetch model config - :param node_data_model: node data model - :return: - """ - model_name = node_data_model.name - provider_name = node_data_model.provider - - model_manager = ModelManager() - model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, - model_type=ModelType.LLM, - provider=provider_name, - model=model_name - ) - - provider_model_bundle = model_instance.provider_model_bundle - model_type_instance = model_instance.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - model_credentials = model_instance.credentials - - # check model - provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_name, - model_type=ModelType.LLM - ) - - if provider_model is None: - raise ValueError(f"Model {model_name} not exist.") - - if provider_model.status == ModelStatus.NO_CONFIGURE: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - elif provider_model.status == ModelStatus.NO_PERMISSION: - raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") - elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") - - # model config - completion_params = node_data_model.completion_params - stop = [] - if 'stop' in completion_params: - stop = completion_params['stop'] - del completion_params['stop'] - - # get model mode - model_mode = node_data_model.mode - if not model_mode: - raise ValueError("LLM mode is required.") - - model_schema = model_type_instance.get_model_schema( - model_name, - model_credentials - ) - - if not model_schema: - raise ValueError(f"Model {model_name} not exist.") - - return model_instance, ModelConfigWithCredentialsEntity( - provider=provider_name, - model=model_name, - model_schema=model_schema, - mode=model_mode, - provider_model_bundle=provider_model_bundle, - credentials=model_credentials, - parameters=completion_params, - stop=stop, - ) - - def _fetch_memory(self, node_data_memory: Optional[MemoryConfig], - variable_pool: VariablePool, - model_instance: ModelInstance) -> Optional[TokenBufferMemory]: - """ - Fetch memory - :param node_data_memory: node data memory - :param variable_pool: variable pool - :return: - """ - if not node_data_memory: - return None - - # get conversation id - conversation_id = variable_pool.get_any(['sys', SystemVariable.CONVERSATION_ID.value]) - if conversation_id is None: - return None - - # get conversation - conversation = db.session.query(Conversation).filter( - Conversation.app_id == self.app_id, - Conversation.id == conversation_id - ).first() - - if not conversation: - return None - - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) - - return memory - - def _fetch_prompt_messages(self, node_data: LLMNodeData, - query: Optional[str], - query_prompt_template: Optional[str], - inputs: dict[str, str], - files: list["FileVar"], - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: - """ - Fetch prompt messages - :param node_data: node data - :param query: query - :param query_prompt_template: query prompt template - :param inputs: inputs - :param files: files - :param context: context - :param memory: memory - :param model_config: model config - :return: - """ - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - prompt_messages = prompt_transform.get_prompt( - prompt_template=node_data.prompt_template, - inputs=inputs, - query=query if query else '', - files=files, - context=context, - memory_config=node_data.memory, - memory=memory, - model_config=model_config, - query_prompt_template=query_prompt_template, - ) - stop = model_config.stop - - vision_enabled = node_data.vision.enabled - vision_detail = node_data.vision.configs.detail if node_data.vision.configs else None - filtered_prompt_messages = [] - for prompt_message in prompt_messages: - if prompt_message.is_empty(): - continue - - if not isinstance(prompt_message.content, str): - prompt_message_content = [] - for content_item in prompt_message.content: - if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(content_item, ImagePromptMessageContent): - # Override vision config if LLM node has vision config - if vision_detail: - content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail) - prompt_message_content.append(content_item) - elif content_item.type == PromptMessageContentType.TEXT: - prompt_message_content.append(content_item) - - if len(prompt_message_content) > 1: - prompt_message.content = prompt_message_content - elif (len(prompt_message_content) == 1 - and prompt_message_content[0].type == PromptMessageContentType.TEXT): - prompt_message.content = prompt_message_content[0].data - - filtered_prompt_messages.append(prompt_message) - - if not filtered_prompt_messages: - raise ValueError("No prompt found in the LLM configuration. " - "Please ensure a prompt is properly configured before proceeding.") - - return filtered_prompt_messages, stop - - @classmethod - def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: - """ - Deduct LLM quota - :param tenant_id: tenant id - :param model_instance: model instance - :param usage: usage - :return: - """ - provider_model_bundle = model_instance.provider_model_bundle - provider_configuration = provider_model_bundle.configuration - - if provider_configuration.using_provider_type != ProviderType.SYSTEM: - return - - system_configuration = provider_configuration.system_configuration - - quota_unit = None - for quota_configuration in system_configuration.quota_configurations: - if quota_configuration.quota_type == system_configuration.current_quota_type: - quota_unit = quota_configuration.quota_unit - - if quota_configuration.quota_limit == -1: - return - - break - - used_quota = None - if quota_unit: - if quota_unit == QuotaUnit.TOKENS: - used_quota = usage.total_tokens - elif quota_unit == QuotaUnit.CREDITS: - used_quota = 1 - - if 'gpt-4' in model_instance.model: - used_quota = 20 - else: - used_quota = 1 - - if used_quota is not None: - db.session.query(Provider).filter( - Provider.tenant_id == tenant_id, - Provider.provider_name == model_instance.provider, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == system_configuration.current_quota_type.value, - Provider.quota_limit > Provider.quota_used - ).update({'quota_used': Provider.quota_used + used_quota}) - db.session.commit() - - @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]: - """ - Extract variable selector to variable mapping - :param node_data: node data - :return: - """ - - prompt_template = node_data.prompt_template - - variable_selectors = [] - if isinstance(prompt_template, list): - for prompt in prompt_template: - if prompt.edition_type != 'jinja2': - variable_template_parser = VariableTemplateParser(template=prompt.text) - variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - else: - if prompt_template.edition_type != 'jinja2': - variable_template_parser = VariableTemplateParser(template=prompt_template.text) - variable_selectors = variable_template_parser.extract_variable_selectors() - - variable_mapping = {} - for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector - - memory = node_data.memory - if memory and memory.query_prompt_template: - query_variable_selectors = (VariableTemplateParser(template=memory.query_prompt_template) - .extract_variable_selectors()) - for variable_selector in query_variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector - - if node_data.context.enabled: - variable_mapping['#context#'] = node_data.context.variable_selector - - if node_data.vision.enabled: - variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value] - - if node_data.memory: - variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value] - - if node_data.prompt_config: - enable_jinja = False - - if isinstance(prompt_template, list): - for prompt in prompt_template: - if prompt.edition_type == 'jinja2': - enable_jinja = True - break - else: - if prompt_template.edition_type == 'jinja2': - enable_jinja = True - - if enable_jinja: - for variable_selector in node_data.prompt_config.jinja2_variables or []: - variable_mapping[variable_selector.variable] = variable_selector.value_selector - - return variable_mapping - - @classmethod - def get_default_config(cls, filters: Optional[dict] = None) -> dict: - """ - Get default config of node. - :param filters: filter by node config parameters. - :return: - """ - return { - "type": "llm", - "config": { - "prompt_templates": { - "chat_model": { - "prompts": [ - { - "role": "system", - "text": "You are a helpful AI assistant.", - "edition_type": "basic" - } - ] - }, - "completion_model": { - "conversation_histories_role": { - "user_prefix": "Human", - "assistant_prefix": "Assistant" - }, - "prompt": { - "text": "Here is the chat histories between human and assistant, inside " - " XML tags.\n\n\n{{" - "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:", - "edition_type": "basic" - }, - "stop": ["Human:"] - } - } - } - } diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py new file mode 100644 index 00000000000000..eb4d1c9d87aa6a --- /dev/null +++ b/api/core/workflow/nodes/llm/node.py @@ -0,0 +1,717 @@ +import json +from collections.abc import Generator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Optional, cast + +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.entities.model_entities import ModelStatus +from core.entities.provider_entities import QuotaUnit +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities import ( + AudioPromptMessageContent, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + TextPromptMessageContent, + VideoPromptMessageContent, +) +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.utils.encoders import jsonable_encoder +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig +from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.variables import ( + ArrayAnySegment, + ArrayFileSegment, + ArraySegment, + FileSegment, + NoneSegment, + ObjectSegment, + StringSegment, +) +from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.event import InNodeEvent +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event import ( + ModelInvokeCompletedEvent, + NodeEvent, + RunCompletedEvent, + RunRetrieverResourceEvent, + RunStreamChunkEvent, +) +from core.workflow.utils.variable_template_parser import VariableTemplateParser +from extensions.ext_database import db +from models.model import Conversation +from models.provider import Provider, ProviderType +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import ( + LLMNodeChatModelMessage, + LLMNodeCompletionModelPromptTemplate, + LLMNodeData, + ModelConfig, +) +from .exc import ( + InvalidContextStructureError, + InvalidVariableTypeError, + LLMModeRequiredError, + LLMNodeError, + ModelNotExistError, + NoPromptFoundError, + VariableNotFoundError, +) + +if TYPE_CHECKING: + from core.file.models import File + + +class LLMNode(BaseNode[LLMNodeData]): + _node_data_cls = LLMNodeData + _node_type = NodeType.LLM + + def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None]: + node_inputs = None + process_data = None + + try: + # init messages template + self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template) + + # fetch variables and fetch values from variable pool + inputs = self._fetch_inputs(node_data=self.node_data) + + # fetch jinja2 inputs + jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data) + + # merge inputs + inputs.update(jinja_inputs) + + node_inputs = {} + + # fetch files + files = ( + self._fetch_files(selector=self.node_data.vision.configs.variable_selector) + if self.node_data.vision.enabled + else [] + ) + + if files: + node_inputs["#files#"] = [file.to_dict() for file in files] + + # fetch context value + generator = self._fetch_context(node_data=self.node_data) + context = None + for event in generator: + if isinstance(event, RunRetrieverResourceEvent): + context = event.context + yield event + + if context: + node_inputs["#context#"] = context + + # fetch model config + model_instance, model_config = self._fetch_model_config(self.node_data.model) + + # fetch memory + memory = self._fetch_memory(node_data_memory=self.node_data.memory, model_instance=model_instance) + + # fetch prompt messages + if self.node_data.memory: + query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) + if not query: + raise VariableNotFoundError("Query not found") + query = query.text + else: + query = None + + prompt_messages, stop = self._fetch_prompt_messages( + system_query=query, + inputs=inputs, + files=files, + context=context, + memory=memory, + model_config=model_config, + prompt_template=self.node_data.prompt_template, + memory_config=self.node_data.memory, + vision_enabled=self.node_data.vision.enabled, + vision_detail=self.node_data.vision.configs.detail, + ) + + process_data = { + "model_mode": model_config.mode, + "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, prompt_messages=prompt_messages + ), + "model_provider": model_config.provider, + "model_name": model_config.model, + } + + # handle invoke result + generator = self._invoke_llm( + node_data_model=self.node_data.model, + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=stop, + ) + + result_text = "" + usage = LLMUsage.empty_usage() + finish_reason = None + for event in generator: + if isinstance(event, RunStreamChunkEvent): + yield event + elif isinstance(event, ModelInvokeCompletedEvent): + result_text = event.text + usage = event.usage + finish_reason = event.finish_reason + break + except LLMNodeError as e: + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + inputs=node_inputs, + process_data=process_data, + ) + ) + return + + outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=node_inputs, + process_data=process_data, + outputs=outputs, + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, + NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, + NodeRunMetadataKey.CURRENCY: usage.currency, + }, + llm_usage=usage, + ) + ) + + def _invoke_llm( + self, + node_data_model: ModelConfig, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + stop: Optional[list[str]] = None, + ) -> Generator[NodeEvent, None, None]: + db.session.close() + + invoke_result = model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=node_data_model.completion_params, + stop=stop, + stream=True, + user=self.user_id, + ) + + # handle invoke result + generator = self._handle_invoke_result(invoke_result=invoke_result) + + usage = LLMUsage.empty_usage() + for event in generator: + yield event + if isinstance(event, ModelInvokeCompletedEvent): + usage = event.usage + + # deduct quota + self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) + + def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]: + if isinstance(invoke_result, LLMResult): + return + + model = None + prompt_messages: list[PromptMessage] = [] + full_text = "" + usage = None + finish_reason = None + for result in invoke_result: + text = result.delta.message.content + full_text += text + + yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"]) + + if not model: + model = result.model + + if not prompt_messages: + prompt_messages = result.prompt_messages + + if not usage and result.delta.usage: + usage = result.delta.usage + + if not finish_reason and result.delta.finish_reason: + finish_reason = result.delta.finish_reason + + if not usage: + usage = LLMUsage.empty_usage() + + yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason) + + def _transform_chat_messages( + self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, / + ) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: + if isinstance(messages, LLMNodeCompletionModelPromptTemplate): + if messages.edition_type == "jinja2" and messages.jinja2_text: + messages.text = messages.jinja2_text + + return messages + + for message in messages: + if message.edition_type == "jinja2" and message.jinja2_text: + message.text = message.jinja2_text + + return messages + + def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]: + variables = {} + + if not node_data.prompt_config: + return variables + + for variable_selector in node_data.prompt_config.jinja2_variables or []: + variable_name = variable_selector.variable + variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) + if variable is None: + raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") + + def parse_dict(input_dict: Mapping[str, Any]) -> str: + """ + Parse dict into string + """ + # check if it's a context structure + if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict: + return input_dict["content"] + + # else, parse the dict + try: + return json.dumps(input_dict, ensure_ascii=False) + except Exception: + return str(input_dict) + + if isinstance(variable, ArraySegment): + result = "" + for item in variable.value: + if isinstance(item, dict): + result += parse_dict(item) + else: + result += str(item) + result += "\n" + value = result.strip() + elif isinstance(variable, ObjectSegment): + value = parse_dict(variable.value) + else: + value = variable.text + + variables[variable_name] = value + + return variables + + def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]: + inputs = {} + prompt_template = node_data.prompt_template + + variable_selectors = [] + if isinstance(prompt_template, list): + for prompt in prompt_template: + variable_template_parser = VariableTemplateParser(template=prompt.text) + variable_selectors.extend(variable_template_parser.extract_variable_selectors()) + elif isinstance(prompt_template, CompletionModelPromptTemplate): + variable_template_parser = VariableTemplateParser(template=prompt_template.text) + variable_selectors = variable_template_parser.extract_variable_selectors() + + for variable_selector in variable_selectors: + variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) + if variable is None: + raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") + if isinstance(variable, NoneSegment): + inputs[variable_selector.variable] = "" + inputs[variable_selector.variable] = variable.to_object() + + memory = node_data.memory + if memory and memory.query_prompt_template: + query_variable_selectors = VariableTemplateParser( + template=memory.query_prompt_template + ).extract_variable_selectors() + for variable_selector in query_variable_selectors: + variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) + if variable is None: + raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") + if isinstance(variable, NoneSegment): + continue + inputs[variable_selector.variable] = variable.to_object() + + return inputs + + def _fetch_files(self, *, selector: Sequence[str]) -> Sequence["File"]: + variable = self.graph_runtime_state.variable_pool.get(selector) + if variable is None: + return [] + elif isinstance(variable, FileSegment): + return [variable.value] + elif isinstance(variable, ArrayFileSegment): + return variable.value + elif isinstance(variable, NoneSegment | ArrayAnySegment): + return [] + raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}") + + def _fetch_context(self, node_data: LLMNodeData): + if not node_data.context.enabled: + return + + if not node_data.context.variable_selector: + return + + context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector) + if context_value_variable: + if isinstance(context_value_variable, StringSegment): + yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value) + elif isinstance(context_value_variable, ArraySegment): + context_str = "" + original_retriever_resource = [] + for item in context_value_variable.value: + if isinstance(item, str): + context_str += item + "\n" + else: + if "content" not in item: + raise InvalidContextStructureError(f"Invalid context structure: {item}") + + context_str += item["content"] + "\n" + + retriever_resource = self._convert_to_original_retriever_resource(item) + if retriever_resource: + original_retriever_resource.append(retriever_resource) + + yield RunRetrieverResourceEvent( + retriever_resources=original_retriever_resource, context=context_str.strip() + ) + + def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: + if ( + "metadata" in context_dict + and "_source" in context_dict["metadata"] + and context_dict["metadata"]["_source"] == "knowledge" + ): + metadata = context_dict.get("metadata", {}) + + source = { + "position": metadata.get("position"), + "dataset_id": metadata.get("dataset_id"), + "dataset_name": metadata.get("dataset_name"), + "document_id": metadata.get("document_id"), + "document_name": metadata.get("document_name"), + "data_source_type": metadata.get("document_data_source_type"), + "segment_id": metadata.get("segment_id"), + "retriever_from": metadata.get("retriever_from"), + "score": metadata.get("score"), + "hit_count": metadata.get("segment_hit_count"), + "word_count": metadata.get("segment_word_count"), + "segment_position": metadata.get("segment_position"), + "index_node_hash": metadata.get("segment_index_node_hash"), + "content": context_dict.get("content"), + "page": metadata.get("page"), + } + + return source + + return None + + def _fetch_model_config( + self, node_data_model: ModelConfig + ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: + model_name = node_data_model.name + provider_name = node_data_model.provider + + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name + ) + + provider_model_bundle = model_instance.provider_model_bundle + model_type_instance = model_instance.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + model_credentials = model_instance.credentials + + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=model_name, model_type=ModelType.LLM + ) + + if provider_model is None: + raise ModelNotExistError(f"Model {model_name} not exist.") + + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + + # model config + completion_params = node_data_model.completion_params + stop = [] + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] + + # get model mode + model_mode = node_data_model.mode + if not model_mode: + raise LLMModeRequiredError("LLM mode is required.") + + model_schema = model_type_instance.get_model_schema(model_name, model_credentials) + + if not model_schema: + raise ModelNotExistError(f"Model {model_name} not exist.") + + return model_instance, ModelConfigWithCredentialsEntity( + provider=provider_name, + model=model_name, + model_schema=model_schema, + mode=model_mode, + provider_model_bundle=provider_model_bundle, + credentials=model_credentials, + parameters=completion_params, + stop=stop, + ) + + def _fetch_memory( + self, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance + ) -> Optional[TokenBufferMemory]: + if not node_data_memory: + return None + + # get conversation id + conversation_id_variable = self.graph_runtime_state.variable_pool.get( + ["sys", SystemVariableKey.CONVERSATION_ID.value] + ) + if not isinstance(conversation_id_variable, StringSegment): + return None + conversation_id = conversation_id_variable.value + + # get conversation + conversation = ( + db.session.query(Conversation) + .filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id) + .first() + ) + + if not conversation: + return None + + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) + + return memory + + def _fetch_prompt_messages( + self, + *, + system_query: str | None = None, + inputs: dict[str, str] | None = None, + files: Sequence["File"], + context: str | None = None, + memory: TokenBufferMemory | None = None, + model_config: ModelConfigWithCredentialsEntity, + prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, + memory_config: MemoryConfig | None = None, + vision_enabled: bool = False, + vision_detail: ImagePromptMessageContent.DETAIL, + ) -> tuple[list[PromptMessage], Optional[list[str]]]: + inputs = inputs or {} + + prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) + prompt_messages = prompt_transform.get_prompt( + prompt_template=prompt_template, + inputs=inputs, + query=system_query or "", + files=files, + context=context, + memory_config=memory_config, + memory=memory, + model_config=model_config, + ) + stop = model_config.stop + filtered_prompt_messages = [] + for prompt_message in prompt_messages: + if prompt_message.is_empty(): + continue + + if not isinstance(prompt_message.content, str): + prompt_message_content = [] + for content_item in prompt_message.content or []: + # Skip image if vision is disabled + if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE: + continue + + if isinstance(content_item, ImagePromptMessageContent): + # Override vision config if LLM node has vision config, + # cuz vision detail is related to the configuration from FileUpload feature. + content_item.detail = vision_detail + prompt_message_content.append(content_item) + elif isinstance( + content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent + ): + prompt_message_content.append(content_item) + + if len(prompt_message_content) > 1: + prompt_message.content = prompt_message_content + elif ( + len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT + ): + prompt_message.content = prompt_message_content[0].data + + filtered_prompt_messages.append(prompt_message) + + if not filtered_prompt_messages: + raise NoPromptFoundError( + "No prompt found in the LLM configuration. " + "Please ensure a prompt is properly configured before proceeding." + ) + + return filtered_prompt_messages, stop + + @classmethod + def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: + provider_model_bundle = model_instance.provider_model_bundle + provider_configuration = provider_model_bundle.configuration + + if provider_configuration.using_provider_type != ProviderType.SYSTEM: + return + + system_configuration = provider_configuration.system_configuration + + quota_unit = None + for quota_configuration in system_configuration.quota_configurations: + if quota_configuration.quota_type == system_configuration.current_quota_type: + quota_unit = quota_configuration.quota_unit + + if quota_configuration.quota_limit == -1: + return + + break + + used_quota = None + if quota_unit: + if quota_unit == QuotaUnit.TOKENS: + used_quota = usage.total_tokens + elif quota_unit == QuotaUnit.CREDITS: + used_quota = 1 + + if "gpt-4" in model_instance.model: + used_quota = 20 + else: + used_quota = 1 + + if used_quota is not None and system_configuration.current_quota_type is not None: + db.session.query(Provider).filter( + Provider.tenant_id == tenant_id, + Provider.provider_name == model_instance.provider, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == system_configuration.current_quota_type.value, + Provider.quota_limit > Provider.quota_used, + ).update({"quota_used": Provider.quota_used + used_quota}) + db.session.commit() + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: LLMNodeData, + ) -> Mapping[str, Sequence[str]]: + prompt_template = node_data.prompt_template + + variable_selectors = [] + if isinstance(prompt_template, list) and all( + isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template + ): + for prompt in prompt_template: + if prompt.edition_type != "jinja2": + variable_template_parser = VariableTemplateParser(template=prompt.text) + variable_selectors.extend(variable_template_parser.extract_variable_selectors()) + elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): + if prompt_template.edition_type != "jinja2": + variable_template_parser = VariableTemplateParser(template=prompt_template.text) + variable_selectors = variable_template_parser.extract_variable_selectors() + else: + raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}") + + variable_mapping = {} + for variable_selector in variable_selectors: + variable_mapping[variable_selector.variable] = variable_selector.value_selector + + memory = node_data.memory + if memory and memory.query_prompt_template: + query_variable_selectors = VariableTemplateParser( + template=memory.query_prompt_template + ).extract_variable_selectors() + for variable_selector in query_variable_selectors: + variable_mapping[variable_selector.variable] = variable_selector.value_selector + + if node_data.context.enabled: + variable_mapping["#context#"] = node_data.context.variable_selector + + if node_data.vision.enabled: + variable_mapping["#files#"] = ["sys", SystemVariableKey.FILES.value] + + if node_data.memory: + variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value] + + if node_data.prompt_config: + enable_jinja = False + + if isinstance(prompt_template, list): + for prompt in prompt_template: + if prompt.edition_type == "jinja2": + enable_jinja = True + break + else: + if prompt_template.edition_type == "jinja2": + enable_jinja = True + + if enable_jinja: + for variable_selector in node_data.prompt_config.jinja2_variables or []: + variable_mapping[variable_selector.variable] = variable_selector.value_selector + + variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} + + return variable_mapping + + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + return { + "type": "llm", + "config": { + "prompt_templates": { + "chat_model": { + "prompts": [ + {"role": "system", "text": "You are a helpful AI assistant.", "edition_type": "basic"} + ] + }, + "completion_model": { + "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, + "prompt": { + "text": "Here is the chat histories between human and assistant, inside " + " XML tags.\n\n\n{{" + "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:", + "edition_type": "basic", + }, + "stop": ["Human:"], + }, + } + }, + } diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py index 8a5684551ef7f3..b7cd7a948e3f8b 100644 --- a/api/core/workflow/nodes/loop/entities.py +++ b/api/core/workflow/nodes/loop/entities.py @@ -1,5 +1,4 @@ - -from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState +from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState class LoopNodeData(BaseIterationNodeData): @@ -7,7 +6,8 @@ class LoopNodeData(BaseIterationNodeData): Loop Node Data. """ + class LoopState(BaseIterationState): """ Loop State. - """ \ No newline at end of file + """ diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 7d53c6f5f2c32e..6fdff966026b63 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -1,20 +1,37 @@ -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseIterationNode +from typing import Any + +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from core.workflow.nodes.loop.entities import LoopNodeData, LoopState +from core.workflow.utils.condition.entities import Condition -class LoopNode(BaseIterationNode): +class LoopNode(BaseNode[LoopNodeData]): """ Loop Node. """ + _node_data_cls = LoopNodeData _node_type = NodeType.LOOP - def _run(self, variable_pool: VariablePool) -> LoopState: - return super()._run(variable_pool) + def _run(self) -> LoopState: + return super()._run() - def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str: + @classmethod + def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]: """ - Get next iteration start node id based on the graph. + Get conditions. """ + node_id = node_config.get("id") + if not node_id: + return [] + + # TODO waiting for implementation + return [ + Condition( + variable_selector=[node_id, "index"], + comparison_operator="≤", + value_type="value_selector", + value_selector=[], + ) + ] diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py new file mode 100644 index 00000000000000..c13b5ff76f3d2f --- /dev/null +++ b/api/core/workflow/nodes/node_mapping.py @@ -0,0 +1,41 @@ +from core.workflow.nodes.answer import AnswerNode +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.code import CodeNode +from core.workflow.nodes.document_extractor import DocumentExtractorNode +from core.workflow.nodes.end import EndNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.http_request import HttpRequestNode +from core.workflow.nodes.if_else import IfElseNode +from core.workflow.nodes.iteration import IterationNode, IterationStartNode +from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode +from core.workflow.nodes.list_operator import ListOperatorNode +from core.workflow.nodes.llm import LLMNode +from core.workflow.nodes.parameter_extractor import ParameterExtractorNode +from core.workflow.nodes.question_classifier import QuestionClassifierNode +from core.workflow.nodes.start import StartNode +from core.workflow.nodes.template_transform import TemplateTransformNode +from core.workflow.nodes.tool import ToolNode +from core.workflow.nodes.variable_aggregator import VariableAggregatorNode +from core.workflow.nodes.variable_assigner import VariableAssignerNode + +node_type_classes_mapping: dict[NodeType, type[BaseNode]] = { + NodeType.START: StartNode, + NodeType.END: EndNode, + NodeType.ANSWER: AnswerNode, + NodeType.LLM: LLMNode, + NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, + NodeType.IF_ELSE: IfElseNode, + NodeType.CODE: CodeNode, + NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode, + NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, + NodeType.HTTP_REQUEST: HttpRequestNode, + NodeType.TOOL: ToolNode, + NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, + NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR + NodeType.ITERATION: IterationNode, + NodeType.ITERATION_START: IterationStartNode, + NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode, + NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode, + NodeType.DOCUMENT_EXTRACTOR: DocumentExtractorNode, + NodeType.LIST_OPERATOR: ListOperatorNode, +} diff --git a/api/core/workflow/nodes/parameter_extractor/__init__.py b/api/core/workflow/nodes/parameter_extractor/__init__.py index e69de29bb2d1d6..bdbf19a7d36d7e 100644 --- a/api/core/workflow/nodes/parameter_extractor/__init__.py +++ b/api/core/workflow/nodes/parameter_extractor/__init__.py @@ -0,0 +1,3 @@ +from .parameter_extractor_node import ParameterExtractorNode + +__all__ = ["ParameterExtractorNode"] diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index 7bb123b1267c91..a001b44dc7dfee 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -1,54 +1,50 @@ from typing import Any, Literal, Optional -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, Field, field_validator from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.nodes.base import BaseNodeData +from core.workflow.nodes.llm import ModelConfig, VisionConfig -class ModelConfig(BaseModel): - """ - Model Config. - """ - provider: str - name: str - mode: str - completion_params: dict[str, Any] = {} - class ParameterConfig(BaseModel): """ Parameter Config. """ + name: str - type: Literal['string', 'number', 'bool', 'select', 'array[string]', 'array[number]', 'array[object]'] + type: Literal["string", "number", "bool", "select", "array[string]", "array[number]", "array[object]"] options: Optional[list[str]] = None description: str required: bool - @field_validator('name', mode='before') + @field_validator("name", mode="before") @classmethod def validate_name(cls, value) -> str: if not value: - raise ValueError('Parameter name is required') - if value in ['__reason', '__is_success']: - raise ValueError('Invalid parameter name, __reason and __is_success are reserved') + raise ValueError("Parameter name is required") + if value in {"__reason", "__is_success"}: + raise ValueError("Invalid parameter name, __reason and __is_success are reserved") return value + class ParameterExtractorNodeData(BaseNodeData): """ Parameter Extractor Node Data. """ + model: ModelConfig query: list[str] parameters: list[ParameterConfig] instruction: Optional[str] = None memory: Optional[MemoryConfig] = None - reasoning_mode: Literal['function_call', 'prompt'] + reasoning_mode: Literal["function_call", "prompt"] + vision: VisionConfig = Field(default_factory=VisionConfig) - @field_validator('reasoning_mode', mode='before') + @field_validator("reasoning_mode", mode="before") @classmethod def set_reasoning_mode(cls, v) -> str: - return v or 'function_call' + return v or "function_call" def get_parameter_json_schema(self) -> dict: """ @@ -56,32 +52,26 @@ def get_parameter_json_schema(self) -> dict: :return: parameter json schema """ - parameters = { - 'type': 'object', - 'properties': {}, - 'required': [] - } + parameters = {"type": "object", "properties": {}, "required": []} for parameter in self.parameters: - parameter_schema = { - 'description': parameter.description - } - - if parameter.type in ['string', 'select']: - parameter_schema['type'] = 'string' - elif parameter.type.startswith('array'): - parameter_schema['type'] = 'array' + parameter_schema: dict[str, Any] = {"description": parameter.description} + + if parameter.type in {"string", "select"}: + parameter_schema["type"] = "string" + elif parameter.type.startswith("array"): + parameter_schema["type"] = "array" nested_type = parameter.type[6:-1] - parameter_schema['items'] = {'type': nested_type} + parameter_schema["items"] = {"type": nested_type} else: - parameter_schema['type'] = parameter.type + parameter_schema["type"] = parameter.type + + if parameter.type == "select": + parameter_schema["enum"] = parameter.options - if parameter.type == 'select': - parameter_schema['enum'] = parameter.options + parameters["properties"][parameter.name] = parameter_schema - parameters['properties'][parameter.name] = parameter_schema - if parameter.required: - parameters['required'].append(parameter.name) + parameters["required"].append(parameter.name) - return parameters \ No newline at end of file + return parameters diff --git a/api/core/workflow/nodes/parameter_extractor/exc.py b/api/core/workflow/nodes/parameter_extractor/exc.py new file mode 100644 index 00000000000000..6511aba1856999 --- /dev/null +++ b/api/core/workflow/nodes/parameter_extractor/exc.py @@ -0,0 +1,50 @@ +class ParameterExtractorNodeError(ValueError): + """Base error for ParameterExtractorNode.""" + + +class InvalidModelTypeError(ParameterExtractorNodeError): + """Raised when the model is not a Large Language Model.""" + + +class ModelSchemaNotFoundError(ParameterExtractorNodeError): + """Raised when the model schema is not found.""" + + +class InvalidInvokeResultError(ParameterExtractorNodeError): + """Raised when the invoke result is invalid.""" + + +class InvalidTextContentTypeError(ParameterExtractorNodeError): + """Raised when the text content type is invalid.""" + + +class InvalidNumberOfParametersError(ParameterExtractorNodeError): + """Raised when the number of parameters is invalid.""" + + +class RequiredParameterMissingError(ParameterExtractorNodeError): + """Raised when a required parameter is missing.""" + + +class InvalidSelectValueError(ParameterExtractorNodeError): + """Raised when a select value is invalid.""" + + +class InvalidNumberValueError(ParameterExtractorNodeError): + """Raised when a number value is invalid.""" + + +class InvalidBoolValueError(ParameterExtractorNodeError): + """Raised when a bool value is invalid.""" + + +class InvalidStringValueError(ParameterExtractorNodeError): + """Raised when a string value is invalid.""" + + +class InvalidArrayValueError(ParameterExtractorNodeError): + """Raised when an array value is invalid.""" + + +class InvalidModelModeError(ParameterExtractorNodeError): + """Raised when the model mode is invalid.""" diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 2876695a825ba0..b64bde8ac5e675 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -1,8 +1,10 @@ import json import uuid -from typing import Optional, cast +from collections.abc import Mapping, Sequence +from typing import Any, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.file import File from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage @@ -21,12 +23,31 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.llm.entities import ModelConfig -from core.workflow.nodes.llm.llm_node import LLMNode -from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from core.workflow.nodes.parameter_extractor.prompts import ( +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.llm import LLMNode, ModelConfig +from core.workflow.utils import variable_template_parser +from extensions.ext_database import db +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import ParameterExtractorNodeData +from .exc import ( + InvalidArrayValueError, + InvalidBoolValueError, + InvalidInvokeResultError, + InvalidModelModeError, + InvalidModelTypeError, + InvalidNumberOfParametersError, + InvalidNumberValueError, + InvalidSelectValueError, + InvalidStringValueError, + InvalidTextContentTypeError, + ModelSchemaNotFoundError, + ParameterExtractorNodeError, + RequiredParameterMissingError, +) +from .prompts import ( CHAT_EXAMPLE, CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE, COMPLETION_GENERATE_JSON_PROMPT, @@ -35,15 +56,13 @@ FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT, FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE, ) -from core.workflow.utils.variable_template_parser import VariableTemplateParser -from extensions.ext_database import db -from models.workflow import WorkflowNodeExecutionStatus class ParameterExtractorNode(LLMNode): """ Parameter Extractor Node. """ + _node_data_cls = ParameterExtractorNodeData _node_type = NodeType.PARAMETER_EXTRACTOR @@ -56,89 +75,109 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: "model": { "prompt_templates": { "completion_model": { - "conversation_histories_role": { - "user_prefix": "Human", - "assistant_prefix": "Assistant" - }, - "stop": ["Human:"] + "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, + "stop": ["Human:"], } } } } - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self): """ Run the node. """ node_data = cast(ParameterExtractorNodeData, self.node_data) - variable = variable_pool.get_any(node_data.query) - if not variable: - raise ValueError("Input variable content not found or is empty") - query = variable + variable = self.graph_runtime_state.variable_pool.get(node_data.query) + query = variable.text if variable else "" - inputs = { - 'query': query, - 'parameters': jsonable_encoder(node_data.parameters), - 'instruction': jsonable_encoder(node_data.instruction), - } + files = ( + self._fetch_files( + selector=node_data.vision.configs.variable_selector, + ) + if node_data.vision.enabled + else [] + ) model_instance, model_config = self._fetch_model_config(node_data.model) if not isinstance(model_instance.model_type_instance, LargeLanguageModel): - raise ValueError("Model is not a Large Language Model") + raise InvalidModelTypeError("Model is not a Large Language Model") llm_model = model_instance.model_type_instance - model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials) + model_schema = llm_model.get_model_schema( + model=model_config.model, + credentials=model_config.credentials, + ) if not model_schema: - raise ValueError("Model schema not found") + raise ModelSchemaNotFoundError("Model schema not found") # fetch memory - memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) + memory = self._fetch_memory( + node_data_memory=node_data.memory, + model_instance=model_instance, + ) - if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \ - and node_data.reasoning_mode == 'function_call': - # use function call + if ( + set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} + and node_data.reasoning_mode == "function_call" + ): + # use function call prompt_messages, prompt_message_tools = self._generate_function_call_prompt( - node_data, query, variable_pool, model_config, memory + node_data=node_data, + query=query, + variable_pool=self.graph_runtime_state.variable_pool, + model_config=model_config, + memory=memory, + files=files, ) else: # use prompt engineering - prompt_messages = self._generate_prompt_engineering_prompt(node_data, query, variable_pool, model_config, - memory) + prompt_messages = self._generate_prompt_engineering_prompt( + data=node_data, + query=query, + variable_pool=self.graph_runtime_state.variable_pool, + model_config=model_config, + memory=memory, + files=files, + ) + prompt_message_tools = [] + inputs = { + "query": query, + "files": [f.to_dict() for f in files], + "parameters": jsonable_encoder(node_data.parameters), + "instruction": jsonable_encoder(node_data.instruction), + } + process_data = { - 'model_mode': model_config.mode, - 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_config.mode, - prompt_messages=prompt_messages + "model_mode": model_config.mode, + "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, prompt_messages=prompt_messages ), - 'usage': None, - 'function': {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), - 'tool_call': None, + "usage": None, + "function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), + "tool_call": None, } try: - text, usage, tool_call = self._invoke_llm( + text, usage, tool_call = self._invoke( node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, tools=prompt_message_tools, stop=model_config.stop, ) - process_data['usage'] = jsonable_encoder(usage) - process_data['tool_call'] = jsonable_encoder(tool_call) - process_data['llm_text'] = text - except Exception as e: + process_data["usage"] = jsonable_encoder(usage) + process_data["tool_call"] = jsonable_encoder(tool_call) + process_data["llm_text"] = text + except ParameterExtractorNodeError as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=inputs, process_data=process_data, - outputs={ - '__is_success': 0, - '__reason': str(e) - }, + outputs={"__is_success": 0, "__reason": str(e)}, error=str(e), - metadata={} + metadata={}, ) error = None @@ -152,42 +191,34 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: error = "Failed to extract result from function call or text response, using empty result." try: - result = self._validate_result(node_data, result) - except Exception as e: + result = self._validate_result(data=node_data, result=result or {}) + except ParameterExtractorNodeError as e: error = str(e) # transform result into standard format - result = self._transform_result(node_data, result) + result = self._transform_result(data=node_data, result=result or {}) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, process_data=process_data, - outputs={ - '__is_success': 1 if not error else 0, - '__reason': error, - **result - }, + outputs={"__is_success": 1 if not error else 0, "__reason": error, **result}, metadata={ NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency - } + NodeRunMetadataKey.CURRENCY: usage.currency, + }, + llm_usage=usage, ) - def _invoke_llm(self, node_data_model: ModelConfig, - model_instance: ModelInstance, - prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - stop: list[str]) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]: - """ - Invoke large language model - :param node_data_model: node data model - :param model_instance: model instance - :param prompt_messages: prompt messages - :param stop: stop - :return: - """ + def _invoke( + self, + node_data_model: ModelConfig, + model_instance: ModelInstance, + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + stop: list[str], + ) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]: db.session.close() invoke_result = model_instance.invoke_llm( @@ -201,9 +232,12 @@ def _invoke_llm(self, node_data_model: ModelConfig, # handle invoke result if not isinstance(invoke_result, LLMResult): - raise ValueError(f"Invalid invoke result: {invoke_result}") + raise InvalidInvokeResultError(f"Invalid invoke result: {invoke_result}") text = invoke_result.message.content + if not isinstance(text, str): + raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.") + usage = invoke_result.usage tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None @@ -212,32 +246,36 @@ def _invoke_llm(self, node_data_model: ModelConfig, return text, usage, tool_call - def _generate_function_call_prompt(self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], - ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: + def _generate_function_call_prompt( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + files: Sequence[File], + ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: """ Generate function call prompt. """ - query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format(content=query, structure=json.dumps( - node_data.get_parameter_json_schema())) + query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format( + content=query, structure=json.dumps(node_data.get_parameter_json_schema()) + ) prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '') - prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, memory, - rest_token) + rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "") + prompt_template = self._get_function_calling_prompt_template( + node_data, query, variable_pool, memory, rest_token + ) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', - files=[], - context='', + query="", + files=files, + context="", memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) # find last user message @@ -250,49 +288,53 @@ def _generate_function_call_prompt(self, example_messages = [] for example in FUNCTION_CALLING_EXTRACTOR_EXAMPLE: id = uuid.uuid4().hex - example_messages.extend([ - UserPromptMessage(content=example['user']['query']), - AssistantPromptMessage( - content=example['assistant']['text'], - tool_calls=[ - AssistantPromptMessage.ToolCall( - id=id, - type='function', - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=example['assistant']['function_call']['name'], - arguments=json.dumps(example['assistant']['function_call']['parameters'] - ) - )) - ] - ), - ToolPromptMessage( - content='Great! You have called the function with the correct parameters.', - tool_call_id=id - ), - AssistantPromptMessage( - content='I have extracted the parameters, let\'s move on.', - ) - ]) + example_messages.extend( + [ + UserPromptMessage(content=example["user"]["query"]), + AssistantPromptMessage( + content=example["assistant"]["text"], + tool_calls=[ + AssistantPromptMessage.ToolCall( + id=id, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=example["assistant"]["function_call"]["name"], + arguments=json.dumps(example["assistant"]["function_call"]["parameters"]), + ), + ) + ], + ), + ToolPromptMessage( + content="Great! You have called the function with the correct parameters.", tool_call_id=id + ), + AssistantPromptMessage( + content="I have extracted the parameters, let's move on.", + ), + ] + ) - prompt_messages = prompt_messages[:last_user_message_idx] + \ - example_messages + prompt_messages[last_user_message_idx:] + prompt_messages = ( + prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:] + ) # generate tool tool = PromptMessageTool( name=FUNCTION_CALLING_EXTRACTOR_NAME, - description='Extract parameters from the natural language text', + description="Extract parameters from the natural language text", parameters=node_data.get_parameter_json_schema(), ) return prompt_messages, [tool] - def _generate_prompt_engineering_prompt(self, - data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], - ) -> list[PromptMessage]: + def _generate_prompt_engineering_prompt( + self, + data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + files: Sequence[File], + ) -> list[PromptMessage]: """ Generate prompt engineering prompt. """ @@ -300,74 +342,92 @@ def _generate_prompt_engineering_prompt(self, if model_mode == ModelMode.COMPLETION: return self._generate_prompt_engineering_completion_prompt( - data, query, variable_pool, model_config, memory + node_data=data, + query=query, + variable_pool=variable_pool, + model_config=model_config, + memory=memory, + files=files, ) elif model_mode == ModelMode.CHAT: return self._generate_prompt_engineering_chat_prompt( - data, query, variable_pool, model_config, memory + node_data=data, + query=query, + variable_pool=variable_pool, + model_config=model_config, + memory=memory, + files=files, ) else: - raise ValueError(f"Invalid model mode: {model_mode}") - - def _generate_prompt_engineering_completion_prompt(self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], - ) -> list[PromptMessage]: + raise InvalidModelModeError(f"Invalid model mode: {model_mode}") + + def _generate_prompt_engineering_completion_prompt( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + files: Sequence[File], + ) -> list[PromptMessage]: """ Generate completion prompt. """ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '') - prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, memory, - rest_token) + rest_token = self._calculate_rest_token( + node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context="" + ) + prompt_template = self._get_prompt_engineering_prompt_template( + node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token + ) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, - inputs={ - 'structure': json.dumps(node_data.get_parameter_json_schema()) - }, - query='', - files=[], - context='', + inputs={"structure": json.dumps(node_data.get_parameter_json_schema())}, + query="", + files=files, + context="", memory_config=node_data.memory, memory=memory, - model_config=model_config + model_config=model_config, ) return prompt_messages - def _generate_prompt_engineering_chat_prompt(self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - memory: Optional[TokenBufferMemory], - ) -> list[PromptMessage]: + def _generate_prompt_engineering_chat_prompt( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + memory: Optional[TokenBufferMemory], + files: Sequence[File], + ) -> list[PromptMessage]: """ Generate chat prompt. """ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '') + rest_token = self._calculate_rest_token( + node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context="" + ) prompt_template = self._get_prompt_engineering_prompt_template( - node_data, - CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( - structure=json.dumps(node_data.get_parameter_json_schema()), - text=query + node_data=node_data, + query=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( + structure=json.dumps(node_data.get_parameter_json_schema()), text=query ), - variable_pool, memory, rest_token + variable_pool=variable_pool, + memory=memory, + max_token_limit=rest_token, ) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', - files=[], - context='', + query="", + files=files, + context="", memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) # find last user message @@ -379,18 +439,23 @@ def _generate_prompt_engineering_chat_prompt(self, # add example messages before last user message example_messages = [] for example in CHAT_EXAMPLE: - example_messages.extend([ - UserPromptMessage(content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( - structure=json.dumps(example['user']['json']), - text=example['user']['query'], - )), - AssistantPromptMessage( - content=json.dumps(example['assistant']['json']), - ) - ]) + example_messages.extend( + [ + UserPromptMessage( + content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( + structure=json.dumps(example["user"]["json"]), + text=example["user"]["query"], + ) + ), + AssistantPromptMessage( + content=json.dumps(example["assistant"]["json"]), + ), + ] + ) - prompt_messages = prompt_messages[:last_user_message_idx] + \ - example_messages + prompt_messages[last_user_message_idx:] + prompt_messages = ( + prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:] + ) return prompt_messages @@ -399,35 +464,36 @@ def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> di Validate result. """ if len(data.parameters) != len(result): - raise ValueError("Invalid number of parameters") + raise InvalidNumberOfParametersError("Invalid number of parameters") for parameter in data.parameters: if parameter.required and parameter.name not in result: - raise ValueError(f"Parameter {parameter.name} is required") + raise RequiredParameterMissingError(f"Parameter {parameter.name} is required") - if parameter.type == 'select' and parameter.options and result.get(parameter.name) not in parameter.options: - raise ValueError(f"Invalid `select` value for parameter {parameter.name}") + if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options: + raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}") - if parameter.type == 'number' and not isinstance(result.get(parameter.name), int | float): - raise ValueError(f"Invalid `number` value for parameter {parameter.name}") + if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float): + raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}") - if parameter.type == 'bool' and not isinstance(result.get(parameter.name), bool): - raise ValueError(f"Invalid `bool` value for parameter {parameter.name}") + if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool): + raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}") - if parameter.type == 'string' and not isinstance(result.get(parameter.name), str): - raise ValueError(f"Invalid `string` value for parameter {parameter.name}") + if parameter.type == "string" and not isinstance(result.get(parameter.name), str): + raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}") - if parameter.type.startswith('array'): - if not isinstance(result.get(parameter.name), list): - raise ValueError(f"Invalid `array` value for parameter {parameter.name}") + if parameter.type.startswith("array"): + parameters = result.get(parameter.name) + if not isinstance(parameters, list): + raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}") nested_type = parameter.type[6:-1] - for item in result.get(parameter.name): - if nested_type == 'number' and not isinstance(item, int | float): - raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}") - if nested_type == 'string' and not isinstance(item, str): - raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}") - if nested_type == 'object' and not isinstance(item, dict): - raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}") + for item in parameters: + if nested_type == "number" and not isinstance(item, int | float): + raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}") + if nested_type == "string" and not isinstance(item, str): + raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}") + if nested_type == "object" and not isinstance(item, dict): + raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}") return result def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict: @@ -438,12 +504,12 @@ def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> d for parameter in data.parameters: if parameter.name in result: # transform value - if parameter.type == 'number': + if parameter.type == "number": if isinstance(result[parameter.name], int | float): transformed_result[parameter.name] = result[parameter.name] elif isinstance(result[parameter.name], str): try: - if '.' in result[parameter.name]: + if "." in result[parameter.name]: result[parameter.name] = float(result[parameter.name]) else: result[parameter.name] = int(result[parameter.name]) @@ -460,40 +526,40 @@ def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> d # transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true') # elif isinstance(result[parameter.name], int): # transformed_result[parameter.name] = bool(result[parameter.name]) - elif parameter.type in ['string', 'select']: + elif parameter.type in {"string", "select"}: if isinstance(result[parameter.name], str): transformed_result[parameter.name] = result[parameter.name] - elif parameter.type.startswith('array'): + elif parameter.type.startswith("array"): if isinstance(result[parameter.name], list): nested_type = parameter.type[6:-1] transformed_result[parameter.name] = [] for item in result[parameter.name]: - if nested_type == 'number': + if nested_type == "number": if isinstance(item, int | float): transformed_result[parameter.name].append(item) elif isinstance(item, str): try: - if '.' in item: + if "." in item: transformed_result[parameter.name].append(float(item)) else: transformed_result[parameter.name].append(int(item)) except ValueError: pass - elif nested_type == 'string': + elif nested_type == "string": if isinstance(item, str): transformed_result[parameter.name].append(item) - elif nested_type == 'object': + elif nested_type == "object": if isinstance(item, dict): transformed_result[parameter.name].append(item) if parameter.name not in transformed_result: - if parameter.type == 'number': + if parameter.type == "number": transformed_result[parameter.name] = 0 - elif parameter.type == 'bool': + elif parameter.type == "bool": transformed_result[parameter.name] = False - elif parameter.type in ['string', 'select']: - transformed_result[parameter.name] = '' - elif parameter.type.startswith('array'): + elif parameter.type in {"string", "select"}: + transformed_result[parameter.name] = "" + elif parameter.type.startswith("array"): transformed_result[parameter.name] = [] return transformed_result @@ -509,24 +575,24 @@ def extract_json(text): """ stack = [] for i, c in enumerate(text): - if c == '{' or c == '[': + if c in {"{", "["}: stack.append(c) - elif c == '}' or c == ']': + elif c in {"}", "]"}: # check if stack is empty if not stack: return text[:i] # check if the last element in stack is matching - if (c == '}' and stack[-1] == '{') or (c == ']' and stack[-1] == '['): + if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["): stack.pop() if not stack: - return text[:i + 1] + return text[: i + 1] else: return text[:i] return None # extract json from the text for idx in range(len(result)): - if result[idx] == '{' or result[idx] == '[': + if result[idx] == "{" or result[idx] == "[": json_str = extract_json(result[idx:]) if json_str: try: @@ -549,102 +615,95 @@ def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict: """ result = {} for parameter in data.parameters: - if parameter.type == 'number': + if parameter.type == "number": result[parameter.name] = 0 - elif parameter.type == 'bool': + elif parameter.type == "bool": result[parameter.name] = False - elif parameter.type in ['string', 'select']: - result[parameter.name] = '' + elif parameter.type in {"string", "select"}: + result[parameter.name] = "" return result - def _render_instruction(self, instruction: str, variable_pool: VariablePool) -> str: - """ - Render instruction. - """ - variable_template_parser = VariableTemplateParser(instruction) - inputs = {} - for selector in variable_template_parser.extract_variable_selectors(): - variable = variable_pool.get_any(selector.value_selector) - inputs[selector.variable] = variable - - return variable_template_parser.format(inputs) - - def _get_function_calling_prompt_template(self, node_data: ParameterExtractorNodeData, query: str, - variable_pool: VariablePool, - memory: Optional[TokenBufferMemory], - max_token_limit: int = 2000) \ - -> list[ChatModelMessage]: + def _get_function_calling_prompt_template( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + memory: Optional[TokenBufferMemory], + max_token_limit: int = 2000, + ) -> list[ChatModelMessage]: model_mode = ModelMode.value_of(node_data.model.mode) input_text = query - memory_str = '' - instruction = self._render_instruction(node_data.instruction or '', variable_pool) + memory_str = "" + instruction = variable_pool.convert_template(node_data.instruction or "").text - if memory: - memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit, - message_limit=node_data.memory.window.size) + if memory and node_data.memory and node_data.memory.window: + memory_str = memory.get_history_prompt_text( + max_token_limit=max_token_limit, message_limit=node_data.memory.window.size + ) if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( role=PromptMessageRole.SYSTEM, - text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction) - ) - user_prompt_message = ChatModelMessage( - role=PromptMessageRole.USER, - text=input_text + text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction), ) + user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] else: - raise ValueError(f"Model mode {model_mode} not support.") - - def _get_prompt_engineering_prompt_template(self, node_data: ParameterExtractorNodeData, query: str, - variable_pool: VariablePool, - memory: Optional[TokenBufferMemory], - max_token_limit: int = 2000) \ - -> list[ChatModelMessage]: - + raise InvalidModelModeError(f"Model mode {model_mode} not support.") + + def _get_prompt_engineering_prompt_template( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + memory: Optional[TokenBufferMemory], + max_token_limit: int = 2000, + ): model_mode = ModelMode.value_of(node_data.model.mode) input_text = query - memory_str = '' - instruction = self._render_instruction(node_data.instruction or '', variable_pool) + memory_str = "" + instruction = variable_pool.convert_template(node_data.instruction or "").text - if memory: - memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit, - message_limit=node_data.memory.window.size) + if memory and node_data.memory and node_data.memory.window: + memory_str = memory.get_history_prompt_text( + max_token_limit=max_token_limit, message_limit=node_data.memory.window.size + ) if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( role=PromptMessageRole.SYSTEM, - text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction) - ) - user_prompt_message = ChatModelMessage( - role=PromptMessageRole.USER, - text=input_text + text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction), ) + user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] elif model_mode == ModelMode.COMPLETION: return CompletionModelPromptTemplate( - text=COMPLETION_GENERATE_JSON_PROMPT.format(histories=memory_str, - text=input_text, - instruction=instruction) - .replace('{γγγ', '') - .replace('}γγγ', '') + text=COMPLETION_GENERATE_JSON_PROMPT.format( + histories=memory_str, text=input_text, instruction=instruction + ) + .replace("{γγγ", "") + .replace("}γγγ", "") ) else: - raise ValueError(f"Model mode {model_mode} not support.") - - def _calculate_rest_token(self, node_data: ParameterExtractorNodeData, query: str, - variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, - context: Optional[str]) -> int: + raise InvalidModelModeError(f"Model mode {model_mode} not support.") + + def _calculate_rest_token( + self, + node_data: ParameterExtractorNodeData, + query: str, + variable_pool: VariablePool, + model_config: ModelConfigWithCredentialsEntity, + context: Optional[str], + ) -> int: prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) model_instance, model_config = self._fetch_model_config(node_data.model) if not isinstance(model_instance.model_type_instance, LargeLanguageModel): - raise ValueError("Model is not a Large Language Model") + raise InvalidModelTypeError("Model is not a Large Language Model") llm_model = model_instance.model_type_instance model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials) if not model_schema: - raise ValueError("Model schema not found") + raise ModelSchemaNotFoundError("Model schema not found") if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) @@ -654,12 +713,12 @@ def _calculate_rest_token(self, node_data: ParameterExtractorNodeData, query: st prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', + query="", files=[], context=context, memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) rest_tokens = 2000 @@ -668,26 +727,28 @@ def _calculate_rest_token(self, node_data: ParameterExtractorNodeData, query: st model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) - curr_message_tokens = model_type_instance.get_num_tokens( - model_config.model, - model_config.credentials, - prompt_messages - ) + 1000 # add 1000 to ensure tool call messages + curr_message_tokens = ( + model_type_instance.get_num_tokens(model_config.model, model_config.credentials, prompt_messages) + 1000 + ) # add 1000 to ensure tool call messages max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template or "") + ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens rest_tokens = max(rest_tokens, 0) return rest_tokens - def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ - ModelInstance, ModelConfigWithCredentialsEntity]: + def _fetch_model_config( + self, node_data_model: ModelConfig + ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config. """ @@ -697,22 +758,27 @@ def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ return self._model_instance, self._model_config @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: ParameterExtractorNodeData) -> dict[ - str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: ParameterExtractorNodeData, + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ - node_data = node_data - - variable_mapping = { - 'query': node_data.query - } + variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query} if node_data.instruction: - variable_template_parser = VariableTemplateParser(template=node_data.instruction) - for selector in variable_template_parser.extract_variable_selectors(): + selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction) + for selector in selectors: variable_mapping[selector.variable] = selector.value_selector + variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} + return variable_mapping diff --git a/api/core/workflow/nodes/parameter_extractor/prompts.py b/api/core/workflow/nodes/parameter_extractor/prompts.py index 499c58d505832c..58fcecc53b09fd 100644 --- a/api/core/workflow/nodes/parameter_extractor/prompts.py +++ b/api/core/workflow/nodes/parameter_extractor/prompts.py @@ -1,4 +1,4 @@ -FUNCTION_CALLING_EXTRACTOR_NAME = 'extract_parameters' +FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters" FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy. ### Task @@ -23,7 +23,7 @@ To illustrate, if the task involves extracting a user's name and their request, your function call might look like this: Ensure your output follows a similar structure to examples. ### Final Output Produce well-formatted function calls in json without XML tags, as shown in the example. -""" +""" # noqa: E501 FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information from context inside XML tags by calling the function {FUNCTION_CALLING_EXTRACTOR_NAME} with the correct parameters with structure inside XML tags. @@ -33,63 +33,52 @@ \x7bstructure\x7d -""" - -FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [{ - 'user': { - 'query': 'What is the weather today in SF?', - 'function': { - 'name': FUNCTION_CALLING_EXTRACTOR_NAME, - 'parameters': { - 'type': 'object', - 'properties': { - 'location': { - 'type': 'string', - 'description': 'The location to get the weather information', - 'required': True +""" # noqa: E501 + +FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [ + { + "user": { + "query": "What is the weather today in SF?", + "function": { + "name": FUNCTION_CALLING_EXTRACTOR_NAME, + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the weather information", + "required": True, + }, }, + "required": ["location"], }, - 'required': ['location'] - } - } + }, + }, + "assistant": { + "text": "I need always call the function with the correct parameters." + " in this case, I need to call the function with the location parameter.", + "function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"location": "San Francisco"}}, + }, }, - 'assistant': { - 'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the location parameter.', - 'function_call' : { - 'name': FUNCTION_CALLING_EXTRACTOR_NAME, - 'parameters': { - 'location': 'San Francisco' - } - } - } -}, { - 'user': { - 'query': 'I want to eat some apple pie.', - 'function': { - 'name': FUNCTION_CALLING_EXTRACTOR_NAME, - 'parameters': { - 'type': 'object', - 'properties': { - 'food': { - 'type': 'string', - 'description': 'The food to eat', - 'required': True - } + { + "user": { + "query": "I want to eat some apple pie.", + "function": { + "name": FUNCTION_CALLING_EXTRACTOR_NAME, + "parameters": { + "type": "object", + "properties": {"food": {"type": "string", "description": "The food to eat", "required": True}}, + "required": ["food"], }, - 'required': ['food'] - } - } + }, + }, + "assistant": { + "text": "I need always call the function with the correct parameters." + " in this case, I need to call the function with the food parameter.", + "function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"food": "apple pie"}}, + }, }, - 'assistant': { - 'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the food parameter.', - 'function_call' : { - 'name': FUNCTION_CALLING_EXTRACTOR_NAME, - 'parameters': { - 'food': 'apple pie' - } - } - } -}] +] COMPLETION_GENERATE_JSON_PROMPT = """### Instructions: Some extra information are provided below, I should always follow the instructions as possible as I can. @@ -130,7 +119,7 @@ ### Answer I should always output a valid JSON object. Output nothing other than the JSON object. ```JSON -""" +""" # noqa: E501 CHAT_GENERATE_JSON_PROMPT = """You should always follow the instructions and output a valid JSON object. The structure of the JSON object you can found in the instructions. @@ -161,46 +150,33 @@ """ -CHAT_EXAMPLE = [{ - 'user': { - 'query': 'What is the weather today in SF?', - 'json': { - 'type': 'object', - 'properties': { - 'location': { - 'type': 'string', - 'description': 'The location to get the weather information', - 'required': True - } +CHAT_EXAMPLE = [ + { + "user": { + "query": "What is the weather today in SF?", + "json": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the weather information", + "required": True, + } + }, + "required": ["location"], }, - 'required': ['location'] - } + }, + "assistant": {"text": "I need to output a valid JSON object.", "json": {"location": "San Francisco"}}, }, - 'assistant': { - 'text': 'I need to output a valid JSON object.', - 'json': { - 'location': 'San Francisco' - } - } -}, { - 'user': { - 'query': 'I want to eat some apple pie.', - 'json': { - 'type': 'object', - 'properties': { - 'food': { - 'type': 'string', - 'description': 'The food to eat', - 'required': True - } + { + "user": { + "query": "I want to eat some apple pie.", + "json": { + "type": "object", + "properties": {"food": {"type": "string", "description": "The food to eat", "required": True}}, + "required": ["food"], }, - 'required': ['food'] - } + }, + "assistant": {"text": "I need to output a valid JSON object.", "json": {"result": "apple pie"}}, }, - 'assistant': { - 'text': 'I need to output a valid JSON object.', - 'json': { - 'result': 'apple pie' - } - } -}] \ No newline at end of file +] diff --git a/api/core/workflow/nodes/question_classifier/__init__.py b/api/core/workflow/nodes/question_classifier/__init__.py index e69de29bb2d1d6..70414c4199efdf 100644 --- a/api/core/workflow/nodes/question_classifier/__init__.py +++ b/api/core/workflow/nodes/question_classifier/__init__.py @@ -0,0 +1,4 @@ +from .entities import QuestionClassifierNodeData +from .question_classifier_node import QuestionClassifierNode + +__all__ = ["QuestionClassifierNodeData", "QuestionClassifierNode"] diff --git a/api/core/workflow/nodes/question_classifier/entities.py b/api/core/workflow/nodes/question_classifier/entities.py index c0b0a8b6968ead..5219f11d267c07 100644 --- a/api/core/workflow/nodes/question_classifier/entities.py +++ b/api/core/workflow/nodes/question_classifier/entities.py @@ -1,36 +1,21 @@ -from typing import Any, Optional +from typing import Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.workflow.entities.base_node_data_entities import BaseNodeData - - -class ModelConfig(BaseModel): - """ - Model Config. - """ - provider: str - name: str - mode: str - completion_params: dict[str, Any] = {} +from core.workflow.nodes.base import BaseNodeData +from core.workflow.nodes.llm import ModelConfig, VisionConfig class ClassConfig(BaseModel): - """ - Class Config. - """ id: str name: str class QuestionClassifierNodeData(BaseNodeData): - """ - Knowledge retrieval Node Data. - """ query_variable_selector: list[str] - type: str = 'question-classifier' model: ModelConfig classes: list[ClassConfig] instruction: Optional[str] = None memory: Optional[MemoryConfig] = None + vision: VisionConfig = Field(default_factory=VisionConfig) diff --git a/api/core/workflow/nodes/question_classifier/exc.py b/api/core/workflow/nodes/question_classifier/exc.py new file mode 100644 index 00000000000000..2c6354e2a70237 --- /dev/null +++ b/api/core/workflow/nodes/question_classifier/exc.py @@ -0,0 +1,6 @@ +class QuestionClassifierNodeError(ValueError): + """Base class for QuestionClassifierNode errors.""" + + +class InvalidModelTypeError(QuestionClassifierNodeError): + """Raised when the model is not a Large Language Model.""" diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 2e1464efcef1e6..744dfd3d8d656b 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,24 +1,32 @@ import json import logging -from typing import Optional, Union, cast +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Any, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.llm_generator.output_parser.errors import OutputParserError from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole -from core.model_runtime.entities.model_entities import ModelPropertyKey +from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.llm.llm_node import LLMNode -from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData -from core.workflow.nodes.question_classifier.template_prompts import ( +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.event import ModelInvokeCompletedEvent +from core.workflow.nodes.llm import ( + LLMNode, + LLMNodeChatModelMessage, + LLMNodeCompletionModelPromptTemplate, +) +from core.workflow.utils.variable_template_parser import VariableTemplateParser +from libs.json_in_md_parser import parse_and_check_json_markdown +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import QuestionClassifierNodeData +from .exc import InvalidModelTypeError +from .template_prompts import ( QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1, QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2, QUESTION_CLASSIFIER_COMPLETION_PROMPT, @@ -27,55 +35,90 @@ QUESTION_CLASSIFIER_USER_PROMPT_2, QUESTION_CLASSIFIER_USER_PROMPT_3, ) -from core.workflow.utils.variable_template_parser import VariableTemplateParser -from libs.json_in_md_parser import parse_and_check_json_markdown -from models.workflow import WorkflowNodeExecutionStatus + +if TYPE_CHECKING: + from core.file import File class QuestionClassifierNode(LLMNode): _node_data_cls = QuestionClassifierNodeData - node_type = NodeType.QUESTION_CLASSIFIER + _node_type = NodeType.QUESTION_CLASSIFIER - def _run(self, variable_pool: VariablePool) -> NodeRunResult: - node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data) - node_data = cast(QuestionClassifierNodeData, node_data) + def _run(self): + node_data = cast(QuestionClassifierNodeData, self.node_data) + variable_pool = self.graph_runtime_state.variable_pool # extract variables - variable = variable_pool.get(node_data.query_variable_selector) + variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None query = variable.value if variable else None - variables = { - 'query': query - } + variables = {"query": query} # fetch model config model_instance, model_config = self._fetch_model_config(node_data.model) # fetch memory - memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) + memory = self._fetch_memory( + node_data_memory=node_data.memory, + model_instance=model_instance, + ) # fetch instruction - instruction = self._format_instruction(node_data.instruction, variable_pool) if node_data.instruction else '' - node_data.instruction = instruction + node_data.instruction = node_data.instruction or "" + node_data.instruction = variable_pool.convert_template(node_data.instruction).text + + files: Sequence[File] = ( + self._fetch_files( + selector=node_data.vision.configs.variable_selector, + ) + if node_data.vision.enabled + else [] + ) + # fetch prompt messages - prompt_messages, stop = self._fetch_prompt( + rest_token = self._calculate_rest_token( node_data=node_data, - context='', - query=query, + query=query or "", + model_config=model_config, + context="", + ) + prompt_template = self._get_prompt_template( + node_data=node_data, + query=query or "", + memory=memory, + max_token_limit=rest_token, + ) + prompt_messages, stop = self._fetch_prompt_messages( + prompt_template=prompt_template, + system_query=query, memory=memory, - model_config=model_config + model_config=model_config, + files=files, + vision_enabled=node_data.vision.enabled, + vision_detail=node_data.vision.configs.detail, ) # handle invoke result - result_text, usage = self._invoke_llm( + generator = self._invoke_llm( node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, - stop=stop + stop=stop, ) + + result_text = "" + usage = LLMUsage.empty_usage() + finish_reason = None + for event in generator: + if isinstance(event, ModelInvokeCompletedEvent): + result_text = event.text + usage = event.usage + finish_reason = event.finish_reason + break + category_name = node_data.classes[0].name category_id = node_data.classes[0].id try: result_text_json = parse_and_check_json_markdown(result_text, []) # result_text_json = json.loads(result_text.strip('```JSON\n')) - if 'category_name' in result_text_json and 'category_id' in result_text_json: - category_id_result = result_text_json['category_id'] + if "category_name" in result_text_json and "category_id" in result_text_json: + category_id_result = result_text_json["category_id"] classes = node_data.classes classes_map = {class_.id: class_.name for class_ in classes} category_ids = [_class.id for _class in classes] @@ -83,20 +126,18 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: category_name = classes_map[category_id_result] category_id = category_id_result - except Exception: - logging.error(f"Failed to parse result text: {result_text}") + except OutputParserError: + logging.exception(f"Failed to parse result text: {result_text}") try: process_data = { - 'model_mode': model_config.mode, - 'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_config.mode, - prompt_messages=prompt_messages + "model_mode": model_config.mode, + "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, prompt_messages=prompt_messages ), - 'usage': jsonable_encoder(usage), - } - outputs = { - 'class_name': category_name + "usage": jsonable_encoder(usage), + "finish_reason": finish_reason, } + outputs = {"class_name": category_name} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -107,8 +148,9 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: metadata={ NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency - } + NodeRunMetadataKey.CURRENCY: usage.currency, + }, + llm_usage=usage, ) except ValueError as e: @@ -119,21 +161,36 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: metadata={ NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency - } + NodeRunMetadataKey.CURRENCY: usage.currency, + }, + llm_usage=usage, ) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: - node_data = node_data - node_data = cast(cls._node_data_cls, node_data) - variable_mapping = {'query': node_data.query_variable_selector} + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: QuestionClassifierNodeData, + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ + variable_mapping = {"query": node_data.query_variable_selector} variable_selectors = [] if node_data.instruction: variable_template_parser = VariableTemplateParser(template=node_data.instruction) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector + + variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} + return variable_mapping @classmethod @@ -143,169 +200,115 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: :param filters: filter by node config parameters. :return: """ - return { - "type": "question-classifier", - "config": { - "instructions": "" - } - } + return {"type": "question-classifier", "config": {"instructions": ""}} - def _fetch_prompt(self, node_data: QuestionClassifierNodeData, - query: str, - context: Optional[str], - memory: Optional[TokenBufferMemory], - model_config: ModelConfigWithCredentialsEntity) \ - -> tuple[list[PromptMessage], Optional[list[str]]]: - """ - Fetch prompt - :param node_data: node data - :param query: inputs - :param context: context - :param memory: memory - :param model_config: model config - :return: - """ - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, model_config, context) - prompt_template = self._get_prompt_template(node_data, query, memory, rest_token) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={}, - query='', - files=[], - context=context, - memory_config=node_data.memory, - memory=None, - model_config=model_config - ) - stop = model_config.stop - - return prompt_messages, stop - - def _calculate_rest_token(self, node_data: QuestionClassifierNodeData, query: str, - model_config: ModelConfigWithCredentialsEntity, - context: Optional[str]) -> int: + def _calculate_rest_token( + self, + node_data: QuestionClassifierNodeData, + query: str, + model_config: ModelConfigWithCredentialsEntity, + context: Optional[str], + ) -> int: prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) prompt_template = self._get_prompt_template(node_data, query, None, 2000) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, inputs={}, - query='', + query="", files=[], context=context, memory_config=node_data.memory, memory=None, - model_config=model_config + model_config=model_config, ) rest_tokens = 2000 model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) - curr_message_tokens = model_instance.get_llm_num_tokens( - prompt_messages - ) + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) max_tokens = 0 for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template or "") + ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens rest_tokens = max(rest_tokens, 0) return rest_tokens - def _get_prompt_template(self, node_data: QuestionClassifierNodeData, query: str, - memory: Optional[TokenBufferMemory], - max_token_limit: int = 2000) \ - -> Union[list[ChatModelMessage], CompletionModelPromptTemplate]: + def _get_prompt_template( + self, + node_data: QuestionClassifierNodeData, + query: str, + memory: Optional[TokenBufferMemory], + max_token_limit: int = 2000, + ): model_mode = ModelMode.value_of(node_data.model.mode) classes = node_data.classes categories = [] for class_ in classes: - category = { - 'category_id': class_.id, - 'category_name': class_.name - } + category = {"category_id": class_.id, "category_name": class_.name} categories.append(category) - instruction = node_data.instruction if node_data.instruction else '' + instruction = node_data.instruction or "" input_text = query - memory_str = '' + memory_str = "" if memory: - memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit, - message_limit=node_data.memory.window.size) - prompt_messages = [] + memory_str = memory.get_history_prompt_text( + max_token_limit=max_token_limit, + message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None, + ) + prompt_messages: list[LLMNodeChatModelMessage] = [] if model_mode == ModelMode.CHAT: - system_prompt_messages = ChatModelMessage( - role=PromptMessageRole.SYSTEM, - text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str) + system_prompt_messages = LLMNodeChatModelMessage( + role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str) ) prompt_messages.append(system_prompt_messages) - user_prompt_message_1 = ChatModelMessage( - role=PromptMessageRole.USER, - text=QUESTION_CLASSIFIER_USER_PROMPT_1 + user_prompt_message_1 = LLMNodeChatModelMessage( + role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_1 ) prompt_messages.append(user_prompt_message_1) - assistant_prompt_message_1 = ChatModelMessage( - role=PromptMessageRole.ASSISTANT, - text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 + assistant_prompt_message_1 = LLMNodeChatModelMessage( + role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 ) prompt_messages.append(assistant_prompt_message_1) - user_prompt_message_2 = ChatModelMessage( - role=PromptMessageRole.USER, - text=QUESTION_CLASSIFIER_USER_PROMPT_2 + user_prompt_message_2 = LLMNodeChatModelMessage( + role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2 ) prompt_messages.append(user_prompt_message_2) - assistant_prompt_message_2 = ChatModelMessage( - role=PromptMessageRole.ASSISTANT, - text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 + assistant_prompt_message_2 = LLMNodeChatModelMessage( + role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 ) prompt_messages.append(assistant_prompt_message_2) - user_prompt_message_3 = ChatModelMessage( + user_prompt_message_3 = LLMNodeChatModelMessage( role=PromptMessageRole.USER, - text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(input_text=input_text, - categories=json.dumps(categories, ensure_ascii=False), - classification_instructions=instruction) + text=QUESTION_CLASSIFIER_USER_PROMPT_3.format( + input_text=input_text, + categories=json.dumps(categories, ensure_ascii=False), + classification_instructions=instruction, + ), ) prompt_messages.append(user_prompt_message_3) return prompt_messages elif model_mode == ModelMode.COMPLETION: - return CompletionModelPromptTemplate( - text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(histories=memory_str, - input_text=input_text, - categories=json.dumps(categories), - classification_instructions=instruction, - ensure_ascii=False) + return LLMNodeCompletionModelPromptTemplate( + text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format( + histories=memory_str, + input_text=input_text, + categories=json.dumps(categories), + classification_instructions=instruction, + ensure_ascii=False, + ) ) else: - raise ValueError(f"Model mode {model_mode} not support.") - - def _format_instruction(self, instruction: str, variable_pool: VariablePool) -> str: - inputs = {} - - variable_selectors = [] - variable_template_parser = VariableTemplateParser(template=instruction) - variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - for variable_selector in variable_selectors: - variable = variable_pool.get(variable_selector.value_selector) - variable_value = variable.value if variable else None - if variable_value is None: - raise ValueError(f'Variable {variable_selector.variable} not found') - - inputs[variable_selector.variable] = variable_value - - prompt_template = PromptTemplateParser(template=instruction, with_variable_tmpl=True) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - - instruction = prompt_template.format( - prompt_inputs - ) - return instruction + raise InvalidModelTypeError(f"Model mode {model_mode} not support.") diff --git a/api/core/workflow/nodes/question_classifier/template_prompts.py b/api/core/workflow/nodes/question_classifier/template_prompts.py index e0de148cc2c08c..4bca2d9dd4edc9 100644 --- a/api/core/workflow/nodes/question_classifier/template_prompts.py +++ b/api/core/workflow/nodes/question_classifier/template_prompts.py @@ -1,12 +1,10 @@ - - QUESTION_CLASSIFIER_SYSTEM_PROMPT = """ ### Job Description', You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories. ### Task - Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output.Additionally, you need to extract the key words from the text that are related to the classification. + Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification. ### Format - The input text is in the variable input_text.Categories are specified as a category list with two filed category_id and category_name in the variable categories .Classification instructions may be included to improve the classification accuracy. + The input text is in the variable input_text. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy. ### Constraint DO NOT include anything other than the JSON array in your response. ### Memory @@ -14,13 +12,13 @@ {histories} -""" +""" # noqa: E501 QUESTION_CLASSIFIER_USER_PROMPT_1 = """ { "input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."], "categories": [{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"},{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"},{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"},{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}], "classification_instructions": ["classify the text based on the feedback provided by customer"]} -""" +""" # noqa: E501 QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = """ ```json @@ -34,7 +32,7 @@ {"input_text": ["bad service, slow to bring the food"], "categories": [{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"},{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"},{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}], "classification_instructions": []} -""" +""" # noqa: E501 QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """ ```json @@ -54,7 +52,7 @@ ### Job Description You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories. ### Task -Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification. +Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification. ### Format The input text is in the variable input_text. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy. ### Constraint @@ -75,4 +73,4 @@ ### User Input {{"input_text" : ["{input_text}"], "categories" : {categories},"classification_instruction" : ["{classification_instructions}"]}} ### Assistant Output -""" +""" # noqa: E501 diff --git a/api/core/workflow/nodes/start/__init__.py b/api/core/workflow/nodes/start/__init__.py index e69de29bb2d1d6..54117804231aa9 100644 --- a/api/core/workflow/nodes/start/__init__.py +++ b/api/core/workflow/nodes/start/__init__.py @@ -0,0 +1,3 @@ +from .start_node import StartNode + +__all__ = ["StartNode"] diff --git a/api/core/workflow/nodes/start/entities.py b/api/core/workflow/nodes/start/entities.py index 0bd5f203bf72a5..594d1b7bab8d68 100644 --- a/api/core/workflow/nodes/start/entities.py +++ b/api/core/workflow/nodes/start/entities.py @@ -1,9 +1,14 @@ +from collections.abc import Sequence + +from pydantic import Field + from core.app.app_config.entities import VariableEntity -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.nodes.base import BaseNodeData class StartNodeData(BaseNodeData): """ Start Node Data """ - variables: list[VariableEntity] = [] + + variables: Sequence[VariableEntity] = Field(default_factory=list) diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 661b403d32f8ca..a7b91e82bbdd92 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,39 +1,35 @@ +from collections.abc import Mapping, Sequence +from typing import Any -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseNode +from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from core.workflow.nodes.start.entities import StartNodeData from models.workflow import WorkflowNodeExecutionStatus -class StartNode(BaseNode): +class StartNode(BaseNode[StartNodeData]): _node_data_cls = StartNodeData _node_type = NodeType.START - def _run(self, variable_pool: VariablePool) -> NodeRunResult: - """ - Run node - :param variable_pool: variable pool - :return: - """ - # Get cleaned inputs - cleaned_inputs = dict(variable_pool.user_inputs) + def _run(self) -> NodeRunResult: + node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) + system_inputs = self.graph_runtime_state.variable_pool.system_variables - for var in variable_pool.system_variables: - cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var] + # TODO: System variables should be directly accessible, no need for special handling + # Set system variables as node outputs. + for var in system_inputs: + node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=cleaned_inputs, - outputs=cleaned_inputs - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=node_inputs) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: - """ - Extract variable selector to variable mapping - :param node_data: node data - :return: - """ + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: StartNodeData, + ) -> Mapping[str, Sequence[str]]: return {} diff --git a/api/core/workflow/nodes/template_transform/__init__.py b/api/core/workflow/nodes/template_transform/__init__.py index e69de29bb2d1d6..43863b9d59aaf3 100644 --- a/api/core/workflow/nodes/template_transform/__init__.py +++ b/api/core/workflow/nodes/template_transform/__init__.py @@ -0,0 +1,3 @@ +from .template_transform_node import TemplateTransformNode + +__all__ = ["TemplateTransformNode"] diff --git a/api/core/workflow/nodes/template_transform/entities.py b/api/core/workflow/nodes/template_transform/entities.py index d9099a8118498e..96adff6ffaa953 100644 --- a/api/core/workflow/nodes/template_transform/entities.py +++ b/api/core/workflow/nodes/template_transform/entities.py @@ -1,12 +1,11 @@ - - -from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.nodes.base import BaseNodeData class TemplateTransformNodeData(BaseNodeData): """ Code Node Data. """ + variables: list[VariableSelector] - template: str \ No newline at end of file + template: str diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 21f71db6c549aa..22a1b218880db9 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -1,16 +1,18 @@ import os -from typing import Optional, cast +from collections.abc import Mapping, Sequence +from typing import Any, Optional -from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseNode +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData from models.workflow import WorkflowNodeExecutionStatus -MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get('TEMPLATE_TRANSFORM_MAX_LENGTH', '80000')) +MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000")) -class TemplateTransformNode(BaseNode): + +class TemplateTransformNode(BaseNode[TemplateTransformNodeData]): _node_data_cls = TemplateTransformNodeData _node_type = NodeType.TEMPLATE_TRANSFORM @@ -23,66 +25,47 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ return { "type": "template-transform", - "config": { - "variables": [ - { - "variable": "arg1", - "value_selector": [] - } - ], - "template": "{{ arg1 }}" - } + "config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"}, } - def _run(self, variable_pool: VariablePool) -> NodeRunResult: - """ - Run node - """ - node_data = self.node_data - node_data: TemplateTransformNodeData = cast(self._node_data_cls, node_data) - + def _run(self) -> NodeRunResult: # Get variables variables = {} - for variable_selector in node_data.variables: + for variable_selector in self.node_data.variables: variable_name = variable_selector.variable - value = variable_pool.get_any(variable_selector.value_selector) - variables[variable_name] = value + value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) + variables[variable_name] = value.to_object() if value else None # Run code try: result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, - code=node_data.template, - inputs=variables + language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables ) - except CodeExecutionException as e: - return NodeRunResult( - inputs=variables, - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e) - ) - - if len(result['result']) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: + except CodeExecutionError as e: + return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) + + if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: return NodeRunResult( inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, - error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters" + error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters", ) return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, - outputs={ - 'output': result['result'] - } + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]} ) - + @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ return { - variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables - } \ No newline at end of file + node_id + "." + variable_selector.variable: variable_selector.value_selector + for variable_selector in node_data.variables + } diff --git a/api/core/workflow/nodes/tool/__init__.py b/api/core/workflow/nodes/tool/__init__.py index e69de29bb2d1d6..f4982e655d193f 100644 --- a/api/core/workflow/nodes/tool/__init__.py +++ b/api/core/workflow/nodes/tool/__init__.py @@ -0,0 +1,3 @@ +from .tool_node import ToolNode + +__all__ = ["ToolNode"] diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 5da5cd07271bae..9e29791481436e 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -3,54 +3,52 @@ from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.nodes.base import BaseNodeData class ToolEntity(BaseModel): provider_id: str - provider_type: Literal['builtin', 'api', 'workflow'] - provider_name: str # redundancy + provider_type: Literal["builtin", "api", "workflow"] + provider_name: str # redundancy tool_name: str - tool_label: str # redundancy + tool_label: str # redundancy tool_configurations: dict[str, Any] - @field_validator('tool_configurations', mode='before') + @field_validator("tool_configurations", mode="before") @classmethod def validate_tool_configurations(cls, value, values: ValidationInfo): if not isinstance(value, dict): - raise ValueError('tool_configurations must be a dictionary') - - for key in values.data.get('tool_configurations', {}).keys(): - value = values.data.get('tool_configurations', {}).get(key) + raise ValueError("tool_configurations must be a dictionary") + + for key in values.data.get("tool_configurations", {}): + value = values.data.get("tool_configurations", {}).get(key) if not isinstance(value, str | int | float | bool): - raise ValueError(f'{key} must be a string') - + raise ValueError(f"{key} must be a string") + return value + class ToolNodeData(BaseNodeData, ToolEntity): class ToolInput(BaseModel): # TODO: check this type value: Union[Any, list[str]] - type: Literal['mixed', 'variable', 'constant'] + type: Literal["mixed", "variable", "constant"] - @field_validator('type', mode='before') + @field_validator("type", mode="before") @classmethod def check_type(cls, value, validation_info: ValidationInfo): typ = value - value = validation_info.data.get('value') - if typ == 'mixed' and not isinstance(value, str): - raise ValueError('value must be a string') - elif typ == 'variable': + value = validation_info.data.get("value") + if typ == "mixed" and not isinstance(value, str): + raise ValueError("value must be a string") + elif typ == "variable": if not isinstance(value, list): - raise ValueError('value must be a list') + raise ValueError("value must be a list") for val in value: if not isinstance(val, str): - raise ValueError('value must be a list of strings') - elif typ == 'constant' and not isinstance(value, str | int | float | bool): - raise ValueError('value must be a string, int, float, or bool') + raise ValueError("value must be a list of strings") + elif typ == "constant" and not isinstance(value, str | int | float | bool): + raise ValueError("value must be a string, int, float, or bool") return typ - """ - Tool Node Schema - """ tool_parameters: dict[str, ToolInput] diff --git a/api/core/workflow/nodes/tool/exc.py b/api/core/workflow/nodes/tool/exc.py new file mode 100644 index 00000000000000..7212e8bfc071bf --- /dev/null +++ b/api/core/workflow/nodes/tool/exc.py @@ -0,0 +1,16 @@ +class ToolNodeError(ValueError): + """Base exception for tool node errors.""" + + pass + + +class ToolParameterError(ToolNodeError): + """Exception raised for errors in tool parameters.""" + + pass + + +class ToolFileError(ToolNodeError): + """Exception raised for errors related to tool files.""" + + pass diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 554e3b6074ed58..6870b7467d11a4 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,24 +1,35 @@ from collections.abc import Mapping, Sequence from os import path -from typing import Any, cast +from typing import Any + +from sqlalchemy import select +from sqlalchemy.orm import Session -from core.app.segments import ArrayAnyVariable, parser from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.file.file_obj import FileTransferMethod, FileType, FileVar +from core.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.tool_engine import ToolEngine from core.tools.tool_manager import ToolManager from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariable -from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.tool.entities import ToolNodeData +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser -from models import WorkflowNodeExecutionStatus +from extensions.ext_database import db +from factories import file_factory +from models import ToolFile +from models.workflow import WorkflowNodeExecutionStatus + +from .entities import ToolNodeData +from .exc import ( + ToolFileError, + ToolNodeError, + ToolParameterError, +) -class ToolNode(BaseNode): +class ToolNode(BaseNode[ToolNodeData]): """ Tool Node """ @@ -26,38 +37,41 @@ class ToolNode(BaseNode): _node_data_cls = ToolNodeData _node_type = NodeType.TOOL - def _run(self, variable_pool: VariablePool) -> NodeRunResult: - """ - Run the tool node - """ - - node_data = cast(ToolNodeData, self.node_data) - + def _run(self) -> NodeRunResult: # fetch tool icon tool_info = { - 'provider_type': node_data.provider_type, - 'provider_id': node_data.provider_id + "provider_type": self.node_data.provider_type, + "provider_id": self.node_data.provider_id, } # get tool runtime try: tool_runtime = ToolManager.get_workflow_tool_runtime( - self.tenant_id, self.app_id, self.node_id, node_data, self.invoke_from + self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from ) - except Exception as e: + except ToolNodeError as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs={}, metadata={ - NodeRunMetadataKey.TOOL_INFO: tool_info + NodeRunMetadataKey.TOOL_INFO: tool_info, }, - error=f'Failed to get tool runtime: {str(e)}' + error=f"Failed to get tool runtime: {str(e)}", ) # get parameters - tool_parameters = tool_runtime.get_runtime_parameters() or [] - parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data) - parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data, for_log=True) + tool_parameters = tool_runtime.parameters or [] + parameters = self._generate_parameters( + tool_parameters=tool_parameters, + variable_pool=self.graph_runtime_state.variable_pool, + node_data=self.node_data, + ) + parameters_for_log = self._generate_parameters( + tool_parameters=tool_parameters, + variable_pool=self.graph_runtime_state.variable_pool, + node_data=self.node_data, + for_log=True, + ) try: messages = ToolEngine.workflow_invoke( @@ -66,15 +80,16 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: user_id=self.user_id, workflow_tool_callback=DifyWorkflowCallbackHandler(), workflow_call_depth=self.workflow_call_depth, + thread_pool_id=self.thread_pool_id, ) - except Exception as e: + except ToolNodeError as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, metadata={ - NodeRunMetadataKey.TOOL_INFO: tool_info + NodeRunMetadataKey.TOOL_INFO: tool_info, }, - error=f'Failed to invoke tool: {str(e)}', + error=f"Failed to invoke tool: {str(e)}", ) # convert tool messages @@ -83,14 +98,14 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={ - 'text': plain_text, - 'files': files, - 'json': json + "text": plain_text, + "files": files, + "json": json, }, metadata={ - NodeRunMetadataKey.TOOL_INFO: tool_info + NodeRunMetadataKey.TOOL_INFO: tool_info, }, - inputs=parameters_for_log + inputs=parameters_for_log, ) def _generate_parameters( @@ -121,31 +136,25 @@ def _generate_parameters( if not parameter: result[parameter_name] = None continue - if parameter.type == ToolParameter.ToolParameterType.FILE: - result[parameter_name] = [ - v.to_dict() for v in self._fetch_files(variable_pool) - ] + tool_input = node_data.tool_parameters[parameter_name] + if tool_input.type == "variable": + variable = variable_pool.get(tool_input.value) + if variable is None: + raise ToolParameterError(f"Variable {tool_input.value} does not exist") + parameter_value = variable.value + elif tool_input.type in {"mixed", "constant"}: + segment_group = variable_pool.convert_template(str(tool_input.value)) + parameter_value = segment_group.log if for_log else segment_group.text else: - tool_input = node_data.tool_parameters[parameter_name] - if tool_input.type == 'variable': - # TODO: check if the variable exists in the variable pool - parameter_value = variable_pool.get(tool_input.value).value - else: - segment_group = parser.convert_template( - template=str(tool_input.value), - variable_pool=variable_pool, - ) - parameter_value = segment_group.log if for_log else segment_group.text - result[parameter_name] = parameter_value + raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'") + result[parameter_name] = parameter_value return result - def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]: - variable = variable_pool.get(['sys', SystemVariable.FILES.value]) - assert isinstance(variable, ArrayAnyVariable) - return list(variable.value) if variable else [] - - def _convert_tool_messages(self, messages: list[ToolInvokeMessage]): + def _convert_tool_messages( + self, + messages: list[ToolInvokeMessage], + ): """ Convert ToolInvokeMessages into tuple[plain_text, files] """ @@ -163,47 +172,81 @@ def _convert_tool_messages(self, messages: list[ToolInvokeMessage]): return plain_text, files, json - def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[FileVar]: + def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[File]: """ Extract tool response binary """ result = [] - for response in tool_response: - if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ - response.type == ToolInvokeMessage.MessageType.IMAGE: - url = response.message - ext = path.splitext(url)[1] - mimetype = response.meta.get('mime_type', 'image/jpeg') - filename = response.save_as or url.split('/')[-1] - transfer_method = response.meta.get('transfer_method', FileTransferMethod.TOOL_FILE) + if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: + url = str(response.message) if response.message else None + ext = path.splitext(url)[1] if url else ".bin" + tool_file_id = str(url).split("/")[-1].split(".")[0] + transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) - # get tool file id - tool_file_id = url.split('/')[-1].split('.')[0] - result.append(FileVar( + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ToolFileError(f"Tool file {tool_file_id} does not exist") + + mapping = { + "tool_file_id": tool_file_id, + "type": FileType.IMAGE, + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( + mapping=mapping, tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=transfer_method, - url=url, - related_id=tool_file_id, - filename=filename, - extension=ext, - mime_type=mimetype, - )) + ) + result.append(file) elif response.type == ToolInvokeMessage.MessageType.BLOB: # get tool file id - tool_file_id = response.message.split('/')[-1].split('.')[0] - result.append(FileVar( + tool_file_id = str(response.message).split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ValueError(f"tool file {tool_file_id} not exists") + mapping = { + "tool_file_id": tool_file_id, + "type": FileType.IMAGE, + "transfer_method": FileTransferMethod.TOOL_FILE, + } + file = file_factory.build_from_mapping( + mapping=mapping, tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=FileTransferMethod.TOOL_FILE, - related_id=tool_file_id, - filename=response.save_as, - extension=path.splitext(response.save_as)[1], - mime_type=response.meta.get('mime_type', 'application/octet-stream'), - )) + ) + result.append(file) elif response.type == ToolInvokeMessage.MessageType.LINK: - pass # TODO: + url = str(response.message) + transfer_method = FileTransferMethod.TOOL_FILE + tool_file_id = url.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ToolFileError(f"Tool file {tool_file_id} does not exist") + if "." in url: + extension = "." + url.split("/")[-1].split(".")[1] + else: + extension = ".bin" + mapping = { + "tool_file_id": tool_file_id, + "type": FileType.IMAGE, + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, + ) + result.append(file) + + elif response.type == ToolInvokeMessage.MessageType.FILE: + assert response.meta is not None + result.append(response.meta["file"]) return result @@ -211,32 +254,47 @@ def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> """ Extract tool response text """ - return '\n'.join([ - f'{message.message}' if message.type == ToolInvokeMessage.MessageType.TEXT else - f'Link: {message.message}' if message.type == ToolInvokeMessage.MessageType.LINK else '' - for message in tool_response - ]) + return "\n".join( + [ + f"{message.message}" + if message.type == ToolInvokeMessage.MessageType.TEXT + else f"Link: {message.message}" + if message.type == ToolInvokeMessage.MessageType.LINK + else "" + for message in tool_response + ] + ) - def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]) -> list[dict]: + def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]): return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON] @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: ToolNodeData, + ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id :param node_data: node data :return: """ result = {} for parameter_name in node_data.tool_parameters: input = node_data.tool_parameters[parameter_name] - if input.type == 'mixed': - selectors = VariableTemplateParser(input.value).extract_variable_selectors() + if input.type == "mixed": + selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() for selector in selectors: result[selector.variable] = selector.value_selector - elif input.type == 'variable': + elif input.type == "variable": result[parameter_name] = input.value - elif input.type == 'constant': + elif input.type == "constant": pass + result = {node_id + "." + key: value for key, value in result.items()} + return result diff --git a/api/core/workflow/nodes/variable_aggregator/__init__.py b/api/core/workflow/nodes/variable_aggregator/__init__.py index e69de29bb2d1d6..0b6bf2a5b62ada 100644 --- a/api/core/workflow/nodes/variable_aggregator/__init__.py +++ b/api/core/workflow/nodes/variable_aggregator/__init__.py @@ -0,0 +1,3 @@ +from .variable_aggregator_node import VariableAggregatorNode + +__all__ = ["VariableAggregatorNode"] diff --git a/api/core/workflow/nodes/variable_aggregator/entities.py b/api/core/workflow/nodes/variable_aggregator/entities.py index cea88334b90738..71a930e6b0a5cb 100644 --- a/api/core/workflow/nodes/variable_aggregator/entities.py +++ b/api/core/workflow/nodes/variable_aggregator/entities.py @@ -1,33 +1,35 @@ - - from typing import Literal, Optional from pydantic import BaseModel -from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.nodes.base import BaseNodeData class AdvancedSettings(BaseModel): """ Advanced setting. """ + group_enabled: bool class Group(BaseModel): """ Group. """ - output_type: Literal['string', 'number', 'array', 'object'] + + output_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] variables: list[list[str]] group_name: str groups: list[Group] + class VariableAssignerNodeData(BaseNodeData): """ Knowledge retrieval Node Data. """ - type: str = 'variable-assigner' + + type: str = "variable-assigner" output_type: str variables: list[list[str]] - advanced_settings: Optional[AdvancedSettings] = None \ No newline at end of file + advanced_settings: Optional[AdvancedSettings] = None diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index 885f7d76170f94..031a7b83095541 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,53 +1,51 @@ -from typing import cast +from collections.abc import Mapping, Sequence +from typing import Any -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseNode +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData from models.workflow import WorkflowNodeExecutionStatus -class VariableAggregatorNode(BaseNode): +class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): _node_data_cls = VariableAssignerNodeData _node_type = NodeType.VARIABLE_AGGREGATOR - def _run(self, variable_pool: VariablePool) -> NodeRunResult: - node_data = cast(VariableAssignerNodeData, self.node_data) + def _run(self) -> NodeRunResult: # Get variables outputs = {} inputs = {} - if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled: - for selector in node_data.variables: - variable = variable_pool.get_any(selector) + if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled: + for selector in self.node_data.variables: + variable = self.graph_runtime_state.variable_pool.get(selector) if variable is not None: - outputs = { - "output": variable - } + outputs = {"output": variable.to_object()} - inputs = { - '.'.join(selector[1:]): variable - } + inputs = {".".join(selector[1:]): variable.to_object()} break else: - for group in node_data.advanced_settings.groups: + for group in self.node_data.advanced_settings.groups: for selector in group.variables: - variable = variable_pool.get_any(selector) + variable = self.graph_runtime_state.variable_pool.get(selector) if variable is not None: - outputs[group.group_name] = { - 'output': variable - } - inputs['.'.join(selector[1:])] = variable + outputs[group.group_name] = {"output": variable.to_object()} + inputs[".".join(selector[1:])] = variable.to_object() break - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs=outputs, - inputs=inputs - ) + return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, inputs=inputs) @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping( + cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: VariableAssignerNodeData + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ return {} diff --git a/api/core/workflow/nodes/variable_assigner/__init__.py b/api/core/workflow/nodes/variable_assigner/__init__.py index 552cc367f2674f..83da4bdc79bb21 100644 --- a/api/core/workflow/nodes/variable_assigner/__init__.py +++ b/api/core/workflow/nodes/variable_assigner/__init__.py @@ -1,109 +1,8 @@ -from collections.abc import Sequence -from enum import Enum -from typing import Optional, cast - -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.app.segments import SegmentType, Variable, factory -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import BaseNode -from extensions.ext_database import db -from models import ConversationVariable, WorkflowNodeExecutionStatus - - -class VariableAssignerNodeError(Exception): - pass - - -class WriteMode(str, Enum): - OVER_WRITE = 'over-write' - APPEND = 'append' - CLEAR = 'clear' - - -class VariableAssignerData(BaseNodeData): - title: str = 'Variable Assigner' - desc: Optional[str] = 'Assign a value to a variable' - assigned_variable_selector: Sequence[str] - write_mode: WriteMode - input_variable_selector: Sequence[str] - - -class VariableAssignerNode(BaseNode): - _node_data_cls: type[BaseNodeData] = VariableAssignerData - _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER - - def _run(self, variable_pool: VariablePool) -> NodeRunResult: - data = cast(VariableAssignerData, self.node_data) - - # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject - original_variable = variable_pool.get(data.assigned_variable_selector) - if not isinstance(original_variable, Variable): - raise VariableAssignerNodeError('assigned variable not found') - - match data.write_mode: - case WriteMode.OVER_WRITE: - income_value = variable_pool.get(data.input_variable_selector) - if not income_value: - raise VariableAssignerNodeError('input value not found') - updated_variable = original_variable.model_copy(update={'value': income_value.value}) - - case WriteMode.APPEND: - income_value = variable_pool.get(data.input_variable_selector) - if not income_value: - raise VariableAssignerNodeError('input value not found') - updated_value = original_variable.value + [income_value.value] - updated_variable = original_variable.model_copy(update={'value': updated_value}) - - case WriteMode.CLEAR: - income_value = get_zero_value(original_variable.value_type) - updated_variable = original_variable.model_copy(update={'value': income_value.to_object()}) - - case _: - raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}') - - # Over write the variable. - variable_pool.add(data.assigned_variable_selector, updated_variable) - - # Update conversation variable. - # TODO: Find a better way to use the database. - conversation_id = variable_pool.get(['sys', 'conversation_id']) - if not conversation_id: - raise VariableAssignerNodeError('conversation_id not found') - update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={ - 'value': income_value.to_object(), - }, - ) - - -def update_conversation_variable(conversation_id: str, variable: Variable): - stmt = select(ConversationVariable).where( - ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id - ) - with Session(db.engine) as session: - row = session.scalar(stmt) - if not row: - raise VariableAssignerNodeError('conversation variable not found in the database') - row.data = variable.model_dump_json() - session.commit() - - -def get_zero_value(t: SegmentType): - match t: - case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER: - return factory.build_segment([]) - case SegmentType.OBJECT: - return factory.build_segment({}) - case SegmentType.STRING: - return factory.build_segment('') - case SegmentType.NUMBER: - return factory.build_segment(0) - case _: - raise VariableAssignerNodeError(f'unsupported variable type: {t}') +from .node import VariableAssignerNode +from .node_data import VariableAssignerData, WriteMode + +__all__ = [ + "VariableAssignerNode", + "VariableAssignerData", + "WriteMode", +] diff --git a/api/core/workflow/nodes/variable_assigner/exc.py b/api/core/workflow/nodes/variable_assigner/exc.py new file mode 100644 index 00000000000000..914be2225642cd --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/exc.py @@ -0,0 +1,2 @@ +class VariableAssignerNodeError(Exception): + pass diff --git a/api/core/workflow/nodes/variable_assigner/node.py b/api/core/workflow/nodes/variable_assigner/node.py new file mode 100644 index 00000000000000..4e66f640dff963 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/node.py @@ -0,0 +1,89 @@ +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.variables import SegmentType, Variable +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode, BaseNodeData +from core.workflow.nodes.enums import NodeType +from extensions.ext_database import db +from factories import variable_factory +from models import ConversationVariable +from models.workflow import WorkflowNodeExecutionStatus + +from .exc import VariableAssignerNodeError +from .node_data import VariableAssignerData, WriteMode + + +class VariableAssignerNode(BaseNode[VariableAssignerData]): + _node_data_cls: type[BaseNodeData] = VariableAssignerData + _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER + + def _run(self) -> NodeRunResult: + # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject + original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector) + if not isinstance(original_variable, Variable): + raise VariableAssignerNodeError("assigned variable not found") + + match self.node_data.write_mode: + case WriteMode.OVER_WRITE: + income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) + if not income_value: + raise VariableAssignerNodeError("input value not found") + updated_variable = original_variable.model_copy(update={"value": income_value.value}) + + case WriteMode.APPEND: + income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) + if not income_value: + raise VariableAssignerNodeError("input value not found") + updated_value = original_variable.value + [income_value.value] + updated_variable = original_variable.model_copy(update={"value": updated_value}) + + case WriteMode.CLEAR: + income_value = get_zero_value(original_variable.value_type) + updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) + + case _: + raise VariableAssignerNodeError(f"unsupported write mode: {self.node_data.write_mode}") + + # Over write the variable. + self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable) + + # TODO: Move database operation to the pipeline. + # Update conversation variable. + conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"]) + if not conversation_id: + raise VariableAssignerNodeError("conversation_id not found") + update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={ + "value": income_value.to_object(), + }, + ) + + +def update_conversation_variable(conversation_id: str, variable: Variable): + stmt = select(ConversationVariable).where( + ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id + ) + with Session(db.engine) as session: + row = session.scalar(stmt) + if not row: + raise VariableAssignerNodeError("conversation variable not found in the database") + row.data = variable.model_dump_json() + session.commit() + + +def get_zero_value(t: SegmentType): + match t: + case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER: + return variable_factory.build_segment([]) + case SegmentType.OBJECT: + return variable_factory.build_segment({}) + case SegmentType.STRING: + return variable_factory.build_segment("") + case SegmentType.NUMBER: + return variable_factory.build_segment(0) + case _: + raise VariableAssignerNodeError(f"unsupported variable type: {t}") diff --git a/api/core/workflow/nodes/variable_assigner/node_data.py b/api/core/workflow/nodes/variable_assigner/node_data.py new file mode 100644 index 00000000000000..70ae29d45f47be --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/node_data.py @@ -0,0 +1,19 @@ +from collections.abc import Sequence +from enum import Enum +from typing import Optional + +from core.workflow.nodes.base import BaseNodeData + + +class WriteMode(str, Enum): + OVER_WRITE = "over-write" + APPEND = "append" + CLEAR = "clear" + + +class VariableAssignerData(BaseNodeData): + title: str = "Variable Assigner" + desc: Optional[str] = "Assign a value to a variable" + assigned_variable_selector: Sequence[str] + write_mode: WriteMode + input_variable_selector: Sequence[str] diff --git a/api/core/workflow/utils/condition/__init__.py b/api/core/workflow/utils/condition/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/workflow/utils/condition/entities.py b/api/core/workflow/utils/condition/entities.py new file mode 100644 index 00000000000000..799c735f5409ee --- /dev/null +++ b/api/core/workflow/utils/condition/entities.py @@ -0,0 +1,49 @@ +from collections.abc import Sequence +from typing import Literal + +from pydantic import BaseModel, Field + +SupportedComparisonOperator = Literal[ + # for string or array + "contains", + "not contains", + "start with", + "end with", + "is", + "is not", + "empty", + "not empty", + "in", + "not in", + "all of", + # for number + "=", + "≠", + ">", + "<", + "≥", + "≤", + "null", + "not null", + # for file + "exists", + "not exists", +] + + +class SubCondition(BaseModel): + key: str + comparison_operator: SupportedComparisonOperator + value: str | Sequence[str] | None = None + + +class SubVariableCondition(BaseModel): + logical_operator: Literal["and", "or"] + conditions: list[SubCondition] = Field(default=list) + + +class Condition(BaseModel): + variable_selector: list[str] + comparison_operator: SupportedComparisonOperator + value: str | Sequence[str] | None = None + sub_variable_condition: SubVariableCondition | None = None diff --git a/api/core/workflow/utils/condition/processor.py b/api/core/workflow/utils/condition/processor.py new file mode 100644 index 00000000000000..19473f39d2299a --- /dev/null +++ b/api/core/workflow/utils/condition/processor.py @@ -0,0 +1,385 @@ +from collections.abc import Sequence +from typing import Any, Literal + +from core.file import FileAttribute, file_manager +from core.variables import ArrayFileSegment +from core.workflow.entities.variable_pool import VariablePool + +from .entities import Condition, SubCondition, SupportedComparisonOperator + + +class ConditionProcessor: + def process_conditions( + self, + *, + variable_pool: VariablePool, + conditions: Sequence[Condition], + operator: Literal["and", "or"], + ): + input_conditions = [] + group_results = [] + + for condition in conditions: + variable = variable_pool.get(condition.variable_selector) + if variable is None: + raise ValueError(f"Variable {condition.variable_selector} not found") + + if isinstance(variable, ArrayFileSegment) and condition.comparison_operator in { + "contains", + "not contains", + "all of", + }: + # check sub conditions + if not condition.sub_variable_condition: + raise ValueError("Sub variable is required") + result = _process_sub_conditions( + variable=variable, + sub_conditions=condition.sub_variable_condition.conditions, + operator=condition.sub_variable_condition.logical_operator, + ) + elif condition.comparison_operator in { + "exists", + "not exists", + }: + result = _evaluate_condition( + value=variable.value, + operator=condition.comparison_operator, + expected=None, + ) + else: + actual_value = variable.value if variable else None + expected_value = condition.value + if isinstance(expected_value, str): + expected_value = variable_pool.convert_template(expected_value).text + input_conditions.append( + { + "actual_value": actual_value, + "expected_value": expected_value, + "comparison_operator": condition.comparison_operator, + } + ) + result = _evaluate_condition( + value=actual_value, + operator=condition.comparison_operator, + expected=expected_value, + ) + group_results.append(result) + + final_result = all(group_results) if operator == "and" else any(group_results) + return input_conditions, group_results, final_result + + +def _evaluate_condition( + *, + operator: SupportedComparisonOperator, + value: Any, + expected: str | Sequence[str] | None, +) -> bool: + match operator: + case "contains": + return _assert_contains(value=value, expected=expected) + case "not contains": + return _assert_not_contains(value=value, expected=expected) + case "start with": + return _assert_start_with(value=value, expected=expected) + case "end with": + return _assert_end_with(value=value, expected=expected) + case "is": + return _assert_is(value=value, expected=expected) + case "is not": + return _assert_is_not(value=value, expected=expected) + case "empty": + return _assert_empty(value=value) + case "not empty": + return _assert_not_empty(value=value) + case "=": + return _assert_equal(value=value, expected=expected) + case "≠": + return _assert_not_equal(value=value, expected=expected) + case ">": + return _assert_greater_than(value=value, expected=expected) + case "<": + return _assert_less_than(value=value, expected=expected) + case "≥": + return _assert_greater_than_or_equal(value=value, expected=expected) + case "≤": + return _assert_less_than_or_equal(value=value, expected=expected) + case "null": + return _assert_null(value=value) + case "not null": + return _assert_not_null(value=value) + case "in": + return _assert_in(value=value, expected=expected) + case "not in": + return _assert_not_in(value=value, expected=expected) + case "all of" if isinstance(expected, list): + return _assert_all_of(value=value, expected=expected) + case "exists": + return _assert_exists(value=value) + case "not exists": + return _assert_not_exists(value=value) + case _: + raise ValueError(f"Unsupported operator: {operator}") + + +def _assert_contains(*, value: Any, expected: Any) -> bool: + if not value: + return False + + if not isinstance(value, str | list): + raise ValueError("Invalid actual value type: string or array") + + if expected not in value: + return False + return True + + +def _assert_not_contains(*, value: Any, expected: Any) -> bool: + if not value: + return True + + if not isinstance(value, str | list): + raise ValueError("Invalid actual value type: string or array") + + if expected in value: + return False + return True + + +def _assert_start_with(*, value: Any, expected: Any) -> bool: + if not value: + return False + + if not isinstance(value, str): + raise ValueError("Invalid actual value type: string") + + if not value.startswith(expected): + return False + return True + + +def _assert_end_with(*, value: Any, expected: Any) -> bool: + if not value: + return False + + if not isinstance(value, str): + raise ValueError("Invalid actual value type: string") + + if not value.endswith(expected): + return False + return True + + +def _assert_is(*, value: Any, expected: Any) -> bool: + if value is None: + return False + + if not isinstance(value, str): + raise ValueError("Invalid actual value type: string") + + if value != expected: + return False + return True + + +def _assert_is_not(*, value: Any, expected: Any) -> bool: + if value is None: + return False + + if not isinstance(value, str): + raise ValueError("Invalid actual value type: string") + + if value == expected: + return False + return True + + +def _assert_empty(*, value: Any) -> bool: + if not value: + return True + return False + + +def _assert_not_empty(*, value: Any) -> bool: + if value: + return True + return False + + +def _assert_equal(*, value: Any, expected: Any) -> bool: + if value is None: + return False + + if not isinstance(value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(value, int): + expected = int(expected) + else: + expected = float(expected) + + if value != expected: + return False + return True + + +def _assert_not_equal(*, value: Any, expected: Any) -> bool: + if value is None: + return False + + if not isinstance(value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(value, int): + expected = int(expected) + else: + expected = float(expected) + + if value == expected: + return False + return True + + +def _assert_greater_than(*, value: Any, expected: Any) -> bool: + if value is None: + return False + + if not isinstance(value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(value, int): + expected = int(expected) + else: + expected = float(expected) + + if value <= expected: + return False + return True + + +def _assert_less_than(*, value: Any, expected: Any) -> bool: + if value is None: + return False + + if not isinstance(value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(value, int): + expected = int(expected) + else: + expected = float(expected) + + if value >= expected: + return False + return True + + +def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool: + if value is None: + return False + + if not isinstance(value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(value, int): + expected = int(expected) + else: + expected = float(expected) + + if value < expected: + return False + return True + + +def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool: + if value is None: + return False + + if not isinstance(value, int | float): + raise ValueError("Invalid actual value type: number") + + if isinstance(value, int): + expected = int(expected) + else: + expected = float(expected) + + if value > expected: + return False + return True + + +def _assert_null(*, value: Any) -> bool: + if value is None: + return True + return False + + +def _assert_not_null(*, value: Any) -> bool: + if value is not None: + return True + return False + + +def _assert_in(*, value: Any, expected: Any) -> bool: + if not value: + return False + + if not isinstance(expected, list): + raise ValueError("Invalid expected value type: array") + + if value not in expected: + return False + return True + + +def _assert_not_in(*, value: Any, expected: Any) -> bool: + if not value: + return True + + if not isinstance(expected, list): + raise ValueError("Invalid expected value type: array") + + if value in expected: + return False + return True + + +def _assert_all_of(*, value: Any, expected: Sequence[str]) -> bool: + if not value: + return False + + if not all(item in value for item in expected): + return False + return True + + +def _assert_exists(*, value: Any) -> bool: + return value is not None + + +def _assert_not_exists(*, value: Any) -> bool: + return value is None + + +def _process_sub_conditions( + variable: ArrayFileSegment, + sub_conditions: Sequence[SubCondition], + operator: Literal["and", "or"], +) -> bool: + files = variable.value + group_results = [] + for condition in sub_conditions: + key = FileAttribute(condition.key) + values = [file_manager.get_attr(file=file, attr=key) for file in files] + sub_group_results = [ + _evaluate_condition( + value=value, + operator=condition.comparison_operator, + expected=condition.value, + ) + for value in values + ] + # Determine the result based on the presence of "not" in the comparison operator + result = all(sub_group_results) if "not" in condition.comparison_operator else any(sub_group_results) + group_results.append(result) + return all(group_results) if operator == "and" else any(group_results) diff --git a/api/core/workflow/utils/variable_template_parser.py b/api/core/workflow/utils/variable_template_parser.py index c43fde172c7b7a..1d8fb38ebf8237 100644 --- a/api/core/workflow/utils/variable_template_parser.py +++ b/api/core/workflow/utils/variable_template_parser.py @@ -1,42 +1,21 @@ import re -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from typing import Any from core.workflow.entities.variable_entities import VariableSelector -from core.workflow.entities.variable_pool import VariablePool -REGEX = re.compile(r'\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}') +REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}") +SELECTOR_PATTERN = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}") -def parse_mixed_template(*, template: str, variable_pool: VariablePool) -> str: - """ - This is an alternative to the VariableTemplateParser class, - offering the same functionality but with better readability and ease of use. - """ - variable_keys = [match[0] for match in re.findall(REGEX, template)] - variable_keys = list(set(variable_keys)) - - # This key_selector is a tuple of (key, selector) where selector is a list of keys - # e.g. ('#node_id.query.name#', ['node_id', 'query', 'name']) - key_selectors = filter( - lambda t: len(t[1]) >= 2, - ((key, selector.replace('#', '').split('.')) for key, selector in zip(variable_keys, variable_keys)), - ) - inputs = {key: variable_pool.get_any(selector) for key, selector in key_selectors} - - def replacer(match): - key = match.group(1) - # return original matched string if key not found - value = inputs.get(key, match.group(0)) - if value is None: - value = '' - value = str(value) - # remove template variables if required - return re.sub(REGEX, r'{\1}', value) - - result = re.sub(REGEX, replacer, template) - result = re.sub(r'<\|.*?\|>', '', result) - return result + +def extract_selectors_from_template(template: str, /) -> Sequence[VariableSelector]: + parts = SELECTOR_PATTERN.split(template) + selectors = [] + for part in filter(lambda x: x, parts): + if "." in part and part[0] == "#" and part[-1] == "#": + selectors.append(VariableSelector(variable=f"{part}", value_selector=part[1:-1].split("."))) + return selectors class VariableTemplateParser: @@ -101,8 +80,8 @@ def extract_variable_selectors(self) -> list[VariableSelector]: """ variable_selectors = [] for variable_key in self.variable_keys: - remove_hash = variable_key.replace('#', '') - split_result = remove_hash.split('.') + remove_hash = variable_key.replace("#", "") + split_result = remove_hash.split(".") if len(split_result) < 2: continue @@ -127,7 +106,7 @@ def replacer(match): value = inputs.get(key, match.group(0)) # return original matched string if key not found if value is None: - value = '' + value = "" # convert the value to string if isinstance(value, list | dict | bool | int | float): value = str(value) @@ -136,7 +115,7 @@ def replacer(match): return VariableTemplateParser.remove_template_variables(value) prompt = re.sub(REGEX, replacer, self.template) - return re.sub(r'<\|.*?\|>', '', prompt) + return re.sub(r"<\|.*?\|>", "", prompt) @classmethod def remove_template_variables(cls, text: str): @@ -149,4 +128,4 @@ def remove_template_variables(cls, text: str): Returns: The text with template variables removed. """ - return re.sub(REGEX, r'{\1}', text) + return re.sub(REGEX, r"{\1}", text) diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 3157eedfee5238..e69de29bb2d1d6 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,1005 +0,0 @@ -import logging -import time -from collections.abc import Mapping, Sequence -from typing import Any, Optional, cast - -import contexts -from configs import dify_config -from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException -from core.app.entities.app_invoke_entities import InvokeFrom -from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar -from core.workflow.callbacks.base_workflow_callback import WorkflowCallback -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool, VariableValue -from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState -from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom -from core.workflow.nodes.code.code_node import CodeNode -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.http_request.http_request_node import HttpRequestNode -from core.workflow.nodes.if_else.if_else_node import IfElseNode -from core.workflow.nodes.iteration.entities import IterationState -from core.workflow.nodes.iteration.iteration_node import IterationNode -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode -from core.workflow.nodes.llm.entities import LLMNodeData -from core.workflow.nodes.llm.llm_node import LLMNode -from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode -from core.workflow.nodes.tool.tool_node import ToolNode -from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode -from core.workflow.nodes.variable_assigner import VariableAssignerNode -from extensions.ext_database import db -from models.workflow import ( - Workflow, - WorkflowNodeExecutionStatus, -) - -node_classes: Mapping[NodeType, type[BaseNode]] = { - NodeType.START: StartNode, - NodeType.END: EndNode, - NodeType.ANSWER: AnswerNode, - NodeType.LLM: LLMNode, - NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, - NodeType.IF_ELSE: IfElseNode, - NodeType.CODE: CodeNode, - NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode, - NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, - NodeType.HTTP_REQUEST: HttpRequestNode, - NodeType.TOOL: ToolNode, - NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, - NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, - NodeType.ITERATION: IterationNode, - NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode, - NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode, -} - -logger = logging.getLogger(__name__) - - -class WorkflowEngineManager: - def get_default_configs(self) -> list[dict]: - """ - Get default block configs - """ - default_block_configs = [] - for node_type, node_class in node_classes.items(): - default_config = node_class.get_default_config() - if default_config: - default_block_configs.append(default_config) - - return default_block_configs - - def get_default_config(self, node_type: NodeType, filters: Optional[dict] = None) -> Optional[dict]: - """ - Get default config of node. - :param node_type: node type - :param filters: filter by node config parameters. - :return: - """ - node_class = node_classes.get(node_type) - if not node_class: - return None - - default_config = node_class.get_default_config(filters=filters) - if not default_config: - return None - - return default_config - - def run_workflow( - self, - *, - workflow: Workflow, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - callbacks: Sequence[WorkflowCallback], - call_depth: int = 0, - variable_pool: VariablePool | None = None, - ) -> None: - """ - :param workflow: Workflow instance - :param user_id: user id - :param user_from: user from - :param invoke_from: invoke from - :param callbacks: workflow callbacks - :param call_depth: call depth - :param variable_pool: variable pool - """ - # fetch workflow graph - graph = workflow.graph_dict - if not graph: - raise ValueError('workflow graph not found') - - if 'nodes' not in graph or 'edges' not in graph: - raise ValueError('nodes or edges not found in workflow graph') - - if not isinstance(graph.get('nodes'), list): - raise ValueError('nodes in workflow graph must be a list') - - if not isinstance(graph.get('edges'), list): - raise ValueError('edges in workflow graph must be a list') - - - workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH - if call_depth > workflow_call_max_depth: - raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) - - # init workflow run state - if not variable_pool: - variable_pool = contexts.workflow_variable_pool.get() - workflow_run_state = WorkflowRunState( - workflow=workflow, - start_at=time.perf_counter(), - variable_pool=variable_pool, - user_id=user_id, - user_from=user_from, - invoke_from=invoke_from, - workflow_call_depth=call_depth - ) - - # init workflow run - if callbacks: - for callback in callbacks: - callback.on_workflow_run_started() - - # run workflow - self._run_workflow( - workflow=workflow, - workflow_run_state=workflow_run_state, - callbacks=callbacks, - ) - - def _run_workflow(self, workflow: Workflow, - workflow_run_state: WorkflowRunState, - callbacks: Sequence[WorkflowCallback], - start_at: Optional[str] = None, - end_at: Optional[str] = None) -> None: - """ - Run workflow - :param workflow: Workflow instance - :param user_id: user id - :param user_from: user from - :param user_inputs: user variables inputs - :param system_inputs: system inputs, like: query, files - :param callbacks: workflow callbacks - :param call_depth: call depth - :param start_at: force specific start node - :param end_at: force specific end node - :return: - """ - graph = workflow.graph_dict - - try: - answer_prov_node_ids = [] - for node in graph.get('nodes', []): - if node.get('id', '') == 'answer': - try: - answer_prov_node_ids.append(node.get('data', {}) - .get('answer', '') - .replace('#', '') - .replace('.text', '') - .replace('{{', '') - .replace('}}', '').split('.')[0]) - except Exception as e: - logger.error(e) - - predecessor_node: BaseNode | None = None - current_iteration_node: BaseIterationNode | None = None - has_entry_node = False - max_execution_steps = dify_config.WORKFLOW_MAX_EXECUTION_STEPS - max_execution_time = dify_config.WORKFLOW_MAX_EXECUTION_TIME - while True: - # get next node, multiple target nodes in the future - next_node = self._get_next_overall_node( - workflow_run_state=workflow_run_state, - graph=graph, - predecessor_node=predecessor_node, - callbacks=callbacks, - start_at=start_at, - end_at=end_at - ) - - if not next_node: - # reached loop/iteration end or overall end - if current_iteration_node and workflow_run_state.current_iteration_state: - # reached loop/iteration end - # get next iteration - next_iteration = current_iteration_node.get_next_iteration( - variable_pool=workflow_run_state.variable_pool, - state=workflow_run_state.current_iteration_state - ) - self._workflow_iteration_next( - graph=graph, - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - callbacks=callbacks - ) - if isinstance(next_iteration, NodeRunResult): - if next_iteration.outputs: - for variable_key, variable_value in next_iteration.outputs.items(): - # append variables to variable pool recursively - self._append_variables_recursively( - variable_pool=workflow_run_state.variable_pool, - node_id=current_iteration_node.node_id, - variable_key_list=[variable_key], - variable_value=variable_value - ) - self._workflow_iteration_completed( - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - callbacks=callbacks - ) - # iteration has ended - next_node = self._get_next_overall_node( - workflow_run_state=workflow_run_state, - graph=graph, - predecessor_node=current_iteration_node, - callbacks=callbacks, - start_at=start_at, - end_at=end_at - ) - current_iteration_node = None - workflow_run_state.current_iteration_state = None - # continue overall process - elif isinstance(next_iteration, str): - # move to next iteration - next_node_id = next_iteration - # get next id - next_node = self._get_node(workflow_run_state=workflow_run_state, graph=graph, node_id=next_node_id, callbacks=callbacks) - - if not next_node: - break - - # check is already ran - if self._check_node_has_ran(workflow_run_state, next_node.node_id): - predecessor_node = next_node - continue - - has_entry_node = True - - # max steps reached - if workflow_run_state.workflow_node_steps > max_execution_steps: - raise ValueError('Max steps {} reached.'.format(max_execution_steps)) - - # or max execution time reached - if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=max_execution_time): - raise ValueError('Max execution time {}s reached.'.format(max_execution_time)) - - # handle iteration nodes - if isinstance(next_node, BaseIterationNode): - current_iteration_node = next_node - workflow_run_state.current_iteration_state = next_node.run( - variable_pool=workflow_run_state.variable_pool - ) - self._workflow_iteration_started( - graph=graph, - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - predecessor_node_id=predecessor_node.node_id if predecessor_node else None, - callbacks=callbacks - ) - predecessor_node = next_node - # move to start node of iteration - next_node_id = next_node.get_next_iteration( - variable_pool=workflow_run_state.variable_pool, - state=workflow_run_state.current_iteration_state - ) - self._workflow_iteration_next( - graph=graph, - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - callbacks=callbacks - ) - if isinstance(next_node_id, NodeRunResult): - # iteration has ended - current_iteration_node.set_output( - variable_pool=workflow_run_state.variable_pool, - state=workflow_run_state.current_iteration_state - ) - self._workflow_iteration_completed( - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - callbacks=callbacks - ) - current_iteration_node = None - workflow_run_state.current_iteration_state = None - continue - else: - next_node = self._get_node(workflow_run_state=workflow_run_state, graph=graph, node_id=next_node_id, callbacks=callbacks) - - if next_node and next_node.node_id in answer_prov_node_ids: - next_node.is_answer_previous_node = True - - # run workflow, run multiple target nodes in the future - self._run_workflow_node( - workflow_run_state=workflow_run_state, - node=next_node, - predecessor_node=predecessor_node, - callbacks=callbacks - ) - - if next_node.node_type in [NodeType.END]: - break - - predecessor_node = next_node - - if not has_entry_node: - self._workflow_run_failed( - error='Start node not found in workflow graph.', - callbacks=callbacks - ) - return - except GenerateTaskStoppedException as e: - return - except Exception as e: - self._workflow_run_failed( - error=str(e), - callbacks=callbacks - ) - return - - # workflow run success - self._workflow_run_success( - callbacks=callbacks - ) - - def single_step_run_workflow_node(self, workflow: Workflow, - node_id: str, - user_id: str, - user_inputs: dict) -> tuple[BaseNode, NodeRunResult]: - """ - Single step run workflow node - :param workflow: Workflow instance - :param node_id: node id - :param user_id: user id - :param user_inputs: user inputs - :return: - """ - # fetch node info from workflow graph - graph = workflow.graph_dict - if not graph: - raise ValueError('workflow graph not found') - - nodes = graph.get('nodes') - if not nodes: - raise ValueError('nodes not found in workflow graph') - - # fetch node config from node id - node_config = None - for node in nodes: - if node.get('id') == node_id: - node_config = node - break - - if not node_config: - raise ValueError('node id not found in workflow graph') - - # Get node class - node_type = NodeType.value_of(node_config.get('data', {}).get('type')) - node_cls = node_classes.get(node_type) - - # init workflow run state - node_instance = node_cls( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - workflow_id=workflow.id, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - config=node_config, - workflow_call_depth=0 - ) - - try: - # init variable pool - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - environment_variables=workflow.environment_variables, - conversation_variables=workflow.conversation_variables, - ) - - if node_cls is None: - raise ValueError('Node class not found') - # variable selector to variable mapping - variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(node_config) - - self._mapping_user_inputs_to_variable_pool( - variable_mapping=variable_mapping, - user_inputs=user_inputs, - variable_pool=variable_pool, - tenant_id=workflow.tenant_id, - node_instance=node_instance - ) - - # run node - node_run_result = node_instance.run( - variable_pool=variable_pool - ) - - # sign output files - node_run_result.outputs = self.handle_special_values(node_run_result.outputs) - except Exception as e: - raise WorkflowNodeRunFailedError( - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_title=node_instance.node_data.title, - error=str(e) - ) - - return node_instance, node_run_result - - def single_step_run_iteration_workflow_node(self, workflow: Workflow, - node_id: str, - user_id: str, - user_inputs: dict, - callbacks: Sequence[WorkflowCallback], - ) -> None: - """ - Single iteration run workflow node - """ - # fetch node info from workflow graph - graph = workflow.graph_dict - if not graph: - raise ValueError('workflow graph not found') - - nodes = graph.get('nodes') - if not nodes: - raise ValueError('nodes not found in workflow graph') - - for node in nodes: - if node.get('id') == node_id: - if node.get('data', {}).get('type') in [ - NodeType.ITERATION.value, - NodeType.LOOP.value, - ]: - node_config = node - else: - raise ValueError('node id is not an iteration node') - - # init variable pool - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - environment_variables=workflow.environment_variables, - conversation_variables=workflow.conversation_variables, - ) - - # variable selector to variable mapping - iteration_nested_nodes = [ - node for node in nodes - if node.get('data', {}).get('iteration_id') == node_id or node.get('id') == node_id - ] - iteration_nested_node_ids = [node.get('id') for node in iteration_nested_nodes] - - if not iteration_nested_nodes: - raise ValueError('iteration has no nested nodes') - - # init workflow run - if callbacks: - for callback in callbacks: - callback.on_workflow_run_started() - - for node_config in iteration_nested_nodes: - # mapping user inputs to variable pool - node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type'))) - if node_cls is None: - raise ValueError('Node class not found') - variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(node_config) - - # remove iteration variables - variable_mapping = { - f'{node_config.get("id")}.{key}': value for key, value in variable_mapping.items() - if value[0] != node_id - } - - # remove variable out from iteration - variable_mapping = { - key: value for key, value in variable_mapping.items() - if value[0] not in iteration_nested_node_ids - } - - # append variables to variable pool - node_instance = node_cls( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - workflow_id=workflow.id, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - config=node_config, - callbacks=callbacks, - workflow_call_depth=0 - ) - - self._mapping_user_inputs_to_variable_pool( - variable_mapping=variable_mapping, - user_inputs=user_inputs, - variable_pool=variable_pool, - tenant_id=workflow.tenant_id, - node_instance=node_instance - ) - - # fetch end node of iteration - end_node_id = None - for edge in graph.get('edges'): - if edge.get('source') == node_id: - end_node_id = edge.get('target') - break - - if not end_node_id: - raise ValueError('end node of iteration not found') - - # init workflow run state - workflow_run_state = WorkflowRunState( - workflow=workflow, - start_at=time.perf_counter(), - variable_pool=variable_pool, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - workflow_call_depth=0 - ) - - # run workflow - self._run_workflow( - workflow=workflow, - workflow_run_state=workflow_run_state, - callbacks=callbacks, - start_at=node_id, - end_at=end_node_id - ) - - def _workflow_run_success(self, callbacks: Sequence[WorkflowCallback]) -> None: - """ - Workflow run success - :param callbacks: workflow callbacks - :return: - """ - - if callbacks: - for callback in callbacks: - callback.on_workflow_run_succeeded() - - def _workflow_run_failed(self, error: str, - callbacks: Sequence[WorkflowCallback]) -> None: - """ - Workflow run failed - :param error: error message - :param callbacks: workflow callbacks - :return: - """ - if callbacks: - for callback in callbacks: - callback.on_workflow_run_failed( - error=error - ) - - def _workflow_iteration_started(self, *, graph: Mapping[str, Any], - current_iteration_node: BaseIterationNode, - workflow_run_state: WorkflowRunState, - predecessor_node_id: Optional[str] = None, - callbacks: Sequence[WorkflowCallback]) -> None: - """ - Workflow iteration started - :param current_iteration_node: current iteration node - :param workflow_run_state: workflow run state - :param callbacks: workflow callbacks - :return: - """ - # get nested nodes - iteration_nested_nodes = [ - node for node in graph.get('nodes') - if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id - ] - - if not iteration_nested_nodes: - raise ValueError('iteration has no nested nodes') - - if callbacks: - if isinstance(workflow_run_state.current_iteration_state, IterationState): - for callback in callbacks: - callback.on_workflow_iteration_started( - node_id=current_iteration_node.node_id, - node_type=NodeType.ITERATION, - node_run_index=workflow_run_state.workflow_node_steps, - node_data=current_iteration_node.node_data, - inputs=workflow_run_state.current_iteration_state.inputs, - predecessor_node_id=predecessor_node_id, - metadata=workflow_run_state.current_iteration_state.metadata.model_dump() - ) - - # add steps - workflow_run_state.workflow_node_steps += 1 - - def _workflow_iteration_next(self, *, graph: Mapping[str, Any], - current_iteration_node: BaseIterationNode, - workflow_run_state: WorkflowRunState, - callbacks: Sequence[WorkflowCallback]) -> None: - """ - Workflow iteration next - :param workflow_run_state: workflow run state - :return: - """ - if callbacks: - if isinstance(workflow_run_state.current_iteration_state, IterationState): - for callback in callbacks: - callback.on_workflow_iteration_next( - node_id=current_iteration_node.node_id, - node_type=NodeType.ITERATION, - index=workflow_run_state.current_iteration_state.index, - node_run_index=workflow_run_state.workflow_node_steps, - output=workflow_run_state.current_iteration_state.get_current_output() - ) - # clear ran nodes - workflow_run_state.workflow_node_runs = [ - node_run for node_run in workflow_run_state.workflow_node_runs - if node_run.iteration_node_id != current_iteration_node.node_id - ] - - # clear variables in current iteration - nodes = graph.get('nodes') - nodes = [node for node in nodes if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id] - - for node in nodes: - workflow_run_state.variable_pool.remove((node.get('id'),)) - - def _workflow_iteration_completed(self, *, current_iteration_node: BaseIterationNode, - workflow_run_state: WorkflowRunState, - callbacks: Sequence[WorkflowCallback]) -> None: - if callbacks: - if isinstance(workflow_run_state.current_iteration_state, IterationState): - for callback in callbacks: - callback.on_workflow_iteration_completed( - node_id=current_iteration_node.node_id, - node_type=NodeType.ITERATION, - node_run_index=workflow_run_state.workflow_node_steps, - outputs={ - 'output': workflow_run_state.current_iteration_state.outputs - } - ) - - def _get_next_overall_node(self, *, workflow_run_state: WorkflowRunState, - graph: Mapping[str, Any], - predecessor_node: Optional[BaseNode] = None, - callbacks: Sequence[WorkflowCallback], - start_at: Optional[str] = None, - end_at: Optional[str] = None) -> Optional[BaseNode]: - """ - Get next node - multiple target nodes in the future. - :param graph: workflow graph - :param predecessor_node: predecessor node - :param callbacks: workflow callbacks - :return: - """ - nodes = graph.get('nodes') - if not nodes: - return None - - if not predecessor_node: - for node_config in nodes: - node_cls = None - if start_at: - if node_config.get('id') == start_at: - node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type'))) - else: - if node_config.get('data', {}).get('type', '') == NodeType.START.value: - node_cls = StartNode - if node_cls: - return node_cls( - tenant_id=workflow_run_state.tenant_id, - app_id=workflow_run_state.app_id, - workflow_id=workflow_run_state.workflow_id, - user_id=workflow_run_state.user_id, - user_from=workflow_run_state.user_from, - invoke_from=workflow_run_state.invoke_from, - config=node_config, - callbacks=callbacks, - workflow_call_depth=workflow_run_state.workflow_call_depth - ) - - else: - edges = graph.get('edges') - source_node_id = predecessor_node.node_id - - # fetch all outgoing edges from source node - outgoing_edges = [edge for edge in edges if edge.get('source') == source_node_id] - if not outgoing_edges: - return None - - # fetch target node id from outgoing edges - outgoing_edge = None - source_handle = predecessor_node.node_run_result.edge_source_handle \ - if predecessor_node.node_run_result else None - if source_handle: - for edge in outgoing_edges: - if edge.get('sourceHandle') and edge.get('sourceHandle') == source_handle: - outgoing_edge = edge - break - else: - outgoing_edge = outgoing_edges[0] - - if not outgoing_edge: - return None - - target_node_id = outgoing_edge.get('target') - - if end_at and target_node_id == end_at: - return None - - # fetch target node from target node id - target_node_config = None - for node in nodes: - if node.get('id') == target_node_id: - target_node_config = node - break - - if not target_node_config: - return None - - # get next node - target_node = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type'))) - - return target_node( - tenant_id=workflow_run_state.tenant_id, - app_id=workflow_run_state.app_id, - workflow_id=workflow_run_state.workflow_id, - user_id=workflow_run_state.user_id, - user_from=workflow_run_state.user_from, - invoke_from=workflow_run_state.invoke_from, - config=target_node_config, - callbacks=callbacks, - workflow_call_depth=workflow_run_state.workflow_call_depth - ) - - def _get_node(self, workflow_run_state: WorkflowRunState, - graph: Mapping[str, Any], - node_id: str, - callbacks: Sequence[WorkflowCallback]): - """ - Get node from graph by node id - """ - nodes = graph.get('nodes') - if not nodes: - return None - - for node_config in nodes: - if node_config.get('id') == node_id: - node_type = NodeType.value_of(node_config.get('data', {}).get('type')) - node_cls = node_classes[node_type] - return node_cls( - tenant_id=workflow_run_state.tenant_id, - app_id=workflow_run_state.app_id, - workflow_id=workflow_run_state.workflow_id, - user_id=workflow_run_state.user_id, - user_from=workflow_run_state.user_from, - invoke_from=workflow_run_state.invoke_from, - config=node_config, - callbacks=callbacks, - workflow_call_depth=workflow_run_state.workflow_call_depth - ) - - def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: - """ - Check timeout - :param start_at: start time - :param max_execution_time: max execution time - :return: - """ - return time.perf_counter() - start_at > max_execution_time - - def _check_node_has_ran(self, workflow_run_state: WorkflowRunState, node_id: str) -> bool: - """ - Check node has ran - """ - return bool([ - node_and_result for node_and_result in workflow_run_state.workflow_node_runs - if node_and_result.node_id == node_id - ]) - - def _run_workflow_node(self, *, workflow_run_state: WorkflowRunState, - node: BaseNode, - predecessor_node: Optional[BaseNode] = None, - callbacks: Sequence[WorkflowCallback]) -> None: - if callbacks: - for callback in callbacks: - callback.on_workflow_node_execute_started( - node_id=node.node_id, - node_type=node.node_type, - node_data=node.node_data, - node_run_index=workflow_run_state.workflow_node_steps, - predecessor_node_id=predecessor_node.node_id if predecessor_node else None - ) - - db.session.close() - - workflow_nodes_and_result = WorkflowNodeAndResult( - node=node, - result=None - ) - - # add to workflow_nodes_and_results - workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result) - - # add steps - workflow_run_state.workflow_node_steps += 1 - - # mark node as running - if workflow_run_state.current_iteration_state: - workflow_run_state.workflow_node_runs.append(WorkflowRunState.NodeRun( - node_id=node.node_id, - iteration_node_id=workflow_run_state.current_iteration_state.iteration_node_id - )) - - try: - # run node, result must have inputs, process_data, outputs, execution_metadata - node_run_result = node.run( - variable_pool=workflow_run_state.variable_pool - ) - except GenerateTaskStoppedException as e: - node_run_result = NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error='Workflow stopped.' - ) - except Exception as e: - logger.exception(f"Node {node.node_data.title} run failed: {str(e)}") - node_run_result = NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e) - ) - - if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: - # node run failed - if callbacks: - for callback in callbacks: - callback.on_workflow_node_execute_failed( - node_id=node.node_id, - node_type=node.node_type, - node_data=node.node_data, - error=node_run_result.error, - inputs=node_run_result.inputs, - outputs=node_run_result.outputs, - process_data=node_run_result.process_data, - ) - - raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}") - - if node.is_answer_previous_node and not isinstance(node, LLMNode): - if not node_run_result.metadata: - node_run_result.metadata = {} - node_run_result.metadata["is_answer_previous_node"]=True - workflow_nodes_and_result.result = node_run_result - - # node run success - if callbacks: - for callback in callbacks: - callback.on_workflow_node_execute_succeeded( - node_id=node.node_id, - node_type=node.node_type, - node_data=node.node_data, - inputs=node_run_result.inputs, - process_data=node_run_result.process_data, - outputs=node_run_result.outputs, - execution_metadata=node_run_result.metadata - ) - - if node_run_result.outputs: - for variable_key, variable_value in node_run_result.outputs.items(): - # append variables to variable pool recursively - self._append_variables_recursively( - variable_pool=workflow_run_state.variable_pool, - node_id=node.node_id, - variable_key_list=[variable_key], - variable_value=variable_value - ) - - if node_run_result.metadata and node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): - workflow_run_state.total_tokens += int(node_run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)) - - db.session.close() - - def _append_variables_recursively(self, variable_pool: VariablePool, - node_id: str, - variable_key_list: list[str], - variable_value: VariableValue): - """ - Append variables recursively - :param variable_pool: variable pool - :param node_id: node id - :param variable_key_list: variable key list - :param variable_value: variable value - :return: - """ - variable_pool.add( - [node_id] + variable_key_list, variable_value - ) - - # if variable_value is a dict, then recursively append variables - if isinstance(variable_value, dict): - for key, value in variable_value.items(): - # construct new key list - new_key_list = variable_key_list + [key] - self._append_variables_recursively( - variable_pool=variable_pool, - node_id=node_id, - variable_key_list=new_key_list, - variable_value=value - ) - - @classmethod - def handle_special_values(cls, value: Optional[dict]) -> Optional[dict]: - """ - Handle special values - :param value: value - :return: - """ - if not value: - return None - - new_value = value.copy() - if isinstance(new_value, dict): - for key, val in new_value.items(): - if isinstance(val, FileVar): - new_value[key] = val.to_dict() - elif isinstance(val, list): - new_val = [] - for v in val: - if isinstance(v, FileVar): - new_val.append(v.to_dict()) - else: - new_val.append(v) - - new_value[key] = new_val - - return new_value - - def _mapping_user_inputs_to_variable_pool(self, - variable_mapping: Mapping[str, Sequence[str]], - user_inputs: dict, - variable_pool: VariablePool, - tenant_id: str, - node_instance: BaseNode): - for variable_key, variable_selector in variable_mapping.items(): - if variable_key not in user_inputs and not variable_pool.get(variable_selector): - raise ValueError(f'Variable key {variable_key} not found in user inputs.') - - # fetch variable node id from variable selector - variable_node_id = variable_selector[0] - variable_key_list = variable_selector[1:] - - # get value - value = user_inputs.get(variable_key) - - # FIXME: temp fix for image type - if node_instance.node_type == NodeType.LLM: - new_value = [] - if isinstance(value, list): - node_data = node_instance.node_data - node_data = cast(LLMNodeData, node_data) - - detail = node_data.vision.configs.detail if node_data.vision.configs else None - - for item in value: - if isinstance(item, dict) and 'type' in item and item['type'] == 'image': - transfer_method = FileTransferMethod.value_of(item.get('transfer_method')) - file = FileVar( - tenant_id=tenant_id, - type=FileType.IMAGE, - transfer_method=transfer_method, - url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None, - related_id=item.get( - 'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None, - extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None), - ) - new_value.append(file) - - if new_value: - value = new_value - - # append variable and value to variable pool - variable_pool.add([variable_node_id]+variable_key_list, value) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py new file mode 100644 index 00000000000000..84b251223f96f1 --- /dev/null +++ b/api/core/workflow/workflow_entry.py @@ -0,0 +1,293 @@ +import logging +import time +import uuid +from collections.abc import Generator, Mapping, Sequence +from typing import Any, Optional, cast + +from configs import dify_config +from core.app.app_config.entities import FileUploadConfig +from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom +from core.file.models import File, FileTransferMethod, ImageConfig +from core.workflow.callbacks import WorkflowCallback +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.errors import WorkflowNodeRunFailedError +from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.nodes import NodeType +from core.workflow.nodes.base import BaseNode, BaseNodeData +from core.workflow.nodes.event import NodeEvent +from core.workflow.nodes.llm import LLMNodeData +from core.workflow.nodes.node_mapping import node_type_classes_mapping +from factories import file_factory +from models.enums import UserFrom +from models.workflow import ( + Workflow, + WorkflowType, +) + +logger = logging.getLogger(__name__) + + +class WorkflowEntry: + def __init__( + self, + tenant_id: str, + app_id: str, + workflow_id: str, + workflow_type: WorkflowType, + graph_config: Mapping[str, Any], + graph: Graph, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + call_depth: int, + variable_pool: VariablePool, + thread_pool_id: Optional[str] = None, + ) -> None: + """ + Init workflow entry + :param tenant_id: tenant id + :param app_id: app id + :param workflow_id: workflow id + :param workflow_type: workflow type + :param graph_config: workflow graph config + :param graph: workflow graph + :param user_id: user id + :param user_from: user from + :param invoke_from: invoke from + :param call_depth: call depth + :param variable_pool: variable pool + :param thread_pool_id: thread pool id + """ + # check call depth + workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH + if call_depth > workflow_call_max_depth: + raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth)) + + # init workflow run state + self.graph_engine = GraphEngine( + tenant_id=tenant_id, + app_id=app_id, + workflow_type=workflow_type, + workflow_id=workflow_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + call_depth=call_depth, + graph=graph, + graph_config=graph_config, + variable_pool=variable_pool, + max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, + max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, + thread_pool_id=thread_pool_id, + ) + + def run( + self, + *, + callbacks: Sequence[WorkflowCallback], + ) -> Generator[GraphEngineEvent, None, None]: + """ + :param callbacks: workflow callbacks + """ + graph_engine = self.graph_engine + + try: + # run workflow + generator = graph_engine.run() + for event in generator: + if callbacks: + for callback in callbacks: + callback.on_event(event=event) + yield event + except GenerateTaskStoppedError: + pass + except Exception as e: + logger.exception("Unknown Error when workflow entry running") + if callbacks: + for callback in callbacks: + callback.on_event(event=GraphRunFailedEvent(error=str(e))) + return + + @classmethod + def single_step_run( + cls, workflow: Workflow, node_id: str, user_id: str, user_inputs: dict + ) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]: + """ + Single step run workflow node + :param workflow: Workflow instance + :param node_id: node id + :param user_id: user id + :param user_inputs: user inputs + :return: + """ + # fetch node info from workflow graph + graph = workflow.graph_dict + if not graph: + raise ValueError("workflow graph not found") + + nodes = graph.get("nodes") + if not nodes: + raise ValueError("nodes not found in workflow graph") + + # fetch node config from node id + node_config = None + for node in nodes: + if node.get("id") == node_id: + node_config = node + break + + if not node_config: + raise ValueError("node id not found in workflow graph") + + # Get node class + node_type = NodeType(node_config.get("data", {}).get("type")) + node_cls = node_type_classes_mapping.get(node_type) + node_cls = cast(type[BaseNode], node_cls) + + if not node_cls: + raise ValueError(f"Node class not found for node type {node_type}") + + # init variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + environment_variables=workflow.environment_variables, + ) + + # init graph + graph = Graph.init(graph_config=workflow.graph_dict) + + # init workflow run state + node_instance = node_cls( + id=str(uuid.uuid4()), + config=node_config, + graph_init_params=GraphInitParams( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + workflow_type=WorkflowType.value_of(workflow.type), + workflow_id=workflow.id, + graph_config=workflow.graph_dict, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + ) + + try: + # variable selector to variable mapping + try: + variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=workflow.graph_dict, config=node_config + ) + except NotImplementedError: + variable_mapping = {} + + cls.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id=workflow.tenant_id, + node_type=node_type, + node_data=node_instance.node_data, + ) + + # run node + generator = node_instance.run() + + return node_instance, generator + except Exception as e: + raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) + + @staticmethod + def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: + return WorkflowEntry._handle_special_values(value) + + @staticmethod + def _handle_special_values(value: Any) -> Any: + if value is None: + return value + if isinstance(value, dict): + res = {} + for k, v in value.items(): + res[k] = WorkflowEntry._handle_special_values(v) + return res + if isinstance(value, list): + res = [] + for item in value: + res.append(WorkflowEntry._handle_special_values(item)) + return res + if isinstance(value, File): + return value.to_dict() + return value + + @classmethod + def mapping_user_inputs_to_variable_pool( + cls, + variable_mapping: Mapping[str, Sequence[str]], + user_inputs: dict, + variable_pool: VariablePool, + tenant_id: str, + node_type: NodeType, + node_data: BaseNodeData, + ) -> None: + for node_variable, variable_selector in variable_mapping.items(): + # fetch node id and variable key from node_variable + node_variable_list = node_variable.split(".") + if len(node_variable_list) < 1: + raise ValueError(f"Invalid node variable {node_variable}") + + node_variable_key = ".".join(node_variable_list[1:]) + + if (node_variable_key not in user_inputs and node_variable not in user_inputs) and not variable_pool.get( + variable_selector + ): + raise ValueError(f"Variable key {node_variable} not found in user inputs.") + + # fetch variable node id from variable selector + variable_node_id = variable_selector[0] + variable_key_list = variable_selector[1:] + variable_key_list = cast(list[str], variable_key_list) + + # get input value + input_value = user_inputs.get(node_variable) + if not input_value: + input_value = user_inputs.get(node_variable_key) + + # FIXME: temp fix for image type + if node_type == NodeType.LLM: + new_value = [] + if isinstance(input_value, list): + node_data = cast(LLMNodeData, node_data) + + detail = node_data.vision.configs.detail if node_data.vision.configs else None + + for item in input_value: + if isinstance(item, dict) and "type" in item and item["type"] == "image": + transfer_method = FileTransferMethod.value_of(item.get("transfer_method")) + mapping = { + "id": item.get("id"), + "transfer_method": transfer_method, + "upload_file_id": item.get("upload_file_id"), + "url": item.get("url"), + } + config = FileUploadConfig(image_config=ImageConfig(detail=detail) if detail else None) + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + config=config, + ) + new_value.append(file) + + if new_value: + input_value = new_value + + # append variable and value to variable pool + variable_pool.add([variable_node_id] + variable_key_list, input_value) diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 9cf5c505d138af..1edc558676747a 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -20,11 +20,11 @@ if [[ "${MODE}" == "worker" ]]; then CONCURRENCY_OPTION="-c ${CELERY_WORKER_AMOUNT:-1}" fi - exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION --loglevel INFO \ + exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION --loglevel ${LOG_LEVEL} \ -Q ${CELERY_QUEUES:-dataset,generation,mail,ops_trace,app_deletion} elif [[ "${MODE}" == "beat" ]]; then - exec celery -A app.celery beat --loglevel INFO + exec celery -A app.celery beat --loglevel ${LOG_LEVEL} else if [[ "${DEBUG}" == "true" ]]; then exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py index 7ee7146d0983cf..1d6ad35333c014 100644 --- a/api/events/event_handlers/__init__.py +++ b/api/events/event_handlers/__init__.py @@ -3,8 +3,8 @@ from .create_document_index import handle from .create_installed_app_when_app_created import handle from .create_site_record_when_app_created import handle -from .deduct_quota_when_messaeg_created import handle +from .deduct_quota_when_message_created import handle from .delete_tool_parameters_cache_when_sync_draft_workflow import handle from .update_app_dataset_join_when_app_model_config_updated import handle from .update_app_dataset_join_when_app_published_workflow_updated import handle -from .update_provider_last_used_at_when_messaeg_created import handle +from .update_provider_last_used_at_when_message_created import handle diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 72a135e73d4ca5..5af45e1e5026df 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -5,7 +5,7 @@ import click from werkzeug.exceptions import NotFound -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from events.event_handlers.document_index_event import document_index_created from extensions.ext_database import db from models.dataset import Document @@ -14,7 +14,7 @@ @document_index_created.connect def handle(sender, **kwargs): dataset_id = sender - document_ids = kwargs.get("document_ids", None) + document_ids = kwargs.get("document_ids") documents = [] start_at = time.perf_counter() for document_id in document_ids: @@ -43,7 +43,7 @@ def handle(sender, **kwargs): indexing_runner.run(documents) end_at = time.perf_counter() logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) - except DocumentIsPausedException as ex: + except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py index abaf0e41ec30e6..1515661b2d45b8 100644 --- a/api/events/event_handlers/create_site_record_when_app_created.py +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -11,11 +11,14 @@ def handle(sender, **kwargs): site = Site( app_id=app.id, title=app.name, + icon_type=app.icon_type, icon=app.icon, icon_background=app.icon_background, default_language=account.interface_language, customize_token_strategy="not_allow", code=Site.generate_code(16), + created_by=app.created_by, + updated_by=app.updated_by, ) db.session.add(site) diff --git a/api/events/event_handlers/deduct_quota_when_messaeg_created.py b/api/events/event_handlers/deduct_quota_when_message_created.py similarity index 100% rename from api/events/event_handlers/deduct_quota_when_messaeg_created.py rename to api/events/event_handlers/deduct_quota_when_message_created.py diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index f96bb5ef74b62e..9c5955c8c5a1a5 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -1,6 +1,6 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager -from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolEntity from events.app_event import app_draft_workflow_was_synced diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index 59375b1a0b1a81..de7c0f4dfeb74f 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -18,8 +18,7 @@ def handle(sender, **kwargs): added_dataset_ids = dataset_ids else: old_dataset_ids = set() - for app_dataset_join in app_dataset_joins: - old_dataset_ids.add(app_dataset_join.dataset_id) + old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins) added_dataset_ids = dataset_ids - old_dataset_ids removed_dataset_ids = old_dataset_ids - dataset_ids diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 333b85ecb2907a..453395e8d7dc1c 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -1,6 +1,6 @@ from typing import cast -from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes import NodeType from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db @@ -22,8 +22,7 @@ def handle(sender, **kwargs): added_dataset_ids = dataset_ids else: old_dataset_ids = set() - for app_dataset_join in app_dataset_joins: - old_dataset_ids.add(app_dataset_join.dataset_id) + old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins) added_dataset_ids = dataset_ids - old_dataset_ids removed_dataset_ids = old_dataset_ids - dataset_ids diff --git a/api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py b/api/events/event_handlers/update_provider_last_used_at_when_message_created.py similarity index 100% rename from api/events/event_handlers/update_provider_last_used_at_when_messaeg_created.py rename to api/events/event_handlers/update_provider_last_used_at_when_message_created.py diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index f5ec7c1759cb9d..c5de7395b8cd29 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -1,8 +1,12 @@ from datetime import timedelta +import pytz from celery import Celery, Task +from celery.schedules import crontab from flask import Flask +from configs import dify_config + def init_app(app: Flask) -> Celery: class FlaskTask(Task): @@ -10,11 +14,21 @@ def __call__(self, *args: object, **kwargs: object) -> object: with app.app_context(): return self.run(*args, **kwargs) + broker_transport_options = {} + + if dify_config.CELERY_USE_SENTINEL: + broker_transport_options = { + "master_name": dify_config.CELERY_SENTINEL_MASTER_NAME, + "sentinel_kwargs": { + "socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT, + }, + } + celery_app = Celery( app.name, task_cls=FlaskTask, - broker=app.config["CELERY_BROKER_URL"], - backend=app.config["CELERY_BACKEND"], + broker=dify_config.CELERY_BROKER_URL, + backend=dify_config.CELERY_BACKEND, task_ignore_result=True, ) @@ -27,11 +41,17 @@ def __call__(self, *args: object, **kwargs: object) -> object: } celery_app.conf.update( - result_backend=app.config["CELERY_RESULT_BACKEND"], + result_backend=dify_config.CELERY_RESULT_BACKEND, + broker_transport_options=broker_transport_options, broker_connection_retry_on_startup=True, + worker_log_format=dify_config.LOG_FORMAT, + worker_task_log_format=dify_config.LOG_FORMAT, + worker_logfile=dify_config.LOG_FILE, + worker_hijack_root_logger=False, + timezone=pytz.timezone(dify_config.LOG_TZ), ) - if app.config["BROKER_USE_SSL"]: + if dify_config.BROKER_USE_SSL: celery_app.conf.update( broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration ) @@ -42,8 +62,10 @@ def __call__(self, *args: object, **kwargs: object) -> object: imports = [ "schedule.clean_embedding_cache_task", "schedule.clean_unused_datasets_task", + "schedule.create_tidb_serverless_task", + "schedule.update_tidb_serverless_status_task", ] - day = app.config["CELERY_BEAT_SCHEDULER_TIME"] + day = dify_config.CELERY_BEAT_SCHEDULER_TIME beat_schedule = { "clean_embedding_cache_task": { "task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task", @@ -53,6 +75,14 @@ def __call__(self, *args: object, **kwargs: object) -> object: "task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task", "schedule": timedelta(days=day), }, + "create_tidb_serverless_task": { + "task": "schedule.create_tidb_serverless_task.create_tidb_serverless_task", + "schedule": crontab(minute="0", hour="*"), + }, + "update_tidb_serverless_status_task": { + "task": "schedule.update_tidb_serverless_status_task.update_tidb_serverless_status_task", + "schedule": crontab(minute="30", hour="*"), + }, } celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) diff --git a/api/extensions/ext_compress.py b/api/extensions/ext_compress.py index 38e67749fcc994..a6de28597bc08a 100644 --- a/api/extensions/ext_compress.py +++ b/api/extensions/ext_compress.py @@ -1,8 +1,10 @@ from flask import Flask +from configs import dify_config + def init_app(app: Flask): - if app.config.get("API_COMPRESSION_ENABLED"): + if dify_config.API_COMPRESSION_ENABLED: from flask_compress import Compress app.config["COMPRESS_MIMETYPES"] = [ diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py new file mode 100644 index 00000000000000..56b1d6bd28ba90 --- /dev/null +++ b/api/extensions/ext_logging.py @@ -0,0 +1,45 @@ +import logging +import os +import sys +from logging.handlers import RotatingFileHandler + +from flask import Flask + +from configs import dify_config + + +def init_app(app: Flask): + log_handlers = None + log_file = dify_config.LOG_FILE + if log_file: + log_dir = os.path.dirname(log_file) + os.makedirs(log_dir, exist_ok=True) + log_handlers = [ + RotatingFileHandler( + filename=log_file, + maxBytes=dify_config.LOG_FILE_MAX_SIZE * 1024 * 1024, + backupCount=dify_config.LOG_FILE_BACKUP_COUNT, + ), + logging.StreamHandler(sys.stdout), + ] + + logging.basicConfig( + level=dify_config.LOG_LEVEL, + format=dify_config.LOG_FORMAT, + datefmt=dify_config.LOG_DATEFORMAT, + handlers=log_handlers, + force=True, + ) + log_tz = dify_config.LOG_TZ + if log_tz: + from datetime import datetime + + import pytz + + timezone = pytz.timezone(log_tz) + + def time_converter(seconds): + return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() + + for handler in logging.root.handlers: + handler.formatter.converter = time_converter diff --git a/api/extensions/ext_mail.py b/api/extensions/ext_mail.py index b435294abc23ba..5c5b331d8ab95f 100644 --- a/api/extensions/ext_mail.py +++ b/api/extensions/ext_mail.py @@ -4,6 +4,8 @@ import resend from flask import Flask +from configs import dify_config + class Mail: def __init__(self): @@ -14,41 +16,44 @@ def is_inited(self) -> bool: return self._client is not None def init_app(self, app: Flask): - if app.config.get("MAIL_TYPE"): - if app.config.get("MAIL_DEFAULT_SEND_FROM"): - self._default_send_from = app.config.get("MAIL_DEFAULT_SEND_FROM") + mail_type = dify_config.MAIL_TYPE + if not mail_type: + logging.warning("MAIL_TYPE is not set") + return - if app.config.get("MAIL_TYPE") == "resend": - api_key = app.config.get("RESEND_API_KEY") + if dify_config.MAIL_DEFAULT_SEND_FROM: + self._default_send_from = dify_config.MAIL_DEFAULT_SEND_FROM + + match mail_type: + case "resend": + api_key = dify_config.RESEND_API_KEY if not api_key: raise ValueError("RESEND_API_KEY is not set") - api_url = app.config.get("RESEND_API_URL") + api_url = dify_config.RESEND_API_URL if api_url: resend.api_url = api_url resend.api_key = api_key self._client = resend.Emails - elif app.config.get("MAIL_TYPE") == "smtp": + case "smtp": from libs.smtp import SMTPClient - if not app.config.get("SMTP_SERVER") or not app.config.get("SMTP_PORT"): + if not dify_config.SMTP_SERVER or not dify_config.SMTP_PORT: raise ValueError("SMTP_SERVER and SMTP_PORT are required for smtp mail type") - if not app.config.get("SMTP_USE_TLS") and app.config.get("SMTP_OPPORTUNISTIC_TLS"): + if not dify_config.SMTP_USE_TLS and dify_config.SMTP_OPPORTUNISTIC_TLS: raise ValueError("SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS") self._client = SMTPClient( - server=app.config.get("SMTP_SERVER"), - port=app.config.get("SMTP_PORT"), - username=app.config.get("SMTP_USERNAME"), - password=app.config.get("SMTP_PASSWORD"), - _from=app.config.get("MAIL_DEFAULT_SEND_FROM"), - use_tls=app.config.get("SMTP_USE_TLS"), - opportunistic_tls=app.config.get("SMTP_OPPORTUNISTIC_TLS"), + server=dify_config.SMTP_SERVER, + port=dify_config.SMTP_PORT, + username=dify_config.SMTP_USERNAME, + password=dify_config.SMTP_PASSWORD, + _from=dify_config.MAIL_DEFAULT_SEND_FROM, + use_tls=dify_config.SMTP_USE_TLS, + opportunistic_tls=dify_config.SMTP_OPPORTUNISTIC_TLS, ) - else: - raise ValueError("Unsupported mail type {}".format(app.config.get("MAIL_TYPE"))) - else: - logging.warning("MAIL_TYPE is not set") + case _: + raise ValueError("Unsupported mail type {}".format(mail_type)) def send(self, to: str, subject: str, html: str, from_: Optional[str] = None): if not self._client: diff --git a/api/extensions/ext_proxy_fix.py b/api/extensions/ext_proxy_fix.py new file mode 100644 index 00000000000000..c106a4384a156f --- /dev/null +++ b/api/extensions/ext_proxy_fix.py @@ -0,0 +1,10 @@ +from flask import Flask + +from configs import dify_config + + +def init_app(app: Flask): + if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED: + from werkzeug.middleware.proxy_fix import ProxyFix + + app.wsgi_app = ProxyFix(app.wsgi_app) diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index d5fb162fd8f2fb..e1f8409f2190fe 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -1,26 +1,85 @@ import redis from redis.connection import Connection, SSLConnection +from redis.sentinel import Sentinel -redis_client = redis.Redis() +from configs import dify_config + + +class RedisClientWrapper(redis.Redis): + """ + A wrapper class for the Redis client that addresses the issue where the global + `redis_client` variable cannot be updated when a new Redis instance is returned + by Sentinel. + + This class allows for deferred initialization of the Redis client, enabling the + client to be re-initialized with a new instance when necessary. This is particularly + useful in scenarios where the Redis instance may change dynamically, such as during + a failover in a Sentinel-managed Redis setup. + + Attributes: + _client (redis.Redis): The actual Redis client instance. It remains None until + initialized with the `initialize` method. + + Methods: + initialize(client): Initializes the Redis client if it hasn't been initialized already. + __getattr__(item): Delegates attribute access to the Redis client, raising an error + if the client is not initialized. + """ + + def __init__(self): + self._client = None + + def initialize(self, client): + if self._client is None: + self._client = client + + def __getattr__(self, item): + if self._client is None: + raise RuntimeError("Redis client is not initialized. Call init_app first.") + return getattr(self._client, item) + + +redis_client = RedisClientWrapper() def init_app(app): + global redis_client connection_class = Connection - if app.config.get("REDIS_USE_SSL"): + if dify_config.REDIS_USE_SSL: connection_class = SSLConnection - redis_client.connection_pool = redis.ConnectionPool( - **{ - "host": app.config.get("REDIS_HOST"), - "port": app.config.get("REDIS_PORT"), - "username": app.config.get("REDIS_USERNAME"), - "password": app.config.get("REDIS_PASSWORD"), - "db": app.config.get("REDIS_DB"), - "encoding": "utf-8", - "encoding_errors": "strict", - "decode_responses": False, - }, - connection_class=connection_class, - ) + redis_params = { + "username": dify_config.REDIS_USERNAME, + "password": dify_config.REDIS_PASSWORD, + "db": dify_config.REDIS_DB, + "encoding": "utf-8", + "encoding_errors": "strict", + "decode_responses": False, + } + + if dify_config.REDIS_USE_SENTINEL: + sentinel_hosts = [ + (node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",") + ] + sentinel = Sentinel( + sentinel_hosts, + sentinel_kwargs={ + "socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT, + "username": dify_config.REDIS_SENTINEL_USERNAME, + "password": dify_config.REDIS_SENTINEL_PASSWORD, + }, + ) + master = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params) + redis_client.initialize(master) + else: + redis_params.update( + { + "host": dify_config.REDIS_HOST, + "port": dify_config.REDIS_PORT, + "connection_class": connection_class, + } + ) + pool = redis.ConnectionPool(**redis_params) + redis_client.initialize(redis.Redis(connection_pool=pool)) app.extensions["redis"] = redis_client diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 227c6635f0eb11..11f1dd93c6a670 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -1,17 +1,38 @@ +import openai import sentry_sdk +from langfuse import parse_error from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException +from configs import dify_config +from core.model_runtime.errors.invoke import InvokeRateLimitError + + +def before_send(event, hint): + if "exc_info" in hint: + exc_type, exc_value, tb = hint["exc_info"] + if parse_error.defaultErrorResponse in str(exc_value): + return None + + return event + def init_app(app): - if app.config.get("SENTRY_DSN"): + if dify_config.SENTRY_DSN: sentry_sdk.init( - dsn=app.config.get("SENTRY_DSN"), + dsn=dify_config.SENTRY_DSN, integrations=[FlaskIntegration(), CeleryIntegration()], - ignore_errors=[HTTPException, ValueError], - traces_sample_rate=app.config.get("SENTRY_TRACES_SAMPLE_RATE", 1.0), - profiles_sample_rate=app.config.get("SENTRY_PROFILES_SAMPLE_RATE", 1.0), - environment=app.config.get("DEPLOY_ENV"), - release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}", + ignore_errors=[ + HTTPException, + ValueError, + openai.APIStatusError, + InvokeRateLimitError, + parse_error.defaultErrorResponse, + ], + traces_sample_rate=dify_config.SENTRY_TRACES_SAMPLE_RATE, + profiles_sample_rate=dify_config.SENTRY_PROFILES_SAMPLE_RATE, + environment=dify_config.DEPLOY_ENV, + release=f"dify-{dify_config.CURRENT_VERSION}-{dify_config.COMMIT_SHA}", + before_send=before_send, ) diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index e6c4352577fc3f..86fadf23d787f3 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -1,15 +1,12 @@ +import logging from collections.abc import Generator from typing import Union from flask import Flask -from extensions.storage.aliyun_storage import AliyunStorage -from extensions.storage.azure_storage import AzureStorage -from extensions.storage.google_storage import GoogleStorage -from extensions.storage.local_storage import LocalStorage -from extensions.storage.oci_storage import OCIStorage -from extensions.storage.s3_storage import S3Storage -from extensions.storage.tencent_storage import TencentStorage +from configs import dify_config +from extensions.storage.base_storage import BaseStorage +from extensions.storage.storage_type import StorageType class Storage: @@ -17,45 +14,109 @@ def __init__(self): self.storage_runner = None def init_app(self, app: Flask): - storage_type = app.config.get("STORAGE_TYPE") - if storage_type == "s3": - self.storage_runner = S3Storage(app=app) - elif storage_type == "azure-blob": - self.storage_runner = AzureStorage(app=app) - elif storage_type == "aliyun-oss": - self.storage_runner = AliyunStorage(app=app) - elif storage_type == "google-storage": - self.storage_runner = GoogleStorage(app=app) - elif storage_type == "tencent-cos": - self.storage_runner = TencentStorage(app=app) - elif storage_type == "oci-storage": - self.storage_runner = OCIStorage(app=app) - else: - self.storage_runner = LocalStorage(app=app) + storage_factory = self.get_storage_factory(dify_config.STORAGE_TYPE) + with app.app_context(): + self.storage_runner = storage_factory() - def save(self, filename, data): - self.storage_runner.save(filename, data) + @staticmethod + def get_storage_factory(storage_type: str) -> type[BaseStorage]: + match storage_type: + case StorageType.S3: + from extensions.storage.aws_s3_storage import AwsS3Storage + + return AwsS3Storage + case StorageType.AZURE_BLOB: + from extensions.storage.azure_blob_storage import AzureBlobStorage + + return AzureBlobStorage + case StorageType.ALIYUN_OSS: + from extensions.storage.aliyun_oss_storage import AliyunOssStorage + + return AliyunOssStorage + case StorageType.GOOGLE_STORAGE: + from extensions.storage.google_cloud_storage import GoogleCloudStorage + + return GoogleCloudStorage + case StorageType.TENCENT_COS: + from extensions.storage.tencent_cos_storage import TencentCosStorage + + return TencentCosStorage + case StorageType.OCI_STORAGE: + from extensions.storage.oracle_oci_storage import OracleOCIStorage + + return OracleOCIStorage + case StorageType.HUAWEI_OBS: + from extensions.storage.huawei_obs_storage import HuaweiObsStorage - def load(self, filename: str, stream: bool = False) -> Union[bytes, Generator]: - if stream: - return self.load_stream(filename) - else: - return self.load_once(filename) + return HuaweiObsStorage + case StorageType.BAIDU_OBS: + from extensions.storage.baidu_obs_storage import BaiduObsStorage + + return BaiduObsStorage + case StorageType.VOLCENGINE_TOS: + from extensions.storage.volcengine_tos_storage import VolcengineTosStorage + + return VolcengineTosStorage + case StorageType.SUPBASE: + from extensions.storage.supabase_storage import SupabaseStorage + + return SupabaseStorage + case StorageType.LOCAL | _: + from extensions.storage.local_fs_storage import LocalFsStorage + + return LocalFsStorage + + def save(self, filename, data): + try: + self.storage_runner.save(filename, data) + except Exception as e: + logging.exception("Failed to save file: %s", e) + raise e + + def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]: + try: + if stream: + return self.load_stream(filename) + else: + return self.load_once(filename) + except Exception as e: + logging.exception("Failed to load file: %s", e) + raise e def load_once(self, filename: str) -> bytes: - return self.storage_runner.load_once(filename) + try: + return self.storage_runner.load_once(filename) + except Exception as e: + logging.exception("Failed to load_once file: %s", e) + raise e def load_stream(self, filename: str) -> Generator: - return self.storage_runner.load_stream(filename) + try: + return self.storage_runner.load_stream(filename) + except Exception as e: + logging.exception("Failed to load_stream file: %s", e) + raise e def download(self, filename, target_filepath): - self.storage_runner.download(filename, target_filepath) + try: + self.storage_runner.download(filename, target_filepath) + except Exception as e: + logging.exception("Failed to download file: %s", e) + raise e def exists(self, filename): - return self.storage_runner.exists(filename) + try: + return self.storage_runner.exists(filename) + except Exception as e: + logging.exception("Failed to check file exists: %s", e) + raise e def delete(self, filename): - return self.storage_runner.delete(filename) + try: + return self.storage_runner.delete(filename) + except Exception as e: + logging.exception("Failed to delete file: %s", e) + raise e storage = Storage() diff --git a/api/extensions/storage/aliyun_oss_storage.py b/api/extensions/storage/aliyun_oss_storage.py new file mode 100644 index 00000000000000..58c917dbd386bc --- /dev/null +++ b/api/extensions/storage/aliyun_oss_storage.py @@ -0,0 +1,54 @@ +import posixpath +from collections.abc import Generator + +import oss2 as aliyun_s3 + +from configs import dify_config +from extensions.storage.base_storage import BaseStorage + + +class AliyunOssStorage(BaseStorage): + """Implementation for Aliyun OSS storage.""" + + def __init__(self): + super().__init__() + self.bucket_name = dify_config.ALIYUN_OSS_BUCKET_NAME + self.folder = dify_config.ALIYUN_OSS_PATH + oss_auth_method = aliyun_s3.Auth + region = None + if dify_config.ALIYUN_OSS_AUTH_VERSION == "v4": + oss_auth_method = aliyun_s3.AuthV4 + region = dify_config.ALIYUN_OSS_REGION + oss_auth = oss_auth_method(dify_config.ALIYUN_OSS_ACCESS_KEY, dify_config.ALIYUN_OSS_SECRET_KEY) + self.client = aliyun_s3.Bucket( + oss_auth, + dify_config.ALIYUN_OSS_ENDPOINT, + self.bucket_name, + connect_timeout=30, + region=region, + ) + + def save(self, filename, data): + self.client.put_object(self.__wrapper_folder_filename(filename), data) + + def load_once(self, filename: str) -> bytes: + obj = self.client.get_object(self.__wrapper_folder_filename(filename)) + data = obj.read() + return data + + def load_stream(self, filename: str) -> Generator: + obj = self.client.get_object(self.__wrapper_folder_filename(filename)) + while chunk := obj.read(4096): + yield chunk + + def download(self, filename, target_filepath): + self.client.get_object_to_file(self.__wrapper_folder_filename(filename), target_filepath) + + def exists(self, filename): + return self.client.object_exists(self.__wrapper_folder_filename(filename)) + + def delete(self, filename): + self.client.delete_object(self.__wrapper_folder_filename(filename)) + + def __wrapper_folder_filename(self, filename) -> str: + return posixpath.join(self.folder, filename) if self.folder else filename diff --git a/api/extensions/storage/aliyun_storage.py b/api/extensions/storage/aliyun_storage.py deleted file mode 100644 index b962cedc55178d..00000000000000 --- a/api/extensions/storage/aliyun_storage.py +++ /dev/null @@ -1,55 +0,0 @@ -from collections.abc import Generator -from contextlib import closing - -import oss2 as aliyun_s3 -from flask import Flask - -from extensions.storage.base_storage import BaseStorage - - -class AliyunStorage(BaseStorage): - """Implementation for aliyun storage.""" - - def __init__(self, app: Flask): - super().__init__(app) - - app_config = self.app.config - self.bucket_name = app_config.get("ALIYUN_OSS_BUCKET_NAME") - oss_auth_method = aliyun_s3.Auth - region = None - if app_config.get("ALIYUN_OSS_AUTH_VERSION") == "v4": - oss_auth_method = aliyun_s3.AuthV4 - region = app_config.get("ALIYUN_OSS_REGION") - oss_auth = oss_auth_method(app_config.get("ALIYUN_OSS_ACCESS_KEY"), app_config.get("ALIYUN_OSS_SECRET_KEY")) - self.client = aliyun_s3.Bucket( - oss_auth, - app_config.get("ALIYUN_OSS_ENDPOINT"), - self.bucket_name, - connect_timeout=30, - region=region, - ) - - def save(self, filename, data): - self.client.put_object(filename, data) - - def load_once(self, filename: str) -> bytes: - with closing(self.client.get_object(filename)) as obj: - data = obj.read() - return data - - def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - with closing(self.client.get_object(filename)) as obj: - while chunk := obj.read(4096): - yield chunk - - return generate() - - def download(self, filename, target_filepath): - self.client.get_object_to_file(filename, target_filepath) - - def exists(self, filename): - return self.client.object_exists(filename) - - def delete(self, filename): - self.client.delete_object(filename) diff --git a/api/extensions/storage/aws_s3_storage.py b/api/extensions/storage/aws_s3_storage.py new file mode 100644 index 00000000000000..ab2d0fba3b19f3 --- /dev/null +++ b/api/extensions/storage/aws_s3_storage.py @@ -0,0 +1,85 @@ +import logging +from collections.abc import Generator + +import boto3 +from botocore.client import Config +from botocore.exceptions import ClientError + +from configs import dify_config +from extensions.storage.base_storage import BaseStorage + +logger = logging.getLogger(__name__) + + +class AwsS3Storage(BaseStorage): + """Implementation for Amazon Web Services S3 storage.""" + + def __init__(self): + super().__init__() + self.bucket_name = dify_config.S3_BUCKET_NAME + if dify_config.S3_USE_AWS_MANAGED_IAM: + logger.info("Using AWS managed IAM role for S3") + + session = boto3.Session() + region_name = dify_config.S3_REGION + self.client = session.client(service_name="s3", region_name=region_name) + else: + logger.info("Using ak and sk for S3") + + self.client = boto3.client( + "s3", + aws_secret_access_key=dify_config.S3_SECRET_KEY, + aws_access_key_id=dify_config.S3_ACCESS_KEY, + endpoint_url=dify_config.S3_ENDPOINT, + region_name=dify_config.S3_REGION, + config=Config(s3={"addressing_style": dify_config.S3_ADDRESS_STYLE}), + ) + # create bucket + try: + self.client.head_bucket(Bucket=self.bucket_name) + except ClientError as e: + # if bucket not exists, create it + if e.response["Error"]["Code"] == "404": + self.client.create_bucket(Bucket=self.bucket_name) + # if bucket is not accessible, pass, maybe the bucket is existing but not accessible + elif e.response["Error"]["Code"] == "403": + pass + else: + # other error, raise exception + raise + + def save(self, filename, data): + self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data) + + def load_once(self, filename: str) -> bytes: + try: + data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() + except ClientError as ex: + if ex.response["Error"]["Code"] == "NoSuchKey": + raise FileNotFoundError("File not found") + else: + raise + return data + + def load_stream(self, filename: str) -> Generator: + try: + response = self.client.get_object(Bucket=self.bucket_name, Key=filename) + yield from response["Body"].iter_chunks() + except ClientError as ex: + if ex.response["Error"]["Code"] == "NoSuchKey": + raise FileNotFoundError("File not found") + else: + raise + + def download(self, filename, target_filepath): + self.client.download_file(self.bucket_name, filename, target_filepath) + + def exists(self, filename): + try: + self.client.head_object(Bucket=self.bucket_name, Key=filename) + return True + except: + return False + + def delete(self, filename): + self.client.delete_object(Bucket=self.bucket_name, Key=filename) diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py new file mode 100644 index 00000000000000..11a7544274d452 --- /dev/null +++ b/api/extensions/storage/azure_blob_storage.py @@ -0,0 +1,73 @@ +from collections.abc import Generator +from datetime import datetime, timedelta, timezone + +from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas + +from configs import dify_config +from extensions.ext_redis import redis_client +from extensions.storage.base_storage import BaseStorage + + +class AzureBlobStorage(BaseStorage): + """Implementation for Azure Blob storage.""" + + def __init__(self): + super().__init__() + self.bucket_name = dify_config.AZURE_BLOB_CONTAINER_NAME + self.account_url = dify_config.AZURE_BLOB_ACCOUNT_URL + self.account_name = dify_config.AZURE_BLOB_ACCOUNT_NAME + self.account_key = dify_config.AZURE_BLOB_ACCOUNT_KEY + + def save(self, filename, data): + client = self._sync_client() + blob_container = client.get_container_client(container=self.bucket_name) + blob_container.upload_blob(filename, data) + + def load_once(self, filename: str) -> bytes: + client = self._sync_client() + blob = client.get_container_client(container=self.bucket_name) + blob = blob.get_blob_client(blob=filename) + data = blob.download_blob().readall() + return data + + def load_stream(self, filename: str) -> Generator: + client = self._sync_client() + blob = client.get_blob_client(container=self.bucket_name, blob=filename) + blob_data = blob.download_blob() + yield from blob_data.chunks() + + def download(self, filename, target_filepath): + client = self._sync_client() + + blob = client.get_blob_client(container=self.bucket_name, blob=filename) + with open(target_filepath, "wb") as my_blob: + blob_data = blob.download_blob() + blob_data.readinto(my_blob) + + def exists(self, filename): + client = self._sync_client() + + blob = client.get_blob_client(container=self.bucket_name, blob=filename) + return blob.exists() + + def delete(self, filename): + client = self._sync_client() + + blob_container = client.get_container_client(container=self.bucket_name) + blob_container.delete_blob(filename) + + def _sync_client(self): + cache_key = "azure_blob_sas_token_{}_{}".format(self.account_name, self.account_key) + cache_result = redis_client.get(cache_key) + if cache_result is not None: + sas_token = cache_result.decode("utf-8") + else: + sas_token = generate_account_sas( + account_name=self.account_name, + account_key=self.account_key, + resource_types=ResourceTypes(service=True, container=True, object=True), + permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), + expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1), + ) + redis_client.set(cache_key, sas_token, ex=3000) + return BlobServiceClient(account_url=self.account_url, credential=sas_token) diff --git a/api/extensions/storage/azure_storage.py b/api/extensions/storage/azure_storage.py deleted file mode 100644 index ca8cbb9188b5c9..00000000000000 --- a/api/extensions/storage/azure_storage.py +++ /dev/null @@ -1,78 +0,0 @@ -from collections.abc import Generator -from datetime import datetime, timedelta, timezone - -from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas -from flask import Flask - -from extensions.ext_redis import redis_client -from extensions.storage.base_storage import BaseStorage - - -class AzureStorage(BaseStorage): - """Implementation for azure storage.""" - - def __init__(self, app: Flask): - super().__init__(app) - app_config = self.app.config - self.bucket_name = app_config.get("AZURE_BLOB_CONTAINER_NAME") - self.account_url = app_config.get("AZURE_BLOB_ACCOUNT_URL") - self.account_name = app_config.get("AZURE_BLOB_ACCOUNT_NAME") - self.account_key = app_config.get("AZURE_BLOB_ACCOUNT_KEY") - - def save(self, filename, data): - client = self._sync_client() - blob_container = client.get_container_client(container=self.bucket_name) - blob_container.upload_blob(filename, data) - - def load_once(self, filename: str) -> bytes: - client = self._sync_client() - blob = client.get_container_client(container=self.bucket_name) - blob = blob.get_blob_client(blob=filename) - data = blob.download_blob().readall() - return data - - def load_stream(self, filename: str) -> Generator: - client = self._sync_client() - - def generate(filename: str = filename) -> Generator: - blob = client.get_blob_client(container=self.bucket_name, blob=filename) - blob_data = blob.download_blob() - yield from blob_data.chunks() - - return generate(filename) - - def download(self, filename, target_filepath): - client = self._sync_client() - - blob = client.get_blob_client(container=self.bucket_name, blob=filename) - with open(target_filepath, "wb") as my_blob: - blob_data = blob.download_blob() - blob_data.readinto(my_blob) - - def exists(self, filename): - client = self._sync_client() - - blob = client.get_blob_client(container=self.bucket_name, blob=filename) - return blob.exists() - - def delete(self, filename): - client = self._sync_client() - - blob_container = client.get_container_client(container=self.bucket_name) - blob_container.delete_blob(filename) - - def _sync_client(self): - cache_key = "azure_blob_sas_token_{}_{}".format(self.account_name, self.account_key) - cache_result = redis_client.get(cache_key) - if cache_result is not None: - sas_token = cache_result.decode("utf-8") - else: - sas_token = generate_account_sas( - account_name=self.account_name, - account_key=self.account_key, - resource_types=ResourceTypes(service=True, container=True, object=True), - permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), - expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1), - ) - redis_client.set(cache_key, sas_token, ex=3000) - return BlobServiceClient(account_url=self.account_url, credential=sas_token) diff --git a/api/extensions/storage/baidu_obs_storage.py b/api/extensions/storage/baidu_obs_storage.py new file mode 100644 index 00000000000000..e0d2140e91272c --- /dev/null +++ b/api/extensions/storage/baidu_obs_storage.py @@ -0,0 +1,56 @@ +import base64 +import hashlib +from collections.abc import Generator + +from baidubce.auth.bce_credentials import BceCredentials +from baidubce.bce_client_configuration import BceClientConfiguration +from baidubce.services.bos.bos_client import BosClient + +from configs import dify_config +from extensions.storage.base_storage import BaseStorage + + +class BaiduObsStorage(BaseStorage): + """Implementation for Baidu OBS storage.""" + + def __init__(self): + super().__init__() + self.bucket_name = dify_config.BAIDU_OBS_BUCKET_NAME + client_config = BceClientConfiguration( + credentials=BceCredentials( + access_key_id=dify_config.BAIDU_OBS_ACCESS_KEY, + secret_access_key=dify_config.BAIDU_OBS_SECRET_KEY, + ), + endpoint=dify_config.BAIDU_OBS_ENDPOINT, + ) + + self.client = BosClient(config=client_config) + + def save(self, filename, data): + md5 = hashlib.md5() + md5.update(data) + content_md5 = base64.standard_b64encode(md5.digest()) + self.client.put_object( + bucket_name=self.bucket_name, key=filename, data=data, content_length=len(data), content_md5=content_md5 + ) + + def load_once(self, filename: str) -> bytes: + response = self.client.get_object(bucket_name=self.bucket_name, key=filename) + return response.data.read() + + def load_stream(self, filename: str) -> Generator: + response = self.client.get_object(bucket_name=self.bucket_name, key=filename).data + while chunk := response.read(4096): + yield chunk + + def download(self, filename, target_filepath): + self.client.get_object_to_file(bucket_name=self.bucket_name, key=filename, file_name=target_filepath) + + def exists(self, filename): + res = self.client.get_object_meta_data(bucket_name=self.bucket_name, key=filename) + if res is None: + return False + return True + + def delete(self, filename): + self.client.delete_object(bucket_name=self.bucket_name, key=filename) diff --git a/api/extensions/storage/base_storage.py b/api/extensions/storage/base_storage.py index c3fe9ec82a5b41..50abab8537ffa4 100644 --- a/api/extensions/storage/base_storage.py +++ b/api/extensions/storage/base_storage.py @@ -3,16 +3,12 @@ from abc import ABC, abstractmethod from collections.abc import Generator -from flask import Flask - class BaseStorage(ABC): """Interface for file storage.""" - app = None - - def __init__(self, app: Flask): - self.app = app + def __init__(self): # noqa: B027 + pass @abstractmethod def save(self, filename, data): diff --git a/api/extensions/storage/google_cloud_storage.py b/api/extensions/storage/google_cloud_storage.py new file mode 100644 index 00000000000000..26b662d2f04daf --- /dev/null +++ b/api/extensions/storage/google_cloud_storage.py @@ -0,0 +1,60 @@ +import base64 +import io +import json +from collections.abc import Generator + +from google.cloud import storage as google_cloud_storage + +from configs import dify_config +from extensions.storage.base_storage import BaseStorage + + +class GoogleCloudStorage(BaseStorage): + """Implementation for Google Cloud storage.""" + + def __init__(self): + super().__init__() + + self.bucket_name = dify_config.GOOGLE_STORAGE_BUCKET_NAME + service_account_json_str = dify_config.GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64 + # if service_account_json_str is empty, use Application Default Credentials + if service_account_json_str: + service_account_json = base64.b64decode(service_account_json_str).decode("utf-8") + # convert str to object + service_account_obj = json.loads(service_account_json) + self.client = google_cloud_storage.Client.from_service_account_info(service_account_obj) + else: + self.client = google_cloud_storage.Client() + + def save(self, filename, data): + bucket = self.client.get_bucket(self.bucket_name) + blob = bucket.blob(filename) + with io.BytesIO(data) as stream: + blob.upload_from_file(stream) + + def load_once(self, filename: str) -> bytes: + bucket = self.client.get_bucket(self.bucket_name) + blob = bucket.get_blob(filename) + data = blob.download_as_bytes() + return data + + def load_stream(self, filename: str) -> Generator: + bucket = self.client.get_bucket(self.bucket_name) + blob = bucket.get_blob(filename) + with blob.open(mode="rb") as blob_stream: + while chunk := blob_stream.read(4096): + yield chunk + + def download(self, filename, target_filepath): + bucket = self.client.get_bucket(self.bucket_name) + blob = bucket.get_blob(filename) + blob.download_to_filename(target_filepath) + + def exists(self, filename): + bucket = self.client.get_bucket(self.bucket_name) + blob = bucket.blob(filename) + return blob.exists() + + def delete(self, filename): + bucket = self.client.get_bucket(self.bucket_name) + bucket.delete_blob(filename) diff --git a/api/extensions/storage/google_storage.py b/api/extensions/storage/google_storage.py deleted file mode 100644 index 9ed1fcf0b4e118..00000000000000 --- a/api/extensions/storage/google_storage.py +++ /dev/null @@ -1,64 +0,0 @@ -import base64 -import io -import json -from collections.abc import Generator -from contextlib import closing - -from flask import Flask -from google.cloud import storage as GoogleCloudStorage - -from extensions.storage.base_storage import BaseStorage - - -class GoogleStorage(BaseStorage): - """Implementation for google storage.""" - - def __init__(self, app: Flask): - super().__init__(app) - app_config = self.app.config - self.bucket_name = app_config.get("GOOGLE_STORAGE_BUCKET_NAME") - service_account_json_str = app_config.get("GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64") - # if service_account_json_str is empty, use Application Default Credentials - if service_account_json_str: - service_account_json = base64.b64decode(service_account_json_str).decode("utf-8") - # convert str to object - service_account_obj = json.loads(service_account_json) - self.client = GoogleCloudStorage.Client.from_service_account_info(service_account_obj) - else: - self.client = GoogleCloudStorage.Client() - - def save(self, filename, data): - bucket = self.client.get_bucket(self.bucket_name) - blob = bucket.blob(filename) - with io.BytesIO(data) as stream: - blob.upload_from_file(stream) - - def load_once(self, filename: str) -> bytes: - bucket = self.client.get_bucket(self.bucket_name) - blob = bucket.get_blob(filename) - data = blob.download_as_bytes() - return data - - def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - bucket = self.client.get_bucket(self.bucket_name) - blob = bucket.get_blob(filename) - with closing(blob.open(mode="rb")) as blob_stream: - while chunk := blob_stream.read(4096): - yield chunk - - return generate() - - def download(self, filename, target_filepath): - bucket = self.client.get_bucket(self.bucket_name) - blob = bucket.get_blob(filename) - blob.download_to_filename(target_filepath) - - def exists(self, filename): - bucket = self.client.get_bucket(self.bucket_name) - blob = bucket.blob(filename) - return blob.exists() - - def delete(self, filename): - bucket = self.client.get_bucket(self.bucket_name) - bucket.delete_blob(filename) diff --git a/api/extensions/storage/huawei_obs_storage.py b/api/extensions/storage/huawei_obs_storage.py new file mode 100644 index 00000000000000..20be70ef83dd7a --- /dev/null +++ b/api/extensions/storage/huawei_obs_storage.py @@ -0,0 +1,51 @@ +from collections.abc import Generator + +from obs import ObsClient + +from configs import dify_config +from extensions.storage.base_storage import BaseStorage + + +class HuaweiObsStorage(BaseStorage): + """Implementation for Huawei OBS storage.""" + + def __init__(self): + super().__init__() + + self.bucket_name = dify_config.HUAWEI_OBS_BUCKET_NAME + self.client = ObsClient( + access_key_id=dify_config.HUAWEI_OBS_ACCESS_KEY, + secret_access_key=dify_config.HUAWEI_OBS_SECRET_KEY, + server=dify_config.HUAWEI_OBS_SERVER, + ) + + def save(self, filename, data): + self.client.putObject(bucketName=self.bucket_name, objectKey=filename, content=data) + + def load_once(self, filename: str) -> bytes: + data = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response.read() + return data + + def load_stream(self, filename: str) -> Generator: + response = self.client.getObject(bucketName=self.bucket_name, objectKey=filename)["body"].response + while chunk := response.read(4096): + yield chunk + + def download(self, filename, target_filepath): + self.client.getObject(bucketName=self.bucket_name, objectKey=filename, downloadPath=target_filepath) + + def exists(self, filename): + res = self._get_meta(filename) + if res is None: + return False + return True + + def delete(self, filename): + self.client.deleteObject(bucketName=self.bucket_name, objectKey=filename) + + def _get_meta(self, filename): + res = self.client.getObjectMetadata(bucketName=self.bucket_name, objectKey=filename) + if res.status < 300: + return res + else: + return None diff --git a/api/extensions/storage/local_fs_storage.py b/api/extensions/storage/local_fs_storage.py new file mode 100644 index 00000000000000..5a495ca4d41042 --- /dev/null +++ b/api/extensions/storage/local_fs_storage.py @@ -0,0 +1,62 @@ +import os +import shutil +from collections.abc import Generator +from pathlib import Path + +from flask import current_app + +from configs import dify_config +from extensions.storage.base_storage import BaseStorage + + +class LocalFsStorage(BaseStorage): + """Implementation for local filesystem storage.""" + + def __init__(self): + super().__init__() + folder = dify_config.STORAGE_LOCAL_PATH + if not os.path.isabs(folder): + folder = os.path.join(current_app.root_path, folder) + self.folder = folder + + def _build_filepath(self, filename: str) -> str: + """Build the full file path based on the folder and filename.""" + if not self.folder or self.folder.endswith("/"): + return self.folder + filename + else: + return self.folder + "/" + filename + + def save(self, filename, data): + filepath = self._build_filepath(filename) + folder = os.path.dirname(filepath) + os.makedirs(folder, exist_ok=True) + Path(os.path.join(os.getcwd(), filepath)).write_bytes(data) + + def load_once(self, filename: str) -> bytes: + filepath = self._build_filepath(filename) + if not os.path.exists(filepath): + raise FileNotFoundError("File not found") + return Path(filepath).read_bytes() + + def load_stream(self, filename: str) -> Generator: + filepath = self._build_filepath(filename) + if not os.path.exists(filepath): + raise FileNotFoundError("File not found") + with open(filepath, "rb") as f: + while chunk := f.read(4096): # Read in chunks of 4KB + yield chunk + + def download(self, filename, target_filepath): + filepath = self._build_filepath(filename) + if not os.path.exists(filepath): + raise FileNotFoundError("File not found") + shutil.copyfile(filepath, target_filepath) + + def exists(self, filename): + filepath = self._build_filepath(filename) + return os.path.exists(filepath) + + def delete(self, filename): + filepath = self._build_filepath(filename) + if os.path.exists(filepath): + os.remove(filepath) diff --git a/api/extensions/storage/local_storage.py b/api/extensions/storage/local_storage.py deleted file mode 100644 index 46ee4bf80f8e6d..00000000000000 --- a/api/extensions/storage/local_storage.py +++ /dev/null @@ -1,87 +0,0 @@ -import os -import shutil -from collections.abc import Generator - -from flask import Flask - -from extensions.storage.base_storage import BaseStorage - - -class LocalStorage(BaseStorage): - """Implementation for local storage.""" - - def __init__(self, app: Flask): - super().__init__(app) - folder = self.app.config.get("STORAGE_LOCAL_PATH") - if not os.path.isabs(folder): - folder = os.path.join(app.root_path, folder) - self.folder = folder - - def save(self, filename, data): - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - - folder = os.path.dirname(filename) - os.makedirs(folder, exist_ok=True) - - with open(os.path.join(os.getcwd(), filename), "wb") as f: - f.write(data) - - def load_once(self, filename: str) -> bytes: - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - - if not os.path.exists(filename): - raise FileNotFoundError("File not found") - - with open(filename, "rb") as f: - data = f.read() - - return data - - def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - - if not os.path.exists(filename): - raise FileNotFoundError("File not found") - - with open(filename, "rb") as f: - while chunk := f.read(4096): # Read in chunks of 4KB - yield chunk - - return generate() - - def download(self, filename, target_filepath): - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - - if not os.path.exists(filename): - raise FileNotFoundError("File not found") - - shutil.copyfile(filename, target_filepath) - - def exists(self, filename): - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - - return os.path.exists(filename) - - def delete(self, filename): - if not self.folder or self.folder.endswith("/"): - filename = self.folder + filename - else: - filename = self.folder + "/" + filename - if os.path.exists(filename): - os.remove(filename) diff --git a/api/extensions/storage/oci_storage.py b/api/extensions/storage/oci_storage.py deleted file mode 100644 index e32fa0a0ae78a9..00000000000000 --- a/api/extensions/storage/oci_storage.py +++ /dev/null @@ -1,65 +0,0 @@ -from collections.abc import Generator -from contextlib import closing - -import boto3 -from botocore.exceptions import ClientError -from flask import Flask - -from extensions.storage.base_storage import BaseStorage - - -class OCIStorage(BaseStorage): - def __init__(self, app: Flask): - super().__init__(app) - app_config = self.app.config - self.bucket_name = app_config.get("OCI_BUCKET_NAME") - self.client = boto3.client( - "s3", - aws_secret_access_key=app_config.get("OCI_SECRET_KEY"), - aws_access_key_id=app_config.get("OCI_ACCESS_KEY"), - endpoint_url=app_config.get("OCI_ENDPOINT"), - region_name=app_config.get("OCI_REGION"), - ) - - def save(self, filename, data): - self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data) - - def load_once(self, filename: str) -> bytes: - try: - with closing(self.client) as client: - data = client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() - except ClientError as ex: - if ex.response["Error"]["Code"] == "NoSuchKey": - raise FileNotFoundError("File not found") - else: - raise - return data - - def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - try: - with closing(self.client) as client: - response = client.get_object(Bucket=self.bucket_name, Key=filename) - yield from response["Body"].iter_chunks() - except ClientError as ex: - if ex.response["Error"]["Code"] == "NoSuchKey": - raise FileNotFoundError("File not found") - else: - raise - - return generate() - - def download(self, filename, target_filepath): - with closing(self.client) as client: - client.download_file(self.bucket_name, filename, target_filepath) - - def exists(self, filename): - with closing(self.client) as client: - try: - client.head_object(Bucket=self.bucket_name, Key=filename) - return True - except: - return False - - def delete(self, filename): - self.client.delete_object(Bucket=self.bucket_name, Key=filename) diff --git a/api/extensions/storage/oracle_oci_storage.py b/api/extensions/storage/oracle_oci_storage.py new file mode 100644 index 00000000000000..b59f83b8de90bf --- /dev/null +++ b/api/extensions/storage/oracle_oci_storage.py @@ -0,0 +1,59 @@ +from collections.abc import Generator + +import boto3 +from botocore.exceptions import ClientError + +from configs import dify_config +from extensions.storage.base_storage import BaseStorage + + +class OracleOCIStorage(BaseStorage): + """Implementation for Oracle OCI storage.""" + + def __init__(self): + super().__init__() + + self.bucket_name = dify_config.OCI_BUCKET_NAME + self.client = boto3.client( + "s3", + aws_secret_access_key=dify_config.OCI_SECRET_KEY, + aws_access_key_id=dify_config.OCI_ACCESS_KEY, + endpoint_url=dify_config.OCI_ENDPOINT, + region_name=dify_config.OCI_REGION, + ) + + def save(self, filename, data): + self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data) + + def load_once(self, filename: str) -> bytes: + try: + data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() + except ClientError as ex: + if ex.response["Error"]["Code"] == "NoSuchKey": + raise FileNotFoundError("File not found") + else: + raise + return data + + def load_stream(self, filename: str) -> Generator: + try: + response = self.client.get_object(Bucket=self.bucket_name, Key=filename) + yield from response["Body"].iter_chunks() + except ClientError as ex: + if ex.response["Error"]["Code"] == "NoSuchKey": + raise FileNotFoundError("File not found") + else: + raise + + def download(self, filename, target_filepath): + self.client.download_file(self.bucket_name, filename, target_filepath) + + def exists(self, filename): + try: + self.client.head_object(Bucket=self.bucket_name, Key=filename) + return True + except: + return False + + def delete(self, filename): + self.client.delete_object(Bucket=self.bucket_name, Key=filename) diff --git a/api/extensions/storage/s3_storage.py b/api/extensions/storage/s3_storage.py deleted file mode 100644 index 022ce5b14a7b88..00000000000000 --- a/api/extensions/storage/s3_storage.py +++ /dev/null @@ -1,73 +0,0 @@ -from collections.abc import Generator -from contextlib import closing - -import boto3 -from botocore.client import Config -from botocore.exceptions import ClientError -from flask import Flask - -from extensions.storage.base_storage import BaseStorage - - -class S3Storage(BaseStorage): - """Implementation for s3 storage.""" - - def __init__(self, app: Flask): - super().__init__(app) - app_config = self.app.config - self.bucket_name = app_config.get("S3_BUCKET_NAME") - if app_config.get("S3_USE_AWS_MANAGED_IAM"): - session = boto3.Session() - self.client = session.client("s3") - else: - self.client = boto3.client( - "s3", - aws_secret_access_key=app_config.get("S3_SECRET_KEY"), - aws_access_key_id=app_config.get("S3_ACCESS_KEY"), - endpoint_url=app_config.get("S3_ENDPOINT"), - region_name=app_config.get("S3_REGION"), - config=Config(s3={"addressing_style": app_config.get("S3_ADDRESS_STYLE")}), - ) - - def save(self, filename, data): - self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data) - - def load_once(self, filename: str) -> bytes: - try: - with closing(self.client) as client: - data = client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() - except ClientError as ex: - if ex.response["Error"]["Code"] == "NoSuchKey": - raise FileNotFoundError("File not found") - else: - raise - return data - - def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - try: - with closing(self.client) as client: - response = client.get_object(Bucket=self.bucket_name, Key=filename) - yield from response["Body"].iter_chunks() - except ClientError as ex: - if ex.response["Error"]["Code"] == "NoSuchKey": - raise FileNotFoundError("File not found") - else: - raise - - return generate() - - def download(self, filename, target_filepath): - with closing(self.client) as client: - client.download_file(self.bucket_name, filename, target_filepath) - - def exists(self, filename): - with closing(self.client) as client: - try: - client.head_object(Bucket=self.bucket_name, Key=filename) - return True - except: - return False - - def delete(self, filename): - self.client.delete_object(Bucket=self.bucket_name, Key=filename) diff --git a/api/extensions/storage/storage_type.py b/api/extensions/storage/storage_type.py new file mode 100644 index 00000000000000..415bf251f6e280 --- /dev/null +++ b/api/extensions/storage/storage_type.py @@ -0,0 +1,15 @@ +from enum import Enum + + +class StorageType(str, Enum): + ALIYUN_OSS = "aliyun-oss" + AZURE_BLOB = "azure-blob" + BAIDU_OBS = "baidu-obs" + GOOGLE_STORAGE = "google-storage" + HUAWEI_OBS = "huawei-obs" + LOCAL = "local" + OCI_STORAGE = "oci-storage" + S3 = "s3" + TENCENT_COS = "tencent-cos" + VOLCENGINE_TOS = "volcengine-tos" + SUPBASE = "supabase" diff --git a/api/extensions/storage/supabase_storage.py b/api/extensions/storage/supabase_storage.py new file mode 100644 index 00000000000000..9f7c69a9ae6312 --- /dev/null +++ b/api/extensions/storage/supabase_storage.py @@ -0,0 +1,59 @@ +import io +from collections.abc import Generator +from pathlib import Path + +from supabase import Client + +from configs import dify_config +from extensions.storage.base_storage import BaseStorage + + +class SupabaseStorage(BaseStorage): + """Implementation for supabase obs storage.""" + + def __init__(self): + super().__init__() + if dify_config.SUPABASE_URL is None: + raise ValueError("SUPABASE_URL is not set") + if dify_config.SUPABASE_API_KEY is None: + raise ValueError("SUPABASE_API_KEY is not set") + if dify_config.SUPABASE_BUCKET_NAME is None: + raise ValueError("SUPABASE_BUCKET_NAME is not set") + + self.bucket_name = dify_config.SUPABASE_BUCKET_NAME + self.client = Client(supabase_url=dify_config.SUPABASE_URL, supabase_key=dify_config.SUPABASE_API_KEY) + self.create_bucket(id=dify_config.SUPABASE_BUCKET_NAME, bucket_name=dify_config.SUPABASE_BUCKET_NAME) + + def create_bucket(self, id, bucket_name): + if not self.bucket_exists(): + self.client.storage.create_bucket(id=id, name=bucket_name) + + def save(self, filename, data): + self.client.storage.from_(self.bucket_name).upload(filename, data) + + def load_once(self, filename: str) -> bytes: + content = self.client.storage.from_(self.bucket_name).download(filename) + return content + + def load_stream(self, filename: str) -> Generator: + result = self.client.storage.from_(self.bucket_name).download(filename) + byte_stream = io.BytesIO(result) + while chunk := byte_stream.read(4096): # Read in chunks of 4KB + yield chunk + + def download(self, filename, target_filepath): + result = self.client.storage.from_(self.bucket_name).download(filename) + Path(target_filepath).write_bytes(result) + + def exists(self, filename): + result = self.client.storage.from_(self.bucket_name).list(filename) + if result.count() > 0: + return True + return False + + def delete(self, filename): + self.client.storage.from_(self.bucket_name).remove(filename) + + def bucket_exists(self): + buckets = self.client.storage.list_buckets() + return any(bucket.name == self.bucket_name for bucket in buckets) diff --git a/api/extensions/storage/tencent_cos_storage.py b/api/extensions/storage/tencent_cos_storage.py new file mode 100644 index 00000000000000..13a6c9239c2d1e --- /dev/null +++ b/api/extensions/storage/tencent_cos_storage.py @@ -0,0 +1,43 @@ +from collections.abc import Generator + +from qcloud_cos import CosConfig, CosS3Client + +from configs import dify_config +from extensions.storage.base_storage import BaseStorage + + +class TencentCosStorage(BaseStorage): + """Implementation for Tencent Cloud COS storage.""" + + def __init__(self): + super().__init__() + + self.bucket_name = dify_config.TENCENT_COS_BUCKET_NAME + config = CosConfig( + Region=dify_config.TENCENT_COS_REGION, + SecretId=dify_config.TENCENT_COS_SECRET_ID, + SecretKey=dify_config.TENCENT_COS_SECRET_KEY, + Scheme=dify_config.TENCENT_COS_SCHEME, + ) + self.client = CosS3Client(config) + + def save(self, filename, data): + self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename) + + def load_once(self, filename: str) -> bytes: + data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read() + return data + + def load_stream(self, filename: str) -> Generator: + response = self.client.get_object(Bucket=self.bucket_name, Key=filename) + yield from response["Body"].get_stream(chunk_size=4096) + + def download(self, filename, target_filepath): + response = self.client.get_object(Bucket=self.bucket_name, Key=filename) + response["Body"].get_stream_to_file(target_filepath) + + def exists(self, filename): + return self.client.object_exists(Bucket=self.bucket_name, Key=filename) + + def delete(self, filename): + self.client.delete_object(Bucket=self.bucket_name, Key=filename) diff --git a/api/extensions/storage/tencent_storage.py b/api/extensions/storage/tencent_storage.py deleted file mode 100644 index 1d499cd3bcea1c..00000000000000 --- a/api/extensions/storage/tencent_storage.py +++ /dev/null @@ -1,46 +0,0 @@ -from collections.abc import Generator - -from flask import Flask -from qcloud_cos import CosConfig, CosS3Client - -from extensions.storage.base_storage import BaseStorage - - -class TencentStorage(BaseStorage): - """Implementation for tencent cos storage.""" - - def __init__(self, app: Flask): - super().__init__(app) - app_config = self.app.config - self.bucket_name = app_config.get("TENCENT_COS_BUCKET_NAME") - config = CosConfig( - Region=app_config.get("TENCENT_COS_REGION"), - SecretId=app_config.get("TENCENT_COS_SECRET_ID"), - SecretKey=app_config.get("TENCENT_COS_SECRET_KEY"), - Scheme=app_config.get("TENCENT_COS_SCHEME"), - ) - self.client = CosS3Client(config) - - def save(self, filename, data): - self.client.put_object(Bucket=self.bucket_name, Body=data, Key=filename) - - def load_once(self, filename: str) -> bytes: - data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].get_raw_stream().read() - return data - - def load_stream(self, filename: str) -> Generator: - def generate(filename: str = filename) -> Generator: - response = self.client.get_object(Bucket=self.bucket_name, Key=filename) - yield from response["Body"].get_stream(chunk_size=4096) - - return generate() - - def download(self, filename, target_filepath): - response = self.client.get_object(Bucket=self.bucket_name, Key=filename) - response["Body"].get_stream_to_file(target_filepath) - - def exists(self, filename): - return self.client.object_exists(Bucket=self.bucket_name, Key=filename) - - def delete(self, filename): - self.client.delete_object(Bucket=self.bucket_name, Key=filename) diff --git a/api/extensions/storage/volcengine_tos_storage.py b/api/extensions/storage/volcengine_tos_storage.py new file mode 100644 index 00000000000000..de82be04ea87b7 --- /dev/null +++ b/api/extensions/storage/volcengine_tos_storage.py @@ -0,0 +1,44 @@ +from collections.abc import Generator + +import tos + +from configs import dify_config +from extensions.storage.base_storage import BaseStorage + + +class VolcengineTosStorage(BaseStorage): + """Implementation for Volcengine TOS storage.""" + + def __init__(self): + super().__init__() + self.bucket_name = dify_config.VOLCENGINE_TOS_BUCKET_NAME + self.client = tos.TosClientV2( + ak=dify_config.VOLCENGINE_TOS_ACCESS_KEY, + sk=dify_config.VOLCENGINE_TOS_SECRET_KEY, + endpoint=dify_config.VOLCENGINE_TOS_ENDPOINT, + region=dify_config.VOLCENGINE_TOS_REGION, + ) + + def save(self, filename, data): + self.client.put_object(bucket=self.bucket_name, key=filename, content=data) + + def load_once(self, filename: str) -> bytes: + data = self.client.get_object(bucket=self.bucket_name, key=filename).read() + return data + + def load_stream(self, filename: str) -> Generator: + response = self.client.get_object(bucket=self.bucket_name, key=filename) + while chunk := response.read(4096): + yield chunk + + def download(self, filename, target_filepath): + self.client.get_object_to_file(bucket=self.bucket_name, key=filename, file_path=target_filepath) + + def exists(self, filename): + res = self.client.head_object(bucket=self.bucket_name, key=filename) + if res.status_code != 200: + return False + return True + + def delete(self, filename): + self.client.delete_object(bucket=self.bucket_name, key=filename) diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py new file mode 100644 index 00000000000000..738b2b3478f46a --- /dev/null +++ b/api/factories/file_factory.py @@ -0,0 +1,231 @@ +import mimetypes +from collections.abc import Callable, Mapping, Sequence +from typing import Any + +import httpx +from sqlalchemy import select + +from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig +from core.helper import ssrf_proxy +from extensions.ext_database import db +from models import MessageFile, ToolFile, UploadFile + + +def build_from_message_files( + *, + message_files: Sequence["MessageFile"], + tenant_id: str, + config: FileUploadConfig, +) -> Sequence[File]: + results = [ + build_from_message_file(message_file=file, tenant_id=tenant_id, config=config) + for file in message_files + if file.belongs_to != FileBelongsTo.ASSISTANT + ] + return results + + +def build_from_message_file( + *, + message_file: "MessageFile", + tenant_id: str, + config: FileUploadConfig, +): + mapping = { + "transfer_method": message_file.transfer_method, + "url": message_file.url, + "id": message_file.id, + "type": message_file.type, + "upload_file_id": message_file.upload_file_id, + } + return build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + config=config, + ) + + +def build_from_mapping( + *, + mapping: Mapping[str, Any], + tenant_id: str, + config: FileUploadConfig | None = None, +) -> File: + config = config or FileUploadConfig() + + transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method")) + + build_functions: dict[FileTransferMethod, Callable] = { + FileTransferMethod.LOCAL_FILE: _build_from_local_file, + FileTransferMethod.REMOTE_URL: _build_from_remote_url, + FileTransferMethod.TOOL_FILE: _build_from_tool_file, + } + + build_func = build_functions.get(transfer_method) + if not build_func: + raise ValueError(f"Invalid file transfer method: {transfer_method}") + + file = build_func( + mapping=mapping, + tenant_id=tenant_id, + transfer_method=transfer_method, + ) + + if not _is_file_valid_with_config(file=file, config=config): + raise ValueError(f"File validation failed for file: {file.filename}") + + return file + + +def build_from_mappings( + *, + mappings: Sequence[Mapping[str, Any]], + config: FileUploadConfig | None, + tenant_id: str, +) -> Sequence[File]: + if not config: + return [] + + files = [ + build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + config=config, + ) + for mapping in mappings + ] + + if ( + # If image config is set. + config.image_config + # And the number of image files exceeds the maximum limit + and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits + ): + raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}") + if config.number_limits and len(files) > config.number_limits: + raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}") + + return files + + +def _build_from_local_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, +) -> File: + file_type = FileType.value_of(mapping.get("type")) + stmt = select(UploadFile).where( + UploadFile.id == mapping.get("upload_file_id"), + UploadFile.tenant_id == tenant_id, + ) + + row = db.session.scalar(stmt) + + if row is None: + raise ValueError("Invalid upload file") + + return File( + id=mapping.get("id"), + filename=row.name, + extension="." + row.extension, + mime_type=row.mime_type, + tenant_id=tenant_id, + type=file_type, + transfer_method=transfer_method, + remote_url=row.source_url, + related_id=mapping.get("upload_file_id"), + size=row.size, + ) + + +def _build_from_remote_url( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, +) -> File: + url = mapping.get("url") + if not url: + raise ValueError("Invalid file url") + + mime_type, filename, file_size = _get_remote_file_info(url) + extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin" + + return File( + id=mapping.get("id"), + filename=filename, + tenant_id=tenant_id, + type=FileType.value_of(mapping.get("type")), + transfer_method=transfer_method, + remote_url=url, + mime_type=mime_type, + extension=extension, + size=file_size, + ) + + +def _get_remote_file_info(url: str): + mime_type = mimetypes.guess_type(url)[0] or "" + file_size = -1 + filename = url.split("/")[-1].split("?")[0] or "unknown_file" + + resp = ssrf_proxy.head(url, follow_redirects=True) + if resp.status_code == httpx.codes.OK: + if content_disposition := resp.headers.get("Content-Disposition"): + filename = str(content_disposition.split("filename=")[-1].strip('"')) + file_size = int(resp.headers.get("Content-Length", file_size)) + mime_type = mime_type or str(resp.headers.get("Content-Type", "")) + + return mime_type, filename, file_size + + +def _build_from_tool_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, +) -> File: + tool_file = ( + db.session.query(ToolFile) + .filter( + ToolFile.id == mapping.get("tool_file_id"), + ToolFile.tenant_id == tenant_id, + ) + .first() + ) + + if tool_file is None: + raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") + + extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" + + return File( + id=mapping.get("id"), + tenant_id=tenant_id, + filename=tool_file.name, + type=FileType.value_of(mapping.get("type")), + transfer_method=transfer_method, + remote_url=tool_file.original_url, + related_id=tool_file.id, + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + ) + + +def _is_file_valid_with_config(*, file: File, config: FileUploadConfig) -> bool: + if config.allowed_file_types and file.type not in config.allowed_file_types and file.type != FileType.CUSTOM: + return False + + if config.allowed_extensions and file.extension not in config.allowed_extensions: + return False + + if config.allowed_upload_methods and file.transfer_method not in config.allowed_upload_methods: + return False + + if file.type == FileType.IMAGE and config.image_config: + if config.image_config.transfer_methods and file.transfer_method not in config.image_config.transfer_methods: + return False + + return True diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py new file mode 100644 index 00000000000000..0191102b902cc9 --- /dev/null +++ b/api/factories/variable_factory.py @@ -0,0 +1,98 @@ +from collections.abc import Mapping +from typing import Any + +from configs import dify_config +from core.file import File +from core.variables import ( + ArrayAnySegment, + ArrayFileSegment, + ArrayNumberSegment, + ArrayNumberVariable, + ArrayObjectSegment, + ArrayObjectVariable, + ArraySegment, + ArrayStringSegment, + ArrayStringVariable, + FileSegment, + FloatSegment, + FloatVariable, + IntegerSegment, + IntegerVariable, + NoneSegment, + ObjectSegment, + ObjectVariable, + SecretVariable, + Segment, + SegmentType, + StringSegment, + StringVariable, + Variable, +) +from core.variables.exc import VariableError + + +def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: + if (value_type := mapping.get("value_type")) is None: + raise VariableError("missing value type") + if not mapping.get("name"): + raise VariableError("missing name") + if (value := mapping.get("value")) is None: + raise VariableError("missing value") + match value_type: + case SegmentType.STRING: + result = StringVariable.model_validate(mapping) + case SegmentType.SECRET: + result = SecretVariable.model_validate(mapping) + case SegmentType.NUMBER if isinstance(value, int): + result = IntegerVariable.model_validate(mapping) + case SegmentType.NUMBER if isinstance(value, float): + result = FloatVariable.model_validate(mapping) + case SegmentType.NUMBER if not isinstance(value, float | int): + raise VariableError(f"invalid number value {value}") + case SegmentType.OBJECT if isinstance(value, dict): + result = ObjectVariable.model_validate(mapping) + case SegmentType.ARRAY_STRING if isinstance(value, list): + result = ArrayStringVariable.model_validate(mapping) + case SegmentType.ARRAY_NUMBER if isinstance(value, list): + result = ArrayNumberVariable.model_validate(mapping) + case SegmentType.ARRAY_OBJECT if isinstance(value, list): + result = ArrayObjectVariable.model_validate(mapping) + case _: + raise VariableError(f"not supported value type {value_type}") + if result.size > dify_config.MAX_VARIABLE_SIZE: + raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") + return result + + +def build_segment(value: Any, /) -> Segment: + if value is None: + return NoneSegment() + if isinstance(value, str): + return StringSegment(value=value) + if isinstance(value, int): + return IntegerSegment(value=value) + if isinstance(value, float): + return FloatSegment(value=value) + if isinstance(value, dict): + return ObjectSegment(value=value) + if isinstance(value, File): + return FileSegment(value=value) + if isinstance(value, list): + items = [build_segment(item) for item in value] + types = {item.value_type for item in items} + if len(types) != 1 or all(isinstance(item, ArraySegment) for item in items): + return ArrayAnySegment(value=value) + match types.pop(): + case SegmentType.STRING: + return ArrayStringSegment(value=value) + case SegmentType.NUMBER: + return ArrayNumberSegment(value=value) + case SegmentType.OBJECT: + return ArrayObjectSegment(value=value) + case SegmentType.FILE: + return ArrayFileSegment(value=value) + case SegmentType.NONE: + return ArrayAnySegment(value=value) + case _: + raise ValueError(f"not supported value {value}") + raise ValueError(f"not supported value {value}") diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 7036d58e4ab3a2..aa353a3cc191dd 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -1,14 +1,17 @@ from flask_restful import fields -from libs.helper import TimestampField +from fields.workflow_fields import workflow_partial_fields +from libs.helper import AppIconUrlField, TimestampField app_detail_kernel_fields = { "id": fields.String, "name": fields.String, "description": fields.String, "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon_type": fields.String, "icon": fields.String, "icon_background": fields.String, + "icon_url": AppIconUrlField, } related_app_list = { @@ -37,7 +40,10 @@ "completion_prompt_config": fields.Raw(attribute="completion_prompt_config_dict"), "dataset_configs": fields.Raw(attribute="dataset_configs_dict"), "file_upload": fields.Raw(attribute="file_upload_dict"), + "created_by": fields.String, "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, } app_detail_fields = { @@ -50,8 +56,13 @@ "enable_site": fields.Boolean, "enable_api": fields.Boolean, "model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True), + "workflow": fields.Nested(workflow_partial_fields, allow_null=True), "tracing": fields.Raw, + "use_icon_as_answer_icon": fields.Boolean, + "created_by": fields.String, "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, } prompt_config_fields = { @@ -61,6 +72,10 @@ model_config_partial_fields = { "model": fields.Raw(attribute="model_dict"), "pre_prompt": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, } tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} @@ -71,10 +86,17 @@ "max_active_requests": fields.Raw(), "description": fields.String(attribute="desc_or_prompt"), "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon_type": fields.String, "icon": fields.String, "icon_background": fields.String, + "icon_url": AppIconUrlField, "model_config": fields.Nested(model_config_partial_fields, attribute="app_model_config", allow_null=True), + "workflow": fields.Nested(workflow_partial_fields, allow_null=True), + "use_icon_as_answer_icon": fields.Boolean, + "created_by": fields.String, "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, "tags": fields.List(fields.Nested(tag_fields)), } @@ -104,8 +126,10 @@ "access_token": fields.String(attribute="code"), "code": fields.String, "title": fields.String, + "icon_type": fields.String, "icon": fields.String, "icon_background": fields.String, + "icon_url": AppIconUrlField, "description": fields.String, "default_language": fields.String, "chat_color_theme": fields.String, @@ -118,6 +142,11 @@ "prompt_public": fields.Boolean, "app_base_url": fields.String, "show_workflow_steps": fields.Boolean, + "use_icon_as_answer_icon": fields.Boolean, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, } app_detail_fields_with_site = { @@ -125,14 +154,21 @@ "name": fields.String, "description": fields.String, "mode": fields.String(attribute="mode_compatible_with_agent"), + "icon_type": fields.String, "icon": fields.String, "icon_background": fields.String, + "icon_url": AppIconUrlField, "enable_site": fields.Boolean, "enable_api": fields.Boolean, "model_config": fields.Nested(model_config_fields, attribute="app_model_config", allow_null=True), + "workflow": fields.Nested(workflow_partial_fields, allow_null=True), "site": fields.Nested(site_fields), "api_base_url": fields.String, + "use_icon_as_answer_icon": fields.Boolean, + "created_by": fields.String, "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, "deleted_tools": fields.List(fields.String), } @@ -152,4 +188,5 @@ "customize_token_strategy": fields.String, "prompt_public": fields.Boolean, "show_workflow_steps": fields.Boolean, + "use_icon_as_answer_icon": fields.Boolean, } diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index 1b15fe38800b3e..5bd21be80779a4 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -3,6 +3,8 @@ from fields.member_fields import simple_account_fields from libs.helper import TimestampField +from .raws import FilesContainedField + class MessageTextField(fields.Raw): def format(self, value): @@ -33,8 +35,12 @@ def format(self, value): message_file_fields = { "id": fields.String, + "filename": fields.String, "type": fields.String, "url": fields.String, + "mime_type": fields.String, + "size": fields.Integer, + "transfer_method": fields.String, "belongs_to": fields.String(default="user"), } @@ -55,7 +61,7 @@ def format(self, value): message_detail_fields = { "id": fields.String, "conversation_id": fields.String, - "inputs": fields.Raw, + "inputs": FilesContainedField, "query": fields.String, "message": fields.Raw, "message_tokens": fields.Integer, @@ -71,10 +77,11 @@ def format(self, value): "annotation_hit_history": fields.Nested(annotation_hit_history_fields, allow_null=True), "created_at": TimestampField, "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), - "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "message_files": fields.List(fields.Nested(message_file_fields)), "metadata": fields.Raw(attribute="message_metadata_dict"), "status": fields.String, "error": fields.String, + "parent_message_id": fields.String, } feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer} @@ -98,7 +105,7 @@ def format(self, value): } simple_message_detail_fields = { - "inputs": fields.Raw, + "inputs": FilesContainedField, "query": fields.String, "message": MessageTextField, "answer": fields.String, @@ -111,8 +118,10 @@ def format(self, value): "from_end_user_id": fields.String, "from_end_user_session_id": fields.String(), "from_account_id": fields.String, + "from_account_name": fields.String, "read_at": TimestampField, "created_at": TimestampField, + "updated_at": TimestampField, "annotation": fields.Nested(annotation_fields, allow_null=True), "model_config": fields.Nested(simple_model_config_fields), "user_feedback_stats": fields.Nested(feedback_stat_fields), @@ -146,10 +155,12 @@ def format(self, value): "from_end_user_id": fields.String, "from_end_user_session_id": fields.String, "from_account_id": fields.String, + "from_account_name": fields.String, "name": fields.String, "summary": fields.String(attribute="summary_or_query"), "read_at": TimestampField, "created_at": TimestampField, + "updated_at": TimestampField, "annotated": fields.Boolean, "model_config": fields.Nested(simple_model_config_fields), "message_count": fields.Integer, @@ -172,6 +183,7 @@ def format(self, value): "from_end_user_id": fields.String, "from_account_id": fields.String, "created_at": TimestampField, + "updated_at": TimestampField, "annotated": fields.Boolean, "introduction": fields.String, "model_config": fields.Nested(model_config_fields), @@ -183,10 +195,15 @@ def format(self, value): simple_conversation_fields = { "id": fields.String, "name": fields.String, - "inputs": fields.Raw, + "inputs": FilesContainedField, "status": fields.String, "introduction": fields.String, "created_at": TimestampField, + "updated_at": TimestampField, +} + +conversation_delete_fields = { + "result": fields.String, } conversation_infinite_scroll_pagination_fields = { diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 9cf8da7acdc984..b32423f10c9dd1 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -38,9 +38,20 @@ "score_threshold_enabled": fields.Boolean, "score_threshold": fields.Float, } +external_retrieval_model_fields = { + "top_k": fields.Integer, + "score_threshold": fields.Float, +} tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} +external_knowledge_info_fields = { + "external_knowledge_id": fields.String, + "external_knowledge_api_id": fields.String, + "external_knowledge_api_name": fields.String, + "external_knowledge_api_endpoint": fields.String, +} + dataset_detail_fields = { "id": fields.String, "name": fields.String, @@ -61,6 +72,8 @@ "embedding_available": fields.Boolean, "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields), "tags": fields.List(fields.Nested(tag_fields)), + "external_knowledge_info": fields.Nested(external_knowledge_info_fields), + "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True), } dataset_query_detail_fields = { diff --git a/api/fields/external_dataset_fields.py b/api/fields/external_dataset_fields.py new file mode 100644 index 00000000000000..2281460fe22146 --- /dev/null +++ b/api/fields/external_dataset_fields.py @@ -0,0 +1,11 @@ +from flask_restful import fields + +from libs.helper import TimestampField + +external_knowledge_api_query_detail_fields = { + "id": fields.String, + "name": fields.String, + "setting": fields.String, + "created_by": fields.String, + "created_at": TimestampField, +} diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index e5a03ce77ed5f0..afaacc0568ea0c 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -6,6 +6,9 @@ "file_size_limit": fields.Integer, "batch_count_limit": fields.Integer, "image_file_size_limit": fields.Integer, + "video_file_size_limit": fields.Integer, + "audio_file_size_limit": fields.Integer, + "workflow_file_upload_limit": fields.Integer, } file_fields = { @@ -17,3 +20,20 @@ "created_by": fields.String, "created_at": TimestampField, } + +remote_file_info_fields = { + "file_type": fields.String(attribute="file_type"), + "file_length": fields.Integer(attribute="file_length"), +} + + +file_fields_with_signed_url = { + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "url": fields.String, + "mime_type": fields.String, + "created_by": fields.String, + "created_at": TimestampField, +} diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py index b87cc653240a71..e0b3e340f67b8c 100644 --- a/api/fields/installed_app_fields.py +++ b/api/fields/installed_app_fields.py @@ -1,13 +1,16 @@ from flask_restful import fields -from libs.helper import TimestampField +from libs.helper import AppIconUrlField, TimestampField app_fields = { "id": fields.String, "name": fields.String, "mode": fields.String, + "icon_type": fields.String, "icon": fields.String, "icon_background": fields.String, + "icon_url": AppIconUrlField, + "use_icon_as_answer_icon": fields.Boolean, } installed_app_fields = { diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 3d2df87afb9b19..5f6e7884a69c5e 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -3,6 +3,8 @@ from fields.conversation_fields import message_file_fields from libs.helper import TimestampField +from .raws import FilesContainedField + feedback_fields = {"rating": fields.String} retriever_resource_fields = { @@ -62,14 +64,15 @@ message_fields = { "id": fields.String, "conversation_id": fields.String, - "inputs": fields.Raw, + "parent_message_id": fields.String, + "inputs": FilesContainedField, "query": fields.String, "answer": fields.String(attribute="re_sign_file_url_answer"), "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), "created_at": TimestampField, "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), - "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "message_files": fields.List(fields.Nested(message_file_fields)), "status": fields.String, "error": fields.String, } diff --git a/api/fields/raws.py b/api/fields/raws.py new file mode 100644 index 00000000000000..15ec16ab13e4a8 --- /dev/null +++ b/api/fields/raws.py @@ -0,0 +1,17 @@ +from flask_restful import fields + +from core.file import File + + +class FilesContainedField(fields.Raw): + def format(self, value): + return self._format_file_object(value) + + def _format_file_object(self, v): + if isinstance(v, File): + return v.model_dump() + if isinstance(v, dict): + return {k: self._format_file_object(vv) for k, vv in v.items()} + if isinstance(v, list): + return [self._format_file_object(vv) for vv in v] + return v diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 240b8f2eb03e79..0d860d6f406502 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,7 +1,7 @@ from flask_restful import fields -from core.app.segments import SecretVariable, SegmentType, Variable from core.helper import encrypter +from core.variables import SecretVariable, SegmentType, Variable from fields.member_fields import simple_account_fields from libs.helper import TimestampField @@ -53,3 +53,11 @@ def format(self, value): "environment_variables": fields.List(EnvironmentVariableField()), "conversation_variables": fields.List(fields.Nested(conversation_variable_fields)), } + +workflow_partial_fields = { + "id": fields.String, + "created_by": fields.String, + "created_at": TimestampField, + "updated_by": fields.String, + "updated_at": TimestampField, +} diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py index 2d306edb40c745..83f9c74e339e17 100644 --- a/api/libs/gmpy2_pkcs10aep_cipher.py +++ b/api/libs/gmpy2_pkcs10aep_cipher.py @@ -31,7 +31,7 @@ from Crypto.Util.strxor import strxor -class PKCS1OAEP_Cipher: +class PKCS1OAepCipher: """Cipher object for PKCS#1 v1.5 OAEP. Do not create directly: use :func:`new` instead.""" @@ -204,7 +204,8 @@ def decrypt(self, ciphertext): def new(key, hashAlgo=None, mgfunc=None, label=b"", randfunc=None): - """Return a cipher object :class:`PKCS1OAEP_Cipher` that can be used to perform PKCS#1 OAEP encryption or decryption. + """Return a cipher object :class:`PKCS1OAEP_Cipher` + that can be used to perform PKCS#1 OAEP encryption or decryption. :param key: The key object to use to encrypt or decrypt the message. @@ -237,4 +238,4 @@ def new(key, hashAlgo=None, mgfunc=None, label=b"", randfunc=None): if randfunc is None: randfunc = Random.get_random_bytes - return PKCS1OAEP_Cipher(key, hashAlgo, mgfunc, label, randfunc) + return PKCS1OAepCipher(key, hashAlgo, mgfunc, label, randfunc) diff --git a/api/libs/helper.py b/api/libs/helper.py index 6b584ddcc2abf5..763879650856de 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -12,10 +12,12 @@ from typing import Any, Optional, Union from zoneinfo import available_timezones -from flask import Response, current_app, stream_with_context +from flask import Response, stream_with_context from flask_restful import fields +from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator +from core.file import helpers as file_helpers from extensions.ext_redis import redis_client from models.account import Account @@ -24,6 +26,18 @@ def run(script): return subprocess.getstatusoutput("source /root/.bashrc && " + script) +class AppIconUrlField(fields.Raw): + def output(self, key, obj): + if obj is None: + return None + + from models.model import IconType + + if obj.icon_type == IconType.IMAGE.value: + return file_helpers.get_signed_file_url(obj.icon) + return None + + class TimestampField(fields.Raw): def format(self, value) -> int: return int(value.timestamp()) @@ -71,7 +85,7 @@ def timestamp_value(timestamp): raise ValueError(error) -class str_len: +class StrLen: """Restrict input to an integer in a range (inclusive)""" def __init__(self, max_length, argument="argument"): @@ -89,7 +103,7 @@ def __call__(self, value): return value -class float_range: +class FloatRange: """Restrict input to an float in a range (inclusive)""" def __init__(self, low, high, argument="argument"): @@ -108,7 +122,7 @@ def __call__(self, value): return value -class datetime_string: +class DatetimeString: def __init__(self, format, argument="argument"): self.format = format self.argument = argument @@ -149,7 +163,7 @@ def generate_string(n): return result -def get_remote_ip(request) -> str: +def extract_remote_ip(request) -> str: if request.headers.get("CF-Connecting-IP"): return request.headers.get("Cf-Connecting-Ip") elif request.headers.getlist("X-Forwarded-For"): @@ -176,23 +190,39 @@ def generate() -> Generator: class TokenManager: @classmethod - def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str: - old_token = cls._get_current_token_for_account(account.id, token_type) - if old_token: - if isinstance(old_token, bytes): - old_token = old_token.decode("utf-8") - cls.revoke_token(old_token, token_type) + def generate_token( + cls, + token_type: str, + account: Optional[Account] = None, + email: Optional[str] = None, + additional_data: Optional[dict] = None, + ) -> str: + if account is None and email is None: + raise ValueError("Account or email must be provided") + + account_id = account.id if account else None + account_email = account.email if account else email + + if account_id: + old_token = cls._get_current_token_for_account(account_id, token_type) + if old_token: + if isinstance(old_token, bytes): + old_token = old_token.decode("utf-8") + cls.revoke_token(old_token, token_type) token = str(uuid.uuid4()) - token_data = {"account_id": account.id, "email": account.email, "token_type": token_type} + token_data = {"account_id": account_id, "email": account_email, "token_type": token_type} if additional_data: token_data.update(additional_data) - expiry_hours = current_app.config[f"{token_type.upper()}_TOKEN_EXPIRY_HOURS"] + expiry_minutes = dify_config.model_dump().get(f"{token_type.upper()}_TOKEN_EXPIRY_MINUTES") token_key = cls._get_token_key(token, token_type) - redis_client.setex(token_key, expiry_hours * 60 * 60, json.dumps(token_data)) + expiry_time = int(expiry_minutes * 60) + redis_client.setex(token_key, expiry_time, json.dumps(token_data)) + + if account_id: + cls._set_current_token_for_account(account.id, token, token_type, expiry_minutes) - cls._set_current_token_for_account(account.id, token, token_type, expiry_hours) return token @classmethod @@ -221,9 +251,12 @@ def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Opt return current_token @classmethod - def _set_current_token_for_account(cls, account_id: str, token: str, token_type: str, expiry_hours: int): + def _set_current_token_for_account( + cls, account_id: str, token: str, token_type: str, expiry_hours: Union[int, float] + ): key = cls._get_account_token_key(account_id, token_type) - redis_client.setex(key, expiry_hours * 60 * 60, token) + expiry_time = int(expiry_hours * 60 * 60) + redis_client.setex(key, expiry_time, token) @classmethod def _get_account_token_key(cls, account_id: str, token_type: str) -> str: diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py index 41d69058992f4b..41c5d20c4b08b9 100644 --- a/api/libs/json_in_md_parser.py +++ b/api/libs/json_in_md_parser.py @@ -1,28 +1,31 @@ import json -from core.llm_generator.output_parser.errors import OutputParserException +from core.llm_generator.output_parser.errors import OutputParserError def parse_json_markdown(json_string: str) -> dict: - # Remove the triple backticks if present + # Get json from the backticks/braces json_string = json_string.strip() - start_index = json_string.find("```json") - end_index = json_string.find("```", start_index + len("```json")) - - if start_index != -1 and end_index != -1: - extracted_content = json_string[start_index + len("```json") : end_index].strip() - - # Parse the JSON string into a Python dictionary - parsed = json.loads(extracted_content) - elif start_index != -1 and end_index == -1 and json_string.endswith("``"): - end_index = json_string.find("``", start_index + len("```json")) - extracted_content = json_string[start_index + len("```json") : end_index].strip() - - # Parse the JSON string into a Python dictionary + starts = ["```json", "```", "``", "`", "{"] + ends = ["```", "``", "`", "}"] + end_index = -1 + start_index = 0 + for s in starts: + start_index = json_string.find(s) + if start_index != -1: + if json_string[start_index] != "{": + start_index += len(s) + break + if start_index != -1: + for e in ends: + end_index = json_string.rfind(e, start_index) + if end_index != -1: + if json_string[end_index] == "}": + end_index += 1 + break + if start_index != -1 and end_index != -1 and start_index < end_index: + extracted_content = json_string[start_index:end_index].strip() parsed = json.loads(extracted_content) - elif json_string.startswith("{"): - # Parse the JSON string into a Python dictionary - parsed = json.loads(json_string) else: raise Exception("Could not find JSON block in the output.") @@ -33,10 +36,10 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: try: json_obj = parse_json_markdown(text) except json.JSONDecodeError as e: - raise OutputParserException(f"Got invalid JSON object. Error: {e}") + raise OutputParserError(f"Got invalid JSON object. Error: {e}") for key in expected_keys: if key not in json_obj: - raise OutputParserException( - f"Got invalid return object. Expected key `{key}` " f"to be present, but got {json_obj}" + raise OutputParserError( + f"Got invalid return object. Expected key `{key}` to be present, but got {json_obj}" ) return json_obj diff --git a/api/libs/login.py b/api/libs/login.py index 7f05eb8404a0ba..0ea191a185785d 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -1,4 +1,3 @@ -import os from functools import wraps from flask import current_app, g, has_request_context, request @@ -7,6 +6,7 @@ from werkzeug.exceptions import Unauthorized from werkzeug.local import LocalProxy +from configs import dify_config from extensions.ext_database import db from models.account import Account, Tenant, TenantAccountJoin @@ -52,8 +52,7 @@ def post(): @wraps(func) def decorated_view(*args, **kwargs): auth_header = request.headers.get("Authorization") - admin_api_key_enable = os.getenv("ADMIN_API_KEY_ENABLE", default="False") - if admin_api_key_enable.lower() == "true": + if dify_config.ADMIN_API_KEY_ENABLE: if auth_header: if " " not in auth_header: raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") @@ -61,10 +60,10 @@ def decorated_view(*args, **kwargs): auth_scheme = auth_scheme.lower() if auth_scheme != "bearer": raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - admin_api_key = os.getenv("ADMIN_API_KEY") + admin_api_key = dify_config.ADMIN_API_KEY if admin_api_key: - if os.getenv("ADMIN_API_KEY") == auth_token: + if admin_api_key == auth_token: workspace_id = request.headers.get("X-WORKSPACE-ID") if workspace_id: tenant_account_join = ( @@ -82,7 +81,7 @@ def decorated_view(*args, **kwargs): account.current_tenant = tenant current_app.login_manager._update_request_context_with_user(account) user_logged_in.send(current_app._get_current_object(), user=_get_user()) - if request.method in EXEMPT_METHODS or current_app.config.get("LOGIN_DISABLED"): + if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: pass elif not current_user.is_authenticated: return current_app.login_manager.unauthorized() diff --git a/api/libs/oauth.py b/api/libs/oauth.py index d8ce1a1e6633e6..6b6919de24f90f 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -1,5 +1,6 @@ import urllib.parse from dataclasses import dataclass +from typing import Optional import requests @@ -40,12 +41,14 @@ class GitHubOAuth(OAuth): _USER_INFO_URL = "https://api.github.com/user" _EMAIL_INFO_URL = "https://api.github.com/user/emails" - def get_authorization_url(self): + def get_authorization_url(self, invite_token: Optional[str] = None): params = { "client_id": self.client_id, "redirect_uri": self.redirect_uri, "scope": "user:email", # Request only basic user information } + if invite_token: + params["state"] = invite_token return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" def get_access_token(self, code: str): @@ -90,13 +93,15 @@ class GoogleOAuth(OAuth): _TOKEN_URL = "https://oauth2.googleapis.com/token" _USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo" - def get_authorization_url(self): + def get_authorization_url(self, invite_token: Optional[str] = None): params = { "client_id": self.client_id, "response_type": "code", "redirect_uri": self.redirect_uri, "scope": "openid email", } + if invite_token: + params["state"] = invite_token return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" def get_access_token(self, code: str): diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 6da1a6d39bfd57..e747ea97ada4b2 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -1,3 +1,4 @@ +import datetime import urllib.parse import requests @@ -69,6 +70,7 @@ def get_access_token(self, code: str): if data_source_binding: data_source_binding.source_info = source_info data_source_binding.disabled = False + data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() else: new_data_source_binding = DataSourceOauthBinding( @@ -104,6 +106,7 @@ def save_internal_access_token(self, access_token: str): if data_source_binding: data_source_binding.source_info = source_info data_source_binding.disabled = False + data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() else: new_data_source_binding = DataSourceOauthBinding( @@ -138,6 +141,7 @@ def sync_data_source(self, binding_id: str): } data_source_binding.source_info = new_source_info data_source_binding.disabled = False + data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() else: raise ValueError("Data source binding not found") @@ -158,7 +162,7 @@ def get_authorized_pages(self, access_token: str): page_icon = page_result["icon"] if page_icon: icon_type = page_icon["type"] - if icon_type == "external" or icon_type == "file": + if icon_type in {"external", "file"}: url = page_icon[icon_type]["url"] icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} else: @@ -191,7 +195,7 @@ def get_authorized_pages(self, access_token: str): page_icon = database_result["icon"] if page_icon: icon_type = page_icon["type"] - if icon_type == "external" or icon_type == "file": + if icon_type in {"external", "file"}: url = page_icon[icon_type]["url"] icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} else: diff --git a/api/libs/password.py b/api/libs/password.py index cfcc0db22dded6..cdf55c57e5bc6e 100644 --- a/api/libs/password.py +++ b/api/libs/password.py @@ -13,7 +13,7 @@ def valid_password(password): if re.match(pattern, password) is not None: return password - raise ValueError("Not a valid password.") + raise ValueError("Password must contain letters and numbers, and the length must be greater than 8.") def hash_password(password_str, salt_byte): diff --git a/api/libs/rsa.py b/api/libs/rsa.py index a578bf3e5617f6..637bcc4a1dda61 100644 --- a/api/libs/rsa.py +++ b/api/libs/rsa.py @@ -4,9 +4,9 @@ from Crypto.PublicKey import RSA from Crypto.Random import get_random_bytes -import libs.gmpy2_pkcs10aep_cipher as gmpy2_pkcs10aep_cipher from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from libs import gmpy2_pkcs10aep_cipher def generate_key_pair(tenant_id): diff --git a/api/libs/smtp.py b/api/libs/smtp.py index bd7de7dd689a7a..d57d99f3b72217 100644 --- a/api/libs/smtp.py +++ b/api/libs/smtp.py @@ -39,13 +39,13 @@ def send(self, mail: dict): smtp.sendmail(self._from, mail["to"], msg.as_string()) except smtplib.SMTPException as e: - logging.error(f"SMTP error occurred: {str(e)}") + logging.exception(f"SMTP error occurred: {str(e)}") raise except TimeoutError as e: - logging.error(f"Timeout occurred while sending email: {str(e)}") + logging.exception(f"Timeout occurred while sending email: {str(e)}") raise except Exception as e: - logging.error(f"Unexpected error occurred while sending email: {str(e)}") + logging.exception(f"Unexpected error occurred while sending email: {str(e)}") raise finally: if smtp: diff --git a/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py b/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py index 0fba6a87ebab58..8cd4ec552b4ea9 100644 --- a/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py +++ b/api/migrations/versions/03f98355ba0e_add_workflow_tool_label_and_tool_.py @@ -24,6 +24,7 @@ def upgrade(): with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: batch_op.add_column(sa.Column('label', sa.String(length=255), server_default='', nullable=False)) + def downgrade(): with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: batch_op.drop_column('label') diff --git a/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py b/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py index be2c6155250262..153861a71a5994 100644 --- a/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py +++ b/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py @@ -8,7 +8,7 @@ import sqlalchemy as sa from alembic import op -import models as models +import models.types # revision identifiers, used by Alembic. revision = '04c602f5dc9b' @@ -20,24 +20,20 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('tracing_app_configs', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), - sa.Column('app_id', models.StringUUID(), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), sa.Column('tracing_provider', sa.String(length=255), nullable=True), sa.Column('tracing_config', sa.JSON(), nullable=True), sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False), sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False), sa.PrimaryKeyConstraint('id', name='tracing_app_config_pkey') ) - with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op: - batch_op.create_index('tracing_app_config_app_id_idx', ['app_id'], unique=False) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ## - with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op: - batch_op.drop_index('tracing_app_config_app_id_idx') - op.drop_table('tracing_app_configs') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_08_15_0956-0251a1c768cc_add_tidb_auth_binding.py b/api/migrations/versions/2024_08_15_0956-0251a1c768cc_add_tidb_auth_binding.py new file mode 100644 index 00000000000000..ca2e4104426275 --- /dev/null +++ b/api/migrations/versions/2024_08_15_0956-0251a1c768cc_add_tidb_auth_binding.py @@ -0,0 +1,51 @@ +"""add-tidb-auth-binding + +Revision ID: 0251a1c768cc +Revises: 63a83fcf12ba +Create Date: 2024-08-15 09:56:59.012490 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '0251a1c768cc' +down_revision = 'bbadea11becb' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tidb_auth_bindings', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=True), + sa.Column('cluster_id', sa.String(length=255), nullable=False), + sa.Column('cluster_name', sa.String(length=255), nullable=False), + sa.Column('active', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('status', sa.String(length=255), server_default=sa.text("'CREATING'::character varying"), nullable=False), + sa.Column('account', sa.String(length=255), nullable=False), + sa.Column('password', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tidb_auth_bindings_pkey') + ) + with op.batch_alter_table('tidb_auth_bindings', schema=None) as batch_op: + batch_op.create_index('tidb_auth_bindings_active_idx', ['active'], unique=False) + batch_op.create_index('tidb_auth_bindings_status_idx', ['status'], unique=False) + batch_op.create_index('tidb_auth_bindings_created_at_idx', ['created_at'], unique=False) + batch_op.create_index('tidb_auth_bindings_tenant_idx', ['tenant_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tidb_auth_bindings', schema=None) as batch_op: + batch_op.drop_index('tidb_auth_bindings_tenant_idx') + batch_op.drop_index('tidb_auth_bindings_created_at_idx') + batch_op.drop_index('tidb_auth_bindings_active_idx') + batch_op.drop_index('tidb_auth_bindings_status_idx') + op.drop_table('tidb_auth_bindings') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_08_15_1001-a6be81136580_app_and_site_icon_type.py b/api/migrations/versions/2024_08_15_1001-a6be81136580_app_and_site_icon_type.py new file mode 100644 index 00000000000000..d814666eefd2f2 --- /dev/null +++ b/api/migrations/versions/2024_08_15_1001-a6be81136580_app_and_site_icon_type.py @@ -0,0 +1,39 @@ +"""app and site icon type + +Revision ID: a6be81136580 +Revises: 8782057ff0dc +Create Date: 2024-08-15 10:01:24.697888 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = 'a6be81136580' +down_revision = '8782057ff0dc' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.add_column(sa.Column('icon_type', sa.String(length=255), nullable=True)) + + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.add_column(sa.Column('icon_type', sa.String(length=255), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.drop_column('icon_type') + + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.drop_column('icon_type') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_08_20_0455-2dbe42621d96_rename_workflow__conversation_variables_.py b/api/migrations/versions/2024_08_20_0455-2dbe42621d96_rename_workflow__conversation_variables_.py new file mode 100644 index 00000000000000..3dc7fed818ea2b --- /dev/null +++ b/api/migrations/versions/2024_08_20_0455-2dbe42621d96_rename_workflow__conversation_variables_.py @@ -0,0 +1,28 @@ +"""rename workflow__conversation_variables to workflow_conversation_variables + +Revision ID: 2dbe42621d96 +Revises: a6be81136580 +Create Date: 2024-08-20 04:55:38.160010 + +""" +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '2dbe42621d96' +down_revision = 'a6be81136580' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.rename_table('workflow__conversation_variables', 'workflow_conversation_variables') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.rename_table('workflow_conversation_variables', 'workflow__conversation_variables') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_08_25_0441-d0187d6a88dd_add_created_by_and_updated_by_to_app_.py b/api/migrations/versions/2024_08_25_0441-d0187d6a88dd_add_created_by_and_updated_by_to_app_.py new file mode 100644 index 00000000000000..e0066a302cd5a2 --- /dev/null +++ b/api/migrations/versions/2024_08_25_0441-d0187d6a88dd_add_created_by_and_updated_by_to_app_.py @@ -0,0 +1,52 @@ +"""add created_by and updated_by to app, modelconfig, and site + +Revision ID: d0187d6a88dd +Revises: 2dbe42621d96 +Create Date: 2024-08-25 04:41:18.157397 + +""" + +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = "d0187d6a88dd" +down_revision = "2dbe42621d96" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("app_model_configs", schema=None) as batch_op: + batch_op.add_column(sa.Column("created_by", models.types.StringUUID(), nullable=True)) + batch_op.add_column(sa.Column("updated_by", models.types.StringUUID(), nullable=True)) + + with op.batch_alter_table("apps", schema=None) as batch_op: + batch_op.add_column(sa.Column("created_by", models.types.StringUUID(), nullable=True)) + batch_op.add_column(sa.Column("updated_by", models.types.StringUUID(), nullable=True)) + + with op.batch_alter_table("sites", schema=None) as batch_op: + batch_op.add_column(sa.Column("created_by", models.types.StringUUID(), nullable=True)) + batch_op.add_column(sa.Column("updated_by", models.types.StringUUID(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("sites", schema=None) as batch_op: + batch_op.drop_column("updated_by") + batch_op.drop_column("created_by") + + with op.batch_alter_table("apps", schema=None) as batch_op: + batch_op.drop_column("updated_by") + batch_op.drop_column("created_by") + + with op.batch_alter_table("app_model_configs", schema=None) as batch_op: + batch_op.drop_column("updated_by") + batch_op.drop_column("created_by") + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_09_01_1255-030f4915f36a_add_use_icon_as_answer_icon_fields_for_.py b/api/migrations/versions/2024_09_01_1255-030f4915f36a_add_use_icon_as_answer_icon_fields_for_.py new file mode 100644 index 00000000000000..4406d51ed07aa2 --- /dev/null +++ b/api/migrations/versions/2024_09_01_1255-030f4915f36a_add_use_icon_as_answer_icon_fields_for_.py @@ -0,0 +1,45 @@ +"""add use_icon_as_answer_icon fields for app and site + +Revision ID: 030f4915f36a +Revises: d0187d6a88dd +Create Date: 2024-09-01 12:55:45.129687 + +""" + +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = "030f4915f36a" +down_revision = "d0187d6a88dd" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("apps", schema=None) as batch_op: + batch_op.add_column( + sa.Column("use_icon_as_answer_icon", sa.Boolean(), server_default=sa.text("false"), nullable=False) + ) + + with op.batch_alter_table("sites", schema=None) as batch_op: + batch_op.add_column( + sa.Column("use_icon_as_answer_icon", sa.Boolean(), server_default=sa.text("false"), nullable=False) + ) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table("sites", schema=None) as batch_op: + batch_op.drop_column("use_icon_as_answer_icon") + + with op.batch_alter_table("apps", schema=None) as batch_op: + batch_op.drop_column("use_icon_as_answer_icon") + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_09_11_1012-d57ba9ebb251_add_parent_message_id_to_messages.py b/api/migrations/versions/2024_09_11_1012-d57ba9ebb251_add_parent_message_id_to_messages.py new file mode 100644 index 00000000000000..fd957eeafb2b6c --- /dev/null +++ b/api/migrations/versions/2024_09_11_1012-d57ba9ebb251_add_parent_message_id_to_messages.py @@ -0,0 +1,36 @@ +"""add parent_message_id to messages + +Revision ID: d57ba9ebb251 +Revises: 675b5321501b +Create Date: 2024-09-11 10:12:45.826265 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = 'd57ba9ebb251' +down_revision = '675b5321501b' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.add_column(sa.Column('parent_message_id', models.types.StringUUID(), nullable=True)) + + # Set parent_message_id for existing messages to uuid_nil() to distinguish them from new messages with actual parent IDs or NULLs + op.execute('UPDATE messages SET parent_message_id = uuid_nil() WHERE parent_message_id IS NULL') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.drop_column('parent_message_id') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py b/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py new file mode 100644 index 00000000000000..5337b340db7690 --- /dev/null +++ b/api/migrations/versions/2024_09_24_0922-6af6a521a53e_update_retrieval_resource.py @@ -0,0 +1,48 @@ +"""update-retrieval-resource + +Revision ID: 6af6a521a53e +Revises: ec3df697ebbb +Create Date: 2024-09-24 09:22:43.570120 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '6af6a521a53e' +down_revision = 'd57ba9ebb251' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: + batch_op.alter_column('document_id', + existing_type=sa.UUID(), + nullable=True) + batch_op.alter_column('data_source_type', + existing_type=sa.TEXT(), + nullable=True) + batch_op.alter_column('segment_id', + existing_type=sa.UUID(), + nullable=True) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: + batch_op.alter_column('segment_id', + existing_type=sa.UUID(), + nullable=False) + batch_op.alter_column('data_source_type', + existing_type=sa.TEXT(), + nullable=False) + batch_op.alter_column('document_id', + existing_type=sa.UUID(), + nullable=False) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_09_25_0434-33f5fac87f29_external_knowledge_api.py b/api/migrations/versions/2024_09_25_0434-33f5fac87f29_external_knowledge_api.py new file mode 100644 index 00000000000000..3cb76e72c1ebb2 --- /dev/null +++ b/api/migrations/versions/2024_09_25_0434-33f5fac87f29_external_knowledge_api.py @@ -0,0 +1,73 @@ +"""external_knowledge_api + +Revision ID: 33f5fac87f29 +Revises: 6af6a521a53e +Create Date: 2024-09-25 04:34:57.249436 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '33f5fac87f29' +down_revision = '6af6a521a53e' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('external_knowledge_apis', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('description', sa.String(length=255), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('settings', sa.Text(), nullable=True), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_by', models.types.StringUUID(), nullable=True), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='external_knowledge_apis_pkey') + ) + with op.batch_alter_table('external_knowledge_apis', schema=None) as batch_op: + batch_op.create_index('external_knowledge_apis_name_idx', ['name'], unique=False) + batch_op.create_index('external_knowledge_apis_tenant_idx', ['tenant_id'], unique=False) + + op.create_table('external_knowledge_bindings', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('external_knowledge_api_id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('external_knowledge_id', sa.Text(), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_by', models.types.StringUUID(), nullable=True), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='external_knowledge_bindings_pkey') + ) + with op.batch_alter_table('external_knowledge_bindings', schema=None) as batch_op: + batch_op.create_index('external_knowledge_bindings_dataset_idx', ['dataset_id'], unique=False) + batch_op.create_index('external_knowledge_bindings_external_knowledge_api_idx', ['external_knowledge_api_id'], unique=False) + batch_op.create_index('external_knowledge_bindings_external_knowledge_idx', ['external_knowledge_id'], unique=False) + batch_op.create_index('external_knowledge_bindings_tenant_idx', ['tenant_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('external_knowledge_bindings', schema=None) as batch_op: + batch_op.drop_index('external_knowledge_bindings_tenant_idx') + batch_op.drop_index('external_knowledge_bindings_external_knowledge_idx') + batch_op.drop_index('external_knowledge_bindings_external_knowledge_api_idx') + batch_op.drop_index('external_knowledge_bindings_dataset_idx') + + op.drop_table('external_knowledge_bindings') + with op.batch_alter_table('external_knowledge_apis', schema=None) as batch_op: + batch_op.drop_index('external_knowledge_apis_tenant_idx') + batch_op.drop_index('external_knowledge_apis_name_idx') + + op.drop_table('external_knowledge_apis') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_10_09_1329-d8e744d88ed6_fix_wrong_service_api_history.py b/api/migrations/versions/2024_10_09_1329-d8e744d88ed6_fix_wrong_service_api_history.py new file mode 100644 index 00000000000000..38a5cdf8e5008f --- /dev/null +++ b/api/migrations/versions/2024_10_09_1329-d8e744d88ed6_fix_wrong_service_api_history.py @@ -0,0 +1,48 @@ +"""fix wrong service-api history + +Revision ID: d8e744d88ed6 +Revises: 33f5fac87f29 +Create Date: 2024-10-09 13:29:23.548498 + +""" +from alembic import op +from constants import UUID_NIL +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'd8e744d88ed6' +down_revision = '33f5fac87f29' +branch_labels = None +depends_on = None + +# (UTC) release date of v0.9.0 +v0_9_0_release_date= '2024-09-29 12:00:00' + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + sql = f"""UPDATE + messages +SET + parent_message_id = '{UUID_NIL}' +WHERE + invoke_from = 'service-api' + AND parent_message_id IS NULL + AND created_at >= '{v0_9_0_release_date}';""" + op.execute(sql) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + sql = f"""UPDATE + messages +SET + parent_message_id = NULL +WHERE + invoke_from = 'service-api' + AND parent_message_id = '{UUID_NIL}' + AND created_at >= '{v0_9_0_release_date}';""" + op.execute(sql) + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_10_10_0516-bbadea11becb_add_name_and_size_to_tool_files.py b/api/migrations/versions/2024_10_10_0516-bbadea11becb_add_name_and_size_to_tool_files.py new file mode 100644 index 00000000000000..c17d1db77a96df --- /dev/null +++ b/api/migrations/versions/2024_10_10_0516-bbadea11becb_add_name_and_size_to_tool_files.py @@ -0,0 +1,49 @@ +"""add name and size to tool_files + +Revision ID: bbadea11becb +Revises: 33f5fac87f29 +Create Date: 2024-10-10 05:16:14.764268 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'bbadea11becb' +down_revision = 'd8e744d88ed6' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + # Get the database connection + conn = op.get_bind() + + # Use SQLAlchemy inspector to get the columns of the 'tool_files' table + inspector = sa.inspect(conn) + columns = [col['name'] for col in inspector.get_columns('tool_files')] + + # If 'name' or 'size' columns already exist, exit the upgrade function + if 'name' in columns or 'size' in columns: + return + + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.add_column(sa.Column('name', sa.String(), nullable=True)) + batch_op.add_column(sa.Column('size', sa.Integer(), nullable=True)) + op.execute("UPDATE tool_files SET name = '' WHERE name IS NULL") + op.execute("UPDATE tool_files SET size = -1 WHERE size IS NULL") + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.alter_column('name', existing_type=sa.String(), nullable=False) + batch_op.alter_column('size', existing_type=sa.Integer(), nullable=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_files', schema=None) as batch_op: + batch_op.drop_column('size') + batch_op.drop_column('name') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_10_22_0959-43fa78bc3b7d_add_white_list.py b/api/migrations/versions/2024_10_22_0959-43fa78bc3b7d_add_white_list.py new file mode 100644 index 00000000000000..9daf148bc4e881 --- /dev/null +++ b/api/migrations/versions/2024_10_22_0959-43fa78bc3b7d_add_white_list.py @@ -0,0 +1,42 @@ +"""add_white_list + +Revision ID: 43fa78bc3b7d +Revises: 0251a1c768cc +Create Date: 2024-10-22 09:59:23.713716 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '43fa78bc3b7d' +down_revision = '0251a1c768cc' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('whitelists', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=True), + sa.Column('category', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='whitelists_pkey') + ) + with op.batch_alter_table('whitelists', schema=None) as batch_op: + batch_op.create_index('whitelists_tenant_idx', ['tenant_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('whitelists', schema=None) as batch_op: + batch_op.drop_index('whitelists_tenant_idx') + + op.drop_table('whitelists') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py b/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py new file mode 100644 index 00000000000000..a749c8bddfee01 --- /dev/null +++ b/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py @@ -0,0 +1,31 @@ +"""Add upload_files.source_url + +Revision ID: d3f6769a94a3 +Revises: 43fa78bc3b7d +Create Date: 2024-11-01 04:34:23.816198 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'd3f6769a94a3' +down_revision = '43fa78bc3b7d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.add_column(sa.Column('source_url', sa.String(length=255), server_default='', nullable=False)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.drop_column('source_url') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0449-93ad8c19c40b_rename_conversation_variables_index_name.py b/api/migrations/versions/2024_11_01_0449-93ad8c19c40b_rename_conversation_variables_index_name.py new file mode 100644 index 00000000000000..81a7978f730a37 --- /dev/null +++ b/api/migrations/versions/2024_11_01_0449-93ad8c19c40b_rename_conversation_variables_index_name.py @@ -0,0 +1,52 @@ +"""rename conversation variables index name + +Revision ID: 93ad8c19c40b +Revises: d3f6769a94a3 +Create Date: 2024-11-01 04:49:53.100250 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '93ad8c19c40b' +down_revision = 'd3f6769a94a3' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + conn = op.get_bind() + if conn.dialect.name == 'postgresql': + # Rename indexes for PostgreSQL + op.execute('ALTER INDEX workflow__conversation_variables_app_id_idx RENAME TO workflow_conversation_variables_app_id_idx') + op.execute('ALTER INDEX workflow__conversation_variables_created_at_idx RENAME TO workflow_conversation_variables_created_at_idx') + else: + # For other databases, use the original drop and create method + with op.batch_alter_table('workflow_conversation_variables', schema=None) as batch_op: + batch_op.drop_index('workflow__conversation_variables_app_id_idx') + batch_op.drop_index('workflow__conversation_variables_created_at_idx') + batch_op.create_index(batch_op.f('workflow_conversation_variables_app_id_idx'), ['app_id'], unique=False) + batch_op.create_index(batch_op.f('workflow_conversation_variables_created_at_idx'), ['created_at'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + conn = op.get_bind() + if conn.dialect.name == 'postgresql': + # Rename indexes back for PostgreSQL + op.execute('ALTER INDEX workflow_conversation_variables_app_id_idx RENAME TO workflow__conversation_variables_app_id_idx') + op.execute('ALTER INDEX workflow_conversation_variables_created_at_idx RENAME TO workflow__conversation_variables_created_at_idx') + else: + # For other databases, use the original drop and create method + with op.batch_alter_table('workflow_conversation_variables', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('workflow_conversation_variables_created_at_idx')) + batch_op.drop_index(batch_op.f('workflow_conversation_variables_app_id_idx')) + batch_op.create_index('workflow__conversation_variables_created_at_idx', ['created_at'], unique=False) + batch_op.create_index('workflow__conversation_variables_app_id_idx', ['app_id'], unique=False) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0540-f4d7ce70a7ca_update_upload_files_source_url.py b/api/migrations/versions/2024_11_01_0540-f4d7ce70a7ca_update_upload_files_source_url.py new file mode 100644 index 00000000000000..222379a49021a6 --- /dev/null +++ b/api/migrations/versions/2024_11_01_0540-f4d7ce70a7ca_update_upload_files_source_url.py @@ -0,0 +1,41 @@ +"""update upload_files.source_url + +Revision ID: f4d7ce70a7ca +Revises: 93ad8c19c40b +Create Date: 2024-11-01 05:40:03.531751 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'f4d7ce70a7ca' +down_revision = '93ad8c19c40b' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.alter_column('source_url', + existing_type=sa.VARCHAR(length=255), + type_=sa.TEXT(), + existing_nullable=False, + existing_server_default=sa.text("''::character varying")) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.alter_column('source_url', + existing_type=sa.TEXT(), + type_=sa.VARCHAR(length=255), + existing_nullable=False, + existing_server_default=sa.text("''::character varying")) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py b/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py new file mode 100644 index 00000000000000..9a4ccf352df098 --- /dev/null +++ b/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py @@ -0,0 +1,67 @@ +"""update type of custom_disclaimer to TEXT + +Revision ID: d07474999927 +Revises: f4d7ce70a7ca +Create Date: 2024-11-01 06:22:27.981398 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'd07474999927' +down_revision = 'f4d7ce70a7ca' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.execute("UPDATE recommended_apps SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL") + op.execute("UPDATE sites SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL") + op.execute("UPDATE tool_api_providers SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL") + + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.VARCHAR(length=255), + type_=sa.TEXT(), + nullable=False) + + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.VARCHAR(length=255), + type_=sa.TEXT(), + nullable=False) + + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.VARCHAR(length=255), + type_=sa.TEXT(), + nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.TEXT(), + type_=sa.VARCHAR(length=255), + nullable=True) + + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.TEXT(), + type_=sa.VARCHAR(length=255), + nullable=True) + + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.TEXT(), + type_=sa.VARCHAR(length=255), + nullable=True) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py b/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py new file mode 100644 index 00000000000000..117a7351cd67e7 --- /dev/null +++ b/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py @@ -0,0 +1,73 @@ +"""update workflows graph, features and updated_at + +Revision ID: 09a8d1878d9b +Revises: d07474999927 +Create Date: 2024-11-01 06:23:59.579186 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '09a8d1878d9b' +down_revision = 'd07474999927' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.alter_column('inputs', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=False) + + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.alter_column('inputs', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=False) + + op.execute("UPDATE workflows SET updated_at = created_at WHERE updated_at IS NULL") + op.execute("UPDATE workflows SET graph = '' WHERE graph IS NULL") + op.execute("UPDATE workflows SET features = '' WHERE features IS NULL") + + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.alter_column('graph', + existing_type=sa.TEXT(), + nullable=False) + batch_op.alter_column('features', + existing_type=sa.TEXT(), + nullable=False) + batch_op.alter_column('updated_at', + existing_type=postgresql.TIMESTAMP(), + nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.alter_column('updated_at', + existing_type=postgresql.TIMESTAMP(), + nullable=True) + batch_op.alter_column('features', + existing_type=sa.TEXT(), + nullable=True) + batch_op.alter_column('graph', + existing_type=sa.TEXT(), + nullable=True) + + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.alter_column('inputs', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=True) + + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.alter_column('inputs', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=True) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py b/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py index 09ef5e186cd089..99b7010612aa0f 100644 --- a/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py +++ b/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py @@ -22,17 +22,11 @@ def upgrade(): with op.batch_alter_table('apps', schema=None) as batch_op: batch_op.add_column(sa.Column('tracing', sa.Text(), nullable=True)) - with op.batch_alter_table('trace_app_config', schema=None) as batch_op: - batch_op.create_index('tracing_app_config_app_id_idx', ['app_id'], unique=False) - # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('trace_app_config', schema=None) as batch_op: - batch_op.drop_index('tracing_app_config_app_id_idx') - with op.batch_alter_table('apps', schema=None) as batch_op: batch_op.drop_column('tracing') diff --git a/api/migrations/versions/3b18fea55204_add_tool_label_bings.py b/api/migrations/versions/3b18fea55204_add_tool_label_bings.py index db3119badfde7e..bf54c247ead19c 100644 --- a/api/migrations/versions/3b18fea55204_add_tool_label_bings.py +++ b/api/migrations/versions/3b18fea55204_add_tool_label_bings.py @@ -8,7 +8,7 @@ import sqlalchemy as sa from alembic import op -import models as models +import models.types # revision identifiers, used by Alembic. revision = '3b18fea55204' @@ -20,7 +20,7 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('tool_label_bindings', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), sa.Column('tool_id', sa.String(length=64), nullable=False), sa.Column('tool_type', sa.String(length=40), nullable=False), sa.Column('label_name', sa.String(length=40), nullable=False), diff --git a/api/migrations/versions/4e99a8df00ff_add_load_balancing.py b/api/migrations/versions/4e99a8df00ff_add_load_balancing.py index 67d7b9fbf54875..3be4ba4f2a82e4 100644 --- a/api/migrations/versions/4e99a8df00ff_add_load_balancing.py +++ b/api/migrations/versions/4e99a8df00ff_add_load_balancing.py @@ -8,7 +8,7 @@ import sqlalchemy as sa from alembic import op -import models as models +import models.types # revision identifiers, used by Alembic. revision = '4e99a8df00ff' @@ -20,8 +20,8 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('load_balancing_model_configs', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), - sa.Column('tenant_id', models.StringUUID(), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), sa.Column('provider_name', sa.String(length=255), nullable=False), sa.Column('model_name', sa.String(length=255), nullable=False), sa.Column('model_type', sa.String(length=40), nullable=False), @@ -36,8 +36,8 @@ def upgrade(): batch_op.create_index('load_balancing_model_config_tenant_provider_model_idx', ['tenant_id', 'provider_name', 'model_type'], unique=False) op.create_table('provider_model_settings', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), - sa.Column('tenant_id', models.StringUUID(), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), sa.Column('provider_name', sa.String(length=255), nullable=False), sa.Column('model_name', sa.String(length=255), nullable=False), sa.Column('model_type', sa.String(length=40), nullable=False), diff --git a/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py b/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py new file mode 100644 index 00000000000000..55824945da49b0 --- /dev/null +++ b/api/migrations/versions/675b5321501b_add_node_execution_id_into_node_.py @@ -0,0 +1,35 @@ +"""add node_execution_id into node_executions + +Revision ID: 675b5321501b +Revises: 030f4915f36a +Create Date: 2024-08-12 10:54:02.259331 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '675b5321501b' +down_revision = '030f4915f36a' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.add_column(sa.Column('node_execution_id', sa.String(length=255), nullable=True)) + batch_op.create_index('workflow_node_execution_id_idx', ['tenant_id', 'app_id', 'workflow_id', 'triggered_from', 'node_execution_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.drop_index('workflow_node_execution_id_idx') + batch_op.drop_column('node_execution_id') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py b/api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py index f63bad93457d30..2ba0e13caa936a 100644 --- a/api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py +++ b/api/migrations/versions/7b45942e39bb_add_api_key_auth_binding.py @@ -8,7 +8,7 @@ import sqlalchemy as sa from alembic import op -import models as models +import models.types # revision identifiers, used by Alembic. revision = '7b45942e39bb' @@ -20,8 +20,8 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('data_source_api_key_auth_bindings', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), - sa.Column('tenant_id', models.StringUUID(), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), sa.Column('category', sa.String(length=255), nullable=False), sa.Column('provider', sa.String(length=255), nullable=False), sa.Column('credentials', sa.Text(), nullable=True), diff --git a/api/migrations/versions/7bdef072e63a_add_workflow_tool.py b/api/migrations/versions/7bdef072e63a_add_workflow_tool.py index 67b61e5c76d8ab..f09a682f285bd2 100644 --- a/api/migrations/versions/7bdef072e63a_add_workflow_tool.py +++ b/api/migrations/versions/7bdef072e63a_add_workflow_tool.py @@ -1,6 +1,6 @@ """add workflow tool -Revision ID: 7bdef072e63a +Revision ID: 7bdef072e63a Revises: 5fda94355fce Create Date: 2024-05-04 09:47:19.366961 @@ -8,7 +8,7 @@ import sqlalchemy as sa from alembic import op -import models as models +import models.types # revision identifiers, used by Alembic. revision = '7bdef072e63a' @@ -20,12 +20,12 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('tool_workflow_providers', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), sa.Column('name', sa.String(length=40), nullable=False), sa.Column('icon', sa.String(length=255), nullable=False), - sa.Column('app_id', models.StringUUID(), nullable=False), - sa.Column('user_id', models.StringUUID(), nullable=False), - sa.Column('tenant_id', models.StringUUID(), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), + sa.Column('user_id', models.types.StringUUID(), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), sa.Column('description', sa.Text(), nullable=False), sa.Column('parameter_configuration', sa.Text(), server_default='[]', nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), diff --git a/api/migrations/versions/7e6a8693e07a_add_table_dataset_permissions.py b/api/migrations/versions/7e6a8693e07a_add_table_dataset_permissions.py index ff53eb65a6f56c..865572f3a75c71 100644 --- a/api/migrations/versions/7e6a8693e07a_add_table_dataset_permissions.py +++ b/api/migrations/versions/7e6a8693e07a_add_table_dataset_permissions.py @@ -8,7 +8,7 @@ import sqlalchemy as sa from alembic import op -import models as models +import models.types # revision identifiers, used by Alembic. revision = '7e6a8693e07a' @@ -20,9 +20,9 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('dataset_permissions', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), - sa.Column('dataset_id', models.StringUUID(), nullable=False), - sa.Column('account_id', models.StringUUID(), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('account_id', models.types.StringUUID(), nullable=False), sa.Column('has_permission', sa.Boolean(), server_default=sa.text('true'), nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), sa.PrimaryKeyConstraint('id', name='dataset_permission_pkey') diff --git a/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py b/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py index bfda7d619c5cca..92f41f0abd0d91 100644 --- a/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py +++ b/api/migrations/versions/9e98fbaffb88_add_workflow_tool_version.py @@ -21,6 +21,7 @@ def upgrade(): with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: batch_op.add_column(sa.Column('version', sa.String(length=255), server_default='', nullable=False)) + def downgrade(): with op.batch_alter_table('tool_workflow_providers', schema=None) as batch_op: batch_op.drop_column('version') diff --git a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py index 1ac44d083aaf43..f87819c3672b85 100644 --- a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py +++ b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py @@ -9,7 +9,7 @@ from alembic import op from sqlalchemy.dialects import postgresql -import models as models +import models.types # revision identifiers, used by Alembic. revision = 'c031d46af369' @@ -21,8 +21,8 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('trace_app_config', - sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), - sa.Column('app_id', models.StringUUID(), nullable=False), + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', models.types.StringUUID(), nullable=False), sa.Column('tracing_provider', sa.String(length=255), nullable=True), sa.Column('tracing_config', sa.JSON(), nullable=True), sa.Column('created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False), @@ -30,30 +30,15 @@ def upgrade(): sa.Column('is_active', sa.Boolean(), server_default=sa.text('true'), nullable=False), sa.PrimaryKeyConstraint('id', name='trace_app_config_pkey') ) + with op.batch_alter_table('trace_app_config', schema=None) as batch_op: batch_op.create_index('trace_app_config_app_id_idx', ['app_id'], unique=False) - with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op: - batch_op.drop_index('tracing_app_config_app_id_idx') # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.create_table('tracing_app_configs', - sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False), - sa.Column('app_id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('tracing_provider', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('tracing_config', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), - sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('now()'), autoincrement=False, nullable=False), - sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('now()'), autoincrement=False, nullable=False), - sa.PrimaryKeyConstraint('id', name='trace_app_config_pkey') - ) - with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op: - batch_op.create_index('trace_app_config_app_id_idx', ['app_id'], unique=False) - - with op.batch_alter_table('trace_app_config', schema=None) as batch_op: - batch_op.drop_index('trace_app_config_app_id_idx') - op.drop_table('trace_app_config') + # ### end Alembic commands ### diff --git a/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py b/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py index 2365766837177b..fcca705d214597 100644 --- a/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py +++ b/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py @@ -99,7 +99,7 @@ def upgrade(): id=id, tenant_id=tenant_id, user_id=user_id, - provider='google', + provider='google', encrypted_credentials=encrypted_credentials, created_at=created_at, updated_at=updated_at diff --git a/api/migrations/versions/fca025d3b60f_add_dataset_retrival_model.py b/api/migrations/versions/fca025d3b60f_add_dataset_retrival_model.py index 1f8250c3eb42af..52495be60a62ea 100644 --- a/api/migrations/versions/fca025d3b60f_add_dataset_retrival_model.py +++ b/api/migrations/versions/fca025d3b60f_add_dataset_retrival_model.py @@ -1,4 +1,4 @@ -"""add-dataset-retrival-model +"""add-dataset-retrieval-model Revision ID: fca025d3b60f Revises: b3a09c049e8e diff --git a/api/migrations/versions/fecff1c3da27_remove_extra_tracing_app_config_table .py b/api/migrations/versions/fecff1c3da27_remove_extra_tracing_app_config_table .py index 271b2490de1055..6f76a361d9c0eb 100644 --- a/api/migrations/versions/fecff1c3da27_remove_extra_tracing_app_config_table .py +++ b/api/migrations/versions/fecff1c3da27_remove_extra_tracing_app_config_table .py @@ -20,12 +20,10 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.drop_table('tracing_app_configs') - with op.batch_alter_table('trace_app_config', schema=None) as batch_op: - batch_op.drop_index('tracing_app_config_app_id_idx') - # idx_dataset_permissions_tenant_id with op.batch_alter_table('dataset_permissions', schema=None) as batch_op: batch_op.create_index('idx_dataset_permissions_tenant_id', ['tenant_id']) + # ### end Alembic commands ### @@ -46,9 +44,7 @@ def downgrade(): sa.PrimaryKeyConstraint('id', name='tracing_app_config_pkey') ) - with op.batch_alter_table('trace_app_config', schema=None) as batch_op: - batch_op.create_index('tracing_app_config_app_id_idx', ['app_id']) - with op.batch_alter_table('dataset_permissions', schema=None) as batch_op: batch_op.drop_index('idx_dataset_permissions_tenant_id') + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index 4012611471c337..cd6c7674da0847 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -1,29 +1,53 @@ -from enum import Enum +from .account import Account, AccountIntegrate, InvitationCode, Tenant +from .dataset import Dataset, DatasetProcessRule, Document, DocumentSegment +from .model import ( + ApiToken, + App, + AppMode, + Conversation, + EndUser, + InstalledApp, + Message, + MessageAnnotation, + MessageFile, + RecommendedApp, + Site, + UploadFile, +) +from .source import DataSourceOauthBinding +from .tools import ToolFile +from .workflow import ( + ConversationVariable, + Workflow, + WorkflowAppLog, + WorkflowRun, +) -from .model import App, AppMode, Message -from .types import StringUUID -from .workflow import ConversationVariable, Workflow, WorkflowNodeExecutionStatus - -__all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus', 'Workflow', 'App', 'Message'] - - -class CreatedByRole(Enum): - """ - Enum class for createdByRole - """ - - ACCOUNT = 'account' - END_USER = 'end_user' - - @classmethod - def value_of(cls, value: str) -> 'CreatedByRole': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for role in cls: - if role.value == value: - return role - raise ValueError(f'invalid createdByRole value {value}') +__all__ = [ + "ConversationVariable", + "Document", + "Dataset", + "DatasetProcessRule", + "DocumentSegment", + "DataSourceOauthBinding", + "AppMode", + "Workflow", + "App", + "Message", + "EndUser", + "MessageFile", + "UploadFile", + "Account", + "WorkflowAppLog", + "WorkflowRun", + "Site", + "InstalledApp", + "RecommendedApp", + "ApiToken", + "AccountIntegrate", + "InvitationCode", + "Tenant", + "Conversation", + "MessageAnnotation", + "ToolFile", +] diff --git a/api/models/account.py b/api/models/account.py index 67d940b7b7190e..60b4f11aad2bdc 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -9,21 +9,18 @@ class AccountStatus(str, enum.Enum): - PENDING = 'pending' - UNINITIALIZED = 'uninitialized' - ACTIVE = 'active' - BANNED = 'banned' - CLOSED = 'closed' + PENDING = "pending" + UNINITIALIZED = "uninitialized" + ACTIVE = "active" + BANNED = "banned" + CLOSED = "closed" class Account(UserMixin, db.Model): - __tablename__ = 'accounts' - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='account_pkey'), - db.Index('account_email_idx', 'email') - ) + __tablename__ = "accounts" + __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) name = db.Column(db.String(255), nullable=False) email = db.Column(db.String(255), nullable=False) password = db.Column(db.String(255), nullable=True) @@ -34,11 +31,11 @@ class Account(UserMixin, db.Model): timezone = db.Column(db.String(255)) last_login_at = db.Column(db.DateTime) last_login_ip = db.Column(db.String(255)) - last_active_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + last_active_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) status = db.Column(db.String(16), nullable=False, server_default=db.text("'active'::character varying")) initialized_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def is_password_set(self): @@ -65,11 +62,13 @@ def current_tenant_id(self): @current_tenant_id.setter def current_tenant_id(self, value: str): try: - tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \ - .filter(Tenant.id == value) \ - .filter(TenantAccountJoin.tenant_id == Tenant.id) \ - .filter(TenantAccountJoin.account_id == self.id) \ + tenant_account_join = ( + db.session.query(Tenant, TenantAccountJoin) + .filter(Tenant.id == value) + .filter(TenantAccountJoin.tenant_id == Tenant.id) + .filter(TenantAccountJoin.account_id == self.id) .one_or_none() + ) if tenant_account_join: tenant, ta = tenant_account_join @@ -91,20 +90,18 @@ def get_status(self) -> AccountStatus: @classmethod def get_by_openid(cls, provider: str, open_id: str) -> db.Model: - account_integrate = db.session.query(AccountIntegrate). \ - filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id). \ - one_or_none() + account_integrate = ( + db.session.query(AccountIntegrate) + .filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) + .one_or_none() + ) if account_integrate: - return db.session.query(Account). \ - filter(Account.id == account_integrate.account_id). \ - one_or_none() + return db.session.query(Account).filter(Account.id == account_integrate.account_id).one_or_none() return None def get_integrates(self) -> list[db.Model]: ai = db.Model - return db.session.query(ai).filter( - ai.account_id == self.id - ).all() + return db.session.query(ai).filter(ai.account_id == self.id).all() # check current_user.current_tenant.current_role in ['admin', 'owner'] @property @@ -123,61 +120,75 @@ def is_dataset_editor(self): def is_dataset_operator(self): return self._current_tenant.current_role == TenantAccountRole.DATASET_OPERATOR + class TenantStatus(str, enum.Enum): - NORMAL = 'normal' - ARCHIVE = 'archive' + NORMAL = "normal" + ARCHIVE = "archive" class TenantAccountRole(str, enum.Enum): - OWNER = 'owner' - ADMIN = 'admin' - EDITOR = 'editor' - NORMAL = 'normal' - DATASET_OPERATOR = 'dataset_operator' + OWNER = "owner" + ADMIN = "admin" + EDITOR = "editor" + NORMAL = "normal" + DATASET_OPERATOR = "dataset_operator" @staticmethod def is_valid_role(role: str) -> bool: - return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, - TenantAccountRole.NORMAL, TenantAccountRole.DATASET_OPERATOR} + return role and role in { + TenantAccountRole.OWNER, + TenantAccountRole.ADMIN, + TenantAccountRole.EDITOR, + TenantAccountRole.NORMAL, + TenantAccountRole.DATASET_OPERATOR, + } @staticmethod def is_privileged_role(role: str) -> bool: return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} - + @staticmethod def is_non_owner_role(role: str) -> bool: - return role and role in {TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, TenantAccountRole.NORMAL, - TenantAccountRole.DATASET_OPERATOR} - + return role and role in { + TenantAccountRole.ADMIN, + TenantAccountRole.EDITOR, + TenantAccountRole.NORMAL, + TenantAccountRole.DATASET_OPERATOR, + } + @staticmethod def is_editing_role(role: str) -> bool: return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR} @staticmethod def is_dataset_edit_role(role: str) -> bool: - return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, - TenantAccountRole.DATASET_OPERATOR} + return role and role in { + TenantAccountRole.OWNER, + TenantAccountRole.ADMIN, + TenantAccountRole.EDITOR, + TenantAccountRole.DATASET_OPERATOR, + } + class Tenant(db.Model): - __tablename__ = 'tenants' - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tenant_pkey'), - ) + __tablename__ = "tenants" + __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) name = db.Column(db.String(255), nullable=False) encrypt_public_key = db.Column(db.Text) plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying")) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) custom_config = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) def get_accounts(self) -> list[Account]: - return db.session.query(Account).filter( - Account.id == TenantAccountJoin.account_id, - TenantAccountJoin.tenant_id == self.id - ).all() + return ( + db.session.query(Account) + .filter(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id) + .all() + ) @property def custom_config_dict(self) -> dict: @@ -189,54 +200,54 @@ def custom_config_dict(self, value: dict): class TenantAccountJoinRole(enum.Enum): - OWNER = 'owner' - ADMIN = 'admin' - NORMAL = 'normal' - DATASET_OPERATOR = 'dataset_operator' + OWNER = "owner" + ADMIN = "admin" + NORMAL = "normal" + DATASET_OPERATOR = "dataset_operator" class TenantAccountJoin(db.Model): - __tablename__ = 'tenant_account_joins' + __tablename__ = "tenant_account_joins" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tenant_account_join_pkey'), - db.Index('tenant_account_join_account_id_idx', 'account_id'), - db.Index('tenant_account_join_tenant_id_idx', 'tenant_id'), - db.UniqueConstraint('tenant_id', 'account_id', name='unique_tenant_account_join') + db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), + db.Index("tenant_account_join_account_id_idx", "account_id"), + db.Index("tenant_account_join_tenant_id_idx", "tenant_id"), + db.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) account_id = db.Column(StringUUID, nullable=False) - current = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - role = db.Column(db.String(16), nullable=False, server_default='normal') + current = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + role = db.Column(db.String(16), nullable=False, server_default="normal") invited_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class AccountIntegrate(db.Model): - __tablename__ = 'account_integrates' + __tablename__ = "account_integrates" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='account_integrate_pkey'), - db.UniqueConstraint('account_id', 'provider', name='unique_account_provider'), - db.UniqueConstraint('provider', 'open_id', name='unique_provider_open_id') + db.PrimaryKeyConstraint("id", name="account_integrate_pkey"), + db.UniqueConstraint("account_id", "provider", name="unique_account_provider"), + db.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) account_id = db.Column(StringUUID, nullable=False) provider = db.Column(db.String(16), nullable=False) open_id = db.Column(db.String(255), nullable=False) encrypted_token = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class InvitationCode(db.Model): - __tablename__ = 'invitation_codes' + __tablename__ = "invitation_codes" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='invitation_code_pkey'), - db.Index('invitation_codes_batch_idx', 'batch'), - db.Index('invitation_codes_code_idx', 'code', 'status') + db.PrimaryKeyConstraint("id", name="invitation_code_pkey"), + db.Index("invitation_codes_batch_idx", "batch"), + db.Index("invitation_codes_code_idx", "code", "status"), ) id = db.Column(db.Integer, nullable=False) @@ -247,4 +258,4 @@ class InvitationCode(db.Model): used_by_tenant_id = db.Column(StringUUID) used_by_account_id = db.Column(StringUUID) deprecated_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 7f69323628a7cc..97173747afc4b1 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -6,22 +6,22 @@ class APIBasedExtensionPoint(enum.Enum): - APP_EXTERNAL_DATA_TOOL_QUERY = 'app.external_data_tool.query' - PING = 'ping' - APP_MODERATION_INPUT = 'app.moderation.input' - APP_MODERATION_OUTPUT = 'app.moderation.output' + APP_EXTERNAL_DATA_TOOL_QUERY = "app.external_data_tool.query" + PING = "ping" + APP_MODERATION_INPUT = "app.moderation.input" + APP_MODERATION_OUTPUT = "app.moderation.output" class APIBasedExtension(db.Model): - __tablename__ = 'api_based_extensions' + __tablename__ = "api_based_extensions" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='api_based_extension_pkey'), - db.Index('api_based_extension_tenant_idx', 'tenant_id'), + db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"), + db.Index("api_based_extension_tenant_idx", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) api_endpoint = db.Column(db.String(255), nullable=False) api_key = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/dataset.py b/api/models/dataset.py index 0d48177eb60409..a1a626d7e46b2c 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -1,4 +1,5 @@ import base64 +import enum import hashlib import hmac import json @@ -13,7 +14,7 @@ from sqlalchemy.dialects.postgresql import JSONB from configs import dify_config -from core.rag.retrieval.retrival_methods import RetrievalMethod +from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from extensions.ext_storage import storage @@ -22,33 +23,36 @@ from .types import StringUUID +class DatasetPermissionEnum(str, enum.Enum): + ONLY_ME = "only_me" + ALL_TEAM = "all_team_members" + PARTIAL_TEAM = "partial_members" + + class Dataset(db.Model): - __tablename__ = 'datasets' + __tablename__ = "datasets" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_pkey'), - db.Index('dataset_tenant_idx', 'tenant_id'), - db.Index('retrieval_model_idx', "retrieval_model", postgresql_using='gin') + db.PrimaryKeyConstraint("id", name="dataset_pkey"), + db.Index("dataset_tenant_idx", "tenant_id"), + db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"), ) - INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None] + INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None] + PROVIDER_LIST = ["vendor", "external", None] - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=True) - provider = db.Column(db.String(255), nullable=False, - server_default=db.text("'vendor'::character varying")) - permission = db.Column(db.String(255), nullable=False, - server_default=db.text("'only_me'::character varying")) + provider = db.Column(db.String(255), nullable=False, server_default=db.text("'vendor'::character varying")) + permission = db.Column(db.String(255), nullable=False, server_default=db.text("'only_me'::character varying")) data_source_type = db.Column(db.String(255)) indexing_technique = db.Column(db.String(255), nullable=True) index_struct = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) embedding_model = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True) collection_binding_id = db.Column(StringUUID, nullable=True) @@ -56,8 +60,9 @@ class Dataset(db.Model): @property def dataset_keyword_table(self): - dataset_keyword_table = db.session.query(DatasetKeywordTable).filter( - DatasetKeywordTable.dataset_id == self.id).first() + dataset_keyword_table = ( + db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == self.id).first() + ) if dataset_keyword_table: return dataset_keyword_table @@ -67,19 +72,33 @@ def dataset_keyword_table(self): def index_struct_dict(self): return json.loads(self.index_struct) if self.index_struct else None + @property + def external_retrieval_model(self): + default_retrieval_model = { + "top_k": 2, + "score_threshold": 0.0, + } + return self.retrieval_model or default_retrieval_model + @property def created_by_account(self): return db.session.get(Account, self.created_by) @property def latest_process_rule(self): - return DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) \ - .order_by(DatasetProcessRule.created_at.desc()).first() + return ( + DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) + .order_by(DatasetProcessRule.created_at.desc()) + .first() + ) @property def app_count(self): - return db.session.query(func.count(AppDatasetJoin.id)).filter(AppDatasetJoin.dataset_id == self.id, - App.id == AppDatasetJoin.app_id).scalar() + return ( + db.session.query(func.count(AppDatasetJoin.id)) + .filter(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id) + .scalar() + ) @property def document_count(self): @@ -87,30 +106,40 @@ def document_count(self): @property def available_document_count(self): - return db.session.query(func.count(Document.id)).filter( - Document.dataset_id == self.id, - Document.indexing_status == 'completed', - Document.enabled == True, - Document.archived == False - ).scalar() + return ( + db.session.query(func.count(Document.id)) + .filter( + Document.dataset_id == self.id, + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + ) + .scalar() + ) @property def available_segment_count(self): - return db.session.query(func.count(DocumentSegment.id)).filter( - DocumentSegment.dataset_id == self.id, - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True - ).scalar() + return ( + db.session.query(func.count(DocumentSegment.id)) + .filter( + DocumentSegment.dataset_id == self.id, + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + ) + .scalar() + ) @property def word_count(self): - return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \ - .filter(Document.dataset_id == self.id).scalar() + return ( + Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) + .filter(Document.dataset_id == self.id) + .scalar() + ) @property def doc_form(self): - document = db.session.query(Document).filter( - Document.dataset_id == self.id).first() + document = db.session.query(Document).filter(Document.dataset_id == self.id).first() if document: return document.doc_form return None @@ -118,76 +147,91 @@ def doc_form(self): @property def retrieval_model_dict(self): default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } - return self.retrieval_model if self.retrieval_model else default_retrieval_model + return self.retrieval_model or default_retrieval_model @property def tags(self): - tags = db.session.query(Tag).join( - TagBinding, - Tag.id == TagBinding.tag_id - ).filter( - TagBinding.target_id == self.id, - TagBinding.tenant_id == self.tenant_id, - Tag.tenant_id == self.tenant_id, - Tag.type == 'knowledge' - ).all() - - return tags if tags else [] + tags = ( + db.session.query(Tag) + .join(TagBinding, Tag.id == TagBinding.tag_id) + .filter( + TagBinding.target_id == self.id, + TagBinding.tenant_id == self.tenant_id, + Tag.tenant_id == self.tenant_id, + Tag.type == "knowledge", + ) + .all() + ) + + return tags or [] + + @property + def external_knowledge_info(self): + if self.provider != "external": + return None + external_knowledge_binding = ( + db.session.query(ExternalKnowledgeBindings).filter(ExternalKnowledgeBindings.dataset_id == self.id).first() + ) + if not external_knowledge_binding: + return None + external_knowledge_api = ( + db.session.query(ExternalKnowledgeApis) + .filter(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id) + .first() + ) + if not external_knowledge_api: + return None + return { + "external_knowledge_id": external_knowledge_binding.external_knowledge_id, + "external_knowledge_api_id": external_knowledge_api.id, + "external_knowledge_api_name": external_knowledge_api.name, + "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""), + } @staticmethod def gen_collection_name_by_id(dataset_id: str) -> str: normalized_dataset_id = dataset_id.replace("-", "_") - return f'Vector_index_{normalized_dataset_id}_Node' + return f"Vector_index_{normalized_dataset_id}_Node" class DatasetProcessRule(db.Model): - __tablename__ = 'dataset_process_rules' + __tablename__ = "dataset_process_rules" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_process_rule_pkey'), - db.Index('dataset_process_rule_dataset_id_idx', 'dataset_id'), + db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), + db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, nullable=False, - server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) dataset_id = db.Column(StringUUID, nullable=False) - mode = db.Column(db.String(255), nullable=False, - server_default=db.text("'automatic'::character varying")) + mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) rules = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - MODES = ['automatic', 'custom'] - PRE_PROCESSING_RULES = ['remove_stopwords', 'remove_extra_spaces', 'remove_urls_emails'] + MODES = ["automatic", "custom"] + PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] AUTOMATIC_RULES = { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': False} + "pre_processing_rules": [ + {"id": "remove_extra_spaces", "enabled": True}, + {"id": "remove_urls_emails", "enabled": False}, ], - 'segmentation': { - 'delimiter': '\n', - 'max_tokens': 500, - 'chunk_overlap': 50 - } + "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, } def to_dict(self): return { - 'id': self.id, - 'dataset_id': self.dataset_id, - 'mode': self.mode, - 'rules': self.rules_dict, - 'created_by': self.created_by, - 'created_at': self.created_at, + "id": self.id, + "dataset_id": self.dataset_id, + "mode": self.mode, + "rules": self.rules_dict, + "created_by": self.created_by, + "created_at": self.created_at, } @property @@ -199,17 +243,16 @@ def rules_dict(self): class Document(db.Model): - __tablename__ = 'documents' + __tablename__ = "documents" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='document_pkey'), - db.Index('document_dataset_id_idx', 'dataset_id'), - db.Index('document_is_paused_idx', 'is_paused'), - db.Index('document_tenant_idx', 'tenant_id'), + db.PrimaryKeyConstraint("id", name="document_pkey"), + db.Index("document_dataset_id_idx", "dataset_id"), + db.Index("document_is_paused_idx", "is_paused"), + db.Index("document_tenant_idx", "tenant_id"), ) # initial fields - id = db.Column(StringUUID, nullable=False, - server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) position = db.Column(db.Integer, nullable=False) @@ -221,8 +264,7 @@ class Document(db.Model): created_from = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) created_api_request_id = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) # start processing processing_started_at = db.Column(db.DateTime, nullable=True) @@ -244,7 +286,7 @@ class Document(db.Model): completed_at = db.Column(db.DateTime, nullable=True) # pause - is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) + is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) paused_by = db.Column(StringUUID, nullable=True) paused_at = db.Column(db.DateTime, nullable=True) @@ -253,44 +295,39 @@ class Document(db.Model): stopped_at = db.Column(db.DateTime, nullable=True) # basic fields - indexing_status = db.Column(db.String( - 255), nullable=False, server_default=db.text("'waiting'::character varying")) - enabled = db.Column(db.Boolean, nullable=False, - server_default=db.text('true')) + indexing_status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) disabled_at = db.Column(db.DateTime, nullable=True) disabled_by = db.Column(StringUUID, nullable=True) - archived = db.Column(db.Boolean, nullable=False, - server_default=db.text('false')) + archived = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) archived_reason = db.Column(db.String(255), nullable=True) archived_by = db.Column(StringUUID, nullable=True) archived_at = db.Column(db.DateTime, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) doc_type = db.Column(db.String(40), nullable=True) doc_metadata = db.Column(db.JSON, nullable=True) - doc_form = db.Column(db.String( - 255), nullable=False, server_default=db.text("'text_model'::character varying")) + doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) doc_language = db.Column(db.String(255), nullable=True) - DATA_SOURCES = ['upload_file', 'notion_import', 'website_crawl'] + DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"] @property def display_status(self): status = None - if self.indexing_status == 'waiting': - status = 'queuing' - elif self.indexing_status not in ['completed', 'error', 'waiting'] and self.is_paused: - status = 'paused' - elif self.indexing_status in ['parsing', 'cleaning', 'splitting', 'indexing']: - status = 'indexing' - elif self.indexing_status == 'error': - status = 'error' - elif self.indexing_status == 'completed' and not self.archived and self.enabled: - status = 'available' - elif self.indexing_status == 'completed' and not self.archived and not self.enabled: - status = 'disabled' - elif self.indexing_status == 'completed' and self.archived: - status = 'archived' + if self.indexing_status == "waiting": + status = "queuing" + elif self.indexing_status not in {"completed", "error", "waiting"} and self.is_paused: + status = "paused" + elif self.indexing_status in {"parsing", "cleaning", "splitting", "indexing"}: + status = "indexing" + elif self.indexing_status == "error": + status = "error" + elif self.indexing_status == "completed" and not self.archived and self.enabled: + status = "available" + elif self.indexing_status == "completed" and not self.archived and not self.enabled: + status = "disabled" + elif self.indexing_status == "completed" and self.archived: + status = "archived" return status @property @@ -307,24 +344,26 @@ def data_source_info_dict(self): @property def data_source_detail_dict(self): if self.data_source_info: - if self.data_source_type == 'upload_file': + if self.data_source_type == "upload_file": data_source_info_dict = json.loads(self.data_source_info) - file_detail = db.session.query(UploadFile). \ - filter(UploadFile.id == data_source_info_dict['upload_file_id']). \ - one_or_none() + file_detail = ( + db.session.query(UploadFile) + .filter(UploadFile.id == data_source_info_dict["upload_file_id"]) + .one_or_none() + ) if file_detail: return { - 'upload_file': { - 'id': file_detail.id, - 'name': file_detail.name, - 'size': file_detail.size, - 'extension': file_detail.extension, - 'mime_type': file_detail.mime_type, - 'created_by': file_detail.created_by, - 'created_at': file_detail.created_at.timestamp() + "upload_file": { + "id": file_detail.id, + "name": file_detail.name, + "size": file_detail.size, + "extension": file_detail.extension, + "mime_type": file_detail.mime_type, + "created_by": file_detail.created_by, + "created_at": file_detail.created_at.timestamp(), } } - elif self.data_source_type == 'notion_import' or self.data_source_type == 'website_crawl': + elif self.data_source_type in {"notion_import", "website_crawl"}: return json.loads(self.data_source_info) return {} @@ -350,120 +389,123 @@ def segment_count(self): @property def hit_count(self): - return DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) \ - .filter(DocumentSegment.document_id == self.id).scalar() + return ( + DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) + .filter(DocumentSegment.document_id == self.id) + .scalar() + ) def to_dict(self): return { - 'id': self.id, - 'tenant_id': self.tenant_id, - 'dataset_id': self.dataset_id, - 'position': self.position, - 'data_source_type': self.data_source_type, - 'data_source_info': self.data_source_info, - 'dataset_process_rule_id': self.dataset_process_rule_id, - 'batch': self.batch, - 'name': self.name, - 'created_from': self.created_from, - 'created_by': self.created_by, - 'created_api_request_id': self.created_api_request_id, - 'created_at': self.created_at, - 'processing_started_at': self.processing_started_at, - 'file_id': self.file_id, - 'word_count': self.word_count, - 'parsing_completed_at': self.parsing_completed_at, - 'cleaning_completed_at': self.cleaning_completed_at, - 'splitting_completed_at': self.splitting_completed_at, - 'tokens': self.tokens, - 'indexing_latency': self.indexing_latency, - 'completed_at': self.completed_at, - 'is_paused': self.is_paused, - 'paused_by': self.paused_by, - 'paused_at': self.paused_at, - 'error': self.error, - 'stopped_at': self.stopped_at, - 'indexing_status': self.indexing_status, - 'enabled': self.enabled, - 'disabled_at': self.disabled_at, - 'disabled_by': self.disabled_by, - 'archived': self.archived, - 'archived_reason': self.archived_reason, - 'archived_by': self.archived_by, - 'archived_at': self.archived_at, - 'updated_at': self.updated_at, - 'doc_type': self.doc_type, - 'doc_metadata': self.doc_metadata, - 'doc_form': self.doc_form, - 'doc_language': self.doc_language, - 'display_status': self.display_status, - 'data_source_info_dict': self.data_source_info_dict, - 'average_segment_length': self.average_segment_length, - 'dataset_process_rule': self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, - 'dataset': self.dataset.to_dict() if self.dataset else None, - 'segment_count': self.segment_count, - 'hit_count': self.hit_count + "id": self.id, + "tenant_id": self.tenant_id, + "dataset_id": self.dataset_id, + "position": self.position, + "data_source_type": self.data_source_type, + "data_source_info": self.data_source_info, + "dataset_process_rule_id": self.dataset_process_rule_id, + "batch": self.batch, + "name": self.name, + "created_from": self.created_from, + "created_by": self.created_by, + "created_api_request_id": self.created_api_request_id, + "created_at": self.created_at, + "processing_started_at": self.processing_started_at, + "file_id": self.file_id, + "word_count": self.word_count, + "parsing_completed_at": self.parsing_completed_at, + "cleaning_completed_at": self.cleaning_completed_at, + "splitting_completed_at": self.splitting_completed_at, + "tokens": self.tokens, + "indexing_latency": self.indexing_latency, + "completed_at": self.completed_at, + "is_paused": self.is_paused, + "paused_by": self.paused_by, + "paused_at": self.paused_at, + "error": self.error, + "stopped_at": self.stopped_at, + "indexing_status": self.indexing_status, + "enabled": self.enabled, + "disabled_at": self.disabled_at, + "disabled_by": self.disabled_by, + "archived": self.archived, + "archived_reason": self.archived_reason, + "archived_by": self.archived_by, + "archived_at": self.archived_at, + "updated_at": self.updated_at, + "doc_type": self.doc_type, + "doc_metadata": self.doc_metadata, + "doc_form": self.doc_form, + "doc_language": self.doc_language, + "display_status": self.display_status, + "data_source_info_dict": self.data_source_info_dict, + "average_segment_length": self.average_segment_length, + "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, + "dataset": self.dataset.to_dict() if self.dataset else None, + "segment_count": self.segment_count, + "hit_count": self.hit_count, } @classmethod def from_dict(cls, data: dict): return cls( - id=data.get('id'), - tenant_id=data.get('tenant_id'), - dataset_id=data.get('dataset_id'), - position=data.get('position'), - data_source_type=data.get('data_source_type'), - data_source_info=data.get('data_source_info'), - dataset_process_rule_id=data.get('dataset_process_rule_id'), - batch=data.get('batch'), - name=data.get('name'), - created_from=data.get('created_from'), - created_by=data.get('created_by'), - created_api_request_id=data.get('created_api_request_id'), - created_at=data.get('created_at'), - processing_started_at=data.get('processing_started_at'), - file_id=data.get('file_id'), - word_count=data.get('word_count'), - parsing_completed_at=data.get('parsing_completed_at'), - cleaning_completed_at=data.get('cleaning_completed_at'), - splitting_completed_at=data.get('splitting_completed_at'), - tokens=data.get('tokens'), - indexing_latency=data.get('indexing_latency'), - completed_at=data.get('completed_at'), - is_paused=data.get('is_paused'), - paused_by=data.get('paused_by'), - paused_at=data.get('paused_at'), - error=data.get('error'), - stopped_at=data.get('stopped_at'), - indexing_status=data.get('indexing_status'), - enabled=data.get('enabled'), - disabled_at=data.get('disabled_at'), - disabled_by=data.get('disabled_by'), - archived=data.get('archived'), - archived_reason=data.get('archived_reason'), - archived_by=data.get('archived_by'), - archived_at=data.get('archived_at'), - updated_at=data.get('updated_at'), - doc_type=data.get('doc_type'), - doc_metadata=data.get('doc_metadata'), - doc_form=data.get('doc_form'), - doc_language=data.get('doc_language') + id=data.get("id"), + tenant_id=data.get("tenant_id"), + dataset_id=data.get("dataset_id"), + position=data.get("position"), + data_source_type=data.get("data_source_type"), + data_source_info=data.get("data_source_info"), + dataset_process_rule_id=data.get("dataset_process_rule_id"), + batch=data.get("batch"), + name=data.get("name"), + created_from=data.get("created_from"), + created_by=data.get("created_by"), + created_api_request_id=data.get("created_api_request_id"), + created_at=data.get("created_at"), + processing_started_at=data.get("processing_started_at"), + file_id=data.get("file_id"), + word_count=data.get("word_count"), + parsing_completed_at=data.get("parsing_completed_at"), + cleaning_completed_at=data.get("cleaning_completed_at"), + splitting_completed_at=data.get("splitting_completed_at"), + tokens=data.get("tokens"), + indexing_latency=data.get("indexing_latency"), + completed_at=data.get("completed_at"), + is_paused=data.get("is_paused"), + paused_by=data.get("paused_by"), + paused_at=data.get("paused_at"), + error=data.get("error"), + stopped_at=data.get("stopped_at"), + indexing_status=data.get("indexing_status"), + enabled=data.get("enabled"), + disabled_at=data.get("disabled_at"), + disabled_by=data.get("disabled_by"), + archived=data.get("archived"), + archived_reason=data.get("archived_reason"), + archived_by=data.get("archived_by"), + archived_at=data.get("archived_at"), + updated_at=data.get("updated_at"), + doc_type=data.get("doc_type"), + doc_metadata=data.get("doc_metadata"), + doc_form=data.get("doc_form"), + doc_language=data.get("doc_language"), ) + class DocumentSegment(db.Model): - __tablename__ = 'document_segments' + __tablename__ = "document_segments" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='document_segment_pkey'), - db.Index('document_segment_dataset_id_idx', 'dataset_id'), - db.Index('document_segment_document_id_idx', 'document_id'), - db.Index('document_segment_tenant_dataset_idx', 'dataset_id', 'tenant_id'), - db.Index('document_segment_tenant_document_idx', 'document_id', 'tenant_id'), - db.Index('document_segment_dataset_node_idx', 'dataset_id', 'index_node_id'), - db.Index('document_segment_tenant_idx', 'tenant_id'), + db.PrimaryKeyConstraint("id", name="document_segment_pkey"), + db.Index("document_segment_dataset_id_idx", "dataset_id"), + db.Index("document_segment_document_id_idx", "document_id"), + db.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"), + db.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"), + db.Index("document_segment_dataset_node_idx", "dataset_id", "index_node_id"), + db.Index("document_segment_tenant_idx", "tenant_id"), ) # initial fields - id = db.Column(StringUUID, nullable=False, - server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) document_id = db.Column(StringUUID, nullable=False) @@ -480,18 +522,14 @@ class DocumentSegment(db.Model): # basic fields hit_count = db.Column(db.Integer, nullable=False, default=0) - enabled = db.Column(db.Boolean, nullable=False, - server_default=db.text('true')) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) disabled_at = db.Column(db.DateTime, nullable=True) disabled_by = db.Column(StringUUID, nullable=True) - status = db.Column(db.String(255), nullable=False, - server_default=db.text("'waiting'::character varying")) + status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) indexing_at = db.Column(db.DateTime, nullable=True) completed_at = db.Column(db.DateTime, nullable=True) error = db.Column(db.Text, nullable=True) @@ -507,29 +545,49 @@ def document(self): @property def previous_segment(self): - return db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == self.document_id, - DocumentSegment.position == self.position - 1 - ).first() + return ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1) + .first() + ) @property def next_segment(self): - return db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == self.document_id, - DocumentSegment.position == self.position + 1 - ).first() + return ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1) + .first() + ) def get_sign_content(self): - pattern = r"/files/([a-f0-9\-]+)/image-preview" + signed_urls = [] text = self.content + + # For data before v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/image-preview" matches = re.finditer(pattern, text) - signed_urls = [] for match in matches: upload_file_id = match.group(1) nonce = os.urandom(16).hex() timestamp = str(int(time.time())) data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b'' + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + signed_url = f"{match.group(0)}?{params}" + signed_urls.append((match.start(), match.end(), signed_url)) + + # For data after v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/file-preview" + matches = re.finditer(pattern, text) + for match in matches: + upload_file_id = match.group(1) + nonce = os.urandom(16).hex() + timestamp = str(int(time.time())) + data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() @@ -540,21 +598,20 @@ def get_sign_content(self): # Reconstruct the text with signed URLs offset = 0 for start, end, signed_url in signed_urls: - text = text[:start + offset] + signed_url + text[end + offset:] + text = text[: start + offset] + signed_url + text[end + offset :] offset += len(signed_url) - (end - start) return text - class AppDatasetJoin(db.Model): - __tablename__ = 'app_dataset_joins' + __tablename__ = "app_dataset_joins" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='app_dataset_join_pkey'), - db.Index('app_dataset_join_app_dataset_idx', 'dataset_id', 'app_id'), + db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), + db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"), ) - id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) @@ -565,13 +622,13 @@ def app(self): class DatasetQuery(db.Model): - __tablename__ = 'dataset_queries' + __tablename__ = "dataset_queries" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_query_pkey'), - db.Index('dataset_query_dataset_id_idx', 'dataset_id'), + db.PrimaryKeyConstraint("id", name="dataset_query_pkey"), + db.Index("dataset_query_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) dataset_id = db.Column(StringUUID, nullable=False) content = db.Column(db.Text, nullable=False) source = db.Column(db.String(255), nullable=False) @@ -582,17 +639,18 @@ class DatasetQuery(db.Model): class DatasetKeywordTable(db.Model): - __tablename__ = 'dataset_keyword_tables' + __tablename__ = "dataset_keyword_tables" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_keyword_table_pkey'), - db.Index('dataset_keyword_table_dataset_id_idx', 'dataset_id'), + db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), + db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) dataset_id = db.Column(StringUUID, nullable=False, unique=True) keyword_table = db.Column(db.Text, nullable=False) - data_source_type = db.Column(db.String(255), nullable=False, - server_default=db.text("'database'::character varying")) + data_source_type = db.Column( + db.String(255), nullable=False, server_default=db.text("'database'::character varying") + ) @property def keyword_table_dict(self): @@ -608,19 +666,17 @@ def object_hook(self, dct): return dct # get dataset - dataset = Dataset.query.filter_by( - id=self.dataset_id - ).first() + dataset = Dataset.query.filter_by(id=self.dataset_id).first() if not dataset: return None - if self.data_source_type == 'database': + if self.data_source_type == "database": return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None else: - file_key = 'keyword_files/' + dataset.tenant_id + '/' + self.dataset_id + '.txt' + file_key = "keyword_files/" + dataset.tenant_id + "/" + self.dataset_id + ".txt" try: keyword_table_text = storage.load_once(file_key) if keyword_table_text: - return json.loads(keyword_table_text.decode('utf-8'), cls=SetDecoder) + return json.loads(keyword_table_text.decode("utf-8"), cls=SetDecoder) return None except Exception as e: logging.exception(str(e)) @@ -628,21 +684,21 @@ def object_hook(self, dct): class Embedding(db.Model): - __tablename__ = 'embeddings' + __tablename__ = "embeddings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='embedding_pkey'), - db.UniqueConstraint('model_name', 'hash', 'provider_name', name='embedding_hash_idx'), - db.Index('created_at_idx', 'created_at') + db.PrimaryKeyConstraint("id", name="embedding_pkey"), + db.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"), + db.Index("created_at_idx", "created_at"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) - model_name = db.Column(db.String(255), nullable=False, - server_default=db.text("'text-embedding-ada-002'::character varying")) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + model_name = db.Column( + db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying") + ) hash = db.Column(db.String(64), nullable=False) embedding = db.Column(db.LargeBinary, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - provider_name = db.Column(db.String(255), nullable=False, - server_default=db.text("''::character varying")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying")) def set_embedding(self, embedding_data: list[float]): self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) @@ -652,33 +708,138 @@ def get_embedding(self) -> list[float]: class DatasetCollectionBinding(db.Model): - __tablename__ = 'dataset_collection_bindings' + __tablename__ = "dataset_collection_bindings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey'), - db.Index('provider_model_name_idx', 'provider_name', 'model_name') - + db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), + db.Index("provider_model_name_idx", "provider_name", "model_name"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) provider_name = db.Column(db.String(40), nullable=False) model_name = db.Column(db.String(255), nullable=False) type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) collection_name = db.Column(db.String(64), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + + +class TidbAuthBinding(db.Model): + __tablename__ = "tidb_auth_bindings" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), + db.Index("tidb_auth_bindings_tenant_idx", "tenant_id"), + db.Index("tidb_auth_bindings_active_idx", "active"), + db.Index("tidb_auth_bindings_created_at_idx", "created_at"), + db.Index("tidb_auth_bindings_status_idx", "status"), + ) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=True) + cluster_id = db.Column(db.String(255), nullable=False) + cluster_name = db.Column(db.String(255), nullable=False) + active = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING")) + account = db.Column(db.String(255), nullable=False) + password = db.Column(db.String(255), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + + +class Whitelist(db.Model): + __tablename__ = "whitelists" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="whitelists_pkey"), + db.Index("whitelists_tenant_idx", "tenant_id"), + ) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=True) + category = db.Column(db.String(255), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class DatasetPermission(db.Model): - __tablename__ = 'dataset_permissions' + __tablename__ = "dataset_permissions" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_permission_pkey'), - db.Index('idx_dataset_permissions_dataset_id', 'dataset_id'), - db.Index('idx_dataset_permissions_account_id', 'account_id'), - db.Index('idx_dataset_permissions_tenant_id', 'tenant_id') + db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"), + db.Index("idx_dataset_permissions_dataset_id", "dataset_id"), + db.Index("idx_dataset_permissions_account_id", "account_id"), + db.Index("idx_dataset_permissions_tenant_id", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'), primary_key=True) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True) dataset_id = db.Column(StringUUID, nullable=False) account_id = db.Column(StringUUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False) - has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + + +class ExternalKnowledgeApis(db.Model): + __tablename__ = "external_knowledge_apis" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"), + db.Index("external_knowledge_apis_tenant_idx", "tenant_id"), + db.Index("external_knowledge_apis_name_idx", "name"), + ) + + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + name = db.Column(db.String(255), nullable=False) + description = db.Column(db.String(255), nullable=False) + tenant_id = db.Column(StringUUID, nullable=False) + settings = db.Column(db.Text, nullable=True) + created_by = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_by = db.Column(StringUUID, nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + + def to_dict(self): + return { + "id": self.id, + "tenant_id": self.tenant_id, + "name": self.name, + "description": self.description, + "settings": self.settings_dict, + "dataset_bindings": self.dataset_bindings, + "created_by": self.created_by, + "created_at": self.created_at.isoformat(), + } + + @property + def settings_dict(self): + try: + return json.loads(self.settings) if self.settings else None + except JSONDecodeError: + return None + + @property + def dataset_bindings(self): + external_knowledge_bindings = ( + db.session.query(ExternalKnowledgeBindings) + .filter(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) + .all() + ) + dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] + datasets = db.session.query(Dataset).filter(Dataset.id.in_(dataset_ids)).all() + dataset_bindings = [] + for dataset in datasets: + dataset_bindings.append({"id": dataset.id, "name": dataset.name}) + + return dataset_bindings + + +class ExternalKnowledgeBindings(db.Model): + __tablename__ = "external_knowledge_bindings" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"), + db.Index("external_knowledge_bindings_tenant_idx", "tenant_id"), + db.Index("external_knowledge_bindings_dataset_idx", "dataset_id"), + db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"), + db.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"), + ) + + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + external_knowledge_api_id = db.Column(StringUUID, nullable=False) + dataset_id = db.Column(StringUUID, nullable=False) + external_knowledge_id = db.Column(db.Text, nullable=False) + created_by = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_by = db.Column(StringUUID, nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/enums.py b/api/models/enums.py new file mode 100644 index 00000000000000..a83d35e04245b7 --- /dev/null +++ b/api/models/enums.py @@ -0,0 +1,16 @@ +from enum import Enum + + +class CreatedByRole(str, Enum): + ACCOUNT = "account" + END_USER = "end_user" + + +class UserFrom(str, Enum): + ACCOUNT = "account" + END_USER = "end-user" + + +class WorkflowRunTriggeredFrom(str, Enum): + DEBUGGING = "debugging" + APP_RUN = "app-run" diff --git a/api/models/model.py b/api/models/model.py index 5426d3bc83e020..e909d53e3e29d3 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1,44 +1,47 @@ import json import re import uuid +from collections.abc import Mapping +from datetime import datetime from enum import Enum -from typing import Optional +from typing import Any, Literal, Optional +import sqlalchemy as sa from flask import request from flask_login import UserMixin from sqlalchemy import Float, func, text from sqlalchemy.orm import Mapped, mapped_column from configs import dify_config +from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType +from core.file import helpers as file_helpers from core.file.tool_file_parser import ToolFileParser -from core.file.upload_file_parser import UploadFileParser from extensions.ext_database import db from libs.helper import generate_string +from models.enums import CreatedByRole from .account import Account, Tenant from .types import StringUUID class DifySetup(db.Model): - __tablename__ = 'dify_setups' - __table_args__ = ( - db.PrimaryKeyConstraint('version', name='dify_setup_pkey'), - ) + __tablename__ = "dify_setups" + __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) version = db.Column(db.String(255), nullable=False) - setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) -class AppMode(Enum): - COMPLETION = 'completion' - WORKFLOW = 'workflow' - CHAT = 'chat' - ADVANCED_CHAT = 'advanced-chat' - AGENT_CHAT = 'agent-chat' - CHANNEL = 'channel' +class AppMode(str, Enum): + COMPLETION = "completion" + WORKFLOW = "workflow" + CHAT = "chat" + ADVANCED_CHAT = "advanced-chat" + AGENT_CHAT = "agent-chat" + CHANNEL = "channel" @classmethod - def value_of(cls, value: str) -> 'AppMode': + def value_of(cls, value: str) -> "AppMode": """ Get value of given mode. @@ -48,21 +51,24 @@ def value_of(cls, value: str) -> 'AppMode': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid mode value {value}') + raise ValueError(f"invalid mode value {value}") + + +class IconType(Enum): + IMAGE = "image" + EMOJI = "emoji" class App(db.Model): - __tablename__ = 'apps' - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='app_pkey'), - db.Index('app_tenant_id_idx', 'tenant_id') - ) + __tablename__ = "apps" + __table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id")) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(StringUUID, nullable=False) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) mode = db.Column(db.String(255), nullable=False) + icon_type = db.Column(db.String(255), nullable=True) icon = db.Column(db.String(255)) icon_background = db.Column(db.String(255)) app_model_config_id = db.Column(StringUUID, nullable=True) @@ -70,15 +76,18 @@ class App(db.Model): status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) enable_site = db.Column(db.Boolean, nullable=False) enable_api = db.Column(db.Boolean, nullable=False) - api_rpm = db.Column(db.Integer, nullable=False, server_default=db.text('0')) - api_rph = db.Column(db.Integer, nullable=False, server_default=db.text('0')) - is_demo = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - is_public = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + api_rpm = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + api_rph = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + is_demo = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) tracing = db.Column(db.Text, nullable=True) max_active_requests = db.Column(db.Integer, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_by = db.Column(StringUUID, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_by = db.Column(StringUUID, nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @property def desc_or_prompt(self): @@ -89,7 +98,7 @@ def desc_or_prompt(self): if app_model_config: return app_model_config.pre_prompt else: - return '' + return "" @property def site(self): @@ -97,24 +106,24 @@ def site(self): return site @property - def app_model_config(self) -> Optional['AppModelConfig']: + def app_model_config(self): if self.app_model_config_id: return db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() return None @property - def workflow(self) -> Optional['Workflow']: + def workflow(self) -> Optional["Workflow"]: if self.workflow_id: from .workflow import Workflow + return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() return None @property def api_base_url(self): - return (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL - else request.host_url.rstrip('/')) + '/v1' + return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1" @property def tenant(self): @@ -128,8 +137,9 @@ def is_agent(self) -> bool: return False if not app_model_config.agent_mode: return False - if self.app_model_config.agent_mode_dict.get('enabled', False) \ - and self.app_model_config.agent_mode_dict.get('strategy', '') in ['function_call', 'react']: + if self.app_model_config.agent_mode_dict.get("enabled", False) and self.app_model_config.agent_mode_dict.get( + "strategy", "" + ) in {"function_call", "react"}: self.mode = AppMode.AGENT_CHAT.value db.session.commit() return True @@ -151,16 +161,16 @@ def deleted_tools(self) -> list: if not app_model_config.agent_mode: return [] agent_mode = app_model_config.agent_mode_dict - tools = agent_mode.get('tools', []) + tools = agent_mode.get("tools", []) provider_ids = [] for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: - provider_type = tool.get('provider_type', '') - provider_id = tool.get('provider_id', '') - if provider_type == 'api': + provider_type = tool.get("provider_type", "") + provider_id = tool.get("provider_id", "") + if provider_type == "api": # check if provider id is a uuid string, if not, skip try: uuid.UUID(provider_id) @@ -172,8 +182,7 @@ def deleted_tools(self) -> list: return [] api_providers = db.session.execute( - text('SELECT id FROM tool_api_providers WHERE id IN :provider_ids'), - {'provider_ids': tuple(provider_ids)} + text("SELECT id FROM tool_api_providers WHERE id IN :provider_ids"), {"provider_ids": tuple(provider_ids)} ).fetchall() deleted_tools = [] @@ -182,42 +191,43 @@ def deleted_tools(self) -> list: for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: - provider_type = tool.get('provider_type', '') - provider_id = tool.get('provider_id', '') - if provider_type == 'api' and provider_id not in current_api_provider_ids: - deleted_tools.append(tool['tool_name']) + provider_type = tool.get("provider_type", "") + provider_id = tool.get("provider_id", "") + if provider_type == "api" and provider_id not in current_api_provider_ids: + deleted_tools.append(tool["tool_name"]) return deleted_tools @property def tags(self): - tags = db.session.query(Tag).join( - TagBinding, - Tag.id == TagBinding.tag_id - ).filter( - TagBinding.target_id == self.id, - TagBinding.tenant_id == self.tenant_id, - Tag.tenant_id == self.tenant_id, - Tag.type == 'app' - ).all() + tags = ( + db.session.query(Tag) + .join(TagBinding, Tag.id == TagBinding.tag_id) + .filter( + TagBinding.target_id == self.id, + TagBinding.tenant_id == self.tenant_id, + Tag.tenant_id == self.tenant_id, + Tag.type == "app", + ) + .all() + ) - return tags if tags else [] + return tags or [] class AppModelConfig(db.Model): - __tablename__ = 'app_model_configs' - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='app_model_config_pkey'), - db.Index('app_app_id_idx', 'app_id') - ) + __tablename__ = "app_model_configs" + __table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id")) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) provider = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True) configs = db.Column(db.JSON, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_by = db.Column(StringUUID, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_by = db.Column(StringUUID, nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) opening_statement = db.Column(db.Text) suggested_questions = db.Column(db.Text) suggested_questions_after_answer = db.Column(db.Text) @@ -253,28 +263,29 @@ def suggested_questions_list(self) -> list: @property def suggested_questions_after_answer_dict(self) -> dict: - return json.loads(self.suggested_questions_after_answer) if self.suggested_questions_after_answer \ + return ( + json.loads(self.suggested_questions_after_answer) + if self.suggested_questions_after_answer else {"enabled": False} + ) @property def speech_to_text_dict(self) -> dict: - return json.loads(self.speech_to_text) if self.speech_to_text \ - else {"enabled": False} + return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False} @property def text_to_speech_dict(self) -> dict: - return json.loads(self.text_to_speech) if self.text_to_speech \ - else {"enabled": False} + return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False} @property def retriever_resource_dict(self) -> dict: - return json.loads(self.retriever_resource) if self.retriever_resource \ - else {"enabled": True} + return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} @property def annotation_reply_dict(self) -> dict: - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == self.app_id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == self.app_id).first() + ) if annotation_setting: collection_binding_detail = annotation_setting.collection_binding_detail return { @@ -283,8 +294,8 @@ def annotation_reply_dict(self) -> dict: "score_threshold": annotation_setting.score_threshold, "embedding_model": { "embedding_provider_name": collection_binding_detail.provider_name, - "embedding_model_name": collection_binding_detail.model_name - } + "embedding_model_name": collection_binding_detail.model_name, + }, } else: @@ -296,13 +307,15 @@ def more_like_this_dict(self) -> dict: @property def sensitive_word_avoidance_dict(self) -> dict: - return json.loads(self.sensitive_word_avoidance) if self.sensitive_word_avoidance \ + return ( + json.loads(self.sensitive_word_avoidance) + if self.sensitive_word_avoidance else {"enabled": False, "type": "", "configs": []} + ) @property def external_data_tools_list(self) -> list[dict]: - return json.loads(self.external_data_tools) if self.external_data_tools \ - else [] + return json.loads(self.external_data_tools) if self.external_data_tools else [] @property def user_input_form_list(self) -> dict: @@ -310,8 +323,11 @@ def user_input_form_list(self) -> dict: @property def agent_mode_dict(self) -> dict: - return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": [], - "prompt": None} + return ( + json.loads(self.agent_mode) + if self.agent_mode + else {"enabled": False, "strategy": None, "tools": [], "prompt": None} + ) @property def chat_prompt_config_dict(self) -> dict: @@ -325,19 +341,28 @@ def completion_prompt_config_dict(self) -> dict: def dataset_configs_dict(self) -> dict: if self.dataset_configs: dataset_configs = json.loads(self.dataset_configs) - if 'retrieval_model' not in dataset_configs: - return {'retrieval_model': 'single'} + if "retrieval_model" not in dataset_configs: + return {"retrieval_model": "single"} else: return dataset_configs return { - 'retrieval_model': 'multiple', - } + "retrieval_model": "multiple", + } @property def file_upload_dict(self) -> dict: - return json.loads(self.file_upload) if self.file_upload else { - "image": {"enabled": False, "number_limits": 3, "detail": "high", - "transfer_methods": ["remote_url", "local_file"]}} + return ( + json.loads(self.file_upload) + if self.file_upload + else { + "image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"], + } + } + ) def to_dict(self) -> dict: return { @@ -360,44 +385,53 @@ def to_dict(self) -> dict: "chat_prompt_config": self.chat_prompt_config_dict, "completion_prompt_config": self.completion_prompt_config_dict, "dataset_configs": self.dataset_configs_dict, - "file_upload": self.file_upload_dict + "file_upload": self.file_upload_dict, } - def from_model_config_dict(self, model_config: dict): - self.opening_statement = model_config.get('opening_statement') - self.suggested_questions = json.dumps(model_config['suggested_questions']) \ - if model_config.get('suggested_questions') else None - self.suggested_questions_after_answer = json.dumps(model_config['suggested_questions_after_answer']) \ - if model_config.get('suggested_questions_after_answer') else None - self.speech_to_text = json.dumps(model_config['speech_to_text']) \ - if model_config.get('speech_to_text') else None - self.text_to_speech = json.dumps(model_config['text_to_speech']) \ - if model_config.get('text_to_speech') else None - self.more_like_this = json.dumps(model_config['more_like_this']) \ - if model_config.get('more_like_this') else None - self.sensitive_word_avoidance = json.dumps(model_config['sensitive_word_avoidance']) \ - if model_config.get('sensitive_word_avoidance') else None - self.external_data_tools = json.dumps(model_config['external_data_tools']) \ - if model_config.get('external_data_tools') else None - self.model = json.dumps(model_config['model']) \ - if model_config.get('model') else None - self.user_input_form = json.dumps(model_config['user_input_form']) \ - if model_config.get('user_input_form') else None - self.dataset_query_variable = model_config.get('dataset_query_variable') - self.pre_prompt = model_config['pre_prompt'] - self.agent_mode = json.dumps(model_config['agent_mode']) \ - if model_config.get('agent_mode') else None - self.retriever_resource = json.dumps(model_config['retriever_resource']) \ - if model_config.get('retriever_resource') else None - self.prompt_type = model_config.get('prompt_type', 'simple') - self.chat_prompt_config = json.dumps(model_config.get('chat_prompt_config')) \ - if model_config.get('chat_prompt_config') else None - self.completion_prompt_config = json.dumps(model_config.get('completion_prompt_config')) \ - if model_config.get('completion_prompt_config') else None - self.dataset_configs = json.dumps(model_config.get('dataset_configs')) \ - if model_config.get('dataset_configs') else None - self.file_upload = json.dumps(model_config.get('file_upload')) \ - if model_config.get('file_upload') else None + def from_model_config_dict(self, model_config: Mapping[str, Any]): + self.opening_statement = model_config.get("opening_statement") + self.suggested_questions = ( + json.dumps(model_config["suggested_questions"]) if model_config.get("suggested_questions") else None + ) + self.suggested_questions_after_answer = ( + json.dumps(model_config["suggested_questions_after_answer"]) + if model_config.get("suggested_questions_after_answer") + else None + ) + self.speech_to_text = json.dumps(model_config["speech_to_text"]) if model_config.get("speech_to_text") else None + self.text_to_speech = json.dumps(model_config["text_to_speech"]) if model_config.get("text_to_speech") else None + self.more_like_this = json.dumps(model_config["more_like_this"]) if model_config.get("more_like_this") else None + self.sensitive_word_avoidance = ( + json.dumps(model_config["sensitive_word_avoidance"]) + if model_config.get("sensitive_word_avoidance") + else None + ) + self.external_data_tools = ( + json.dumps(model_config["external_data_tools"]) if model_config.get("external_data_tools") else None + ) + self.model = json.dumps(model_config["model"]) if model_config.get("model") else None + self.user_input_form = ( + json.dumps(model_config["user_input_form"]) if model_config.get("user_input_form") else None + ) + self.dataset_query_variable = model_config.get("dataset_query_variable") + self.pre_prompt = model_config["pre_prompt"] + self.agent_mode = json.dumps(model_config["agent_mode"]) if model_config.get("agent_mode") else None + self.retriever_resource = ( + json.dumps(model_config["retriever_resource"]) if model_config.get("retriever_resource") else None + ) + self.prompt_type = model_config.get("prompt_type", "simple") + self.chat_prompt_config = ( + json.dumps(model_config.get("chat_prompt_config")) if model_config.get("chat_prompt_config") else None + ) + self.completion_prompt_config = ( + json.dumps(model_config.get("completion_prompt_config")) + if model_config.get("completion_prompt_config") + else None + ) + self.dataset_configs = ( + json.dumps(model_config.get("dataset_configs")) if model_config.get("dataset_configs") else None + ) + self.file_upload = json.dumps(model_config.get("file_upload")) if model_config.get("file_upload") else None return self def copy(self): @@ -422,33 +456,33 @@ def copy(self): chat_prompt_config=self.chat_prompt_config, completion_prompt_config=self.completion_prompt_config, dataset_configs=self.dataset_configs, - file_upload=self.file_upload + file_upload=self.file_upload, ) return new_app_model_config class RecommendedApp(db.Model): - __tablename__ = 'recommended_apps' + __tablename__ = "recommended_apps" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='recommended_app_pkey'), - db.Index('recommended_app_app_id_idx', 'app_id'), - db.Index('recommended_app_is_listed_idx', 'is_listed', 'language') + db.PrimaryKeyConstraint("id", name="recommended_app_pkey"), + db.Index("recommended_app_app_id_idx", "app_id"), + db.Index("recommended_app_is_listed_idx", "is_listed", "language"), ) - id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) description = db.Column(db.JSON, nullable=False) copyright = db.Column(db.String(255), nullable=False) privacy_policy = db.Column(db.String(255), nullable=False) - custom_disclaimer = db.Column(db.String(255), nullable=True) + custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") category = db.Column(db.String(255), nullable=False) position = db.Column(db.Integer, nullable=False, default=0) is_listed = db.Column(db.Boolean, nullable=False, default=True) install_count = db.Column(db.Integer, nullable=False, default=0) language = db.Column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def app(self): @@ -457,22 +491,22 @@ def app(self): class InstalledApp(db.Model): - __tablename__ = 'installed_apps' + __tablename__ = "installed_apps" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='installed_app_pkey'), - db.Index('installed_app_tenant_id_idx', 'tenant_id'), - db.Index('installed_app_app_id_idx', 'app_id'), - db.UniqueConstraint('tenant_id', 'app_id', name='unique_tenant_app') + db.PrimaryKeyConstraint("id", name="installed_app_pkey"), + db.Index("installed_app_tenant_id_idx", "tenant_id"), + db.Index("installed_app_app_id_idx", "app_id"), + db.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False) app_owner_tenant_id = db.Column(StringUUID, nullable=False) position = db.Column(db.Integer, nullable=False, default=0) - is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) last_used_at = db.Column(db.DateTime, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def app(self): @@ -485,15 +519,14 @@ def tenant(self): return tenant - class Conversation(db.Model): - __tablename__ = 'conversations' + __tablename__ = "conversations" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='conversation_pkey'), - db.Index('conversation_app_from_user_idx', 'app_id', 'from_source', 'from_end_user_id') + db.PrimaryKeyConstraint("id", name="conversation_pkey"), + db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) app_model_config_id = db.Column(StringUUID, nullable=True) model_provider = db.Column(db.String(255), nullable=True) @@ -502,10 +535,10 @@ class Conversation(db.Model): mode = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False) summary = db.Column(db.Text) - inputs = db.Column(db.JSON) + _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) introduction = db.Column(db.Text) system_instruction = db.Column(db.Text) - system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) status = db.Column(db.String(255), nullable=False) invoke_from = db.Column(db.String(255), nullable=True) from_source = db.Column(db.String(255), nullable=False) @@ -514,13 +547,37 @@ class Conversation(db.Model): read_at = db.Column(db.DateTime) read_account_id = db.Column(StringUUID) dialogue_count: Mapped[int] = mapped_column(default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + + messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all") + message_annotations = db.relationship( + "MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all" + ) + + is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all") - message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all") + @property + def inputs(self): + inputs = self._inputs.copy() + for key, value in inputs.items(): + if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: + inputs[key] = File.model_validate(value) + elif isinstance(value, list) and all( + isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value + ): + inputs[key] = [File.model_validate(item) for item in value] + return inputs - is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + @inputs.setter + def inputs(self, value: Mapping[str, Any]): + inputs = dict(value) + for k, v in inputs.items(): + if isinstance(v, File): + inputs[k] = v.model_dump() + elif isinstance(v, list) and all(isinstance(item, File) for item in v): + inputs[k] = [item.model_dump() for item in v] + self._inputs = inputs @property def model_config(self): @@ -533,20 +590,21 @@ def model_config(self): if self.override_model_configs: override_model_configs = json.loads(self.override_model_configs) - if 'model' in override_model_configs: + if "model" in override_model_configs: app_model_config = AppModelConfig() app_model_config = app_model_config.from_model_config_dict(override_model_configs) model_config = app_model_config.to_dict() else: - model_config['configs'] = override_model_configs + model_config["configs"] = override_model_configs else: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == self.app_model_config_id).first() + app_model_config = ( + db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() + ) model_config = app_model_config.to_dict() - model_config['model_id'] = self.model_id - model_config['provider'] = self.model_provider + model_config["model_id"] = self.model_id + model_config["provider"] = self.model_provider return model_config @@ -559,7 +617,7 @@ def summary_or_query(self): if first_message: return first_message.query else: - return '' + return "" @property def annotated(self): @@ -575,31 +633,51 @@ def message_count(self): @property def user_feedback_stats(self): - like = db.session.query(MessageFeedback) \ - .filter(MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == 'user', - MessageFeedback.rating == 'like').count() + like = ( + db.session.query(MessageFeedback) + .filter( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "user", + MessageFeedback.rating == "like", + ) + .count() + ) - dislike = db.session.query(MessageFeedback) \ - .filter(MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == 'user', - MessageFeedback.rating == 'dislike').count() + dislike = ( + db.session.query(MessageFeedback) + .filter( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "user", + MessageFeedback.rating == "dislike", + ) + .count() + ) - return {'like': like, 'dislike': dislike} + return {"like": like, "dislike": dislike} @property def admin_feedback_stats(self): - like = db.session.query(MessageFeedback) \ - .filter(MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == 'admin', - MessageFeedback.rating == 'like').count() + like = ( + db.session.query(MessageFeedback) + .filter( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "admin", + MessageFeedback.rating == "like", + ) + .count() + ) - dislike = db.session.query(MessageFeedback) \ - .filter(MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == 'admin', - MessageFeedback.rating == 'dislike').count() + dislike = ( + db.session.query(MessageFeedback) + .filter( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "admin", + MessageFeedback.rating == "dislike", + ) + .count() + ) - return {'like': like, 'dislike': dislike} + return {"like": like, "dislike": dislike} @property def first_message(self): @@ -618,59 +696,91 @@ def from_end_user_session_id(self): return None + @property + def from_account_name(self): + if self.from_account_id: + account = db.session.query(Account).filter(Account.id == self.from_account_id).first() + if account: + return account.name + + return None + @property def in_debug_mode(self): return self.override_model_configs is not None class Message(db.Model): - __tablename__ = 'messages' + __tablename__ = "messages" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_pkey'), - db.Index('message_app_id_idx', 'app_id', 'created_at'), - db.Index('message_conversation_id_idx', 'conversation_id'), - db.Index('message_end_user_idx', 'app_id', 'from_source', 'from_end_user_id'), - db.Index('message_account_idx', 'app_id', 'from_source', 'from_account_id'), - db.Index('message_workflow_run_id_idx', 'conversation_id', 'workflow_run_id') + db.PrimaryKeyConstraint("id", name="message_pkey"), + db.Index("message_app_id_idx", "app_id", "created_at"), + db.Index("message_conversation_id_idx", "conversation_id"), + db.Index("message_end_user_idx", "app_id", "from_source", "from_end_user_id"), + db.Index("message_account_idx", "app_id", "from_source", "from_account_id"), + db.Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) model_provider = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True) override_model_configs = db.Column(db.Text) - conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=False) - inputs = db.Column(db.JSON) - query = db.Column(db.Text, nullable=False) + conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=False) + _inputs: Mapped[dict] = mapped_column("inputs", db.JSON) + query: Mapped[str] = db.Column(db.Text, nullable=False) message = db.Column(db.JSON, nullable=False) - message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) - message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) - answer = db.Column(db.Text, nullable=False) - answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + answer: Mapped[str] = db.Column(db.Text, nullable=False) + answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) - answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) - provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text('0')) + answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + parent_message_id = db.Column(StringUUID, nullable=True) + provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_price = db.Column(db.Numeric(10, 7)) currency = db.Column(db.String(255), nullable=False) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) error = db.Column(db.Text) message_metadata = db.Column(db.Text) - invoke_from = db.Column(db.String(255), nullable=True) + invoke_from: Mapped[Optional[str]] = db.Column(db.String(255), nullable=True) from_source = db.Column(db.String(255), nullable=False) - from_end_user_id = db.Column(StringUUID) - from_account_id = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID) + from_account_id: Mapped[Optional[str]] = db.Column(StringUUID) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) workflow_run_id = db.Column(StringUUID) + @property + def inputs(self): + inputs = self._inputs.copy() + for key, value in inputs.items(): + if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: + inputs[key] = File.model_validate(value) + elif isinstance(value, list) and all( + isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value + ): + inputs[key] = [File.model_validate(item) for item in value] + return inputs + + @inputs.setter + def inputs(self, value: Mapping[str, Any]): + inputs = dict(value) + for k, v in inputs.items(): + if isinstance(v, File): + inputs[k] = v.model_dump() + elif isinstance(v, list) and all(isinstance(item, File) for item in v): + inputs[k] = [item.model_dump() for item in v] + self._inputs = inputs + @property def re_sign_file_url_answer(self) -> str: if not self.answer: return self.answer - pattern = r'\[!?.*?\]\((((http|https):\/\/.+)?\/files\/(tools\/)?[\w-]+.*?timestamp=.*&nonce=.*&sign=.*)\)' + pattern = r"\[!?.*?\]\((((http|https):\/\/.+)?\/files\/(tools\/)?[\w-]+.*?timestamp=.*&nonce=.*&sign=.*)\)" matches = re.findall(pattern, self.answer) if not matches: @@ -686,9 +796,9 @@ def re_sign_file_url_answer(self) -> str: re_sign_file_url_answer = self.answer for url in urls: - if 'files/tools' in url: + if "files/tools" in url: # get tool file id - tool_file_id_pattern = r'\/files\/tools\/([\.\w-]+)?\?timestamp=' + tool_file_id_pattern = r"\/files\/tools\/([\.\w-]+)?\?timestamp=" result = re.search(tool_file_id_pattern, url) if not result: continue @@ -696,35 +806,44 @@ def re_sign_file_url_answer(self) -> str: tool_file_id = result.group(1) # get extension - if '.' in tool_file_id: - split_result = tool_file_id.split('.') - extension = f'.{split_result[-1]}' + if "." in tool_file_id: + split_result = tool_file_id.split(".") + extension = f".{split_result[-1]}" if len(extension) > 10: - extension = '.bin' + extension = ".bin" tool_file_id = split_result[0] else: - extension = '.bin' + extension = ".bin" if not tool_file_id: continue sign_url = ToolFileParser.get_tool_file_manager().sign_file( - tool_file_id=tool_file_id, - extension=extension + tool_file_id=tool_file_id, extension=extension ) - else: + elif "file-preview" in url: # get upload file id - upload_file_id_pattern = r'\/files\/([\w-]+)\/image-preview?\?timestamp=' + upload_file_id_pattern = r"\/files\/([\w-]+)\/file-preview?\?timestamp=" result = re.search(upload_file_id_pattern, url) if not result: continue upload_file_id = result.group(1) - if not upload_file_id: continue - - sign_url = UploadFileParser.get_signed_temp_image_url(upload_file_id) + sign_url = file_helpers.get_signed_file_url(upload_file_id) + elif "image-preview" in url: + # image-preview is deprecated, use file-preview instead + upload_file_id_pattern = r"\/files\/([\w-]+)\/image-preview?\?timestamp=" + result = re.search(upload_file_id_pattern, url) + if not result: + continue + upload_file_id = result.group(1) + if not upload_file_id: + continue + sign_url = file_helpers.get_signed_file_url(upload_file_id) + else: + continue re_sign_file_url_answer = re_sign_file_url_answer.replace(url, sign_url) @@ -732,14 +851,20 @@ def re_sign_file_url_answer(self) -> str: @property def user_feedback(self): - feedback = db.session.query(MessageFeedback).filter(MessageFeedback.message_id == self.id, - MessageFeedback.from_source == 'user').first() + feedback = ( + db.session.query(MessageFeedback) + .filter(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user") + .first() + ) return feedback @property def admin_feedback(self): - feedback = db.session.query(MessageFeedback).filter(MessageFeedback.message_id == self.id, - MessageFeedback.from_source == 'admin').first() + feedback = ( + db.session.query(MessageFeedback) + .filter(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin") + .first() + ) return feedback @property @@ -754,11 +879,15 @@ def annotation(self): @property def annotation_hit_history(self): - annotation_history = (db.session.query(AppAnnotationHitHistory) - .filter(AppAnnotationHitHistory.message_id == self.id).first()) + annotation_history = ( + db.session.query(AppAnnotationHitHistory).filter(AppAnnotationHitHistory.message_id == self.id).first() + ) if annotation_history: - annotation = (db.session.query(MessageAnnotation). - filter(MessageAnnotation.id == annotation_history.annotation_id).first()) + annotation = ( + db.session.query(MessageAnnotation) + .filter(MessageAnnotation.id == annotation_history.annotation_id) + .first() + ) return annotation return None @@ -766,8 +895,9 @@ def annotation_hit_history(self): def app_model_config(self): conversation = db.session.query(Conversation).filter(Conversation.id == self.conversation_id).first() if conversation: - return db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation.app_model_config_id).first() + return ( + db.session.query(AppModelConfig).filter(AppModelConfig.id == conversation.app_model_config_id).first() + ) return None @@ -781,123 +911,148 @@ def message_metadata_dict(self) -> dict: @property def agent_thoughts(self): - return db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == self.id) \ - .order_by(MessageAgentThought.position.asc()).all() + return ( + db.session.query(MessageAgentThought) + .filter(MessageAgentThought.message_id == self.id) + .order_by(MessageAgentThought.position.asc()) + .all() + ) @property def retriever_resources(self): - return db.session.query(DatasetRetrieverResource).filter(DatasetRetrieverResource.message_id == self.id) \ - .order_by(DatasetRetrieverResource.position.asc()).all() + return ( + db.session.query(DatasetRetrieverResource) + .filter(DatasetRetrieverResource.message_id == self.id) + .order_by(DatasetRetrieverResource.position.asc()) + .all() + ) @property def message_files(self): - return db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all() + from factories import file_factory - @property - def files(self): - message_files = self.message_files + message_files = db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all() + current_app = db.session.query(App).filter(App.id == self.app_id).first() + if not current_app: + raise ValueError(f"App {self.app_id} not found") - files = [] + files: list[File] = [] for message_file in message_files: - url = message_file.url - if message_file.type == 'image': - if message_file.transfer_method == 'local_file': - upload_file = (db.session.query(UploadFile) - .filter( - UploadFile.id == message_file.upload_file_id - ).first()) - - url = UploadFileParser.get_image_data( - upload_file=upload_file, - force_url=True - ) - if message_file.transfer_method == 'tool_file': - # get tool file id - tool_file_id = message_file.url.split('/')[-1] - # trim extension - tool_file_id = tool_file_id.split('.')[0] - - # get extension - if '.' in message_file.url: - extension = f'.{message_file.url.split(".")[-1]}' - if len(extension) > 10: - extension = '.bin' - else: - extension = '.bin' - # add sign url - url = ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=tool_file_id, extension=extension) + if message_file.transfer_method == "local_file": + if message_file.upload_file_id is None: + raise ValueError(f"MessageFile {message_file.id} is a local file but has no upload_file_id") + file = file_factory.build_from_mapping( + mapping={ + "id": message_file.id, + "upload_file_id": message_file.upload_file_id, + "transfer_method": message_file.transfer_method, + "type": message_file.type, + }, + tenant_id=current_app.tenant_id, + ) + elif message_file.transfer_method == "remote_url": + if message_file.url is None: + raise ValueError(f"MessageFile {message_file.id} is a remote url but has no url") + file = file_factory.build_from_mapping( + mapping={ + "id": message_file.id, + "type": message_file.type, + "transfer_method": message_file.transfer_method, + "url": message_file.url, + }, + tenant_id=current_app.tenant_id, + ) + elif message_file.transfer_method == "tool_file": + if message_file.upload_file_id is None: + assert message_file.url is not None + message_file.upload_file_id = message_file.url.split("/")[-1].split(".")[0] + mapping = { + "id": message_file.id, + "type": message_file.type, + "transfer_method": message_file.transfer_method, + "tool_file_id": message_file.upload_file_id, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=current_app.tenant_id, + ) + else: + raise ValueError( + f"MessageFile {message_file.id} has an invalid transfer_method {message_file.transfer_method}" + ) + files.append(file) - files.append({ - 'id': message_file.id, - 'type': message_file.type, - 'url': url, - 'belongs_to': message_file.belongs_to if message_file.belongs_to else 'user' - }) + result = [ + {"belongs_to": message_file.belongs_to, **file.to_dict()} + for (file, message_file) in zip(files, message_files) + ] - return files + db.session.commit() + return result @property def workflow_run(self): if self.workflow_run_id: from .workflow import WorkflowRun + return db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first() return None def to_dict(self) -> dict: return { - 'id': self.id, - 'app_id': self.app_id, - 'conversation_id': self.conversation_id, - 'inputs': self.inputs, - 'query': self.query, - 'message': self.message, - 'answer': self.answer, - 'status': self.status, - 'error': self.error, - 'message_metadata': self.message_metadata_dict, - 'from_source': self.from_source, - 'from_end_user_id': self.from_end_user_id, - 'from_account_id': self.from_account_id, - 'created_at': self.created_at.isoformat(), - 'updated_at': self.updated_at.isoformat(), - 'agent_based': self.agent_based, - 'workflow_run_id': self.workflow_run_id + "id": self.id, + "app_id": self.app_id, + "conversation_id": self.conversation_id, + "inputs": self.inputs, + "query": self.query, + "message": self.message, + "answer": self.answer, + "status": self.status, + "error": self.error, + "message_metadata": self.message_metadata_dict, + "from_source": self.from_source, + "from_end_user_id": self.from_end_user_id, + "from_account_id": self.from_account_id, + "created_at": self.created_at.isoformat(), + "updated_at": self.updated_at.isoformat(), + "agent_based": self.agent_based, + "workflow_run_id": self.workflow_run_id, } @classmethod def from_dict(cls, data: dict): return cls( - id=data['id'], - app_id=data['app_id'], - conversation_id=data['conversation_id'], - inputs=data['inputs'], - query=data['query'], - message=data['message'], - answer=data['answer'], - status=data['status'], - error=data['error'], - message_metadata=json.dumps(data['message_metadata']), - from_source=data['from_source'], - from_end_user_id=data['from_end_user_id'], - from_account_id=data['from_account_id'], - created_at=data['created_at'], - updated_at=data['updated_at'], - agent_based=data['agent_based'], - workflow_run_id=data['workflow_run_id'] + id=data["id"], + app_id=data["app_id"], + conversation_id=data["conversation_id"], + inputs=data["inputs"], + query=data["query"], + message=data["message"], + answer=data["answer"], + status=data["status"], + error=data["error"], + message_metadata=json.dumps(data["message_metadata"]), + from_source=data["from_source"], + from_end_user_id=data["from_end_user_id"], + from_account_id=data["from_account_id"], + created_at=data["created_at"], + updated_at=data["updated_at"], + agent_based=data["agent_based"], + workflow_run_id=data["workflow_run_id"], ) class MessageFeedback(db.Model): - __tablename__ = 'message_feedbacks' + __tablename__ = "message_feedbacks" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_feedback_pkey'), - db.Index('message_feedback_app_idx', 'app_id'), - db.Index('message_feedback_message_idx', 'message_id', 'from_source'), - db.Index('message_feedback_conversation_idx', 'conversation_id', 'from_source', 'rating') + db.PrimaryKeyConstraint("id", name="message_feedback_pkey"), + db.Index("message_feedback_app_idx", "app_id"), + db.Index("message_feedback_message_idx", "message_id", "from_source"), + db.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) conversation_id = db.Column(StringUUID, nullable=False) message_id = db.Column(StringUUID, nullable=False) @@ -906,8 +1061,8 @@ class MessageFeedback(db.Model): from_source = db.Column(db.String(255), nullable=False) from_end_user_id = db.Column(StringUUID) from_account_id = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def from_account(self): @@ -916,44 +1071,67 @@ def from_account(self): class MessageFile(db.Model): - __tablename__ = 'message_files' + __tablename__ = "message_files" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_file_pkey'), - db.Index('message_file_message_idx', 'message_id'), - db.Index('message_file_created_by_idx', 'created_by') + db.PrimaryKeyConstraint("id", name="message_file_pkey"), + db.Index("message_file_message_idx", "message_id"), + db.Index("message_file_created_by_idx", "created_by"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) - message_id = db.Column(StringUUID, nullable=False) - type = db.Column(db.String(255), nullable=False) - transfer_method = db.Column(db.String(255), nullable=False) - url = db.Column(db.Text, nullable=True) - belongs_to = db.Column(db.String(255), nullable=True) - upload_file_id = db.Column(StringUUID, nullable=True) - created_by_role = db.Column(db.String(255), nullable=False) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + def __init__( + self, + *, + message_id: str, + type: FileType, + transfer_method: FileTransferMethod, + url: str | None = None, + belongs_to: Literal["user", "assistant"] | None = None, + upload_file_id: str | None = None, + created_by_role: CreatedByRole, + created_by: str, + ): + self.message_id = message_id + self.type = type + self.transfer_method = transfer_method + self.url = url + self.belongs_to = belongs_to + self.upload_file_id = upload_file_id + self.created_by_role = created_by_role.value + self.created_by = created_by + + id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + message_id: Mapped[str] = db.Column(StringUUID, nullable=False) + type: Mapped[str] = db.Column(db.String(255), nullable=False) + transfer_method: Mapped[str] = db.Column(db.String(255), nullable=False) + url: Mapped[Optional[str]] = db.Column(db.Text, nullable=True) + belongs_to: Mapped[Optional[str]] = db.Column(db.String(255), nullable=True) + upload_file_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True) + created_by_role: Mapped[str] = db.Column(db.String(255), nullable=False) + created_by: Mapped[str] = db.Column(StringUUID, nullable=False) + created_at: Mapped[datetime] = db.Column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) class MessageAnnotation(db.Model): - __tablename__ = 'message_annotations' + __tablename__ = "message_annotations" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_annotation_pkey'), - db.Index('message_annotation_app_idx', 'app_id'), - db.Index('message_annotation_conversation_idx', 'conversation_id'), - db.Index('message_annotation_message_idx', 'message_id') + db.PrimaryKeyConstraint("id", name="message_annotation_pkey"), + db.Index("message_annotation_app_idx", "app_id"), + db.Index("message_annotation_conversation_idx", "conversation_id"), + db.Index("message_annotation_message_idx", "message_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) - conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=True) + conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=True) message_id = db.Column(StringUUID, nullable=True) question = db.Column(db.Text, nullable=True) content = db.Column(db.Text, nullable=False) - hit_count = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + hit_count = db.Column(db.Integer, nullable=False, server_default=db.text("0")) account_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def account(self): @@ -967,32 +1145,35 @@ def annotation_create_account(self): class AppAnnotationHitHistory(db.Model): - __tablename__ = 'app_annotation_hit_histories' + __tablename__ = "app_annotation_hit_histories" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='app_annotation_hit_histories_pkey'), - db.Index('app_annotation_hit_histories_app_idx', 'app_id'), - db.Index('app_annotation_hit_histories_account_idx', 'account_id'), - db.Index('app_annotation_hit_histories_annotation_idx', 'annotation_id'), - db.Index('app_annotation_hit_histories_message_idx', 'message_id'), + db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"), + db.Index("app_annotation_hit_histories_app_idx", "app_id"), + db.Index("app_annotation_hit_histories_account_idx", "account_id"), + db.Index("app_annotation_hit_histories_annotation_idx", "annotation_id"), + db.Index("app_annotation_hit_histories_message_idx", "message_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) annotation_id = db.Column(StringUUID, nullable=False) source = db.Column(db.Text, nullable=False) question = db.Column(db.Text, nullable=False) account_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - score = db.Column(Float, nullable=False, server_default=db.text('0')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + score = db.Column(Float, nullable=False, server_default=db.text("0")) message_id = db.Column(StringUUID, nullable=False) annotation_question = db.Column(db.Text, nullable=False) annotation_content = db.Column(db.Text, nullable=False) @property def account(self): - account = (db.session.query(Account) - .join(MessageAnnotation, MessageAnnotation.account_id == Account.id) - .filter(MessageAnnotation.id == self.annotation_id).first()) + account = ( + db.session.query(Account) + .join(MessageAnnotation, MessageAnnotation.account_id == Account.id) + .filter(MessageAnnotation.id == self.annotation_id) + .first() + ) return account @property @@ -1002,109 +1183,133 @@ def annotation_create_account(self): class AppAnnotationSetting(db.Model): - __tablename__ = 'app_annotation_settings' + __tablename__ = "app_annotation_settings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='app_annotation_settings_pkey'), - db.Index('app_annotation_settings_app_idx', 'app_id') + db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"), + db.Index("app_annotation_settings_app_idx", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) - score_threshold = db.Column(Float, nullable=False, server_default=db.text('0')) + score_threshold = db.Column(Float, nullable=False, server_default=db.text("0")) collection_binding_id = db.Column(StringUUID, nullable=False) created_user_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_user_id = db.Column(StringUUID, nullable=False) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def created_account(self): - account = (db.session.query(Account) - .join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id) - .filter(AppAnnotationSetting.id == self.annotation_id).first()) + account = ( + db.session.query(Account) + .join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id) + .filter(AppAnnotationSetting.id == self.annotation_id) + .first() + ) return account @property def updated_account(self): - account = (db.session.query(Account) - .join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id) - .filter(AppAnnotationSetting.id == self.annotation_id).first()) + account = ( + db.session.query(Account) + .join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id) + .filter(AppAnnotationSetting.id == self.annotation_id) + .first() + ) return account @property def collection_binding_detail(self): from .dataset import DatasetCollectionBinding - collection_binding_detail = (db.session.query(DatasetCollectionBinding) - .filter(DatasetCollectionBinding.id == self.collection_binding_id).first()) + + collection_binding_detail = ( + db.session.query(DatasetCollectionBinding) + .filter(DatasetCollectionBinding.id == self.collection_binding_id) + .first() + ) return collection_binding_detail class OperationLog(db.Model): - __tablename__ = 'operation_logs' + __tablename__ = "operation_logs" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='operation_log_pkey'), - db.Index('operation_log_account_action_idx', 'tenant_id', 'account_id', 'action') + db.PrimaryKeyConstraint("id", name="operation_log_pkey"), + db.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) account_id = db.Column(StringUUID, nullable=False) action = db.Column(db.String(255), nullable=False) content = db.Column(db.JSON) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) created_ip = db.Column(db.String(255), nullable=False) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class EndUser(UserMixin, db.Model): - __tablename__ = 'end_users' + __tablename__ = "end_users" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='end_user_pkey'), - db.Index('end_user_session_id_idx', 'session_id', 'type'), - db.Index('end_user_tenant_session_id_idx', 'tenant_id', 'session_id', 'type'), + db.PrimaryKeyConstraint("id", name="end_user_pkey"), + db.Index("end_user_session_id_idx", "session_id", "type"), + db.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=True) type = db.Column(db.String(255), nullable=False) external_user_id = db.Column(db.String(255), nullable=True) name = db.Column(db.String(255)) - is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) + is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) session_id = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class Site(db.Model): - __tablename__ = 'sites' + __tablename__ = "sites" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='site_pkey'), - db.Index('site_app_id_idx', 'app_id'), - db.Index('site_code_idx', 'code', 'status') + db.PrimaryKeyConstraint("id", name="site_pkey"), + db.Index("site_app_id_idx", "app_id"), + db.Index("site_code_idx", "code", "status"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) title = db.Column(db.String(255), nullable=False) + icon_type = db.Column(db.String(255), nullable=True) icon = db.Column(db.String(255)) icon_background = db.Column(db.String(255)) description = db.Column(db.Text) default_language = db.Column(db.String(255), nullable=False) chat_color_theme = db.Column(db.String(255)) - chat_color_theme_inverted = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + chat_color_theme_inverted = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) copyright = db.Column(db.String(255)) privacy_policy = db.Column(db.String(255)) - show_workflow_steps = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) - custom_disclaimer = db.Column(db.String(255), nullable=True) + show_workflow_steps = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + _custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", sa.TEXT, default="") customize_domain = db.Column(db.String(255)) customize_token_strategy = db.Column(db.String(255), nullable=False) - prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_by = db.Column(StringUUID, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_by = db.Column(StringUUID, nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) code = db.Column(db.String(255)) + @property + def custom_disclaimer(self): + return self._custom_disclaimer + + @custom_disclaimer.setter + def custom_disclaimer(self, value: str): + if len(value) > 512: + raise ValueError("Custom disclaimer cannot exceed 512 characters.") + self._custom_disclaimer = value + @staticmethod def generate_code(n): while True: @@ -1116,26 +1321,25 @@ def generate_code(n): @property def app_base_url(self): - return ( - dify_config.APP_WEB_URL if dify_config.APP_WEB_URL else request.url_root.rstrip('/')) + return dify_config.APP_WEB_URL or request.url_root.rstrip("/") class ApiToken(db.Model): - __tablename__ = 'api_tokens' + __tablename__ = "api_tokens" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='api_token_pkey'), - db.Index('api_token_app_id_type_idx', 'app_id', 'type'), - db.Index('api_token_token_idx', 'token', 'type'), - db.Index('api_token_tenant_idx', 'tenant_id', 'type') + db.PrimaryKeyConstraint("id", name="api_token_pkey"), + db.Index("api_token_app_id_type_idx", "app_id", "type"), + db.Index("api_token_token_idx", "token", "type"), + db.Index("api_token_tenant_idx", "tenant_id", "type"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=True) tenant_id = db.Column(StringUUID, nullable=True) type = db.Column(db.String(16), nullable=False) token = db.Column(db.String(255), nullable=False) last_used_at = db.Column(db.DateTime, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @staticmethod def generate_api_key(prefix, n): @@ -1148,54 +1352,94 @@ def generate_api_key(prefix, n): class UploadFile(db.Model): - __tablename__ = 'upload_files' + __tablename__ = "upload_files" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='upload_file_pkey'), - db.Index('upload_file_tenant_idx', 'tenant_id') + db.PrimaryKeyConstraint("id", name="upload_file_pkey"), + db.Index("upload_file_tenant_idx", "tenant_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(StringUUID, nullable=False) - storage_type = db.Column(db.String(255), nullable=False) - key = db.Column(db.String(255), nullable=False) - name = db.Column(db.String(255), nullable=False) - size = db.Column(db.Integer, nullable=False) - extension = db.Column(db.String(255), nullable=False) - mime_type = db.Column(db.String(255), nullable=True) - created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying")) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - used = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - used_by = db.Column(StringUUID, nullable=True) - used_at = db.Column(db.DateTime, nullable=True) - hash = db.Column(db.String(255), nullable=True) + id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) + storage_type: Mapped[str] = db.Column(db.String(255), nullable=False) + key: Mapped[str] = db.Column(db.String(255), nullable=False) + name: Mapped[str] = db.Column(db.String(255), nullable=False) + size: Mapped[int] = db.Column(db.Integer, nullable=False) + extension: Mapped[str] = db.Column(db.String(255), nullable=False) + mime_type: Mapped[str] = db.Column(db.String(255), nullable=True) + created_by_role: Mapped[str] = db.Column( + db.String(255), nullable=False, server_default=db.text("'account'::character varying") + ) + created_by: Mapped[str] = db.Column(StringUUID, nullable=False) + created_at: Mapped[datetime] = db.Column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) + used: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + used_by: Mapped[str | None] = db.Column(StringUUID, nullable=True) + used_at: Mapped[datetime | None] = db.Column(db.DateTime, nullable=True) + hash: Mapped[str | None] = db.Column(db.String(255), nullable=True) + source_url: Mapped[str] = mapped_column(sa.TEXT, default="") + + def __init__( + self, + *, + tenant_id: str, + storage_type: str, + key: str, + name: str, + size: int, + extension: str, + mime_type: str, + created_by_role: CreatedByRole, + created_by: str, + created_at: datetime, + used: bool, + used_by: str | None = None, + used_at: datetime | None = None, + hash: str | None = None, + source_url: str = "", + ): + self.tenant_id = tenant_id + self.storage_type = storage_type + self.key = key + self.name = name + self.size = size + self.extension = extension + self.mime_type = mime_type + self.created_by_role = created_by_role.value + self.created_by = created_by + self.created_at = created_at + self.used = used + self.used_by = used_by + self.used_at = used_at + self.hash = hash + self.source_url = source_url class ApiRequest(db.Model): - __tablename__ = 'api_requests' + __tablename__ = "api_requests" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='api_request_pkey'), - db.Index('api_request_token_idx', 'tenant_id', 'api_token_id') + db.PrimaryKeyConstraint("id", name="api_request_pkey"), + db.Index("api_request_token_idx", "tenant_id", "api_token_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) api_token_id = db.Column(StringUUID, nullable=False) path = db.Column(db.String(255), nullable=False) request = db.Column(db.Text, nullable=True) response = db.Column(db.Text, nullable=True) ip = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class MessageChain(db.Model): - __tablename__ = 'message_chains' + __tablename__ = "message_chains" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_chain_pkey'), - db.Index('message_chain_message_id_idx', 'message_id') + db.PrimaryKeyConstraint("id", name="message_chain_pkey"), + db.Index("message_chain_message_id_idx", "message_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = db.Column(StringUUID, nullable=False) type = db.Column(db.String(255), nullable=False) input = db.Column(db.Text, nullable=True) @@ -1204,14 +1448,14 @@ class MessageChain(db.Model): class MessageAgentThought(db.Model): - __tablename__ = 'message_agent_thoughts' + __tablename__ = "message_agent_thoughts" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='message_agent_thought_pkey'), - db.Index('message_agent_thought_message_id_idx', 'message_id'), - db.Index('message_agent_thought_message_chain_id_idx', 'message_chain_id'), + db.PrimaryKeyConstraint("id", name="message_agent_thought_pkey"), + db.Index("message_agent_thought_message_id_idx", "message_id"), + db.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = db.Column(StringUUID, nullable=False) message_chain_id = db.Column(StringUUID, nullable=True) position = db.Column(db.Integer, nullable=False) @@ -1226,12 +1470,12 @@ class MessageAgentThought(db.Model): message = db.Column(db.Text, nullable=True) message_token = db.Column(db.Integer, nullable=True) message_unit_price = db.Column(db.Numeric, nullable=True) - message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) + message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) message_files = db.Column(db.Text, nullable=True) answer = db.Column(db.Text, nullable=True) answer_token = db.Column(db.Integer, nullable=True) answer_unit_price = db.Column(db.Numeric, nullable=True) - answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) + answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) tokens = db.Column(db.Integer, nullable=True) total_price = db.Column(db.Numeric, nullable=True) currency = db.Column(db.String, nullable=True) @@ -1288,9 +1532,7 @@ def tool_inputs_dict(self) -> dict: result[tool] = {} return result else: - return { - tool: {} for tool in tools - } + return {tool: {} for tool in tools} except Exception as e: return {} @@ -1311,32 +1553,28 @@ def tool_outputs_dict(self) -> dict: result[tool] = {} return result else: - return { - tool: {} for tool in tools - } + return {tool: {} for tool in tools} except Exception as e: if self.observation: - return { - tool: self.observation for tool in tools - } + return dict.fromkeys(tools, self.observation) class DatasetRetrieverResource(db.Model): - __tablename__ = 'dataset_retriever_resources' + __tablename__ = "dataset_retriever_resources" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='dataset_retriever_resource_pkey'), - db.Index('dataset_retriever_resource_message_id_idx', 'message_id'), + db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"), + db.Index("dataset_retriever_resource_message_id_idx", "message_id"), ) - id = db.Column(StringUUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) message_id = db.Column(StringUUID, nullable=False) position = db.Column(db.Integer, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) dataset_name = db.Column(db.Text, nullable=False) - document_id = db.Column(StringUUID, nullable=False) + document_id = db.Column(StringUUID, nullable=True) document_name = db.Column(db.Text, nullable=False) - data_source_type = db.Column(db.Text, nullable=False) - segment_id = db.Column(StringUUID, nullable=False) + data_source_type = db.Column(db.Text, nullable=True) + segment_id = db.Column(StringUUID, nullable=True) score = db.Column(db.Float, nullable=True) content = db.Column(db.Text, nullable=False) hit_count = db.Column(db.Integer, nullable=True) @@ -1349,57 +1587,57 @@ class DatasetRetrieverResource(db.Model): class Tag(db.Model): - __tablename__ = 'tags' + __tablename__ = "tags" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tag_pkey'), - db.Index('tag_type_idx', 'type'), - db.Index('tag_name_idx', 'name'), + db.PrimaryKeyConstraint("id", name="tag_pkey"), + db.Index("tag_type_idx", "type"), + db.Index("tag_name_idx", "name"), ) - TAG_TYPE_LIST = ['knowledge', 'app'] + TAG_TYPE_LIST = ["knowledge", "app"] - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=True) type = db.Column(db.String(16), nullable=False) name = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class TagBinding(db.Model): - __tablename__ = 'tag_bindings' + __tablename__ = "tag_bindings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tag_binding_pkey'), - db.Index('tag_bind_target_id_idx', 'target_id'), - db.Index('tag_bind_tag_id_idx', 'tag_id'), + db.PrimaryKeyConstraint("id", name="tag_binding_pkey"), + db.Index("tag_bind_target_id_idx", "target_id"), + db.Index("tag_bind_tag_id_idx", "tag_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=True) tag_id = db.Column(StringUUID, nullable=True) target_id = db.Column(StringUUID, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class TraceAppConfig(db.Model): - __tablename__ = 'trace_app_config' + __tablename__ = "trace_app_config" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tracing_app_config_pkey'), - db.Index('trace_app_config_app_id_idx', 'app_id'), + db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"), + db.Index("trace_app_config_app_id_idx", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) tracing_provider = db.Column(db.String(255), nullable=True) tracing_config = db.Column(db.JSON, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=func.now()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.now(), onupdate=func.now()) - is_active = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) + is_active = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) @property def tracing_config_dict(self): - return self.tracing_config if self.tracing_config else {} + return self.tracing_config or {} @property def tracing_config_str(self): @@ -1407,11 +1645,11 @@ def tracing_config_str(self): def to_dict(self): return { - 'id': self.id, - 'app_id': self.app_id, - 'tracing_provider': self.tracing_provider, - 'tracing_config': self.tracing_config_dict, + "id": self.id, + "app_id": self.app_id, + "tracing_provider": self.tracing_provider, + "tracing_config": self.tracing_config_dict, "is_active": self.is_active, - "created_at": self.created_at.__str__() if self.created_at else None, - 'updated_at': self.updated_at.__str__() if self.updated_at else None, + "created_at": str(self.created_at) if self.created_at else None, + "updated_at": str(self.updated_at) if self.updated_at else None, } diff --git a/api/models/provider.py b/api/models/provider.py index 5d92ee6eb60d18..644915e781084b 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -6,8 +6,8 @@ class ProviderType(Enum): - CUSTOM = 'custom' - SYSTEM = 'system' + CUSTOM = "custom" + SYSTEM = "system" @staticmethod def value_of(value): @@ -18,13 +18,13 @@ def value_of(value): class ProviderQuotaType(Enum): - PAID = 'paid' + PAID = "paid" """hosted paid quota""" - FREE = 'free' + FREE = "free" """third-party free quota""" - TRIAL = 'trial' + TRIAL = "trial" """hosted trial quota""" @staticmethod @@ -39,36 +39,42 @@ class Provider(db.Model): """ Provider model representing the API providers and their configurations. """ - __tablename__ = 'providers' + + __tablename__ = "providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='provider_pkey'), - db.Index('provider_tenant_id_provider_idx', 'tenant_id', 'provider_name'), - db.UniqueConstraint('tenant_id', 'provider_name', 'provider_type', 'quota_type', name='unique_provider_name_type_quota') + db.PrimaryKeyConstraint("id", name="provider_pkey"), + db.Index("provider_tenant_id_provider_idx", "tenant_id", "provider_name"), + db.UniqueConstraint( + "tenant_id", "provider_name", "provider_type", "quota_type", name="unique_provider_name_type_quota" + ), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying")) encrypted_config = db.Column(db.Text, nullable=True) - is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) last_used = db.Column(db.DateTime, nullable=True) quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying")) quota_limit = db.Column(db.BigInteger, nullable=True) quota_used = db.Column(db.BigInteger, default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) def __repr__(self): - return f"" + return ( + f"" + ) @property def token_is_set(self): """ - Returns True if the encrypted_config is not None, indicating that the token is set. - """ + Returns True if the encrypted_config is not None, indicating that the token is set. + """ return self.encrypted_config is not None @property @@ -86,118 +92,123 @@ class ProviderModel(db.Model): """ Provider model representing the API provider_models and their configurations. """ - __tablename__ = 'provider_models' + + __tablename__ = "provider_models" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='provider_model_pkey'), - db.Index('provider_model_tenant_id_provider_idx', 'tenant_id', 'provider_name'), - db.UniqueConstraint('tenant_id', 'provider_name', 'model_name', 'model_type', name='unique_provider_model_name') + db.PrimaryKeyConstraint("id", name="provider_model_pkey"), + db.Index("provider_model_tenant_id_provider_idx", "tenant_id", "provider_name"), + db.UniqueConstraint( + "tenant_id", "provider_name", "model_name", "model_type", name="unique_provider_model_name" + ), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) encrypted_config = db.Column(db.Text, nullable=True) - is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class TenantDefaultModel(db.Model): - __tablename__ = 'tenant_default_models' + __tablename__ = "tenant_default_models" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tenant_default_model_pkey'), - db.Index('tenant_default_model_tenant_id_provider_type_idx', 'tenant_id', 'provider_name', 'model_type'), + db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"), + db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class TenantPreferredModelProvider(db.Model): - __tablename__ = 'tenant_preferred_model_providers' + __tablename__ = "tenant_preferred_model_providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tenant_preferred_model_provider_pkey'), - db.Index('tenant_preferred_model_provider_tenant_provider_idx', 'tenant_id', 'provider_name'), + db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"), + db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) preferred_provider_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class ProviderOrder(db.Model): - __tablename__ = 'provider_orders' + __tablename__ = "provider_orders" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='provider_order_pkey'), - db.Index('provider_order_tenant_provider_idx', 'tenant_id', 'provider_name'), + db.PrimaryKeyConstraint("id", name="provider_order_pkey"), + db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) account_id = db.Column(StringUUID, nullable=False) payment_product_id = db.Column(db.String(191), nullable=False) payment_id = db.Column(db.String(191)) transaction_id = db.Column(db.String(191)) - quantity = db.Column(db.Integer, nullable=False, server_default=db.text('1')) + quantity = db.Column(db.Integer, nullable=False, server_default=db.text("1")) currency = db.Column(db.String(40)) total_amount = db.Column(db.Integer) payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying")) paid_at = db.Column(db.DateTime) pay_failed_at = db.Column(db.DateTime) refunded_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class ProviderModelSetting(db.Model): """ Provider model settings for record the model enabled status and load balancing status. """ - __tablename__ = 'provider_model_settings' + + __tablename__ = "provider_model_settings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='provider_model_setting_pkey'), - db.Index('provider_model_setting_tenant_provider_model_idx', 'tenant_id', 'provider_name', 'model_type'), + db.PrimaryKeyConstraint("id", name="provider_model_setting_pkey"), + db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) - enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) - load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) class LoadBalancingModelConfig(db.Model): """ Configurations for load balancing models. """ - __tablename__ = 'load_balancing_model_configs' + + __tablename__ = "load_balancing_model_configs" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='load_balancing_model_config_pkey'), - db.Index('load_balancing_model_config_tenant_provider_model_idx', 'tenant_id', 'provider_name', 'model_type'), + db.PrimaryKeyConstraint("id", name="load_balancing_model_config_pkey"), + db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) name = db.Column(db.String(255), nullable=False) encrypted_config = db.Column(db.Text, nullable=True) - enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/source.py b/api/models/source.py index adc00028bee43b..07695f06e6cf00 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -8,48 +8,48 @@ class DataSourceOauthBinding(db.Model): - __tablename__ = 'data_source_oauth_bindings' + __tablename__ = "data_source_oauth_bindings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='source_binding_pkey'), - db.Index('source_binding_tenant_id_idx', 'tenant_id'), - db.Index('source_info_idx', "source_info", postgresql_using='gin') + db.PrimaryKeyConstraint("id", name="source_binding_pkey"), + db.Index("source_binding_tenant_id_idx", "tenant_id"), + db.Index("source_info_idx", "source_info", postgresql_using="gin"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) access_token = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False) source_info = db.Column(JSONB, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) class DataSourceApiKeyAuthBinding(db.Model): - __tablename__ = 'data_source_api_key_auth_bindings' + __tablename__ = "data_source_api_key_auth_bindings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='data_source_api_key_auth_binding_pkey'), - db.Index('data_source_api_key_auth_binding_tenant_id_idx', 'tenant_id'), - db.Index('data_source_api_key_auth_binding_provider_idx', 'provider'), + db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"), + db.Index("data_source_api_key_auth_binding_tenant_id_idx", "tenant_id"), + db.Index("data_source_api_key_auth_binding_provider_idx", "provider"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) category = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False) credentials = db.Column(db.Text, nullable=True) # JSON - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) def to_dict(self): return { - 'id': self.id, - 'tenant_id': self.tenant_id, - 'category': self.category, - 'provider': self.provider, - 'credentials': json.loads(self.credentials), - 'created_at': self.created_at.timestamp(), - 'updated_at': self.updated_at.timestamp(), - 'disabled': self.disabled + "id": self.id, + "tenant_id": self.tenant_id, + "category": self.category, + "provider": self.provider, + "credentials": json.loads(self.credentials), + "created_at": self.created_at.timestamp(), + "updated_at": self.updated_at.timestamp(), + "disabled": self.disabled, } diff --git a/api/models/task.py b/api/models/task.py index 618d831d8ed4e9..57b147c78db110 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -8,15 +8,18 @@ class CeleryTask(db.Model): """Task result/status.""" - __tablename__ = 'celery_taskmeta' + __tablename__ = "celery_taskmeta" - id = db.Column(db.Integer, db.Sequence('task_id_sequence'), - primary_key=True, autoincrement=True) + id = db.Column(db.Integer, db.Sequence("task_id_sequence"), primary_key=True, autoincrement=True) task_id = db.Column(db.String(155), unique=True) status = db.Column(db.String(50), default=states.PENDING) result = db.Column(db.PickleType, nullable=True) - date_done = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), - onupdate=lambda: datetime.now(timezone.utc).replace(tzinfo=None), nullable=True) + date_done = db.Column( + db.DateTime, + default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), + onupdate=lambda: datetime.now(timezone.utc).replace(tzinfo=None), + nullable=True, + ) traceback = db.Column(db.Text, nullable=True) name = db.Column(db.String(155), nullable=True) args = db.Column(db.LargeBinary, nullable=True) @@ -29,11 +32,9 @@ class CeleryTask(db.Model): class CeleryTaskSet(db.Model): """TaskSet result.""" - __tablename__ = 'celery_tasksetmeta' + __tablename__ = "celery_tasksetmeta" - id = db.Column(db.Integer, db.Sequence('taskset_id_sequence'), - autoincrement=True, primary_key=True) + id = db.Column(db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True) taskset_id = db.Column(db.String(155), unique=True) result = db.Column(db.PickleType, nullable=True) - date_done = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), - nullable=True) + date_done = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), nullable=True) diff --git a/api/models/tool.py b/api/models/tool.py index 79a70c6b1f2d22..a81bb65174a724 100644 --- a/api/models/tool.py +++ b/api/models/tool.py @@ -7,7 +7,7 @@ class ToolProviderName(Enum): - SERPAPI = 'serpapi' + SERPAPI = "serpapi" @staticmethod def value_of(value): @@ -18,25 +18,25 @@ def value_of(value): class ToolProvider(db.Model): - __tablename__ = 'tool_providers' + __tablename__ = "tool_providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_provider_pkey'), - db.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') + db.PrimaryKeyConstraint("id", name="tool_provider_pkey"), + db.UniqueConstraint("tenant_id", "tool_name", name="unique_tool_provider_tool_name"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) tool_name = db.Column(db.String(40), nullable=False) encrypted_credentials = db.Column(db.Text, nullable=True) - is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def credentials_is_set(self): """ - Returns True if the encrypted_config is not None, indicating that the token is set. - """ + Returns True if the encrypted_config is not None, indicating that the token is set. + """ return self.encrypted_credentials is not None @property diff --git a/api/models/tools.py b/api/models/tools.py index 069dc5bad083c8..4040339e026474 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,6 +1,9 @@ import json +from typing import Optional +import sqlalchemy as sa from sqlalchemy import ForeignKey +from sqlalchemy.orm import Mapped, mapped_column from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle @@ -15,15 +18,16 @@ class BuiltinToolProvider(db.Model): """ This table stores the tool provider information for built-in tools for each tenant. """ - __tablename__ = 'tool_builtin_providers' + + __tablename__ = "tool_builtin_providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_builtin_provider_pkey'), + db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"), # one tenant can only have one tool provider with the same name - db.UniqueConstraint('tenant_id', 'provider', name='unique_builtin_tool_provider') + db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_tool_provider"), ) # id of the tool provider - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # id of the tenant tenant_id = db.Column(StringUUID, nullable=True) # who created this tool provider @@ -32,34 +36,37 @@ class BuiltinToolProvider(db.Model): provider = db.Column(db.String(40), nullable=False) # credential of the tool provider encrypted_credentials = db.Column(db.Text, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def credentials(self) -> dict: return json.loads(self.encrypted_credentials) + class PublishedAppTool(db.Model): """ The table stores the apps published as a tool for each person. """ - __tablename__ = 'tool_published_apps' + + __tablename__ = "tool_published_apps" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='published_app_tool_pkey'), - db.UniqueConstraint('app_id', 'user_id', name='unique_published_app_tool') + db.PrimaryKeyConstraint("id", name="published_app_tool_pkey"), + db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"), ) # id of the tool provider - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # id of the app - app_id = db.Column(StringUUID, ForeignKey('apps.id'), nullable=False) + app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False) # who published this tool user_id = db.Column(StringUUID, nullable=False) # description of the tool, stored in i18n format, for human description = db.Column(db.Text, nullable=False) # llm_description of the tool, for LLM llm_description = db.Column(db.Text, nullable=False) - # query description, query will be seem as a parameter of the tool, to describe this parameter to llm, we need this field + # query description, query will be seem as a parameter of the tool, + # to describe this parameter to llm, we need this field query_description = db.Column(db.Text, nullable=False) # query name, the name of the query parameter query_name = db.Column(db.String(40), nullable=False) @@ -67,35 +74,37 @@ class PublishedAppTool(db.Model): tool_name = db.Column(db.String(40), nullable=False) # author author = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def description_i18n(self) -> I18nObject: return I18nObject(**json.loads(self.description)) - + @property def app(self) -> App: return db.session.query(App).filter(App.id == self.app_id).first() + class ApiToolProvider(db.Model): """ The table stores the api providers. """ - __tablename__ = 'tool_api_providers' + + __tablename__ = "tool_api_providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_api_provider_pkey'), - db.UniqueConstraint('name', 'tenant_id', name='unique_api_tool_provider') + db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"), + db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the api provider name = db.Column(db.String(40), nullable=False) # icon icon = db.Column(db.String(255), nullable=False) # original schema schema = db.Column(db.Text, nullable=False) - schema_type_str = db.Column(db.String(40), nullable=False) + schema_type_str: Mapped[str] = db.Column(db.String(40), nullable=False) # who created this tool user_id = db.Column(StringUUID, nullable=False) # tenant id @@ -109,42 +118,44 @@ class ApiToolProvider(db.Model): # privacy policy privacy_policy = db.Column(db.String(255), nullable=True) # custom_disclaimer - custom_disclaimer = db.Column(db.String(255), nullable=True) + custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def schema_type(self) -> ApiProviderSchemaType: return ApiProviderSchemaType.value_of(self.schema_type_str) - + @property def tools(self) -> list[ApiToolBundle]: return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)] - + @property def credentials(self) -> dict: return json.loads(self.credentials_str) - + @property - def user(self) -> Account: + def user(self) -> Account | None: return db.session.query(Account).filter(Account.id == self.user_id).first() @property - def tenant(self) -> Tenant: + def tenant(self) -> Tenant | None: return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + class ToolLabelBinding(db.Model): """ The table stores the labels for tools. """ - __tablename__ = 'tool_label_bindings' + + __tablename__ = "tool_label_bindings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_label_bind_pkey'), - db.UniqueConstraint('tool_id', 'label_name', name='unique_tool_label_bind'), + db.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"), + db.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # tool id tool_id = db.Column(db.String(64), nullable=False) # tool type @@ -152,28 +163,30 @@ class ToolLabelBinding(db.Model): # label name label_name = db.Column(db.String(40), nullable=False) + class WorkflowToolProvider(db.Model): """ The table stores the workflow providers. """ - __tablename__ = 'tool_workflow_providers' + + __tablename__ = "tool_workflow_providers" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_workflow_provider_pkey'), - db.UniqueConstraint('name', 'tenant_id', name='unique_workflow_tool_provider'), - db.UniqueConstraint('tenant_id', 'app_id', name='unique_workflow_tool_provider_app_id'), + db.PrimaryKeyConstraint("id", name="tool_workflow_provider_pkey"), + db.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"), + db.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the workflow provider name = db.Column(db.String(40), nullable=False) # label of the workflow provider - label = db.Column(db.String(255), nullable=False, server_default='') + label = db.Column(db.String(255), nullable=False, server_default="") # icon icon = db.Column(db.String(255), nullable=False) # app id of the workflow provider app_id = db.Column(StringUUID, nullable=False) # version of the workflow provider - version = db.Column(db.String(255), nullable=False, server_default='') + version = db.Column(db.String(255), nullable=False, server_default="") # who created this tool user_id = db.Column(StringUUID, nullable=False) # tenant id @@ -181,46 +194,43 @@ class WorkflowToolProvider(db.Model): # description of the provider description = db.Column(db.Text, nullable=False) # parameter configuration - parameter_configuration = db.Column(db.Text, nullable=False, server_default='[]') + parameter_configuration = db.Column(db.Text, nullable=False, server_default="[]") # privacy policy - privacy_policy = db.Column(db.String(255), nullable=True, server_default='') + privacy_policy = db.Column(db.String(255), nullable=True, server_default="") - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def schema_type(self) -> ApiProviderSchemaType: return ApiProviderSchemaType.value_of(self.schema_type_str) - + @property - def user(self) -> Account: + def user(self) -> Account | None: return db.session.query(Account).filter(Account.id == self.user_id).first() @property - def tenant(self) -> Tenant: + def tenant(self) -> Tenant | None: return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() - + @property def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]: - return [ - WorkflowToolParameterConfiguration(**config) - for config in json.loads(self.parameter_configuration) - ] - + return [WorkflowToolParameterConfiguration(**config) for config in json.loads(self.parameter_configuration)] + @property - def app(self) -> App: + def app(self) -> App | None: return db.session.query(App).filter(App.id == self.app_id).first() + class ToolModelInvoke(db.Model): """ store the invoke logs from tool invoke """ + __tablename__ = "tool_model_invokes" - __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_model_invoke_pkey'), - ) + __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # who invoke this tool user_id = db.Column(StringUUID, nullable=False) # tenant id @@ -238,29 +248,31 @@ class ToolModelInvoke(db.Model): # invoke response model_response = db.Column(db.Text, nullable=False) - prompt_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) - answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + prompt_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) - answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) - provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text('0')) + answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) + provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_price = db.Column(db.Numeric(10, 7)) currency = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + class ToolConversationVariables(db.Model): """ store the conversation variables from tool invoke """ + __tablename__ = "tool_conversation_variables" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_conversation_variables_pkey'), + db.PrimaryKeyConstraint("id", name="tool_conversation_variables_pkey"), # add index for user_id and conversation_id - db.Index('user_id_idx', 'user_id'), - db.Index('conversation_id_idx', 'conversation_id'), + db.Index("user_id_idx", "user_id"), + db.Index("conversation_id_idx", "conversation_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # conversation user id user_id = db.Column(StringUUID, nullable=False) # tenant id @@ -270,34 +282,48 @@ class ToolConversationVariables(db.Model): # variables pool variables_str = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def variables(self) -> dict: return json.loads(self.variables_str) - + + class ToolFile(db.Model): - """ - store the file created by agent - """ __tablename__ = "tool_files" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='tool_file_pkey'), - # add index for conversation_id - db.Index('tool_file_conversation_id_idx', 'conversation_id'), + db.PrimaryKeyConstraint("id", name="tool_file_pkey"), + db.Index("tool_file_conversation_id_idx", "conversation_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) - # conversation user id - user_id = db.Column(StringUUID, nullable=False) - # tenant id - tenant_id = db.Column(StringUUID, nullable=False) - # conversation id - conversation_id = db.Column(StringUUID, nullable=True) - # file key - file_key = db.Column(db.String(255), nullable=False) - # mime type - mimetype = db.Column(db.String(255), nullable=False) - # original url - original_url = db.Column(db.String(2048), nullable=True) \ No newline at end of file + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + user_id: Mapped[str] = db.Column(StringUUID, nullable=False) + tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) + conversation_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True) + file_key: Mapped[str] = db.Column(db.String(255), nullable=False) + mimetype: Mapped[str] = db.Column(db.String(255), nullable=False) + original_url: Mapped[Optional[str]] = db.Column(db.String(2048), nullable=True) + name: Mapped[str] = mapped_column(default="") + size: Mapped[int] = mapped_column(default=-1) + + def __init__( + self, + *, + user_id: str, + tenant_id: str, + conversation_id: Optional[str] = None, + file_key: str, + mimetype: str, + original_url: Optional[str] = None, + name: str, + size: int, + ): + self.user_id = user_id + self.tenant_id = tenant_id + self.conversation_id = conversation_id + self.file_key = file_key + self.mimetype = mimetype + self.original_url = original_url + self.name = name + self.size = size diff --git a/api/models/types.py b/api/models/types.py index 1614ec20188541..cb6773e70cdd5f 100644 --- a/api/models/types.py +++ b/api/models/types.py @@ -9,13 +9,13 @@ class StringUUID(TypeDecorator): def process_bind_param(self, value, dialect): if value is None: return value - elif dialect.name == 'postgresql': + elif dialect.name == "postgresql": return str(value) else: return value.hex def load_dialect_impl(self, dialect): - if dialect.name == 'postgresql': + if dialect.name == "postgresql": return dialect.type_descriptor(UUID()) else: return dialect.type_descriptor(CHAR(36)) @@ -23,4 +23,4 @@ def load_dialect_impl(self, dialect): def process_result_value(self, value, dialect): if value is None: return value - return str(value) \ No newline at end of file + return str(value) diff --git a/api/models/web.py b/api/models/web.py index 0e901d5f842691..bc088c185d5a8b 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,4 +1,3 @@ - from extensions.ext_database import db from .model import Message @@ -6,18 +5,18 @@ class SavedMessage(db.Model): - __tablename__ = 'saved_messages' + __tablename__ = "saved_messages" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='saved_message_pkey'), - db.Index('saved_message_message_idx', 'app_id', 'message_id', 'created_by_role', 'created_by'), + db.PrimaryKeyConstraint("id", name="saved_message_pkey"), + db.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) message_id = db.Column(StringUUID, nullable=False) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def message(self): @@ -25,15 +24,15 @@ def message(self): class PinnedConversation(db.Model): - __tablename__ = 'pinned_conversations' + __tablename__ = "pinned_conversations" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='pinned_conversation_pkey'), - db.Index('pinned_conversation_conversation_idx', 'app_id', 'conversation_id', 'created_by_role', 'created_by'), + db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"), + db.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) conversation_id = db.Column(StringUUID, nullable=False) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/workflow.py b/api/models/workflow.py index 759e07c7154e0d..4f0e9a5e03705f 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,52 +1,36 @@ import json from collections.abc import Mapping, Sequence +from datetime import datetime, timezone from enum import Enum from typing import Any, Optional, Union +import sqlalchemy as sa from sqlalchemy import func -from sqlalchemy.orm import Mapped +from sqlalchemy.orm import Mapped, mapped_column import contexts from constants import HIDDEN_VALUE -from core.app.segments import SecretVariable, Variable, factory from core.helper import encrypter +from core.variables import SecretVariable, Variable from extensions.ext_database import db +from factories import variable_factory from libs import helper +from models.enums import CreatedByRole from .account import Account from .types import StringUUID -class CreatedByRole(Enum): - """ - Created By Role Enum - """ - ACCOUNT = 'account' - END_USER = 'end_user' - - @classmethod - def value_of(cls, value: str) -> 'CreatedByRole': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f'invalid created by role value {value}') - - class WorkflowType(Enum): """ Workflow Type Enum """ - WORKFLOW = 'workflow' - CHAT = 'chat' + + WORKFLOW = "workflow" + CHAT = "chat" @classmethod - def value_of(cls, value: str) -> 'WorkflowType': + def value_of(cls, value: str) -> "WorkflowType": """ Get value of given mode. @@ -56,10 +40,10 @@ def value_of(cls, value: str) -> 'WorkflowType': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid workflow type value {value}') + raise ValueError(f"invalid workflow type value {value}") @classmethod - def from_app_mode(cls, app_mode: Union[str, 'AppMode']) -> 'WorkflowType': + def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType": """ Get workflow type from app mode. @@ -67,6 +51,7 @@ def from_app_mode(cls, app_mode: Union[str, 'AppMode']) -> 'WorkflowType': :return: workflow type """ from models.model import AppMode + app_mode = app_mode if isinstance(app_mode, AppMode) else AppMode.value_of(app_mode) return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT @@ -104,25 +89,56 @@ class Workflow(db.Model): - updated_at (timestamp) `optional` Last update time """ - __tablename__ = 'workflows' + __tablename__ = "workflows" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='workflow_pkey'), - db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'version'), + db.PrimaryKeyConstraint("id", name="workflow_pkey"), + db.Index("workflow_version_idx", "tenant_id", "app_id", "version"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) - tenant_id = db.Column(StringUUID, nullable=False) - app_id = db.Column(StringUUID, nullable=False) - type = db.Column(db.String(255), nullable=False) - version = db.Column(db.String(255), nullable=False) - graph = db.Column(db.Text) - features = db.Column(db.Text) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_by = db.Column(StringUUID) - updated_at = db.Column(db.DateTime) - _environment_variables = db.Column('environment_variables', db.Text, nullable=False, server_default='{}') - _conversation_variables = db.Column('conversation_variables', db.Text, nullable=False, server_default='{}') + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + type: Mapped[str] = mapped_column(db.String(255), nullable=False) + version: Mapped[str] = mapped_column(db.String(255), nullable=False) + graph: Mapped[str] = mapped_column(sa.Text) + _features: Mapped[str] = mapped_column("features", sa.TEXT) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column( + db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") + ) + updated_by: Mapped[Optional[str]] = mapped_column(StringUUID) + updated_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, default=datetime.now(tz=timezone.utc), server_onupdate=func.current_timestamp() + ) + _environment_variables: Mapped[str] = mapped_column( + "environment_variables", db.Text, nullable=False, server_default="{}" + ) + _conversation_variables: Mapped[str] = mapped_column( + "conversation_variables", db.Text, nullable=False, server_default="{}" + ) + + def __init__( + self, + *, + tenant_id: str, + app_id: str, + type: str, + version: str, + graph: str, + features: str, + created_by: str, + environment_variables: Sequence[Variable], + conversation_variables: Sequence[Variable], + ): + self.tenant_id = tenant_id + self.app_id = app_id + self.type = type + self.version = version + self.graph = graph + self.features = features + self.created_by = created_by + self.environment_variables = environment_variables or [] + self.conversation_variables = conversation_variables or [] @property def created_by_account(self): @@ -136,6 +152,34 @@ def updated_by_account(self): def graph_dict(self) -> Mapping[str, Any]: return json.loads(self.graph) if self.graph else {} + @property + def features(self) -> str: + """ + Convert old features structure to new features structure. + """ + if not self._features: + return self._features + + features = json.loads(self._features) + if features.get("file_upload", {}).get("image", {}).get("enabled", False): + image_enabled = True + image_number_limits = int(features["file_upload"]["image"].get("number_limits", 1)) + image_transfer_methods = features["file_upload"]["image"].get( + "transfer_methods", ["remote_url", "local_file"] + ) + features["file_upload"]["enabled"] = image_enabled + features["file_upload"]["number_limits"] = image_number_limits + features["file_upload"]["allowed_upload_methods"] = image_transfer_methods + features["file_upload"]["allowed_file_types"] = ["image"] + features["file_upload"]["allowed_extensions"] = [] + del features["file_upload"]["image"] + self._features = json.dumps(features) + return self._features + + @features.setter + def features(self, value: str) -> None: + self._features = value + @property def features_dict(self) -> Mapping[str, Any]: return json.loads(self.features) if self.features else {} @@ -146,22 +190,20 @@ def user_input_form(self, to_old_structure: bool = False) -> list: return [] graph_dict = self.graph_dict - if 'nodes' not in graph_dict: + if "nodes" not in graph_dict: return [] - start_node = next((node for node in graph_dict['nodes'] if node['data']['type'] == 'start'), None) + start_node = next((node for node in graph_dict["nodes"] if node["data"]["type"] == "start"), None) if not start_node: return [] # get user_input_form from start node - variables = start_node.get('data', {}).get('variables', []) + variables = start_node.get("data", {}).get("variables", []) if to_old_structure: old_structure_variables = [] for variable in variables: - old_structure_variables.append({ - variable['type']: variable - }) + old_structure_variables.append({variable["type"]: variable}) return old_structure_variables @@ -174,36 +216,33 @@ def unique_hash(self) -> str: :return: hash """ - entity = { - 'graph': self.graph_dict, - 'features': self.features_dict - } + entity = {"graph": self.graph_dict, "features": self.features_dict} return helper.generate_text_hash(json.dumps(entity, sort_keys=True)) @property def tool_published(self) -> bool: from models.tools import WorkflowToolProvider - return db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.app_id == self.app_id - ).first() is not None + + return ( + db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.app_id == self.app_id).first() + is not None + ) @property def environment_variables(self) -> Sequence[Variable]: # TODO: find some way to init `self._environment_variables` when instance created. if self._environment_variables is None: - self._environment_variables = '{}' + self._environment_variables = "{}" tenant_id = contexts.tenant_id.get() environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables) - results = [factory.build_variable_from_mapping(v) for v in environment_variables_dict.values()] + results = [variable_factory.build_variable_from_mapping(v) for v in environment_variables_dict.values()] # decrypt secret variables value decrypt_func = ( - lambda var: var.model_copy( - update={'value': encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)} - ) + lambda var: var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) if isinstance(var, SecretVariable) else var ) @@ -212,23 +251,26 @@ def environment_variables(self) -> Sequence[Variable]: @environment_variables.setter def environment_variables(self, value: Sequence[Variable]): + if not value: + self._environment_variables = "{}" + return + tenant_id = contexts.tenant_id.get() value = list(value) if any(var for var in value if not var.id): - raise ValueError('environment variable require a unique id') + raise ValueError("environment variable require a unique id") - # Compare inputs and origin variables, if the value is HIDDEN_VALUE, use the origin variable value (only update `name`). + # Compare inputs and origin variables, + # if the value is HIDDEN_VALUE, use the origin variable value (only update `name`). origin_variables_dictionary = {var.id: var for var in self.environment_variables} for i, variable in enumerate(value): if variable.id in origin_variables_dictionary and variable.value == HIDDEN_VALUE: - value[i] = origin_variables_dictionary[variable.id].model_copy(update={'name': variable.name}) + value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name}) # encrypt secret variables value encrypt_func = ( - lambda var: var.model_copy( - update={'value': encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)} - ) + lambda var: var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}) if isinstance(var, SecretVariable) else var ) @@ -242,15 +284,15 @@ def environment_variables(self, value: Sequence[Variable]): def to_dict(self, *, include_secret: bool = False) -> Mapping[str, Any]: environment_variables = list(self.environment_variables) environment_variables = [ - v if not isinstance(v, SecretVariable) or include_secret else v.model_copy(update={'value': ''}) + v if not isinstance(v, SecretVariable) or include_secret else v.model_copy(update={"value": ""}) for v in environment_variables ] result = { - 'graph': self.graph_dict, - 'features': self.features_dict, - 'environment_variables': [var.model_dump(mode='json') for var in environment_variables], - 'conversation_variables': [var.model_dump(mode='json') for var in self.conversation_variables], + "graph": self.graph_dict, + "features": self.features_dict, + "environment_variables": [var.model_dump(mode="json") for var in environment_variables], + "conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables], } return result @@ -258,10 +300,10 @@ def to_dict(self, *, include_secret: bool = False) -> Mapping[str, Any]: def conversation_variables(self) -> Sequence[Variable]: # TODO: find some way to init `self._conversation_variables` when instance created. if self._conversation_variables is None: - self._conversation_variables = '{}' + self._conversation_variables = "{}" variables_dict: dict[str, Any] = json.loads(self._conversation_variables) - results = [factory.build_variable_from_mapping(v) for v in variables_dict.values()] + results = [variable_factory.build_variable_from_mapping(v) for v in variables_dict.values()] return results @conversation_variables.setter @@ -272,38 +314,18 @@ def conversation_variables(self, value: Sequence[Variable]) -> None: ) -class WorkflowRunTriggeredFrom(Enum): - """ - Workflow Run Triggered From Enum - """ - DEBUGGING = 'debugging' - APP_RUN = 'app-run' - - @classmethod - def value_of(cls, value: str) -> 'WorkflowRunTriggeredFrom': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f'invalid workflow run triggered from value {value}') - - class WorkflowRunStatus(Enum): """ Workflow Run Status Enum """ - RUNNING = 'running' - SUCCEEDED = 'succeeded' - FAILED = 'failed' - STOPPED = 'stopped' + + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" + STOPPED = "stopped" @classmethod - def value_of(cls, value: str) -> 'WorkflowRunStatus': + def value_of(cls, value: str) -> "WorkflowRunStatus": """ Get value of given mode. @@ -313,7 +335,7 @@ def value_of(cls, value: str) -> 'WorkflowRunStatus': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid workflow run status value {value}') + raise ValueError(f"invalid workflow run status value {value}") class WorkflowRun(db.Model): @@ -354,14 +376,14 @@ class WorkflowRun(db.Model): - finished_at (timestamp) End time """ - __tablename__ = 'workflow_runs' + __tablename__ = "workflow_runs" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='workflow_run_pkey'), - db.Index('workflow_run_triggerd_from_idx', 'tenant_id', 'app_id', 'triggered_from'), - db.Index('workflow_run_tenant_app_sequence_idx', 'tenant_id', 'app_id', 'sequence_number'), + db.PrimaryKeyConstraint("id", name="workflow_run_pkey"), + db.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"), + db.Index("workflow_run_tenant_app_sequence_idx", "tenant_id", "app_id", "sequence_number"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False) sequence_number = db.Column(db.Integer, nullable=False) @@ -372,48 +394,47 @@ class WorkflowRun(db.Model): graph = db.Column(db.Text) inputs = db.Column(db.Text) status = db.Column(db.String(255), nullable=False) - outputs = db.Column(db.Text) + outputs: Mapped[str] = db.Column(db.Text) error = db.Column(db.Text) - elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0')) - total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) - total_steps = db.Column(db.Integer, server_default=db.text('0')) + elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) + total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) + total_steps = db.Column(db.Integer, server_default=db.text("0")) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) finished_at = db.Column(db.DateTime) @property def created_by_account(self): - created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(Account, self.created_by) \ - if created_by_role == CreatedByRole.ACCOUNT else None + created_by_role = CreatedByRole(self.created_by_role) + return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser - created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(EndUser, self.created_by) \ - if created_by_role == CreatedByRole.END_USER else None + + created_by_role = CreatedByRole(self.created_by_role) + return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None @property def graph_dict(self): - return json.loads(self.graph) if self.graph else None + return json.loads(self.graph) if self.graph else {} @property - def inputs_dict(self): - return json.loads(self.inputs) if self.inputs else None + def inputs_dict(self) -> Mapping[str, Any]: + return json.loads(self.inputs) if self.inputs else {} @property - def outputs_dict(self): - return json.loads(self.outputs) if self.outputs else None + def outputs_dict(self) -> Mapping[str, Any]: + return json.loads(self.outputs) if self.outputs else {} @property - def message(self) -> Optional['Message']: + def message(self) -> Optional["Message"]: from models.model import Message - return db.session.query(Message).filter( - Message.app_id == self.app_id, - Message.workflow_run_id == self.id - ).first() + + return ( + db.session.query(Message).filter(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first() + ) @property def workflow(self): @@ -421,51 +442,51 @@ def workflow(self): def to_dict(self): return { - 'id': self.id, - 'tenant_id': self.tenant_id, - 'app_id': self.app_id, - 'sequence_number': self.sequence_number, - 'workflow_id': self.workflow_id, - 'type': self.type, - 'triggered_from': self.triggered_from, - 'version': self.version, - 'graph': self.graph_dict, - 'inputs': self.inputs_dict, - 'status': self.status, - 'outputs': self.outputs_dict, - 'error': self.error, - 'elapsed_time': self.elapsed_time, - 'total_tokens': self.total_tokens, - 'total_steps': self.total_steps, - 'created_by_role': self.created_by_role, - 'created_by': self.created_by, - 'created_at': self.created_at, - 'finished_at': self.finished_at, + "id": self.id, + "tenant_id": self.tenant_id, + "app_id": self.app_id, + "sequence_number": self.sequence_number, + "workflow_id": self.workflow_id, + "type": self.type, + "triggered_from": self.triggered_from, + "version": self.version, + "graph": self.graph_dict, + "inputs": self.inputs_dict, + "status": self.status, + "outputs": self.outputs_dict, + "error": self.error, + "elapsed_time": self.elapsed_time, + "total_tokens": self.total_tokens, + "total_steps": self.total_steps, + "created_by_role": self.created_by_role, + "created_by": self.created_by, + "created_at": self.created_at, + "finished_at": self.finished_at, } @classmethod - def from_dict(cls, data: dict) -> 'WorkflowRun': + def from_dict(cls, data: dict) -> "WorkflowRun": return cls( - id=data.get('id'), - tenant_id=data.get('tenant_id'), - app_id=data.get('app_id'), - sequence_number=data.get('sequence_number'), - workflow_id=data.get('workflow_id'), - type=data.get('type'), - triggered_from=data.get('triggered_from'), - version=data.get('version'), - graph=json.dumps(data.get('graph')), - inputs=json.dumps(data.get('inputs')), - status=data.get('status'), - outputs=json.dumps(data.get('outputs')), - error=data.get('error'), - elapsed_time=data.get('elapsed_time'), - total_tokens=data.get('total_tokens'), - total_steps=data.get('total_steps'), - created_by_role=data.get('created_by_role'), - created_by=data.get('created_by'), - created_at=data.get('created_at'), - finished_at=data.get('finished_at'), + id=data.get("id"), + tenant_id=data.get("tenant_id"), + app_id=data.get("app_id"), + sequence_number=data.get("sequence_number"), + workflow_id=data.get("workflow_id"), + type=data.get("type"), + triggered_from=data.get("triggered_from"), + version=data.get("version"), + graph=json.dumps(data.get("graph")), + inputs=json.dumps(data.get("inputs")), + status=data.get("status"), + outputs=json.dumps(data.get("outputs")), + error=data.get("error"), + elapsed_time=data.get("elapsed_time"), + total_tokens=data.get("total_tokens"), + total_steps=data.get("total_steps"), + created_by_role=data.get("created_by_role"), + created_by=data.get("created_by"), + created_at=data.get("created_at"), + finished_at=data.get("finished_at"), ) @@ -473,11 +494,12 @@ class WorkflowNodeExecutionTriggeredFrom(Enum): """ Workflow Node Execution Triggered From Enum """ - SINGLE_STEP = 'single-step' - WORKFLOW_RUN = 'workflow-run' + + SINGLE_STEP = "single-step" + WORKFLOW_RUN = "workflow-run" @classmethod - def value_of(cls, value: str) -> 'WorkflowNodeExecutionTriggeredFrom': + def value_of(cls, value: str) -> "WorkflowNodeExecutionTriggeredFrom": """ Get value of given mode. @@ -487,19 +509,20 @@ def value_of(cls, value: str) -> 'WorkflowNodeExecutionTriggeredFrom': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid workflow node execution triggered from value {value}') + raise ValueError(f"invalid workflow node execution triggered from value {value}") class WorkflowNodeExecutionStatus(Enum): """ Workflow Node Execution Status Enum """ - RUNNING = 'running' - SUCCEEDED = 'succeeded' - FAILED = 'failed' + + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" @classmethod - def value_of(cls, value: str) -> 'WorkflowNodeExecutionStatus': + def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus": """ Get value of given mode. @@ -509,7 +532,7 @@ def value_of(cls, value: str) -> 'WorkflowNodeExecutionStatus': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid workflow node execution status value {value}') + raise ValueError(f"invalid workflow node execution status value {value}") class WorkflowNodeExecution(db.Model): @@ -560,16 +583,31 @@ class WorkflowNodeExecution(db.Model): - finished_at (timestamp) End time """ - __tablename__ = 'workflow_node_executions' + __tablename__ = "workflow_node_executions" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='workflow_node_execution_pkey'), - db.Index('workflow_node_execution_workflow_run_idx', 'tenant_id', 'app_id', 'workflow_id', - 'triggered_from', 'workflow_run_id'), - db.Index('workflow_node_execution_node_run_idx', 'tenant_id', 'app_id', 'workflow_id', - 'triggered_from', 'node_id'), + db.PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), + db.Index( + "workflow_node_execution_workflow_run_idx", + "tenant_id", + "app_id", + "workflow_id", + "triggered_from", + "workflow_run_id", + ), + db.Index( + "workflow_node_execution_node_run_idx", "tenant_id", "app_id", "workflow_id", "triggered_from", "node_id" + ), + db.Index( + "workflow_node_execution_id_idx", + "tenant_id", + "app_id", + "workflow_id", + "triggered_from", + "node_execution_id", + ), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False) workflow_id = db.Column(StringUUID, nullable=False) @@ -577,6 +615,7 @@ class WorkflowNodeExecution(db.Model): workflow_run_id = db.Column(StringUUID) index = db.Column(db.Integer, nullable=False) predecessor_node_id = db.Column(db.String(255)) + node_execution_id = db.Column(db.String(255), nullable=True) node_id = db.Column(db.String(255), nullable=False) node_type = db.Column(db.String(255), nullable=False) title = db.Column(db.String(255), nullable=False) @@ -585,25 +624,24 @@ class WorkflowNodeExecution(db.Model): outputs = db.Column(db.Text) status = db.Column(db.String(255), nullable=False) error = db.Column(db.Text) - elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0')) + elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) execution_metadata = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) finished_at = db.Column(db.DateTime) @property def created_by_account(self): - created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(Account, self.created_by) \ - if created_by_role == CreatedByRole.ACCOUNT else None + created_by_role = CreatedByRole(self.created_by_role) + return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser - created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(EndUser, self.created_by) \ - if created_by_role == CreatedByRole.END_USER else None + + created_by_role = CreatedByRole(self.created_by_role) + return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None @property def inputs_dict(self): @@ -624,15 +662,17 @@ def execution_metadata_dict(self): @property def extras(self): from core.tools.tool_manager import ToolManager + extras = {} if self.execution_metadata_dict: - from core.workflow.entities.node_entities import NodeType - if self.node_type == NodeType.TOOL.value and 'tool_info' in self.execution_metadata_dict: - tool_info = self.execution_metadata_dict['tool_info'] - extras['icon'] = ToolManager.get_tool_icon( + from core.workflow.nodes import NodeType + + if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict: + tool_info = self.execution_metadata_dict["tool_info"] + extras["icon"] = ToolManager.get_tool_icon( tenant_id=self.tenant_id, - provider_type=tool_info['provider_type'], - provider_id=tool_info['provider_id'] + provider_type=tool_info["provider_type"], + provider_id=tool_info["provider_id"], ) return extras @@ -642,12 +682,13 @@ class WorkflowAppLogCreatedFrom(Enum): """ Workflow App Log Created From Enum """ - SERVICE_API = 'service-api' - WEB_APP = 'web-app' - INSTALLED_APP = 'installed-app' + + SERVICE_API = "service-api" + WEB_APP = "web-app" + INSTALLED_APP = "installed-app" @classmethod - def value_of(cls, value: str) -> 'WorkflowAppLogCreatedFrom': + def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom": """ Get value of given mode. @@ -657,7 +698,7 @@ def value_of(cls, value: str) -> 'WorkflowAppLogCreatedFrom': for mode in cls: if mode.value == value: return mode - raise ValueError(f'invalid workflow app log created from value {value}') + raise ValueError(f"invalid workflow app log created from value {value}") class WorkflowAppLog(db.Model): @@ -689,13 +730,13 @@ class WorkflowAppLog(db.Model): - created_at (timestamp) Creation time """ - __tablename__ = 'workflow_app_logs' + __tablename__ = "workflow_app_logs" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='workflow_app_log_pkey'), - db.Index('workflow_app_log_app_idx', 'tenant_id', 'app_id'), + db.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"), + db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), ) - id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False) workflow_id = db.Column(StringUUID, nullable=False) @@ -703,7 +744,7 @@ class WorkflowAppLog(db.Model): created_from = db.Column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @property def workflow_run(self): @@ -711,27 +752,28 @@ def workflow_run(self): @property def created_by_account(self): - created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(Account, self.created_by) \ - if created_by_role == CreatedByRole.ACCOUNT else None + created_by_role = CreatedByRole(self.created_by_role) + return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser - created_by_role = CreatedByRole.value_of(self.created_by_role) - return db.session.get(EndUser, self.created_by) \ - if created_by_role == CreatedByRole.END_USER else None + + created_by_role = CreatedByRole(self.created_by_role) + return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None class ConversationVariable(db.Model): - __tablename__ = 'workflow__conversation_variables' + __tablename__ = "workflow_conversation_variables" id: Mapped[str] = db.Column(StringUUID, primary_key=True) conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True) app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True) data = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text('CURRENT_TIMESTAMP(0)')) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()) + created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column( + db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None: self.id = id @@ -740,7 +782,7 @@ def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> self.data = data @classmethod - def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> 'ConversationVariable': + def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> "ConversationVariable": obj = cls( id=variable.id, app_id=app_id, @@ -751,4 +793,4 @@ def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) def to_variable(self) -> Variable: mapping = json.loads(self.data) - return factory.build_variable_from_mapping(mapping) + return variable_factory.build_variable_from_mapping(mapping) diff --git a/api/poetry.lock b/api/poetry.lock index 358f9f8510c724..259ede68980a40 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -2,98 +2,113 @@ [[package]] name = "aiohappyeyeballs" -version = "2.3.4" +version = "2.4.3" description = "Happy Eyeballs for asyncio" optional = false -python-versions = "<4.0,>=3.8" +python-versions = ">=3.8" files = [ - {file = "aiohappyeyeballs-2.3.4-py3-none-any.whl", hash = "sha256:40a16ceffcf1fc9e142fd488123b2e218abc4188cf12ac20c67200e1579baa42"}, - {file = "aiohappyeyeballs-2.3.4.tar.gz", hash = "sha256:7e1ae8399c320a8adec76f6c919ed5ceae6edd4c3672f4d9eae2b27e37c80ff6"}, + {file = "aiohappyeyeballs-2.4.3-py3-none-any.whl", hash = "sha256:8a7a83727b2756f394ab2895ea0765a0a8c475e3c71e98d43d76f22b4b435572"}, + {file = "aiohappyeyeballs-2.4.3.tar.gz", hash = "sha256:75cf88a15106a5002a8eb1dab212525c00d1f4c0fa96e551c9fbe6f09a621586"}, ] [[package]] name = "aiohttp" -version = "3.10.1" +version = "3.10.5" description = "Async http client/server framework (asyncio)" optional = false python-versions = ">=3.8" files = [ - {file = "aiohttp-3.10.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:47b4c2412960e64d97258f40616efddaebcb34ff664c8a972119ed38fac2a62c"}, - {file = "aiohttp-3.10.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e7dbf637f87dd315fa1f36aaed8afa929ee2c607454fb7791e74c88a0d94da59"}, - {file = "aiohttp-3.10.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c8fb76214b5b739ce59e2236a6489d9dc3483649cfd6f563dbf5d8e40dbdd57d"}, - {file = "aiohttp-3.10.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c577cdcf8f92862363b3d598d971c6a84ed8f0bf824d4cc1ce70c2fb02acb4a"}, - {file = "aiohttp-3.10.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:777e23609899cb230ad2642b4bdf1008890f84968be78de29099a8a86f10b261"}, - {file = "aiohttp-3.10.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b07286a1090483799599a2f72f76ac396993da31f6e08efedb59f40876c144fa"}, - {file = "aiohttp-3.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b9db600a86414a9a653e3c1c7f6a2f6a1894ab8f83d11505247bd1b90ad57157"}, - {file = "aiohttp-3.10.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:01c3f1eb280008e51965a8d160a108c333136f4a39d46f516c64d2aa2e6a53f2"}, - {file = "aiohttp-3.10.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f5dd109a925fee4c9ac3f6a094900461a2712df41745f5d04782ebcbe6479ccb"}, - {file = "aiohttp-3.10.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:8c81ff4afffef9b1186639506d70ea90888218f5ddfff03870e74ec80bb59970"}, - {file = "aiohttp-3.10.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:2a384dfbe8bfebd203b778a30a712886d147c61943675f4719b56725a8bbe803"}, - {file = "aiohttp-3.10.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:b9fb6508893dc31cfcbb8191ef35abd79751db1d6871b3e2caee83959b4d91eb"}, - {file = "aiohttp-3.10.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:88596384c3bec644a96ae46287bb646d6a23fa6014afe3799156aef42669c6bd"}, - {file = "aiohttp-3.10.1-cp310-cp310-win32.whl", hash = "sha256:68164d43c580c2e8bf8e0eb4960142919d304052ccab92be10250a3a33b53268"}, - {file = "aiohttp-3.10.1-cp310-cp310-win_amd64.whl", hash = "sha256:d6bbe2c90c10382ca96df33b56e2060404a4f0f88673e1e84b44c8952517e5f3"}, - {file = "aiohttp-3.10.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f6979b4f20d3e557a867da9d9227de4c156fcdcb348a5848e3e6190fd7feb972"}, - {file = "aiohttp-3.10.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:03c0c380c83f8a8d4416224aafb88d378376d6f4cadebb56b060688251055cd4"}, - {file = "aiohttp-3.10.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1c2b104e81b3c3deba7e6f5bc1a9a0e9161c380530479970766a6655b8b77c7c"}, - {file = "aiohttp-3.10.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b023b68c61ab0cd48bd38416b421464a62c381e32b9dc7b4bdfa2905807452a4"}, - {file = "aiohttp-3.10.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1a07c76a82390506ca0eabf57c0540cf5a60c993c442928fe4928472c4c6e5e6"}, - {file = "aiohttp-3.10.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:41d8dab8c64ded1edf117d2a64f353efa096c52b853ef461aebd49abae979f16"}, - {file = "aiohttp-3.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:615348fab1a9ef7d0960a905e83ad39051ae9cb0d2837da739b5d3a7671e497a"}, - {file = "aiohttp-3.10.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:256ee6044214ee9d66d531bb374f065ee94e60667d6bbeaa25ca111fc3997158"}, - {file = "aiohttp-3.10.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b7d5bb926805022508b7ddeaad957f1fce7a8d77532068d7bdb431056dc630cd"}, - {file = "aiohttp-3.10.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:028faf71b338f069077af6315ad54281612705d68889f5d914318cbc2aab0d50"}, - {file = "aiohttp-3.10.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:5c12310d153b27aa630750be44e79313acc4e864c421eb7d2bc6fa3429c41bf8"}, - {file = "aiohttp-3.10.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:de1a91d5faded9054957ed0a9e01b9d632109341942fc123947ced358c5d9009"}, - {file = "aiohttp-3.10.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9c186b270979fb1dee3ababe2d12fb243ed7da08b30abc83ebac3a928a4ddb15"}, - {file = "aiohttp-3.10.1-cp311-cp311-win32.whl", hash = "sha256:4a9ce70f5e00380377aac0e568abd075266ff992be2e271765f7b35d228a990c"}, - {file = "aiohttp-3.10.1-cp311-cp311-win_amd64.whl", hash = "sha256:a77c79bac8d908d839d32c212aef2354d2246eb9deb3e2cb01ffa83fb7a6ea5d"}, - {file = "aiohttp-3.10.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:2212296cdb63b092e295c3e4b4b442e7b7eb41e8a30d0f53c16d5962efed395d"}, - {file = "aiohttp-3.10.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:4dcb127ca3eb0a61205818a606393cbb60d93b7afb9accd2fd1e9081cc533144"}, - {file = "aiohttp-3.10.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:cb8b79a65332e1a426ccb6290ce0409e1dc16b4daac1cc5761e059127fa3d134"}, - {file = "aiohttp-3.10.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68cc24f707ed9cb961f6ee04020ca01de2c89b2811f3cf3361dc7c96a14bfbcc"}, - {file = "aiohttp-3.10.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9cb54f5725b4b37af12edf6c9e834df59258c82c15a244daa521a065fbb11717"}, - {file = "aiohttp-3.10.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:51d03e948e53b3639ce4d438f3d1d8202898ec6655cadcc09ec99229d4adc2a9"}, - {file = "aiohttp-3.10.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:786299d719eb5d868f161aeec56d589396b053925b7e0ce36e983d30d0a3e55c"}, - {file = "aiohttp-3.10.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:abda4009a30d51d3f06f36bc7411a62b3e647fa6cc935ef667e3e3d3a7dd09b1"}, - {file = "aiohttp-3.10.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:67f7639424c313125213954e93a6229d3a1d386855d70c292a12628f600c7150"}, - {file = "aiohttp-3.10.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:8e5a26d7aac4c0d8414a347da162696eea0629fdce939ada6aedf951abb1d745"}, - {file = "aiohttp-3.10.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:120548d89f14b76a041088b582454d89389370632ee12bf39d919cc5c561d1ca"}, - {file = "aiohttp-3.10.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:f5293726943bdcea24715b121d8c4ae12581441d22623b0e6ab12d07ce85f9c4"}, - {file = "aiohttp-3.10.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1f8605e573ed6c44ec689d94544b2c4bb1390aaa723a8b5a2cc0a5a485987a68"}, - {file = "aiohttp-3.10.1-cp312-cp312-win32.whl", hash = "sha256:e7168782621be4448d90169a60c8b37e9b0926b3b79b6097bc180c0a8a119e73"}, - {file = "aiohttp-3.10.1-cp312-cp312-win_amd64.whl", hash = "sha256:8fbf8c0ded367c5c8eaf585f85ca8dd85ff4d5b73fb8fe1e6ac9e1b5e62e11f7"}, - {file = "aiohttp-3.10.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:54b7f4a20d7cc6bfa4438abbde069d417bb7a119f870975f78a2b99890226d55"}, - {file = "aiohttp-3.10.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2fa643ca990323db68911b92f3f7a0ca9ae300ae340d0235de87c523601e58d9"}, - {file = "aiohttp-3.10.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d8311d0d690487359fe2247ec5d2cac9946e70d50dced8c01ce9e72341c21151"}, - {file = "aiohttp-3.10.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:222821c60b8f6a64c5908cb43d69c0ee978a1188f6a8433d4757d39231b42cdb"}, - {file = "aiohttp-3.10.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e7b55d9ede66af7feb6de87ff277e0ccf6d51c7db74cc39337fe3a0e31b5872d"}, - {file = "aiohttp-3.10.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5a95151a5567b3b00368e99e9c5334a919514f60888a6b6d2054fea5e66e527e"}, - {file = "aiohttp-3.10.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e9e9171d2fe6bfd9d3838a6fe63b1e91b55e0bf726c16edf265536e4eafed19"}, - {file = "aiohttp-3.10.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a57e73f9523e980f6101dc9a83adcd7ac0006ea8bf7937ca3870391c7bb4f8ff"}, - {file = "aiohttp-3.10.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:0df51a3d70a2bfbb9c921619f68d6d02591f24f10e9c76de6f3388c89ed01de6"}, - {file = "aiohttp-3.10.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:b0de63ff0307eac3961b4af74382d30220d4813f36b7aaaf57f063a1243b4214"}, - {file = "aiohttp-3.10.1-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:8db9b749f589b5af8e4993623dbda6716b2b7a5fcb0fa2277bf3ce4b278c7059"}, - {file = "aiohttp-3.10.1-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:6b14c19172eb53b63931d3e62a9749d6519f7c121149493e6eefca055fcdb352"}, - {file = "aiohttp-3.10.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:5cd57ad998e3038aa87c38fe85c99ed728001bf5dde8eca121cadee06ee3f637"}, - {file = "aiohttp-3.10.1-cp38-cp38-win32.whl", hash = "sha256:df31641e3f02b77eb3c5fb63c0508bee0fc067cf153da0e002ebbb0db0b6d91a"}, - {file = "aiohttp-3.10.1-cp38-cp38-win_amd64.whl", hash = "sha256:93094eba50bc2ad4c40ff4997ead1fdcd41536116f2e7d6cfec9596a8ecb3615"}, - {file = "aiohttp-3.10.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:440954ddc6b77257e67170d57b1026aa9545275c33312357472504eef7b4cc0b"}, - {file = "aiohttp-3.10.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f9f8beed277488a52ee2b459b23c4135e54d6a819eaba2e120e57311015b58e9"}, - {file = "aiohttp-3.10.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d8a8221a63602008550022aa3a4152ca357e1dde7ab3dd1da7e1925050b56863"}, - {file = "aiohttp-3.10.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a702bd3663b5cbf3916e84bf332400d24cdb18399f0877ca6b313ce6c08bfb43"}, - {file = "aiohttp-3.10.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1988b370536eb14f0ce7f3a4a5b422ab64c4e255b3f5d7752c5f583dc8c967fc"}, - {file = "aiohttp-3.10.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7ccf1f0a304352c891d124ac1a9dea59b14b2abed1704aaa7689fc90ef9c5be1"}, - {file = "aiohttp-3.10.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc3ea6ef2a83edad84bbdb5d96e22f587b67c68922cd7b6f9d8f24865e655bcf"}, - {file = "aiohttp-3.10.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:89b47c125ab07f0831803b88aeb12b04c564d5f07a1c1a225d4eb4d2f26e8b5e"}, - {file = "aiohttp-3.10.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:21778552ef3d44aac3278cc6f6d13a6423504fa5f09f2df34bfe489ed9ded7f5"}, - {file = "aiohttp-3.10.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:bde0693073fd5e542e46ea100aa6c1a5d36282dbdbad85b1c3365d5421490a92"}, - {file = "aiohttp-3.10.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:bf66149bb348d8e713f3a8e0b4f5b952094c2948c408e1cfef03b49e86745d60"}, - {file = "aiohttp-3.10.1-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:587237571a85716d6f71f60d103416c9df7d5acb55d96d3d3ced65f39bff9c0c"}, - {file = "aiohttp-3.10.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:bfe33cba6e127d0b5b417623c9aa621f0a69f304742acdca929a9fdab4593693"}, - {file = "aiohttp-3.10.1-cp39-cp39-win32.whl", hash = "sha256:9fbff00646cf8211b330690eb2fd64b23e1ce5b63a342436c1d1d6951d53d8dd"}, - {file = "aiohttp-3.10.1-cp39-cp39-win_amd64.whl", hash = "sha256:5951c328f9ac42d7bce7a6ded535879bc9ae13032818d036749631fa27777905"}, - {file = "aiohttp-3.10.1.tar.gz", hash = "sha256:8b0d058e4e425d3b45e8ec70d49b402f4d6b21041e674798b1f91ba027c73f28"}, + {file = "aiohttp-3.10.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:18a01eba2574fb9edd5f6e5fb25f66e6ce061da5dab5db75e13fe1558142e0a3"}, + {file = "aiohttp-3.10.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:94fac7c6e77ccb1ca91e9eb4cb0ac0270b9fb9b289738654120ba8cebb1189c6"}, + {file = "aiohttp-3.10.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2f1f1c75c395991ce9c94d3e4aa96e5c59c8356a15b1c9231e783865e2772699"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f7acae3cf1a2a2361ec4c8e787eaaa86a94171d2417aae53c0cca6ca3118ff6"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:94c4381ffba9cc508b37d2e536b418d5ea9cfdc2848b9a7fea6aebad4ec6aac1"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c31ad0c0c507894e3eaa843415841995bf8de4d6b2d24c6e33099f4bc9fc0d4f"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0912b8a8fadeb32ff67a3ed44249448c20148397c1ed905d5dac185b4ca547bb"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d93400c18596b7dc4794d48a63fb361b01a0d8eb39f28800dc900c8fbdaca91"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d00f3c5e0d764a5c9aa5a62d99728c56d455310bcc288a79cab10157b3af426f"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:d742c36ed44f2798c8d3f4bc511f479b9ceef2b93f348671184139e7d708042c"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:814375093edae5f1cb31e3407997cf3eacefb9010f96df10d64829362ae2df69"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8224f98be68a84b19f48e0bdc14224b5a71339aff3a27df69989fa47d01296f3"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:d9a487ef090aea982d748b1b0d74fe7c3950b109df967630a20584f9a99c0683"}, + {file = "aiohttp-3.10.5-cp310-cp310-win32.whl", hash = "sha256:d9ef084e3dc690ad50137cc05831c52b6ca428096e6deb3c43e95827f531d5ef"}, + {file = "aiohttp-3.10.5-cp310-cp310-win_amd64.whl", hash = "sha256:66bf9234e08fe561dccd62083bf67400bdbf1c67ba9efdc3dac03650e97c6088"}, + {file = "aiohttp-3.10.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8c6a4e5e40156d72a40241a25cc226051c0a8d816610097a8e8f517aeacd59a2"}, + {file = "aiohttp-3.10.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2c634a3207a5445be65536d38c13791904fda0748b9eabf908d3fe86a52941cf"}, + {file = "aiohttp-3.10.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4aff049b5e629ef9b3e9e617fa6e2dfeda1bf87e01bcfecaf3949af9e210105e"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1942244f00baaacaa8155eca94dbd9e8cc7017deb69b75ef67c78e89fdad3c77"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e04a1f2a65ad2f93aa20f9ff9f1b672bf912413e5547f60749fa2ef8a644e061"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7f2bfc0032a00405d4af2ba27f3c429e851d04fad1e5ceee4080a1c570476697"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:424ae21498790e12eb759040bbb504e5e280cab64693d14775c54269fd1d2bb7"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:975218eee0e6d24eb336d0328c768ebc5d617609affaca5dbbd6dd1984f16ed0"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:4120d7fefa1e2d8fb6f650b11489710091788de554e2b6f8347c7a20ceb003f5"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b90078989ef3fc45cf9221d3859acd1108af7560c52397ff4ace8ad7052a132e"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:ba5a8b74c2a8af7d862399cdedce1533642fa727def0b8c3e3e02fcb52dca1b1"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:02594361128f780eecc2a29939d9dfc870e17b45178a867bf61a11b2a4367277"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:8fb4fc029e135859f533025bc82047334e24b0d489e75513144f25408ecaf058"}, + {file = "aiohttp-3.10.5-cp311-cp311-win32.whl", hash = "sha256:e1ca1ef5ba129718a8fc827b0867f6aa4e893c56eb00003b7367f8a733a9b072"}, + {file = "aiohttp-3.10.5-cp311-cp311-win_amd64.whl", hash = "sha256:349ef8a73a7c5665cca65c88ab24abe75447e28aa3bc4c93ea5093474dfdf0ff"}, + {file = "aiohttp-3.10.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:305be5ff2081fa1d283a76113b8df7a14c10d75602a38d9f012935df20731487"}, + {file = "aiohttp-3.10.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3a1c32a19ee6bbde02f1cb189e13a71b321256cc1d431196a9f824050b160d5a"}, + {file = "aiohttp-3.10.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:61645818edd40cc6f455b851277a21bf420ce347baa0b86eaa41d51ef58ba23d"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c225286f2b13bab5987425558baa5cbdb2bc925b2998038fa028245ef421e75"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8ba01ebc6175e1e6b7275c907a3a36be48a2d487549b656aa90c8a910d9f3178"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8eaf44ccbc4e35762683078b72bf293f476561d8b68ec8a64f98cf32811c323e"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1c43eb1ab7cbf411b8e387dc169acb31f0ca0d8c09ba63f9eac67829585b44f"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de7a5299827253023c55ea549444e058c0eb496931fa05d693b95140a947cb73"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4790f0e15f00058f7599dab2b206d3049d7ac464dc2e5eae0e93fa18aee9e7bf"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:44b324a6b8376a23e6ba25d368726ee3bc281e6ab306db80b5819999c737d820"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:0d277cfb304118079e7044aad0b76685d30ecb86f83a0711fc5fb257ffe832ca"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:54d9ddea424cd19d3ff6128601a4a4d23d54a421f9b4c0fff740505813739a91"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:4f1c9866ccf48a6df2b06823e6ae80573529f2af3a0992ec4fe75b1a510df8a6"}, + {file = "aiohttp-3.10.5-cp312-cp312-win32.whl", hash = "sha256:dc4826823121783dccc0871e3f405417ac116055bf184ac04c36f98b75aacd12"}, + {file = "aiohttp-3.10.5-cp312-cp312-win_amd64.whl", hash = "sha256:22c0a23a3b3138a6bf76fc553789cb1a703836da86b0f306b6f0dc1617398abc"}, + {file = "aiohttp-3.10.5-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:7f6b639c36734eaa80a6c152a238242bedcee9b953f23bb887e9102976343092"}, + {file = "aiohttp-3.10.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f29930bc2921cef955ba39a3ff87d2c4398a0394ae217f41cb02d5c26c8b1b77"}, + {file = "aiohttp-3.10.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f489a2c9e6455d87eabf907ac0b7d230a9786be43fbe884ad184ddf9e9c1e385"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:123dd5b16b75b2962d0fff566effb7a065e33cd4538c1692fb31c3bda2bfb972"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b98e698dc34966e5976e10bbca6d26d6724e6bdea853c7c10162a3235aba6e16"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c3b9162bab7e42f21243effc822652dc5bb5e8ff42a4eb62fe7782bcbcdfacf6"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1923a5c44061bffd5eebeef58cecf68096e35003907d8201a4d0d6f6e387ccaa"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d55f011da0a843c3d3df2c2cf4e537b8070a419f891c930245f05d329c4b0689"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:afe16a84498441d05e9189a15900640a2d2b5e76cf4efe8cbb088ab4f112ee57"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:f8112fb501b1e0567a1251a2fd0747baae60a4ab325a871e975b7bb67e59221f"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:1e72589da4c90337837fdfe2026ae1952c0f4a6e793adbbfbdd40efed7c63599"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:4d46c7b4173415d8e583045fbc4daa48b40e31b19ce595b8d92cf639396c15d5"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:33e6bc4bab477c772a541f76cd91e11ccb6d2efa2b8d7d7883591dfb523e5987"}, + {file = "aiohttp-3.10.5-cp313-cp313-win32.whl", hash = "sha256:c58c6837a2c2a7cf3133983e64173aec11f9c2cd8e87ec2fdc16ce727bcf1a04"}, + {file = "aiohttp-3.10.5-cp313-cp313-win_amd64.whl", hash = "sha256:38172a70005252b6893088c0f5e8a47d173df7cc2b2bd88650957eb84fcf5022"}, + {file = "aiohttp-3.10.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:f6f18898ace4bcd2d41a122916475344a87f1dfdec626ecde9ee802a711bc569"}, + {file = "aiohttp-3.10.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5ede29d91a40ba22ac1b922ef510aab871652f6c88ef60b9dcdf773c6d32ad7a"}, + {file = "aiohttp-3.10.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:673f988370f5954df96cc31fd99c7312a3af0a97f09e407399f61583f30da9bc"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58718e181c56a3c02d25b09d4115eb02aafe1a732ce5714ab70326d9776457c3"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4b38b1570242fbab8d86a84128fb5b5234a2f70c2e32f3070143a6d94bc854cf"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:074d1bff0163e107e97bd48cad9f928fa5a3eb4b9d33366137ffce08a63e37fe"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd31f176429cecbc1ba499d4aba31aaccfea488f418d60376b911269d3b883c5"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7384d0b87d4635ec38db9263e6a3f1eb609e2e06087f0aa7f63b76833737b471"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:8989f46f3d7ef79585e98fa991e6ded55d2f48ae56d2c9fa5e491a6e4effb589"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:c83f7a107abb89a227d6c454c613e7606c12a42b9a4ca9c5d7dad25d47c776ae"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:cde98f323d6bf161041e7627a5fd763f9fd829bcfcd089804a5fdce7bb6e1b7d"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:676f94c5480d8eefd97c0c7e3953315e4d8c2b71f3b49539beb2aa676c58272f"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:2d21ac12dc943c68135ff858c3a989f2194a709e6e10b4c8977d7fcd67dfd511"}, + {file = "aiohttp-3.10.5-cp38-cp38-win32.whl", hash = "sha256:17e997105bd1a260850272bfb50e2a328e029c941c2708170d9d978d5a30ad9a"}, + {file = "aiohttp-3.10.5-cp38-cp38-win_amd64.whl", hash = "sha256:1c19de68896747a2aa6257ae4cf6ef59d73917a36a35ee9d0a6f48cff0f94db8"}, + {file = "aiohttp-3.10.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7e2fe37ac654032db1f3499fe56e77190282534810e2a8e833141a021faaab0e"}, + {file = "aiohttp-3.10.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f5bf3ead3cb66ab990ee2561373b009db5bc0e857549b6c9ba84b20bc462e172"}, + {file = "aiohttp-3.10.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1b2c16a919d936ca87a3c5f0e43af12a89a3ce7ccbce59a2d6784caba945b68b"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad146dae5977c4dd435eb31373b3fe9b0b1bf26858c6fc452bf6af394067e10b"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c5c6fa16412b35999320f5c9690c0f554392dc222c04e559217e0f9ae244b92"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:95c4dc6f61d610bc0ee1edc6f29d993f10febfe5b76bb470b486d90bbece6b22"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da452c2c322e9ce0cfef392e469a26d63d42860f829026a63374fde6b5c5876f"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:898715cf566ec2869d5cb4d5fb4be408964704c46c96b4be267442d265390f32"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:391cc3a9c1527e424c6865e087897e766a917f15dddb360174a70467572ac6ce"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:380f926b51b92d02a34119d072f178d80bbda334d1a7e10fa22d467a66e494db"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ce91db90dbf37bb6fa0997f26574107e1b9d5ff939315247b7e615baa8ec313b"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:9093a81e18c45227eebe4c16124ebf3e0d893830c6aca7cc310bfca8fe59d857"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:ee40b40aa753d844162dcc80d0fe256b87cba48ca0054f64e68000453caead11"}, + {file = "aiohttp-3.10.5-cp39-cp39-win32.whl", hash = "sha256:03f2645adbe17f274444953bdea69f8327e9d278d961d85657cb0d06864814c1"}, + {file = "aiohttp-3.10.5-cp39-cp39-win_amd64.whl", hash = "sha256:d17920f18e6ee090bdd3d0bfffd769d9f2cb4c8ffde3eb203777a3895c128862"}, + {file = "aiohttp-3.10.5.tar.gz", hash = "sha256:f071854b47d39591ce9a17981c46790acb30518e2f83dfca8db2dfa091178691"}, ] [package.dependencies] @@ -110,13 +125,13 @@ speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"] [[package]] name = "aiohttp-retry" -version = "2.8.3" +version = "2.9.0" description = "Simple retry client for aiohttp" optional = false python-versions = ">=3.7" files = [ - {file = "aiohttp_retry-2.8.3-py3-none-any.whl", hash = "sha256:3aeeead8f6afe48272db93ced9440cf4eda8b6fd7ee2abb25357b7eb28525b45"}, - {file = "aiohttp_retry-2.8.3.tar.gz", hash = "sha256:9a8e637e31682ad36e1ff9f8bcba912fcfc7d7041722bc901a4b948da4d71ea9"}, + {file = "aiohttp_retry-2.9.0-py3-none-any.whl", hash = "sha256:7661af92471e9a96c69d9b8f32021360272073397e6a15bc44c1726b12f46056"}, + {file = "aiohttp_retry-2.9.0.tar.gz", hash = "sha256:92c47f1580040208bac95d9a8389a87227ef22758530f2e3f4683395e42c41b5"}, ] [package.dependencies] @@ -138,13 +153,13 @@ frozenlist = ">=1.1.0" [[package]] name = "alembic" -version = "1.13.2" +version = "1.13.3" description = "A database migration tool for SQLAlchemy." optional = false python-versions = ">=3.8" files = [ - {file = "alembic-1.13.2-py3-none-any.whl", hash = "sha256:6b8733129a6224a9a711e17c99b08462dbf7cc9670ba8f2e2ae9af860ceb1953"}, - {file = "alembic-1.13.2.tar.gz", hash = "sha256:1ff0ae32975f4fd96028c39ed9bb3c867fe3af956bd7bb37343b54c9fe7445ef"}, + {file = "alembic-1.13.3-py3-none-any.whl", hash = "sha256:908e905976d15235fae59c9ac42c4c5b75cfcefe3d27c0fbf7ae15a37715d80e"}, + {file = "alembic-1.13.3.tar.gz", hash = "sha256:203503117415561e203aa14541740643a611f641517f0209fcae63e9fa09f1a2"}, ] [package.dependencies] @@ -157,12 +172,12 @@ tz = ["backports.zoneinfo"] [[package]] name = "alibabacloud-credentials" -version = "0.3.5" +version = "0.3.6" description = "The alibabacloud credentials module of alibabaCloud Python SDK." optional = false python-versions = ">=3.6" files = [ - {file = "alibabacloud_credentials-0.3.5.tar.gz", hash = "sha256:ad065ec95921eaf51939195485d0e5cc9e0ea050282059c7d8bf74bdb5496177"}, + {file = "alibabacloud_credentials-0.3.6.tar.gz", hash = "sha256:caa82cf258648dcbe1ca14aeba50ba21845567d6ac3cd48d318e0a445fff7f96"}, ] [package.dependencies] @@ -278,13 +293,13 @@ alibabacloud-tea = "*" [[package]] name = "alibabacloud-tea" -version = "0.3.9" +version = "0.4.0" description = "The tea module of alibabaCloud Python SDK." optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" files = [ - {file = "alibabacloud-tea-0.3.9.tar.gz", hash = "sha256:a9689770003fa9313d1995812f9fe36a2be315e5cdfc8d58de0d96808219ced9"}, - {file = "alibabacloud_tea-0.3.9-py3-none-any.whl", hash = "sha256:402fd2a92e6729f228d8c0300b182f80019edce19d83afa497aeb15fd7947f9a"}, + {file = "alibabacloud-tea-0.4.0.tar.gz", hash = "sha256:bdf72d747723bab190331b3c8593109fe2807504469bc0147f78c8c4945ed396"}, + {file = "alibabacloud_tea-0.4.0-py3-none-any.whl", hash = "sha256:59fae5765e6654f884e130233df6fb61ca0fbe01a29ed0755a1cf099a3d4d863"}, ] [package.dependencies] @@ -306,17 +321,17 @@ alibabacloud-tea = ">=0.0.1" [[package]] name = "alibabacloud-tea-openapi" -version = "0.3.11" +version = "0.3.12" description = "Alibaba Cloud openapi SDK Library for Python" optional = false python-versions = ">=3.6" files = [ - {file = "alibabacloud_tea_openapi-0.3.11.tar.gz", hash = "sha256:3f5cace1b1aeb8a64587574097403cfd066b86ee4c3c9abde587f9abfcad38de"}, + {file = "alibabacloud_tea_openapi-0.3.12.tar.gz", hash = "sha256:2e14809f357438e62c1ef4976a7655110dd54a75bbfa7d905fa3798355cfd974"}, ] [package.dependencies] -alibabacloud_credentials = ">=0.3.1,<1.0.0" -alibabacloud_gateway_spi = ">=0.0.1,<1.0.0" +alibabacloud_credentials = ">=0.3.5,<1.0.0" +alibabacloud_gateway_spi = ">=0.0.2,<1.0.0" alibabacloud_openapi_util = ">=0.2.1,<1.0.0" alibabacloud_tea_util = ">=0.3.13,<1.0.0" alibabacloud_tea_xml = ">=0.0.2,<1.0.0" @@ -349,27 +364,27 @@ alibabacloud-tea = ">=0.0.1" [[package]] name = "aliyun-python-sdk-core" -version = "2.15.1" +version = "2.16.0" description = "The core module of Aliyun Python SDK." optional = false -python-versions = "*" +python-versions = ">=3.7" files = [ - {file = "aliyun-python-sdk-core-2.15.1.tar.gz", hash = "sha256:518550d07f537cd3afac3b6c93b5c997ce3440e4d0c054e3acbdaa8261e90adf"}, + {file = "aliyun-python-sdk-core-2.16.0.tar.gz", hash = "sha256:651caad597eb39d4fad6cf85133dffe92837d53bdf62db9d8f37dab6508bb8f9"}, ] [package.dependencies] -cryptography = ">=2.6.0" +cryptography = ">=3.0.0" jmespath = ">=0.9.3,<1.0.0" [[package]] name = "aliyun-python-sdk-kms" -version = "2.16.3" +version = "2.16.5" description = "The kms module of Aliyun Python sdk." optional = false python-versions = "*" files = [ - {file = "aliyun-python-sdk-kms-2.16.3.tar.gz", hash = "sha256:c31b7d24e153271a3043e801e7b6b6b3f0db47e95a83c8d10cdab8c11662fc39"}, - {file = "aliyun_python_sdk_kms-2.16.3-py2.py3-none-any.whl", hash = "sha256:8bb8c293be94e0cc9114a5286a503d2ec215eaf8a1fb51de5d6c8bcac209d4a1"}, + {file = "aliyun-python-sdk-kms-2.16.5.tar.gz", hash = "sha256:f328a8a19d83ecbb965ffce0ec1e9930755216d104638cd95ecd362753b813b3"}, + {file = "aliyun_python_sdk_kms-2.16.5-py2.py3-none-any.whl", hash = "sha256:24b6cdc4fd161d2942619479c8d050c63ea9cd22b044fe33b60bbb60153786f0"}, ] [package.dependencies] @@ -440,13 +455,13 @@ vertex = ["google-auth (>=2,<3)"] [[package]] name = "anyio" -version = "4.4.0" +version = "4.6.2.post1" description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "anyio-4.4.0-py3-none-any.whl", hash = "sha256:c1b2d8f46a8a812513012e1107cb0e68c17159a7a594208005a57dc776e1bdc7"}, - {file = "anyio-4.4.0.tar.gz", hash = "sha256:5aadc6a1bbb7cdb0bede386cac5e2940f5e2ff3aa20277e991cf028e0585ce94"}, + {file = "anyio-4.6.2.post1-py3-none-any.whl", hash = "sha256:6d170c36fba3bdd840c73d3868c1e777e33676a69c3a72cf0a0d5d6d8009b61d"}, + {file = "anyio-4.6.2.post1.tar.gz", hash = "sha256:4c8bc31ccdb51c7f7bd251f51c609e038d63e34219b44aa86e47576389880b4c"}, ] [package.dependencies] @@ -456,9 +471,9 @@ sniffio = ">=1.1" typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} [package.extras] -doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] -test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] -trio = ["trio (>=0.23)"] +doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21.0b1)"] +trio = ["trio (>=0.26.1)"] [[package]] name = "arxiv" @@ -505,22 +520,22 @@ files = [ [[package]] name = "attrs" -version = "24.2.0" +version = "23.2.0" description = "Classes Without Boilerplate" optional = false python-versions = ">=3.7" files = [ - {file = "attrs-24.2.0-py3-none-any.whl", hash = "sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2"}, - {file = "attrs-24.2.0.tar.gz", hash = "sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346"}, + {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"}, + {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"}, ] [package.extras] -benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"] -tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] +cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] +dev = ["attrs[tests]", "pre-commit"] +docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] +tests = ["attrs[tests-no-zope]", "zope-interface"] +tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"] +tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"] [[package]] name = "authlib" @@ -536,15 +551,81 @@ files = [ [package.dependencies] cryptography = "*" +[[package]] +name = "azure-ai-inference" +version = "1.0.0b5" +description = "Microsoft Azure Ai Inference Client Library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "azure_ai_inference-1.0.0b5-py3-none-any.whl", hash = "sha256:0147653088033f1fd059d5f4bd0fedac82529fdcc7a0d2183d9508b3f80cf549"}, + {file = "azure_ai_inference-1.0.0b5.tar.gz", hash = "sha256:c95b490bcd670ccdeb1048dc2b45e0f8252a4d69a348ca15d4510d327b64dd0d"}, +] + +[package.dependencies] +azure-core = ">=1.30.0" +isodate = ">=0.6.1" +typing-extensions = ">=4.6.0" + +[package.extras] +opentelemetry = ["azure-core-tracing-opentelemetry"] + +[[package]] +name = "azure-ai-ml" +version = "1.20.0" +description = "Microsoft Azure Machine Learning Client Library for Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "azure-ai-ml-1.20.0.tar.gz", hash = "sha256:6432a0da1b7250cb0db5a1c33202e0419935e19ea32d4c2b3220705f8f1d4101"}, + {file = "azure_ai_ml-1.20.0-py3-none-any.whl", hash = "sha256:c7eb3c5ccf82a6ee94403c3e5060763decd38cf03ff2620a4a6577526e605104"}, +] + +[package.dependencies] +azure-common = ">=1.1" +azure-core = ">=1.23.0" +azure-mgmt-core = ">=1.3.0" +azure-storage-blob = ">=12.10.0" +azure-storage-file-datalake = ">=12.2.0" +azure-storage-file-share = "*" +colorama = "*" +isodate = "*" +jsonschema = ">=4.0.0" +marshmallow = ">=3.5" +msrest = ">=0.6.18" +opencensus-ext-azure = "*" +opencensus-ext-logging = "*" +pydash = ">=6.0.0" +pyjwt = "*" +pyyaml = ">=5.1.0" +strictyaml = "*" +tqdm = "*" +typing-extensions = "*" + +[package.extras] +designer = ["mldesigner"] +mount = ["azureml-dataprep-rslex (>=2.22.0)"] + +[[package]] +name = "azure-common" +version = "1.1.28" +description = "Microsoft Azure Client Library for Python (Common)" +optional = false +python-versions = "*" +files = [ + {file = "azure-common-1.1.28.zip", hash = "sha256:4ac0cd3214e36b6a1b6a442686722a5d8cc449603aa833f3f0f40bda836704a3"}, + {file = "azure_common-1.1.28-py2.py3-none-any.whl", hash = "sha256:5c12d3dcf4ec20599ca6b0d3e09e86e146353d443e7fcc050c9a19c1f9df20ad"}, +] + [[package]] name = "azure-core" -version = "1.30.2" +version = "1.31.0" description = "Microsoft Azure Core Library for Python" optional = false python-versions = ">=3.8" files = [ - {file = "azure-core-1.30.2.tar.gz", hash = "sha256:a14dc210efcd608821aa472d9fb8e8d035d29b68993819147bc290a8ac224472"}, - {file = "azure_core-1.30.2-py3-none-any.whl", hash = "sha256:cf019c1ca832e96274ae85abd3d9f752397194d9fea3b41487290562ac8abe4a"}, + {file = "azure_core-1.31.0-py3-none-any.whl", hash = "sha256:22954de3777e0250029360ef31d80448ef1be13b80a459bff80ba7073379e2cd"}, + {file = "azure_core-1.31.0.tar.gz", hash = "sha256:656a0dd61e1869b1506b7c6a3b31d62f15984b1a573d6326f6aa2f3e4123284b"}, ] [package.dependencies] @@ -572,6 +653,20 @@ cryptography = ">=2.5" msal = ">=1.24.0" msal-extensions = ">=0.3.0" +[[package]] +name = "azure-mgmt-core" +version = "1.4.0" +description = "Microsoft Azure Management Core Library for Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "azure-mgmt-core-1.4.0.zip", hash = "sha256:d195208340094f98e5a6661b781cde6f6a051e79ce317caabd8ff97030a9b3ae"}, + {file = "azure_mgmt_core-1.4.0-py3-none-any.whl", hash = "sha256:81071675f186a585555ef01816f2774d49c1c9024cb76e5720c3c0f6b337bb7d"}, +] + +[package.dependencies] +azure-core = ">=1.26.2,<2.0.0" + [[package]] name = "azure-storage-blob" version = "12.13.0" @@ -588,6 +683,42 @@ azure-core = ">=1.23.1,<2.0.0" cryptography = ">=2.1.4" msrest = ">=0.6.21" +[[package]] +name = "azure-storage-file-datalake" +version = "12.8.0" +description = "Microsoft Azure File DataLake Storage Client Library for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "azure-storage-file-datalake-12.8.0.zip", hash = "sha256:12e6306e5efb5ca28e0ccd9fa79a2c61acd589866d6109fe5601b18509da92f4"}, + {file = "azure_storage_file_datalake-12.8.0-py3-none-any.whl", hash = "sha256:b6cf5733fe794bf3c866efbe3ce1941409e35b6b125028ac558b436bf90f2de7"}, +] + +[package.dependencies] +azure-core = ">=1.23.1,<2.0.0" +azure-storage-blob = ">=12.13.0,<13.0.0" +msrest = ">=0.6.21" + +[[package]] +name = "azure-storage-file-share" +version = "12.19.0" +description = "Microsoft Azure Azure File Share Storage Client Library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "azure_storage_file_share-12.19.0-py3-none-any.whl", hash = "sha256:eac6cf1a454aba58af4e6ba450b36d16aa1d0c49679fb64ea8756bb896698c5b"}, + {file = "azure_storage_file_share-12.19.0.tar.gz", hash = "sha256:ea7a4174dc6c52f50ac8c30f228159fcc3675d1f8ba771b8d0efcbc310740278"}, +] + +[package.dependencies] +azure-core = ">=1.30.0" +cryptography = ">=2.1.4" +isodate = ">=0.6.1" +typing-extensions = ">=4.6.0" + +[package.extras] +aio = ["azure-core[aio] (>=1.30.0)"] + [[package]] name = "backoff" version = "2.2.1" @@ -599,6 +730,22 @@ files = [ {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, ] +[[package]] +name = "bce-python-sdk" +version = "0.9.23" +description = "BCE SDK for python" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,<4,>=2.7" +files = [ + {file = "bce_python_sdk-0.9.23-py3-none-any.whl", hash = "sha256:8debe21a040e00060f6044877d594765ed7b18bc765c6bf16b878bca864140a3"}, + {file = "bce_python_sdk-0.9.23.tar.gz", hash = "sha256:19739fed5cd0725356fc5ffa2acbdd8fb23f2a81edb91db21a03174551d0cf41"}, +] + +[package.dependencies] +future = ">=0.6.0" +pycryptodome = ">=3.8.0" +six = ">=1.4.0" + [[package]] name = "bcrypt" version = "4.2.0" @@ -659,13 +806,13 @@ lxml = ["lxml"] [[package]] name = "billiard" -version = "4.2.0" +version = "4.2.1" description = "Python multiprocessing fork with improvements and bugfixes" optional = false python-versions = ">=3.7" files = [ - {file = "billiard-4.2.0-py3-none-any.whl", hash = "sha256:07aa978b308f334ff8282bd4a746e681b3513db5c9a514cbdd810cbbdc19714d"}, - {file = "billiard-4.2.0.tar.gz", hash = "sha256:9a3c3184cb275aa17a732f93f65b20c525d3d9f253722d26a82194803ade5a2c"}, + {file = "billiard-4.2.1-py3-none-any.whl", hash = "sha256:40b59a4ac8806ba2c2369ea98d876bc6108b051c227baffd928c644d15d8f3cb"}, + {file = "billiard-4.2.1.tar.gz", hash = "sha256:12b641b0c539073fc8d3f5b8b7be998956665c4233c7c1fcd66a7e677c4fb36f"}, ] [[package]] @@ -681,17 +828,17 @@ files = [ [[package]] name = "boto3" -version = "1.34.148" +version = "1.35.17" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.34.148-py3-none-any.whl", hash = "sha256:d63d36e5a34533ba69188d56f96da132730d5e9932c4e11c02d79319cd1afcec"}, - {file = "boto3-1.34.148.tar.gz", hash = "sha256:2058397f0a92c301e3116e9e65fbbc70ea49270c250882d65043d19b7c6e2d17"}, + {file = "boto3-1.35.17-py3-none-any.whl", hash = "sha256:67268aa6c4043e9fdeb4ab3c1e9032f44a6fa168c789af5e351f63f1f8880a2f"}, + {file = "boto3-1.35.17.tar.gz", hash = "sha256:4a32db8793569ee5f13c5bf3efb260193353cb8946bf6426e3c330b61c68e59d"}, ] [package.dependencies] -botocore = ">=1.34.148,<1.35.0" +botocore = ">=1.35.17,<1.36.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -700,13 +847,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.34.155" +version = "1.35.52" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.34.155-py3-none-any.whl", hash = "sha256:f2696c11bb0cad627d42512937befd2e3f966aedd15de00d90ee13cf7a16b328"}, - {file = "botocore-1.34.155.tar.gz", hash = "sha256:3aa88abfef23909f68d3e6679a3d4b4bb3c6288a6cfbf9e253aa68dac8edad64"}, + {file = "botocore-1.35.52-py3-none-any.whl", hash = "sha256:cdbb5e43c9c3a977763e2a10d3b8b9c405d51279f9fcfd4ca4800763b22acba5"}, + {file = "botocore-1.35.52.tar.gz", hash = "sha256:1fe7485ea13d638b089103addd818c12984ff1e4d208de15f180b1e25ad944c5"}, ] [package.dependencies] @@ -715,57 +862,51 @@ python-dateutil = ">=2.1,<3.0.0" urllib3 = {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""} [package.extras] -crt = ["awscrt (==0.21.2)"] +crt = ["awscrt (==0.22.0)"] [[package]] name = "bottleneck" -version = "1.4.0" +version = "1.4.2" description = "Fast NumPy array functions written in C" optional = false -python-versions = "*" +python-versions = ">=3.9" files = [ - {file = "Bottleneck-1.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2110af22aa8c2779faba8aa021d6b559df04449bdf21d510eacd7910934189fe"}, - {file = "Bottleneck-1.4.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:381cbd1e52338fcdf9ff01c962e6aa187b2d8b3b369d42e779b6d33ac61f8d35"}, - {file = "Bottleneck-1.4.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a91e40bbb8452e77772614d882be2c34b3b514d9f15460f703293525a6e173d"}, - {file = "Bottleneck-1.4.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:59604949aea476f5075b965129eaa3c2d90891fd43b0dfaf2ad7621bb5db14a5"}, - {file = "Bottleneck-1.4.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c2c92545e1bc8e859d8d137aefa3b24843bd374b17c9814dafa3bbcea9fc4ec0"}, - {file = "Bottleneck-1.4.0-cp310-cp310-win32.whl", hash = "sha256:f63e79bfa2f82a7432c8b147ed321d01ca7769bc17cc04644286a4ce58d30549"}, - {file = "Bottleneck-1.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:d69907d8d679cb5091a3f479c46bf1076f149f6311ff3298bac5089b86a2fab1"}, - {file = "Bottleneck-1.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:67347b0f01f32a232a6269c37afc1c079e08f6455fa12e91f4a1cd12eb0d11a5"}, - {file = "Bottleneck-1.4.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1490348b3bbc0225523dc2c00c6bb3e66168c537d62797bd29783c0826c09838"}, - {file = "Bottleneck-1.4.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a704165552496cbcc8bcc5921bb679fd6fa66bb1e758888de091b1223231c9f0"}, - {file = "Bottleneck-1.4.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:ffb4e4edf7997069719b9269926cc00a2a12c6e015422d1ebc2f621c4541396a"}, - {file = "Bottleneck-1.4.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5d6bf45ed58d5e7414c0011ef2da75474fe597a51970df83596b0bcb79c14c5e"}, - {file = "Bottleneck-1.4.0-cp311-cp311-win32.whl", hash = "sha256:ed209f8f3cb9954773764b0fa2510a7a9247ad245593187ac90bd0747771bc5c"}, - {file = "Bottleneck-1.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:d53f1a72b12cfd76b56934c33bc0cb7c1a295f23a2d3ffba8c764514c9b5e0ff"}, - {file = "Bottleneck-1.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e720ff24370324c84a82b1a18195274715c23181748b2b9e3dacad24198ca06f"}, - {file = "Bottleneck-1.4.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:44305c70c2a1539b0ae968e033f301ad868a6146b47e3cccd73fdfe3fc07c4ee"}, - {file = "Bottleneck-1.4.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1b4dac5d2a871b7bd296c2b92426daa27d5b07aa84ef2557db097d29135da4eb"}, - {file = "Bottleneck-1.4.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fbcdd01db9e27741fb16a02b720cf02389d4b0b99cefe3c834c7df88c2d7412d"}, - {file = "Bottleneck-1.4.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:14b3334a39308fbb05dacd35ac100842aa9e9bc70afbdcebe43e46179d183fd0"}, - {file = "Bottleneck-1.4.0-cp312-cp312-win32.whl", hash = "sha256:520d7a83cd48b3f58e5df1a258acb547f8a5386a8c21ca9e1058d83a0d622fdf"}, - {file = "Bottleneck-1.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:b1339b9ad3ee217253f246cde5c3789eb527cf9dd31ff0a1f5a8bf7fc89eadad"}, - {file = "Bottleneck-1.4.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2749602200aaa0e12a0f3f936dd6d4035384ad10d3acf7ac4f418c501683397"}, - {file = "Bottleneck-1.4.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bb79a2ac135567694f13339f0bebcee96aec09c596b324b61cd7fd5e306f49d"}, - {file = "Bottleneck-1.4.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c6097bf39723e76ff5bba160daab92ae599df212c859db8d46648548584d04a8"}, - {file = "Bottleneck-1.4.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:b5f72b66ccc0272de46b67346cf8490737ba2adc6a302664f5326e7741b6d5ab"}, - {file = "Bottleneck-1.4.0-cp37-cp37m-win32.whl", hash = "sha256:9903f017b9d6f2f69ce241b424ddad7265624f64dc6eafbe257d45661febf8bd"}, - {file = "Bottleneck-1.4.0-cp37-cp37m-win_amd64.whl", hash = "sha256:834816c316ad184cae7ecb615b69876a42cd2cafb07ee66c57a9c1ccacb63339"}, - {file = "Bottleneck-1.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:03c43150f180d86a5633a6da788660d335983f6798fca306ba7f47ff27a1b7e7"}, - {file = "Bottleneck-1.4.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eea333dbcadb780356c54f5c4fa7754f143573b57508fff43d5daf63298eb26a"}, - {file = "Bottleneck-1.4.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6179791c0119aec3708ef74ddadab8d183e3742adb93a9028718e8696bdf572b"}, - {file = "Bottleneck-1.4.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:220b72405f77aebb0137b733b464c2526ded471e4289ac1e840bab8852759a55"}, - {file = "Bottleneck-1.4.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8746f0f727997ce4c7457dc1fec4e4e3c0fdd8803514baa3d1c4ea6515ab04b2"}, - {file = "Bottleneck-1.4.0-cp38-cp38-win32.whl", hash = "sha256:6a36280ee33d9db799163f04e88b950261e590cc71d089f5e179b21680b5d491"}, - {file = "Bottleneck-1.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:de17e012694e6a987bb4eb050dd7f0cf939195a8e00cb23aa93ebee5fd5e64a8"}, - {file = "Bottleneck-1.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:28260197ab8a4a6b7adf810523147b1a3e85607f4e26a0f685eb9d155cfc75af"}, - {file = "Bottleneck-1.4.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:90d5d188a0cca0b9655ff2904ee61e7f183079e97550be98c2541a2eec358a72"}, - {file = "Bottleneck-1.4.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2861ff645d236f1a6f5c6d1ddb3db37d19af1d91057bdc4fd7b76299a15b3079"}, - {file = "Bottleneck-1.4.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6136ce7dcf825c432a20b80ab1c460264a437d8430fff32536176147e0b6b832"}, - {file = "Bottleneck-1.4.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:889e6855b77345622b4ba927335d3118745d590492941f5f78554f157d259e92"}, - {file = "Bottleneck-1.4.0-cp39-cp39-win32.whl", hash = "sha256:817aa43a671ede696ea023d8f35839a391244662340cc95a0f46965dda8b35cf"}, - {file = "Bottleneck-1.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:23834d82177d6997f21fa63156550668cd07a9a6e5a1b66ea80f1a14ac6ffd07"}, - {file = "bottleneck-1.4.0.tar.gz", hash = "sha256:beb36df519b8709e7d357c0c9639b03b885ca6355bbf5e53752c685de51605b8"}, + {file = "Bottleneck-1.4.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:125436df93751a226eab1732783aa8f6125e88e779587aa61be071fb66e41f9d"}, + {file = "Bottleneck-1.4.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c6df9a60ec6ab88fec934ca864266ba95edd89c490af71dc9cd8afb2a54ebd9"}, + {file = "Bottleneck-1.4.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e2fe327dc2d0564e295a5857a252755103f8c6e05b07d3ff80a69afaa9f5065"}, + {file = "Bottleneck-1.4.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:6b7790ca8658cd69e3cc0d0e4ff0e9829d60849bf7945fbd7344fbce05b2bbb8"}, + {file = "Bottleneck-1.4.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:6282fa925ac3768f66e3547f89a512376d3f9de7ef53bdd37aa29232fd864054"}, + {file = "Bottleneck-1.4.2-cp310-cp310-win32.whl", hash = "sha256:e56a206fbf48e3b8054a964398bf1ed843e9625d3c6bdbeb7898cb48bf97441b"}, + {file = "Bottleneck-1.4.2-cp310-cp310-win_amd64.whl", hash = "sha256:eb0c611d15b0fd8f511d288e8964e4725b4b3b0d9d310880cf0ff6b8dd03c859"}, + {file = "Bottleneck-1.4.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b6902ebf3e85315b481bc084f10c5770f8240275ad1e039ac69c7c8d2013b040"}, + {file = "Bottleneck-1.4.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c2fd34b9b490204f95288f0dd35d37042486a95029617246c88c0f94a0ab49fe"}, + {file = "Bottleneck-1.4.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:122845e3106c85465551d4a9a3777841347cfedfbebb3aa985cca110e07030b1"}, + {file = "Bottleneck-1.4.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1f61658ebdf5a178298544336b65020730bf86cc092dab5f6579a99a86bd888b"}, + {file = "Bottleneck-1.4.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:7c7d29c044a3511b36fd744503c3e697e279c273a8477a6d91a2831d04fd19e0"}, + {file = "Bottleneck-1.4.2-cp311-cp311-win32.whl", hash = "sha256:c663cbba8f52011fd82ee08c6a85c93b34b19e0e7ebba322d2d67809f34e0597"}, + {file = "Bottleneck-1.4.2-cp311-cp311-win_amd64.whl", hash = "sha256:89651ef18c06616850203bf8875c958c5d316ea48d8ba60d9b450199d39ae391"}, + {file = "Bottleneck-1.4.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a74ddd0417f42eeaba37375f0fc065b28451e0fba45cb2f99e88880b10b3fa43"}, + {file = "Bottleneck-1.4.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:070d22f2f62ab81297380a89492cca931e4d9443fa4b84c2baeb52db09c3b1b4"}, + {file = "Bottleneck-1.4.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1fc4e7645bd425c05e05acd5541e9e09cb4179e71164e862f082561bf4509eac"}, + {file = "Bottleneck-1.4.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:037315c56605128a39f77d19af6a6019dc8c21a63694a4bfef3c026ed963be2e"}, + {file = "Bottleneck-1.4.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:99778329331d5fae8df19772a019e8b73ba4d9d1650f110cd995ab7657114db0"}, + {file = "Bottleneck-1.4.2-cp312-cp312-win32.whl", hash = "sha256:7363b3c8ce6ca433779cd7e96bcb94c0e516dcacadff0011adcbf0b3ac86bc9d"}, + {file = "Bottleneck-1.4.2-cp312-cp312-win_amd64.whl", hash = "sha256:48c6b9d9287c4102b803fcb01ae66ae7ef6b310b711b4b7b7e23bf952894dc05"}, + {file = "Bottleneck-1.4.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c1c885ad02a6a8fa1f7ee9099f29b9d4c03eb1da2c7ab25839482d5cce739021"}, + {file = "Bottleneck-1.4.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e7a1b023de1de3d84b18826462718fba548fed41870df44354f9ab6a414ea82f"}, + {file = "Bottleneck-1.4.2-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c9dbaf737b605b30c81611f2c1d197c2fd2e46c33f605876c1d332d3360c4fc"}, + {file = "Bottleneck-1.4.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7ebbcbe5d4062e37507b9a81e2aacdb1fcccc6193f7feff124ef2b5a6a5eb740"}, + {file = "Bottleneck-1.4.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:964f6ac4118ddab3bbbac79d4f726b093459be751baba73ee0aa364666e8068e"}, + {file = "Bottleneck-1.4.2-cp313-cp313-win32.whl", hash = "sha256:2db287f6ecdbb1c998085eca9b717fec2bfc48a4ab6ae070a9820ba8ab59c90b"}, + {file = "Bottleneck-1.4.2-cp313-cp313-win_amd64.whl", hash = "sha256:26b5f0531f7044befaad95c20365dd666372e66bdacbfaf009ff65d60285534d"}, + {file = "Bottleneck-1.4.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:72d6aa95cdd782833d2589f81434fd865ba004b8938e07920b6ef02796ce8918"}, + {file = "Bottleneck-1.4.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b33e83665e7daf7f513fe1f7b04b13944d44b6635c45d5a9c89c9e5ed11811b6"}, + {file = "Bottleneck-1.4.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:52248f3e0fead78c17912fb086a585c86f567019247d21c69e87645241b97b02"}, + {file = "Bottleneck-1.4.2-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:dce1a3c5ff89a56fb2678c9bda17b89f60f710d6002ab7cd72b7661bc3fae64d"}, + {file = "Bottleneck-1.4.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:48d2e101d99a9d72aa86da1a048d2094f4e1db0cf77519d1c33239f9d62da162"}, + {file = "Bottleneck-1.4.2-cp39-cp39-win32.whl", hash = "sha256:9d7b12936516f944e3d981a64038f99acb21f0e99f92fad16d9a468248c2b231"}, + {file = "Bottleneck-1.4.2-cp39-cp39-win_amd64.whl", hash = "sha256:7b459d08f1f3e2da85db0a9e2d3e6e3541105f5866e9026dbca32dafc5106f2b"}, + {file = "bottleneck-1.4.2.tar.gz", hash = "sha256:fa8e8e1799dea5483ce6669462660f9d9a95649f6f98a80d315b84ec89f449f4"}, ] [package.dependencies] @@ -921,13 +1062,13 @@ beautifulsoup4 = "*" [[package]] name = "build" -version = "1.2.1" +version = "1.2.2.post1" description = "A simple, correct Python build frontend" optional = false python-versions = ">=3.8" files = [ - {file = "build-1.2.1-py3-none-any.whl", hash = "sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4"}, - {file = "build-1.2.1.tar.gz", hash = "sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d"}, + {file = "build-1.2.2.post1-py3-none-any.whl", hash = "sha256:1d61c0887fa860c01971625baae8bdd338e517b836a2f70dd1f7aa3a6b2fc5b5"}, + {file = "build-1.2.2.post1.tar.gz", hash = "sha256:b36993e92ca9375a219c99e606a122ff365a760a2d4bba0caa09bd5278b608b7"}, ] [package.dependencies] @@ -957,13 +1098,13 @@ files = [ [[package]] name = "celery" -version = "5.3.6" +version = "5.4.0" description = "Distributed Task Queue." optional = false python-versions = ">=3.8" files = [ - {file = "celery-5.3.6-py3-none-any.whl", hash = "sha256:9da4ea0118d232ce97dff5ed4974587fb1c0ff5c10042eb15278487cdd27d1af"}, - {file = "celery-5.3.6.tar.gz", hash = "sha256:870cc71d737c0200c397290d730344cc991d13a057534353d124c9380267aab9"}, + {file = "celery-5.4.0-py3-none-any.whl", hash = "sha256:369631eb580cf8c51a82721ec538684994f8277637edde2dfc0dacd73ed97f64"}, + {file = "celery-5.4.0.tar.gz", hash = "sha256:504a19140e8d3029d5acad88330c541d4c3f64c789d85f94756762d8bca7e706"}, ] [package.dependencies] @@ -979,7 +1120,7 @@ vine = ">=5.1.0,<6.0" [package.extras] arangodb = ["pyArango (>=2.0.2)"] -auth = ["cryptography (==41.0.5)"] +auth = ["cryptography (==42.0.5)"] azureblockblob = ["azure-storage-blob (>=12.15.0)"] brotli = ["brotli (>=1.0.0)", "brotlipy (>=0.7.0)"] cassandra = ["cassandra-driver (>=3.25.0,<4)"] @@ -989,22 +1130,23 @@ couchbase = ["couchbase (>=3.0.0)"] couchdb = ["pycouchdb (==1.14.2)"] django = ["Django (>=2.2.28)"] dynamodb = ["boto3 (>=1.26.143)"] -elasticsearch = ["elastic-transport (<=8.10.0)", "elasticsearch (<=8.11.0)"] +elasticsearch = ["elastic-transport (<=8.13.0)", "elasticsearch (<=8.13.0)"] eventlet = ["eventlet (>=0.32.0)"] +gcs = ["google-cloud-storage (>=2.10.0)"] gevent = ["gevent (>=1.5.0)"] librabbitmq = ["librabbitmq (>=2.0.0)"] memcache = ["pylibmc (==1.6.3)"] mongodb = ["pymongo[srv] (>=4.0.2)"] -msgpack = ["msgpack (==1.0.7)"] -pymemcache = ["python-memcached (==1.59)"] +msgpack = ["msgpack (==1.0.8)"] +pymemcache = ["python-memcached (>=1.61)"] pyro = ["pyro4 (==4.82)"] -pytest = ["pytest-celery (==0.0.0)"] +pytest = ["pytest-celery[all] (>=1.0.0)"] redis = ["redis (>=4.5.2,!=4.5.5,<6.0.0)"] s3 = ["boto3 (>=1.26.143)"] slmq = ["softlayer-messaging (>=1.0.3)"] solar = ["ephem (==4.1.5)"] sqlalchemy = ["sqlalchemy (>=1.4.48,<2.1)"] -sqs = ["boto3 (>=1.26.143)", "kombu[sqs] (>=5.3.0)", "pycurl (>=7.43.0.5)", "urllib3 (>=1.26.16)"] +sqs = ["boto3 (>=1.26.143)", "kombu[sqs] (>=5.3.4)", "pycurl (>=7.43.0.5)", "urllib3 (>=1.26.16)"] tblib = ["tblib (>=1.3.0)", "tblib (>=1.5.0)"] yaml = ["PyYAML (>=3.10)"] zookeeper = ["kazoo (>=1.3.1)"] @@ -1012,89 +1154,89 @@ zstd = ["zstandard (==0.22.0)"] [[package]] name = "certifi" -version = "2024.7.4" +version = "2024.8.30" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" files = [ - {file = "certifi-2024.7.4-py3-none-any.whl", hash = "sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90"}, - {file = "certifi-2024.7.4.tar.gz", hash = "sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b"}, + {file = "certifi-2024.8.30-py3-none-any.whl", hash = "sha256:922820b53db7a7257ffbda3f597266d435245903d80737e34f8a45ff3e3230d8"}, + {file = "certifi-2024.8.30.tar.gz", hash = "sha256:bec941d2aa8195e248a60b31ff9f0558284cf01a52591ceda73ea9afffd69fd9"}, ] [[package]] name = "cffi" -version = "1.17.0" +version = "1.17.1" description = "Foreign Function Interface for Python calling C code." optional = false python-versions = ">=3.8" files = [ - {file = "cffi-1.17.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f9338cc05451f1942d0d8203ec2c346c830f8e86469903d5126c1f0a13a2bcbb"}, - {file = "cffi-1.17.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a0ce71725cacc9ebf839630772b07eeec220cbb5f03be1399e0457a1464f8e1a"}, - {file = "cffi-1.17.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c815270206f983309915a6844fe994b2fa47e5d05c4c4cef267c3b30e34dbe42"}, - {file = "cffi-1.17.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d6bdcd415ba87846fd317bee0774e412e8792832e7805938987e4ede1d13046d"}, - {file = "cffi-1.17.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8a98748ed1a1df4ee1d6f927e151ed6c1a09d5ec21684de879c7ea6aa96f58f2"}, - {file = "cffi-1.17.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0a048d4f6630113e54bb4b77e315e1ba32a5a31512c31a273807d0027a7e69ab"}, - {file = "cffi-1.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:24aa705a5f5bd3a8bcfa4d123f03413de5d86e497435693b638cbffb7d5d8a1b"}, - {file = "cffi-1.17.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:856bf0924d24e7f93b8aee12a3a1095c34085600aa805693fb7f5d1962393206"}, - {file = "cffi-1.17.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:4304d4416ff032ed50ad6bb87416d802e67139e31c0bde4628f36a47a3164bfa"}, - {file = "cffi-1.17.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:331ad15c39c9fe9186ceaf87203a9ecf5ae0ba2538c9e898e3a6967e8ad3db6f"}, - {file = "cffi-1.17.0-cp310-cp310-win32.whl", hash = "sha256:669b29a9eca6146465cc574659058ed949748f0809a2582d1f1a324eb91054dc"}, - {file = "cffi-1.17.0-cp310-cp310-win_amd64.whl", hash = "sha256:48b389b1fd5144603d61d752afd7167dfd205973a43151ae5045b35793232aa2"}, - {file = "cffi-1.17.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c5d97162c196ce54af6700949ddf9409e9833ef1003b4741c2b39ef46f1d9720"}, - {file = "cffi-1.17.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5ba5c243f4004c750836f81606a9fcb7841f8874ad8f3bf204ff5e56332b72b9"}, - {file = "cffi-1.17.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bb9333f58fc3a2296fb1d54576138d4cf5d496a2cc118422bd77835e6ae0b9cb"}, - {file = "cffi-1.17.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:435a22d00ec7d7ea533db494da8581b05977f9c37338c80bc86314bec2619424"}, - {file = "cffi-1.17.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1df34588123fcc88c872f5acb6f74ae59e9d182a2707097f9e28275ec26a12d"}, - {file = "cffi-1.17.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:df8bb0010fdd0a743b7542589223a2816bdde4d94bb5ad67884348fa2c1c67e8"}, - {file = "cffi-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8b5b9712783415695663bd463990e2f00c6750562e6ad1d28e072a611c5f2a6"}, - {file = "cffi-1.17.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ffef8fd58a36fb5f1196919638f73dd3ae0db1a878982b27a9a5a176ede4ba91"}, - {file = "cffi-1.17.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4e67d26532bfd8b7f7c05d5a766d6f437b362c1bf203a3a5ce3593a645e870b8"}, - {file = "cffi-1.17.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:45f7cd36186db767d803b1473b3c659d57a23b5fa491ad83c6d40f2af58e4dbb"}, - {file = "cffi-1.17.0-cp311-cp311-win32.whl", hash = "sha256:a9015f5b8af1bb6837a3fcb0cdf3b874fe3385ff6274e8b7925d81ccaec3c5c9"}, - {file = "cffi-1.17.0-cp311-cp311-win_amd64.whl", hash = "sha256:b50aaac7d05c2c26dfd50c3321199f019ba76bb650e346a6ef3616306eed67b0"}, - {file = "cffi-1.17.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aec510255ce690d240f7cb23d7114f6b351c733a74c279a84def763660a2c3bc"}, - {file = "cffi-1.17.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2770bb0d5e3cc0e31e7318db06efcbcdb7b31bcb1a70086d3177692a02256f59"}, - {file = "cffi-1.17.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:db9a30ec064129d605d0f1aedc93e00894b9334ec74ba9c6bdd08147434b33eb"}, - {file = "cffi-1.17.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a47eef975d2b8b721775a0fa286f50eab535b9d56c70a6e62842134cf7841195"}, - {file = "cffi-1.17.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f3e0992f23bbb0be00a921eae5363329253c3b86287db27092461c887b791e5e"}, - {file = "cffi-1.17.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6107e445faf057c118d5050560695e46d272e5301feffda3c41849641222a828"}, - {file = "cffi-1.17.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eb862356ee9391dc5a0b3cbc00f416b48c1b9a52d252d898e5b7696a5f9fe150"}, - {file = "cffi-1.17.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c1c13185b90bbd3f8b5963cd8ce7ad4ff441924c31e23c975cb150e27c2bf67a"}, - {file = "cffi-1.17.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:17c6d6d3260c7f2d94f657e6872591fe8733872a86ed1345bda872cfc8c74885"}, - {file = "cffi-1.17.0-cp312-cp312-win32.whl", hash = "sha256:c3b8bd3133cd50f6b637bb4322822c94c5ce4bf0d724ed5ae70afce62187c492"}, - {file = "cffi-1.17.0-cp312-cp312-win_amd64.whl", hash = "sha256:dca802c8db0720ce1c49cce1149ff7b06e91ba15fa84b1d59144fef1a1bc7ac2"}, - {file = "cffi-1.17.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:6ce01337d23884b21c03869d2f68c5523d43174d4fc405490eb0091057943118"}, - {file = "cffi-1.17.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:cab2eba3830bf4f6d91e2d6718e0e1c14a2f5ad1af68a89d24ace0c6b17cced7"}, - {file = "cffi-1.17.0-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:14b9cbc8f7ac98a739558eb86fabc283d4d564dafed50216e7f7ee62d0d25377"}, - {file = "cffi-1.17.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b00e7bcd71caa0282cbe3c90966f738e2db91e64092a877c3ff7f19a1628fdcb"}, - {file = "cffi-1.17.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:41f4915e09218744d8bae14759f983e466ab69b178de38066f7579892ff2a555"}, - {file = "cffi-1.17.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4760a68cab57bfaa628938e9c2971137e05ce48e762a9cb53b76c9b569f1204"}, - {file = "cffi-1.17.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:011aff3524d578a9412c8b3cfaa50f2c0bd78e03eb7af7aa5e0df59b158efb2f"}, - {file = "cffi-1.17.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:a003ac9edc22d99ae1286b0875c460351f4e101f8c9d9d2576e78d7e048f64e0"}, - {file = "cffi-1.17.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ef9528915df81b8f4c7612b19b8628214c65c9b7f74db2e34a646a0a2a0da2d4"}, - {file = "cffi-1.17.0-cp313-cp313-win32.whl", hash = "sha256:70d2aa9fb00cf52034feac4b913181a6e10356019b18ef89bc7c12a283bf5f5a"}, - {file = "cffi-1.17.0-cp313-cp313-win_amd64.whl", hash = "sha256:b7b6ea9e36d32582cda3465f54c4b454f62f23cb083ebc7a94e2ca6ef011c3a7"}, - {file = "cffi-1.17.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:964823b2fc77b55355999ade496c54dde161c621cb1f6eac61dc30ed1b63cd4c"}, - {file = "cffi-1.17.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:516a405f174fd3b88829eabfe4bb296ac602d6a0f68e0d64d5ac9456194a5b7e"}, - {file = "cffi-1.17.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dec6b307ce928e8e112a6bb9921a1cb00a0e14979bf28b98e084a4b8a742bd9b"}, - {file = "cffi-1.17.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e4094c7b464cf0a858e75cd14b03509e84789abf7b79f8537e6a72152109c76e"}, - {file = "cffi-1.17.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2404f3de742f47cb62d023f0ba7c5a916c9c653d5b368cc966382ae4e57da401"}, - {file = "cffi-1.17.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3aa9d43b02a0c681f0bfbc12d476d47b2b2b6a3f9287f11ee42989a268a1833c"}, - {file = "cffi-1.17.0-cp38-cp38-win32.whl", hash = "sha256:0bb15e7acf8ab35ca8b24b90af52c8b391690ef5c4aec3d31f38f0d37d2cc499"}, - {file = "cffi-1.17.0-cp38-cp38-win_amd64.whl", hash = "sha256:93a7350f6706b31f457c1457d3a3259ff9071a66f312ae64dc024f049055f72c"}, - {file = "cffi-1.17.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1a2ddbac59dc3716bc79f27906c010406155031a1c801410f1bafff17ea304d2"}, - {file = "cffi-1.17.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6327b572f5770293fc062a7ec04160e89741e8552bf1c358d1a23eba68166759"}, - {file = "cffi-1.17.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dbc183e7bef690c9abe5ea67b7b60fdbca81aa8da43468287dae7b5c046107d4"}, - {file = "cffi-1.17.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5bdc0f1f610d067c70aa3737ed06e2726fd9d6f7bfee4a351f4c40b6831f4e82"}, - {file = "cffi-1.17.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6d872186c1617d143969defeadac5a904e6e374183e07977eedef9c07c8953bf"}, - {file = "cffi-1.17.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0d46ee4764b88b91f16661a8befc6bfb24806d885e27436fdc292ed7e6f6d058"}, - {file = "cffi-1.17.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f76a90c345796c01d85e6332e81cab6d70de83b829cf1d9762d0a3da59c7932"}, - {file = "cffi-1.17.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0e60821d312f99d3e1569202518dddf10ae547e799d75aef3bca3a2d9e8ee693"}, - {file = "cffi-1.17.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:eb09b82377233b902d4c3fbeeb7ad731cdab579c6c6fda1f763cd779139e47c3"}, - {file = "cffi-1.17.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:24658baf6224d8f280e827f0a50c46ad819ec8ba380a42448e24459daf809cf4"}, - {file = "cffi-1.17.0-cp39-cp39-win32.whl", hash = "sha256:0fdacad9e0d9fc23e519efd5ea24a70348305e8d7d85ecbb1a5fa66dc834e7fb"}, - {file = "cffi-1.17.0-cp39-cp39-win_amd64.whl", hash = "sha256:7cbc78dc018596315d4e7841c8c3a7ae31cc4d638c9b627f87d52e8abaaf2d29"}, - {file = "cffi-1.17.0.tar.gz", hash = "sha256:f3157624b7558b914cb039fd1af735e5e8049a87c817cc215109ad1c8779df76"}, + {file = "cffi-1.17.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14"}, + {file = "cffi-1.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:edae79245293e15384b51f88b00613ba9f7198016a5948b5dddf4917d4d26382"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45398b671ac6d70e67da8e4224a065cec6a93541bb7aebe1b198a61b58c7b702"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ad9413ccdeda48c5afdae7e4fa2192157e991ff761e7ab8fdd8926f40b160cc3"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5da5719280082ac6bd9aa7becb3938dc9f9cbd57fac7d2871717b1feb0902ab6"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bb1a08b8008b281856e5971307cc386a8e9c5b625ac297e853d36da6efe9c17"}, + {file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:045d61c734659cc045141be4bae381a41d89b741f795af1dd018bfb532fd0df8"}, + {file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6883e737d7d9e4899a8a695e00ec36bd4e5e4f18fabe0aca0efe0a4b44cdb13e"}, + {file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6b8b4a92e1c65048ff98cfe1f735ef8f1ceb72e3d5f0c25fdb12087a23da22be"}, + {file = "cffi-1.17.1-cp310-cp310-win32.whl", hash = "sha256:c9c3d058ebabb74db66e431095118094d06abf53284d9c81f27300d0e0d8bc7c"}, + {file = "cffi-1.17.1-cp310-cp310-win_amd64.whl", hash = "sha256:0f048dcf80db46f0098ccac01132761580d28e28bc0f78ae0d58048063317e15"}, + {file = "cffi-1.17.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a45e3c6913c5b87b3ff120dcdc03f6131fa0065027d0ed7ee6190736a74cd401"}, + {file = "cffi-1.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:30c5e0cb5ae493c04c8b42916e52ca38079f1b235c2f8ae5f4527b963c401caf"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f75c7ab1f9e4aca5414ed4d8e5c0e303a34f4421f8a0d47a4d019ceff0ab6af4"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1ed2dd2972641495a3ec98445e09766f077aee98a1c896dcb4ad0d303628e41"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:46bf43160c1a35f7ec506d254e5c890f3c03648a4dbac12d624e4490a7046cd1"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a24ed04c8ffd54b0729c07cee15a81d964e6fee0e3d4d342a27b020d22959dc6"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:610faea79c43e44c71e1ec53a554553fa22321b65fae24889706c0a84d4ad86d"}, + {file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a9b15d491f3ad5d692e11f6b71f7857e7835eb677955c00cc0aefcd0669adaf6"}, + {file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:de2ea4b5833625383e464549fec1bc395c1bdeeb5f25c4a3a82b5a8c756ec22f"}, + {file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fc48c783f9c87e60831201f2cce7f3b2e4846bf4d8728eabe54d60700b318a0b"}, + {file = "cffi-1.17.1-cp311-cp311-win32.whl", hash = "sha256:85a950a4ac9c359340d5963966e3e0a94a676bd6245a4b55bc43949eee26a655"}, + {file = "cffi-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:caaf0640ef5f5517f49bc275eca1406b0ffa6aa184892812030f04c2abf589a0"}, + {file = "cffi-1.17.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:805b4371bf7197c329fcb3ead37e710d1bca9da5d583f5073b799d5c5bd1eee4"}, + {file = "cffi-1.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:733e99bc2df47476e3848417c5a4540522f234dfd4ef3ab7fafdf555b082ec0c"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1257bdabf294dceb59f5e70c64a3e2f462c30c7ad68092d01bbbfb1c16b1ba36"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da95af8214998d77a98cc14e3a3bd00aa191526343078b530ceb0bd710fb48a5"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d63afe322132c194cf832bfec0dc69a99fb9bb6bbd550f161a49e9e855cc78ff"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f79fc4fc25f1c8698ff97788206bb3c2598949bfe0fef03d299eb1b5356ada99"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b62ce867176a75d03a665bad002af8e6d54644fad99a3c70905c543130e39d93"}, + {file = "cffi-1.17.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:386c8bf53c502fff58903061338ce4f4950cbdcb23e2902d86c0f722b786bbe3"}, + {file = "cffi-1.17.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4ceb10419a9adf4460ea14cfd6bc43d08701f0835e979bf821052f1805850fe8"}, + {file = "cffi-1.17.1-cp312-cp312-win32.whl", hash = "sha256:a08d7e755f8ed21095a310a693525137cfe756ce62d066e53f502a83dc550f65"}, + {file = "cffi-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:51392eae71afec0d0c8fb1a53b204dbb3bcabcb3c9b807eedf3e1e6ccf2de903"}, + {file = "cffi-1.17.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f3a2b4222ce6b60e2e8b337bb9596923045681d71e5a082783484d845390938e"}, + {file = "cffi-1.17.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0984a4925a435b1da406122d4d7968dd861c1385afe3b45ba82b750f229811e2"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d01b12eeeb4427d3110de311e1774046ad344f5b1a7403101878976ecd7a10f3"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:706510fe141c86a69c8ddc029c7910003a17353970cff3b904ff0686a5927683"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de55b766c7aa2e2a3092c51e0483d700341182f08e67c63630d5b6f200bb28e5"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c59d6e989d07460165cc5ad3c61f9fd8f1b4796eacbd81cee78957842b834af4"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd398dbc6773384a17fe0d3e7eeb8d1a21c2200473ee6806bb5e6a8e62bb73dd"}, + {file = "cffi-1.17.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:3edc8d958eb099c634dace3c7e16560ae474aa3803a5df240542b305d14e14ed"}, + {file = "cffi-1.17.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:72e72408cad3d5419375fc87d289076ee319835bdfa2caad331e377589aebba9"}, + {file = "cffi-1.17.1-cp313-cp313-win32.whl", hash = "sha256:e03eab0a8677fa80d646b5ddece1cbeaf556c313dcfac435ba11f107ba117b5d"}, + {file = "cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a"}, + {file = "cffi-1.17.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:636062ea65bd0195bc012fea9321aca499c0504409f413dc88af450b57ffd03b"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c7eac2ef9b63c79431bc4b25f1cd649d7f061a28808cbc6c47b534bd789ef964"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e221cf152cff04059d011ee126477f0d9588303eb57e88923578ace7baad17f9"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:31000ec67d4221a71bd3f67df918b1f88f676f1c3b535a7eb473255fdc0b83fc"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6f17be4345073b0a7b8ea599688f692ac3ef23ce28e5df79c04de519dbc4912c"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e2b1fac190ae3ebfe37b979cc1ce69c81f4e4fe5746bb401dca63a9062cdaf1"}, + {file = "cffi-1.17.1-cp38-cp38-win32.whl", hash = "sha256:7596d6620d3fa590f677e9ee430df2958d2d6d6de2feeae5b20e82c00b76fbf8"}, + {file = "cffi-1.17.1-cp38-cp38-win_amd64.whl", hash = "sha256:78122be759c3f8a014ce010908ae03364d00a1f81ab5c7f4a7a5120607ea56e1"}, + {file = "cffi-1.17.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b2ab587605f4ba0bf81dc0cb08a41bd1c0a5906bd59243d56bad7668a6fc6c16"}, + {file = "cffi-1.17.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:28b16024becceed8c6dfbc75629e27788d8a3f9030691a1dbf9821a128b22c36"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1d599671f396c4723d016dbddb72fe8e0397082b0a77a4fab8028923bec050e8"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca74b8dbe6e8e8263c0ffd60277de77dcee6c837a3d0881d8c1ead7268c9e576"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f7f5baafcc48261359e14bcd6d9bff6d4b28d9103847c9e136694cb0501aef87"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:98e3969bcff97cae1b2def8ba499ea3d6f31ddfdb7635374834cf89a1a08ecf0"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cdf5ce3acdfd1661132f2a9c19cac174758dc2352bfe37d98aa7512c6b7178b3"}, + {file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9755e4345d1ec879e3849e62222a18c7174d65a6a92d5b346b1863912168b595"}, + {file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f1e22e8c4419538cb197e4dd60acc919d7696e5ef98ee4da4e01d3f8cfa4cc5a"}, + {file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c03e868a0b3bc35839ba98e74211ed2b05d2119be4e8a0f224fba9384f1fe02e"}, + {file = "cffi-1.17.1-cp39-cp39-win32.whl", hash = "sha256:e31ae45bc2e29f6b2abd0de1cc3b9d5205aa847cafaecb8af1476a609a2f6eb7"}, + {file = "cffi-1.17.1-cp39-cp39-win_amd64.whl", hash = "sha256:d016c76bdd850f3c626af19b0542c9677ba156e4ee4fccfdd7848803533ef662"}, + {file = "cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824"}, ] [package.dependencies] @@ -1113,101 +1255,116 @@ files = [ [[package]] name = "charset-normalizer" -version = "3.3.2" +version = "3.4.0" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false python-versions = ">=3.7.0" files = [ - {file = "charset-normalizer-3.3.2.tar.gz", hash = "sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-win32.whl", hash = "sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73"}, - {file = "charset_normalizer-3.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-win32.whl", hash = "sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab"}, - {file = "charset_normalizer-3.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-win32.whl", hash = "sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7"}, - {file = "charset_normalizer-3.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-win32.whl", hash = "sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4"}, - {file = "charset_normalizer-3.3.2-cp37-cp37m-win_amd64.whl", hash = "sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-win32.whl", hash = "sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25"}, - {file = "charset_normalizer-3.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-win32.whl", hash = "sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f"}, - {file = "charset_normalizer-3.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d"}, - {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4f9fc98dad6c2eaa32fc3af1417d95b5e3d08aff968df0cd320066def971f9a6"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0de7b687289d3c1b3e8660d0741874abe7888100efe14bd0f9fd7141bcbda92b"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5ed2e36c3e9b4f21dd9422f6893dec0abf2cca553af509b10cd630f878d3eb99"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40d3ff7fc90b98c637bda91c89d51264a3dcf210cade3a2c6f838c7268d7a4ca"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1110e22af8ca26b90bd6364fe4c763329b0ebf1ee213ba32b68c73de5752323d"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:86f4e8cca779080f66ff4f191a685ced73d2f72d50216f7112185dc02b90b9b7"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f683ddc7eedd742e2889d2bfb96d69573fde1d92fcb811979cdb7165bb9c7d3"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:27623ba66c183eca01bf9ff833875b459cad267aeeb044477fedac35e19ba907"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f606a1881d2663630ea5b8ce2efe2111740df4b687bd78b34a8131baa007f79b"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:0b309d1747110feb25d7ed6b01afdec269c647d382c857ef4663bbe6ad95a912"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:136815f06a3ae311fae551c3df1f998a1ebd01ddd424aa5603a4336997629e95"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:14215b71a762336254351b00ec720a8e85cada43b987da5a042e4ce3e82bd68e"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:79983512b108e4a164b9c8d34de3992f76d48cadc9554c9e60b43f308988aabe"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-win32.whl", hash = "sha256:c94057af19bc953643a33581844649a7fdab902624d2eb739738a30e2b3e60fc"}, + {file = "charset_normalizer-3.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:55f56e2ebd4e3bc50442fbc0888c9d8c94e4e06a933804e2af3e89e2f9c1c749"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:0d99dd8ff461990f12d6e42c7347fd9ab2532fb70e9621ba520f9e8637161d7c"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c57516e58fd17d03ebe67e181a4e4e2ccab1168f8c2976c6a334d4f819fe5944"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6dba5d19c4dfab08e58d5b36304b3f92f3bd5d42c1a3fa37b5ba5cdf6dfcbcee"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf4475b82be41b07cc5e5ff94810e6a01f276e37c2d55571e3fe175e467a1a1c"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ce031db0408e487fd2775d745ce30a7cd2923667cf3b69d48d219f1d8f5ddeb6"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ff4e7cdfdb1ab5698e675ca622e72d58a6fa2a8aa58195de0c0061288e6e3ea"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3710a9751938947e6327ea9f3ea6332a09bf0ba0c09cae9cb1f250bd1f1549bc"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82357d85de703176b5587dbe6ade8ff67f9f69a41c0733cf2425378b49954de5"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:47334db71978b23ebcf3c0f9f5ee98b8d65992b65c9c4f2d34c2eaf5bcaf0594"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:8ce7fd6767a1cc5a92a639b391891bf1c268b03ec7e021c7d6d902285259685c"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:f1a2f519ae173b5b6a2c9d5fa3116ce16e48b3462c8b96dfdded11055e3d6365"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:63bc5c4ae26e4bc6be6469943b8253c0fd4e4186c43ad46e713ea61a0ba49129"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bcb4f8ea87d03bc51ad04add8ceaf9b0f085ac045ab4d74e73bbc2dc033f0236"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-win32.whl", hash = "sha256:9ae4ef0b3f6b41bad6366fb0ea4fc1d7ed051528e113a60fa2a65a9abb5b1d99"}, + {file = "charset_normalizer-3.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:cee4373f4d3ad28f1ab6290684d8e2ebdb9e7a1b74fdc39e4c211995f77bec27"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0713f3adb9d03d49d365b70b84775d0a0d18e4ab08d12bc46baa6132ba78aaf6"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:de7376c29d95d6719048c194a9cf1a1b0393fbe8488a22008610b0361d834ecf"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4a51b48f42d9358460b78725283f04bddaf44a9358197b889657deba38f329db"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b295729485b06c1a0683af02a9e42d2caa9db04a373dc38a6a58cdd1e8abddf1"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee803480535c44e7f5ad00788526da7d85525cfefaf8acf8ab9a310000be4b03"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d59d125ffbd6d552765510e3f31ed75ebac2c7470c7274195b9161a32350284"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8cda06946eac330cbe6598f77bb54e690b4ca93f593dee1568ad22b04f347c15"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07afec21bbbbf8a5cc3651aa96b980afe2526e7f048fdfb7f1014d84acc8b6d8"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6b40e8d38afe634559e398cc32b1472f376a4099c75fe6299ae607e404c033b2"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b8dcd239c743aa2f9c22ce674a145e0a25cb1566c495928440a181ca1ccf6719"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:84450ba661fb96e9fd67629b93d2941c871ca86fc38d835d19d4225ff946a631"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:44aeb140295a2f0659e113b31cfe92c9061622cadbc9e2a2f7b8ef6b1e29ef4b"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1db4e7fefefd0f548d73e2e2e041f9df5c59e178b4c72fbac4cc6f535cfb1565"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-win32.whl", hash = "sha256:5726cf76c982532c1863fb64d8c6dd0e4c90b6ece9feb06c9f202417a31f7dd7"}, + {file = "charset_normalizer-3.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:b197e7094f232959f8f20541ead1d9862ac5ebea1d58e9849c1bf979255dfac9"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:dd4eda173a9fcccb5f2e2bd2a9f423d180194b1bf17cf59e3269899235b2a114"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9e3c4c9e1ed40ea53acf11e2a386383c3304212c965773704e4603d589343ed"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:92a7e36b000bf022ef3dbb9c46bfe2d52c047d5e3f3343f43204263c5addc250"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:54b6a92d009cbe2fb11054ba694bc9e284dad30a26757b1e372a1fdddaf21920"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ffd9493de4c922f2a38c2bf62b831dcec90ac673ed1ca182fe11b4d8e9f2a64"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:35c404d74c2926d0287fbd63ed5d27eb911eb9e4a3bb2c6d294f3cfd4a9e0c23"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4796efc4faf6b53a18e3d46343535caed491776a22af773f366534056c4e1fbc"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e7fdd52961feb4c96507aa649550ec2a0d527c086d284749b2f582f2d40a2e0d"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:92db3c28b5b2a273346bebb24857fda45601aef6ae1c011c0a997106581e8a88"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ab973df98fc99ab39080bfb0eb3a925181454d7c3ac8a1e695fddfae696d9e90"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:4b67fdab07fdd3c10bb21edab3cbfe8cf5696f453afce75d815d9d7223fbe88b"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:aa41e526a5d4a9dfcfbab0716c7e8a1b215abd3f3df5a45cf18a12721d31cb5d"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ffc519621dce0c767e96b9c53f09c5d215578e10b02c285809f76509a3931482"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-win32.whl", hash = "sha256:f19c1585933c82098c2a520f8ec1227f20e339e33aca8fa6f956f6691b784e67"}, + {file = "charset_normalizer-3.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:707b82d19e65c9bd28b81dde95249b07bf9f5b90ebe1ef17d9b57473f8a64b7b"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:dbe03226baf438ac4fda9e2d0715022fd579cb641c4cf639fa40d53b2fe6f3e2"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd9a8bd8900e65504a305bf8ae6fa9fbc66de94178c420791d0293702fce2df7"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8831399554b92b72af5932cdbbd4ddc55c55f631bb13ff8fe4e6536a06c5c51"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a14969b8691f7998e74663b77b4c36c0337cb1df552da83d5c9004a93afdb574"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dcaf7c1524c0542ee2fc82cc8ec337f7a9f7edee2532421ab200d2b920fc97cf"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425c5f215d0eecee9a56cdb703203dda90423247421bf0d67125add85d0c4455"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:d5b054862739d276e09928de37c79ddeec42a6e1bfc55863be96a36ba22926f6"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:f3e73a4255342d4eb26ef6df01e3962e73aa29baa3124a8e824c5d3364a65748"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:2f6c34da58ea9c1a9515621f4d9ac379871a8f21168ba1b5e09d74250de5ad62"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_s390x.whl", hash = "sha256:f09cb5a7bbe1ecae6e87901a2eb23e0256bb524a79ccc53eb0b7629fbe7677c4"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:0099d79bdfcf5c1f0c2c72f91516702ebf8b0b8ddd8905f97a8aecf49712c621"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-win32.whl", hash = "sha256:9c98230f5042f4945f957d006edccc2af1e03ed5e37ce7c373f00a5a4daa6149"}, + {file = "charset_normalizer-3.4.0-cp37-cp37m-win_amd64.whl", hash = "sha256:62f60aebecfc7f4b82e3f639a7d1433a20ec32824db2199a11ad4f5e146ef5ee"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:af73657b7a68211996527dbfeffbb0864e043d270580c5aef06dc4b659a4b578"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cab5d0b79d987c67f3b9e9c53f54a61360422a5a0bc075f43cab5621d530c3b6"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9289fd5dddcf57bab41d044f1756550f9e7cf0c8e373b8cdf0ce8773dc4bd417"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b493a043635eb376e50eedf7818f2f322eabbaa974e948bd8bdd29eb7ef2a51"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fa2566ca27d67c86569e8c85297aaf413ffab85a8960500f12ea34ff98e4c41"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a8e538f46104c815be19c975572d74afb53f29650ea2025bbfaef359d2de2f7f"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fd30dc99682dc2c603c2b315bded2799019cea829f8bf57dc6b61efde6611c8"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2006769bd1640bdf4d5641c69a3d63b71b81445473cac5ded39740a226fa88ab"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:dc15e99b2d8a656f8e666854404f1ba54765871104e50c8e9813af8a7db07f12"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:ab2e5bef076f5a235c3774b4f4028a680432cded7cad37bba0fd90d64b187d19"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:4ec9dd88a5b71abfc74e9df5ebe7921c35cbb3b641181a531ca65cdb5e8e4dea"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:43193c5cda5d612f247172016c4bb71251c784d7a4d9314677186a838ad34858"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:aa693779a8b50cd97570e5a0f343538a8dbd3e496fa5dcb87e29406ad0299654"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-win32.whl", hash = "sha256:7706f5850360ac01d80c89bcef1640683cc12ed87f42579dab6c5d3ed6888613"}, + {file = "charset_normalizer-3.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:c3e446d253bd88f6377260d07c895816ebf33ffffd56c1c792b13bff9c3e1ade"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:980b4f289d1d90ca5efcf07958d3eb38ed9c0b7676bf2831a54d4f66f9c27dfa"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f28f891ccd15c514a0981f3b9db9aa23d62fe1a99997512b0491d2ed323d229a"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8aacce6e2e1edcb6ac625fb0f8c3a9570ccc7bfba1f63419b3769ccf6a00ed0"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd7af3717683bea4c87acd8c0d3d5b44d56120b26fd3f8a692bdd2d5260c620a"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5ff2ed8194587faf56555927b3aa10e6fb69d931e33953943bc4f837dfee2242"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e91f541a85298cf35433bf66f3fab2a4a2cff05c127eeca4af174f6d497f0d4b"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:309a7de0a0ff3040acaebb35ec45d18db4b28232f21998851cfa709eeff49d62"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:285e96d9d53422efc0d7a17c60e59f37fbf3dfa942073f666db4ac71e8d726d0"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:5d447056e2ca60382d460a604b6302d8db69476fd2015c81e7c35417cfabe4cd"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:20587d20f557fe189b7947d8e7ec5afa110ccf72a3128d61a2a387c3313f46be"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:130272c698667a982a5d0e626851ceff662565379baf0ff2cc58067b81d4f11d"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:ab22fbd9765e6954bc0bcff24c25ff71dcbfdb185fcdaca49e81bac68fe724d3"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:7782afc9b6b42200f7362858f9e73b1f8316afb276d316336c0ec3bd73312742"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-win32.whl", hash = "sha256:2de62e8801ddfff069cd5c504ce3bc9672b23266597d4e4f50eda28846c322f2"}, + {file = "charset_normalizer-3.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:95c3c157765b031331dd4db3c775e58deaee050a3042fcad72cbc4189d7c8dca"}, + {file = "charset_normalizer-3.4.0-py3-none-any.whl", hash = "sha256:fe9f97feb71aa9896b81973a7bbada8c49501dc73e58a10fcef6663af95e5079"}, + {file = "charset_normalizer-3.4.0.tar.gz", hash = "sha256:223217c3d4f82c3ac5e29032b3f1c2eb0fb591b72161f86d93f5719079dae93e"}, ] [[package]] @@ -1288,6 +1445,17 @@ typer = ">=0.9.0" typing-extensions = ">=4.5.0" uvicorn = {version = ">=0.18.3", extras = ["standard"]} +[[package]] +name = "circuitbreaker" +version = "2.0.0" +description = "Python Circuit Breaker pattern implementation" +optional = false +python-versions = "*" +files = [ + {file = "circuitbreaker-2.0.0-py2.py3-none-any.whl", hash = "sha256:c8c6f044b616cd5066368734ce4488020392c962b4bd2869d406d883c36d9859"}, + {file = "circuitbreaker-2.0.0.tar.gz", hash = "sha256:28110761ca81a2accbd6b33186bc8c433e69b0933d85e89f280028dbb8c1dd14"}, +] + [[package]] name = "click" version = "8.1.7" @@ -1370,77 +1538,77 @@ testing = ["pytest (>=7.2.1)", "pytest-cov (>=4.0.0)", "tox (>=4.4.3)"] [[package]] name = "clickhouse-connect" -version = "0.7.18" +version = "0.7.19" description = "ClickHouse Database Core Driver for Python, Pandas, and Superset" optional = false python-versions = "~=3.8" files = [ - {file = "clickhouse-connect-0.7.18.tar.gz", hash = "sha256:516aba1fdcf58973b0d0d90168a60c49f6892b6db1183b932f80ae057994eadb"}, - {file = "clickhouse_connect-0.7.18-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:43e712b8fada717160153022314473826adffde00e8cbe8068e0aa1c187c2395"}, - {file = "clickhouse_connect-0.7.18-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0a21244d24c9b2a7d1ea2cf23f254884113e0f6d9950340369ce154d7d377165"}, - {file = "clickhouse_connect-0.7.18-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:347b19f3674b57906dea94dd0e8b72aaedc822131cc2a2383526b19933ed7a33"}, - {file = "clickhouse_connect-0.7.18-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23c5aa1b144491211f662ed26f279845fb367c37d49b681b783ca4f8c51c7891"}, - {file = "clickhouse_connect-0.7.18-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e99b4271ed08cc59162a6025086f1786ded5b8a29f4c38e2d3b2a58af04f85f5"}, - {file = "clickhouse_connect-0.7.18-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:27d76d1dbe988350567dab7fbcc0a54cdd25abedc5585326c753974349818694"}, - {file = "clickhouse_connect-0.7.18-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:d2cd40b4e07df277192ab6bcb187b3f61e0074ad0e256908bf443b3080be4a6c"}, - {file = "clickhouse_connect-0.7.18-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8f4ae2c4fb66b2b49f2e7f893fe730712a61a068e79f7272e60d4dd7d64df260"}, - {file = "clickhouse_connect-0.7.18-cp310-cp310-win32.whl", hash = "sha256:ed871195b25a4e1acfd37f59527ceb872096f0cd65d76af8c91f581c033b1cc0"}, - {file = "clickhouse_connect-0.7.18-cp310-cp310-win_amd64.whl", hash = "sha256:0c4989012e434b9c167bddf9298ca6eb076593e48a2cab7347cd70a446a7b5d3"}, - {file = "clickhouse_connect-0.7.18-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:52cfcd77fc63561e7b51940e32900c13731513d703d7fc54a3a6eb1fa4f7be4e"}, - {file = "clickhouse_connect-0.7.18-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:71d7bb9a24b0eacf8963044d6a1dd9e86dfcdd30afe1bd4a581c00910c83895a"}, - {file = "clickhouse_connect-0.7.18-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:395cfe09d1d39be4206fc1da96fe316f270077791f9758fcac44fd2765446dba"}, - {file = "clickhouse_connect-0.7.18-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac55b2b2eb068b02cbb1afbfc8b2255734e28a646d633c43a023a9b95e08023b"}, - {file = "clickhouse_connect-0.7.18-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4d59bb1df3814acb321f0fe87a4a6eea658463d5e59f6dc8ae10072df1205591"}, - {file = "clickhouse_connect-0.7.18-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:da5ea738641a7ad0ab7a8e1d8d6234639ea1e61c6eac970bbc6b94547d2c2fa7"}, - {file = "clickhouse_connect-0.7.18-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:72eb32a75026401777e34209694ffe64db0ce610475436647ed45589b4ab4efe"}, - {file = "clickhouse_connect-0.7.18-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:43bdd638b1ff27649d0ed9ed5000a8b8d754be891a8d279b27c72c03e3d12dcb"}, - {file = "clickhouse_connect-0.7.18-cp311-cp311-win32.whl", hash = "sha256:f45bdcba1dc84a1f60a8d827310f615ecbc322518c2d36bba7bf878631007152"}, - {file = "clickhouse_connect-0.7.18-cp311-cp311-win_amd64.whl", hash = "sha256:6df629ab4b646a49a74e791e14a1b6a73ccbe6c4ee25f864522588d376b66279"}, - {file = "clickhouse_connect-0.7.18-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:32a35e1e63e4ae708432cbe29c8d116518d2d7b9ecb575b912444c3078b20e20"}, - {file = "clickhouse_connect-0.7.18-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:357529b8c08305ab895cdc898b60a3dc9b36637dfa4dbfedfc1d00548fc88edc"}, - {file = "clickhouse_connect-0.7.18-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2aa124d2bb65e29443779723e52398e8724e4bf56db94c9a93fd8208b9d6e2bf"}, - {file = "clickhouse_connect-0.7.18-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e3646254607e38294e20bf2e20b780b1c3141fb246366a1ad2021531f2c9c1b"}, - {file = "clickhouse_connect-0.7.18-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:433e50309af9d46d1b52e5b93ea105332565558be35296c7555c9c2753687586"}, - {file = "clickhouse_connect-0.7.18-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:251e67753909f76f8b136cad734501e0daf5977ed62747e18baa2b187f41c92c"}, - {file = "clickhouse_connect-0.7.18-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:a9980916495da3ed057e56ce2c922fc23de614ea5d74ed470b8450b58902ccee"}, - {file = "clickhouse_connect-0.7.18-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:555e00660c04a524ea00409f783265ccd0d0192552eb9d4dc10d2aeaf2fa6575"}, - {file = "clickhouse_connect-0.7.18-cp312-cp312-win32.whl", hash = "sha256:f4770c100f0608511f7e572b63a6b222fb780fc67341c11746d361c2b03d36d3"}, - {file = "clickhouse_connect-0.7.18-cp312-cp312-win_amd64.whl", hash = "sha256:fd44a7885d992410668d083ba38d6a268a1567f49709300b4ff84eb6aef63b70"}, - {file = "clickhouse_connect-0.7.18-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9ac122dcabe1a9d3c14d331fade70a0adc78cf4006c8b91ee721942cdaa1190e"}, - {file = "clickhouse_connect-0.7.18-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1e89db8e8cc9187f2e9cd6aa32062f67b3b4de7b21b8703f103e89d659eda736"}, - {file = "clickhouse_connect-0.7.18-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c34bb25e5ab9a97a4154d43fdcd16751c9aa4a6e6f959016e4c5fe5b692728ed"}, - {file = "clickhouse_connect-0.7.18-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:929441a6689a78c63c6a05ee7eb39a183601d93714835ebd537c0572101f7ab1"}, - {file = "clickhouse_connect-0.7.18-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e8852df54b04361e57775d8ae571cd87e6983f7ed968890c62bbba6a2f2c88fd"}, - {file = "clickhouse_connect-0.7.18-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:56333eb772591162627455e2c21c8541ed628a9c6e7c115193ad00f24fc59440"}, - {file = "clickhouse_connect-0.7.18-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:ac6633d2996100552d2ae47ac5e4eb551e11f69d05637ea84f1e13ac0f2bc21a"}, - {file = "clickhouse_connect-0.7.18-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:265085ab548fb49981fe2aef9f46652ee24d5583bf12e652abb13ee2d7e77581"}, - {file = "clickhouse_connect-0.7.18-cp38-cp38-win32.whl", hash = "sha256:5ee6c1f74df5fb19b341c389cfed7535fb627cbb9cb1a9bdcbda85045b86cd49"}, - {file = "clickhouse_connect-0.7.18-cp38-cp38-win_amd64.whl", hash = "sha256:c7a28f810775ce68577181e752ecd2dc8caae77f288b6b9f6a7ce4d36657d4fb"}, - {file = "clickhouse_connect-0.7.18-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:67f9a3953693b609ab068071be5ac9521193f728b29057e913b386582f84b0c2"}, - {file = "clickhouse_connect-0.7.18-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:77e202b8606096769bf45e68b46e6bb8c78c2c451c29cb9b3a7bf505b4060d44"}, - {file = "clickhouse_connect-0.7.18-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8abcbd17f243ca8399a06fb08970d68e73d1ad671f84bb38518449248093f655"}, - {file = "clickhouse_connect-0.7.18-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:192605c2a9412e4c7d4baab85e432a58a0a5520615f05bc14f13c2836cfc6eeb"}, - {file = "clickhouse_connect-0.7.18-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c17108b190ab34645ee1981440ae129ecd7ca0cb6a93b4e5ce3ffc383355243f"}, - {file = "clickhouse_connect-0.7.18-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ac1be43360a6e602784eb60547a03a6c2c574744cb8982ec15aac0e0e57709bd"}, - {file = "clickhouse_connect-0.7.18-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:cf403781d4ffd5a47aa7eff591940df182de4d9c423cfdc7eb6ade1a1b100e22"}, - {file = "clickhouse_connect-0.7.18-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:937c6481ec083e2a0bcf178ea363b72d437ab0c8fcbe65143db64b12c1e077c0"}, - {file = "clickhouse_connect-0.7.18-cp39-cp39-win32.whl", hash = "sha256:77635fea4b3fc4b1568a32674f04d35f4e648e3180528a9bb776e46e76090e4a"}, - {file = "clickhouse_connect-0.7.18-cp39-cp39-win_amd64.whl", hash = "sha256:5ef60eb76be54b6d6bd8f189b076939e2cca16b50b92b763e7a9c7a62b488045"}, - {file = "clickhouse_connect-0.7.18-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:7bf76743d7b92b6cac6b4ef2e7a4c2d030ecf2fd542fcfccb374b2432b8d1027"}, - {file = "clickhouse_connect-0.7.18-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:65b344f174d63096eec098137b5d9c3bb545d67dd174966246c4aa80f9c0bc1e"}, - {file = "clickhouse_connect-0.7.18-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:24dcc19338cd540e6a3e32e8a7c72c5fc4930c0dd5a760f76af9d384b3e57ddc"}, - {file = "clickhouse_connect-0.7.18-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:31f5e42d5fd4eaab616926bae344c17202950d9d9c04716d46bccce6b31dbb73"}, - {file = "clickhouse_connect-0.7.18-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:a890421403c7a59ef85e3afc4ff0d641c5553c52fbb9d6ce30c0a0554649fac6"}, - {file = "clickhouse_connect-0.7.18-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:d61de71d2b82446dd66ade1b925270366c36a2b11779d5d1bcf71b1bfdd161e6"}, - {file = "clickhouse_connect-0.7.18-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e81c4f2172e8d6f3dc4dd64ff2dc426920c0caeed969b4ec5bdd0b2fad1533e4"}, - {file = "clickhouse_connect-0.7.18-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:092cb8e8acdcccce01d239760405fbd8c266052def49b13ad0a96814f5e521ca"}, - {file = "clickhouse_connect-0.7.18-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a1ae8b1bab7f06815abf9d833a66849faa2b9dfadcc5728fd14c494e2879afa8"}, - {file = "clickhouse_connect-0.7.18-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e08ebec4db83109024c97ca2d25740bf57915160d7676edd5c4390777c3e3ec0"}, - {file = "clickhouse_connect-0.7.18-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:e5e42ec23b59597b512b994fec68ac1c2fa6def8594848cc3ae2459cf5e9d76a"}, - {file = "clickhouse_connect-0.7.18-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1aad4543a1ae4d40dc815ef85031a1809fe101687380d516383b168a7407ab2"}, - {file = "clickhouse_connect-0.7.18-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46cb4c604bd696535b1e091efb8047b833ff4220d31dbd95558c3587fda533a7"}, - {file = "clickhouse_connect-0.7.18-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:05e1ef335b81bf6b5908767c3b55e842f1f8463742992653551796eeb8f2d7d6"}, - {file = "clickhouse_connect-0.7.18-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:094e089de4a50a170f5fd1c0ebb2ea357e055266220bb11dfd7ddf2d4e9c9123"}, + {file = "clickhouse-connect-0.7.19.tar.gz", hash = "sha256:ce8f21f035781c5ef6ff57dc162e8150779c009b59f14030ba61f8c9c10c06d0"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6ac74eb9e8d6331bae0303d0fc6bdc2125aa4c421ef646348b588760b38c29e9"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:300f3dea7dd48b2798533ed2486e4b0c3bb03c8d9df9aed3fac44161b92a30f9"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c72629f519105e21600680c791459d729889a290440bbdc61e43cd5eb61d928"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ece0fb202cd9267b3872210e8e0974e4c33c8f91ca9f1c4d92edea997189c72"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a6e5adf0359043d4d21c9a668cc1b6323a1159b3e1a77aea6f82ce528b5e4c5b"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:63432180179e90f6f3c18861216f902d1693979e3c26a7f9ef9912c92ce00d14"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:754b9c58b032835caaa9177b69059dc88307485d2cf6d0d545b3dedb13cb512a"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:24e2694e89d12bba405a14b84c36318620dc50f90adbc93182418742d8f6d73f"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-win32.whl", hash = "sha256:52929826b39b5b0f90f423b7a035930b8894b508768e620a5086248bcbad3707"}, + {file = "clickhouse_connect-0.7.19-cp310-cp310-win_amd64.whl", hash = "sha256:5c301284c87d132963388b6e8e4a690c0776d25acc8657366eccab485e53738f"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ee47af8926a7ec3a970e0ebf29a82cbbe3b1b7eae43336a81b3a0ca18091de5f"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ce429233b2d21a8a149c8cd836a2555393cbcf23d61233520db332942ffb8964"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:617c04f5c46eed3344a7861cd96fb05293e70d3b40d21541b1e459e7574efa96"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f08e33b8cc2dc1873edc5ee4088d4fc3c0dbb69b00e057547bcdc7e9680b43e5"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:921886b887f762e5cc3eef57ef784d419a3f66df85fd86fa2e7fbbf464c4c54a"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6ad0cf8552a9e985cfa6524b674ae7c8f5ba51df5bd3ecddbd86c82cdbef41a7"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:70f838ef0861cdf0e2e198171a1f3fd2ee05cf58e93495eeb9b17dfafb278186"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c5f0d207cb0dcc1adb28ced63f872d080924b7562b263a9d54d4693b670eb066"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-win32.whl", hash = "sha256:8c96c4c242b98fcf8005e678a26dbd4361748721b6fa158c1fe84ad15c7edbbe"}, + {file = "clickhouse_connect-0.7.19-cp311-cp311-win_amd64.whl", hash = "sha256:bda092bab224875ed7c7683707d63f8a2322df654c4716e6611893a18d83e908"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8f170d08166438d29f0dcfc8a91b672c783dc751945559e65eefff55096f9274"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:26b80cb8f66bde9149a9a2180e2cc4895c1b7d34f9dceba81630a9b9a9ae66b2"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9ba80e3598acf916c4d1b2515671f65d9efee612a783c17c56a5a646f4db59b9"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d38c30bd847af0ce7ff738152478f913854db356af4d5824096394d0eab873d"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d41d4b159071c0e4f607563932d4fa5c2a8fc27d3ba1200d0929b361e5191864"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3682c2426f5dbda574611210e3c7c951b9557293a49eb60a7438552435873889"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6d492064dca278eb61be3a2d70a5f082e2ebc8ceebd4f33752ae234116192020"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:62612da163b934c1ff35df6155a47cf17ac0e2d2f9f0f8f913641e5c02cdf39f"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-win32.whl", hash = "sha256:196e48c977affc045794ec7281b4d711e169def00535ecab5f9fdeb8c177f149"}, + {file = "clickhouse_connect-0.7.19-cp312-cp312-win_amd64.whl", hash = "sha256:b771ca6a473d65103dcae82810d3a62475c5372fc38d8f211513c72b954fb020"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:85a016eebff440b76b90a4725bb1804ddc59e42bba77d21c2a2ec4ac1df9e28d"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f059d3e39be1bafbf3cf0e12ed19b3cbf30b468a4840ab85166fd023ce8c3a17"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39ed54ba0998fd6899fcc967af2b452da28bd06de22e7ebf01f15acbfd547eac"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e4b4d786572cb695a087a71cfdc53999f76b7f420f2580c9cffa8cc51442058"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3710ca989ceae03d5ae56a436b4fe246094dbc17a2946ff318cb460f31b69450"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:d104f25a054cb663495a51ccb26ea11bcdc53e9b54c6d47a914ee6fba7523e62"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:ee23b80ee4c5b05861582dd4cd11f0ca0d215a899e9ba299a6ec6e9196943b1b"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:942ec21211d369068ab0ac082312d4df53c638bfc41545d02c41a9055e212df8"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-win32.whl", hash = "sha256:cb8f0a59d1521a6b30afece7c000f6da2cd9f22092e90981aa83342032e5df99"}, + {file = "clickhouse_connect-0.7.19-cp38-cp38-win_amd64.whl", hash = "sha256:98d5779dba942459d5dc6aa083e3a8a83e1cf6191eaa883832118ad7a7e69c87"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9f57aaa32d90f3bd18aa243342b3e75f062dc56a7f988012a22f65fb7946e81d"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5fb25143e4446d3a73fdc1b7d976a0805f763c37bf8f9b2d612a74f65d647830"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b4e19c9952b7b9fe24a99cca0b36a37e17e2a0e59b14457a2ce8868aa32e30e"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9876509aa25804f1377cb1b54dd55c1f5f37a9fbc42fa0c4ac8ac51b38db5926"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:04cfb1dae8fb93117211cfe4e04412b075e47580391f9eee9a77032d8e7d46f4"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b04f7c57f61b5dfdbf49d4b5e4fa5e91ce86bee09bb389b641268afa8f511ab4"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:e5b563f32dcc9cb6ff1f6ed238e83c3e80eb15814b1ea130817c004c241a3c2e"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6018675a231130bd03a7b39a3e875e683286d98115085bfa3ac0918f555f4bfe"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-win32.whl", hash = "sha256:5cb67ae3309396033b825626d60fe2cd789c1d2a183faabef8ffdbbef153d7fb"}, + {file = "clickhouse_connect-0.7.19-cp39-cp39-win_amd64.whl", hash = "sha256:fd225af60478c068cde0952e8df8f731f24c828b75cc1a2e61c21057ff546ecd"}, + {file = "clickhouse_connect-0.7.19-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:6f31898e0281f820e35710b5c4ad1d40a6c01ffae5278afaef4a16877ac8cbfb"}, + {file = "clickhouse_connect-0.7.19-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51c911b0b8281ab4a909320f41dd9c0662796bec157c8f2704de702c552104db"}, + {file = "clickhouse_connect-0.7.19-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1088da11789c519f9bb8927a14b16892e3c65e2893abe2680eae68bf6c63835"}, + {file = "clickhouse_connect-0.7.19-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:03953942cc073078b40619a735ebeaed9bf98efc71c6f43ce92a38540b1308ce"}, + {file = "clickhouse_connect-0.7.19-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:4ac0602fa305d097a0cd40cebbe10a808f6478c9f303d57a48a3a0ad09659544"}, + {file = "clickhouse_connect-0.7.19-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:4fdefe9eb2d38063835f8f1f326d666c3f61de9d6c3a1607202012c386ca7631"}, + {file = "clickhouse_connect-0.7.19-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ff6469822fe8d83f272ffbb3fb99dcc614e20b1d5cddd559505029052eff36e7"}, + {file = "clickhouse_connect-0.7.19-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46298e23f7e7829f0aa880a99837a82390c1371a643b21f8feb77702707b9eaa"}, + {file = "clickhouse_connect-0.7.19-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c6409390b13e09c19435ff65e2ebfcf01f9b2382e4b946191979a5d54ef8625c"}, + {file = "clickhouse_connect-0.7.19-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:cd7e7097b30b70eb695b7b3b6c79ba943548c053cc465fa74efa67a2354f6acd"}, + {file = "clickhouse_connect-0.7.19-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:15e080aead66e43c1f214b3e76ab26e3f342a4a4f50e3bbc3118bdd013d12e5f"}, + {file = "clickhouse_connect-0.7.19-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:194d2a32ba1b370cb5ac375dd4153871bb0394ff040344d8f449cb36ea951a96"}, + {file = "clickhouse_connect-0.7.19-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ac93aafd6a542fdcad4a2b6778575eab6dbdbf8806e86d92e1c1aa00d91cfee"}, + {file = "clickhouse_connect-0.7.19-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b208dd3e29db7154b02652c26157a1903bea03d27867ca5b749edc2285c62161"}, + {file = "clickhouse_connect-0.7.19-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:9724fdf3563b2335791443cb9e2114be7f77c20c8c4bbfb3571a3020606f0773"}, ] [package.dependencies] @@ -1459,127 +1627,16 @@ sqlalchemy = ["sqlalchemy (>1.3.21,<2.0)"] tzlocal = ["tzlocal (>=4.0)"] [[package]] -name = "clickhouse-driver" -version = "0.2.8" -description = "Python driver with native interface for ClickHouse" -optional = false -python-versions = "<4,>=3.7" -files = [ - {file = "clickhouse-driver-0.2.8.tar.gz", hash = "sha256:844b3080e558acbacd42ee569ec83ca7aaa3728f7077b9314c8d09aaa393d752"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3a3a708e020ed2df59e424631f1822ffef4353912fcee143f3b7fc34e866621d"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d258d3c3ac0f03527e295eeaf3cebb0a976bc643f6817ccd1d0d71ce970641b4"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f63fb64a55dea29ed6a7d1d6805ebc95c37108c8a36677bc045d904ad600828"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1b16d5dbd53fe32a99d3c4ab6c478c8aa9ae02aec5a2bd2f24180b0b4c03e1a5"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ad2e1850ce91301ae203bc555fb83272dfebb09ad4df99db38c608d45fc22fa4"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ae9239f61a18050164185ec0a3e92469d084377a66ae033cc6b4efa15922867"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8f222f2577bf304e86eec73dbca9c19d7daa6abcafc0bef68bbf31dd461890b"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:451ac3de1191531d030751b05f122219b93b3c509e781fad81c2c91f0e9256b6"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:5a2c4fea88e91f1d5217b760ffea84631e647d8db2265b821cbe7b0e015c7807"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:19825a3044c48ab65dc6659eb9763e2f0821887bdd9ee14a2f9ae8c539281ebf"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ae13044a10015225297868658a6f1843c2e34b9fcaa6268880e25c4fca9f3c4d"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:548a77efb86012800e76db6d45b3dcffea9a1a26fa3d5fd42021298f0b9a6f16"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-win32.whl", hash = "sha256:ebe4328eaaf937365114b5bab5626600ee57e57d4d099ba2ddbae48c2493f73d"}, - {file = "clickhouse_driver-0.2.8-cp310-cp310-win_amd64.whl", hash = "sha256:7beaeb4d7e6c3aba7e02375eeca85b20cc8e54dc31fcdb25d3c4308f2cd9465f"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8e06ef6bb701c8e42a9c686d77ad30805cf431bb79fa8fe0f4d3dee819e9a12c"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4afbcfa557419ed1783ecde3abbee1134e09b26c3ab0ada5b2118ae587357c2b"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85f628b4bf6db0fe8fe13da8576a9b95c23b463dff59f4c7aa58cedf529d7d97"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:036f4b3283796ca51610385c7b24bdac1bb873f8a2e97a179f66544594aa9840"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2c8916d3d324ce8fd31f8dedd293dc2c29204b94785a5398d1ec1e7ea4e16a26"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:30bee7cddd85c04ec49c753b53580364d907cc05c44daafe31b924a352e5e525"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:03c8a844f6b128348d099dc5d75fad70f4e85802d1649c1b835916ac94ae750a"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:33965329393fd7740b445758787ddacdf70f35fa3411f98a1a86918fff679a46"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:8cf85a7ebb0a56182c5b659602e20bae6b36c48a0edf518a6e6f56042d3fcee0"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:c10fd1f921ff82638cb9513b9b4acfb575b421c44ef6bf6cf57ee3c487b9d538"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:0a30d49bb6c34e3f5fe42e43dd6a7da0523ddfd05834ef02bd70b9363ea7de7e"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ea32c377a347b0801fc7f2b242f2ec7d78df58047097352672d0de5fbfa9e390"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-win32.whl", hash = "sha256:2a85529d1c0c3f2eedf7a4f736d0efc6e6c8032ac90ca5a63f7a067db58384fe"}, - {file = "clickhouse_driver-0.2.8-cp311-cp311-win_amd64.whl", hash = "sha256:1f438f83a7473ce7fe9c16cda8750e2fdda1b09fb87f0ec6b87a2b89acb13f24"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9b71bbef6ee08252cee0593329c8ca8e623547627807d38195331f476eaf8136"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f30b3dd388f28eb4052851effe671354db55aea87de748aaf607e7048f72413e"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e3bb27ce7ca61089c04dc04dbf207c9165d62a85eb9c99d1451fd686b6b773f9"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:59c04ec0b45602b6a63e0779ca7c3d3614be4710ec5ac7214da1b157d43527c5"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a61b14244993c7e0f312983455b7851576a85ab5a9fcc6374e75d2680a985e76"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c99a1b0b7759ccd1bf44c65210543c228ba704e3153014fd3aabfe56a227b1a5"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9f14d860088ab2c7eeb3782c9490ad3f6bf6b1e9235e9db9c3b0079cd4751ffa"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:303887a14a71faddcdee150bc8cde498c25c446b0a72ae586bd67d0c366dbff5"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:359814e4f989c138bfb83e3c81f8f88c8449721dcf32cb8cc25fdb86f4b53c99"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:42de61b4cf9053698b14dbe29e1e3d78cb0a7aaef874fd854df390de5c9cc1f1"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:3bf3089f220480e5a69cbec79f3b65c23afb5c2836e7285234140e5f237f2768"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:41daa4ae5ada22f10c758b0b3b477a51f5df56eef8569cff8e2275de6d9b1b96"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-win32.whl", hash = "sha256:03ea71c7167c6c38c3ba2bbed43615ce0c41ebf3bfa28d96ffcd93cd1cdd07d8"}, - {file = "clickhouse_driver-0.2.8-cp312-cp312-win_amd64.whl", hash = "sha256:76985286e10adb2115da116ae25647319bc485ad9e327cbc27296ccf0b052180"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:271529124914c439a5bbcf8a90e3101311d60c1813e03c0467e01fbabef489ee"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f8f499746bc027c6d05de09efa7b2e4f2241f66c1ac2d6b7748f90709b00e10"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f29f256520bb718c532e7fcd85250d4001f49acbaa9e6896bdf4a70d5557e2ef"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:104d062bdf7eab74e92efcbf72088b3241365242b4f119b3fe91057c4d80825c"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ee34ed08592a6eff5e176f42897c6ab4dfd8c07df16e9f392e18f1f2ee3fe3ca"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f5be9a8d89de881d5ea9d46f9d293caa72dbc7f40b105374cafd88f52b2099ea"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c57efc768fa87e83d6778e7bbd180dd1ff5d647044983ec7d238a8577bd25fa5"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:e1a003475f2d54e9fea8de86b57bc26b409c9efea3d298409ab831f194d62c3b"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:fba71cf41934a23156290a70ef794a5dadc642b21cc25eb13e1f99f2512c8594"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:7289b0e9d1019fed418c577963edd66770222554d1da0c491ca436593667256e"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:16e810cc9be18fdada545b9a521054214dd607bb7aa2f280ca488da23a077e48"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-win32.whl", hash = "sha256:ed4a6590015f18f414250149255dc2ae81ae956b6e670b290d52c2ecb61ed517"}, - {file = "clickhouse_driver-0.2.8-cp37-cp37m-win_amd64.whl", hash = "sha256:9d454f16ccf1b2185cc630f6fb2160b1abde27759c4e94c42e30b9ea911d58f0"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2e487d49c24448873a6802c34aa21858b9e3fb4a2605268a980a5c02b54a6bae"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e877de75b97ddb11a027a7499171ea0aa9cad569b18fce53c9d508353000cfae"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c60dcefddf6e2c65c92b7e6096c222ff6ed73b01b6c5712f9ce8a23f2ec80f1a"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:422cbbabfad3f9b533d9f517f6f4e174111a613cba878402f7ef632b0eadec3a"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ff8a8e25ff6051ff3d0528dbe36305b0140075d2fa49432149ee2a7841f23ed"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19c7a5960d4f7f9a8f9a560ae05020ff5afe874b565cce06510586a0096bb626"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f5b3333257b46f307b713ba507e4bf11b7531ba3765a4150924532298d645ffd"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:bbc2252a697c674e1b8b6123cf205d2b15979eddf74e7ada0e62a0ecc81a75c3"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:af7f1a9a99dafb0f2a91d1a2d4a3e37f86076147d59abbe69b28d39308fe20fb"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:580c34cc505c492a8abeacbd863ce46158643bece914d8fe2fadea0e94c4e0c1"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:5b905eaa6fd3b453299f946a2c8f4a6392f379597e51e46297c6a37699226cda"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6e2b5891c52841aedf803b8054085eb8a611ad4bf57916787a1a9aabf618fb77"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-win32.whl", hash = "sha256:b58a5612db8b3577dc2ae6fda4c783d61c2376396bb364545530aa6a767f166d"}, - {file = "clickhouse_driver-0.2.8-cp38-cp38-win_amd64.whl", hash = "sha256:96b0424bb5dd698c10b899091562a78f4933a9a039409f310fb74db405d73854"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:22cbed52daa584ca9a93efd772ee5c8c1f68ceaaeb21673985004ec2fd411c49"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e36156fe8a355fc830cc0ea1267c804c631c9dbd9b6accdca868a426213e5929"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c1341325f4180e1318d0d2cf0b268008ea250715c6f30a5ccce586860c000b5"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cb52161276f7d77d4af09f1aab97a16edf86014a89e3d9923f0a6b8fdaa12438"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0d1ccd47040c0a8753684a20a0f83b8a0820386889fdf460a3248e0eed142032"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fcda48e938d011e5f4dcebf965e6ec19e020e8efa207b98eeb99c12fa873236d"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2252ab3f8b3bbd705e1d7dc80395c7bea14f5ae51a268fc7be5328da77c0e200"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:e1b9ef3fa0cc6c9de77daa74a2f183186d0b5556c4f6870fc966a41fde6cae2b"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d0afa3c68fed6b5e6f23eb3f053d3aba86d09dbbc7706a0120ab5595d5c37003"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:102027bb687ff7a978f7110348f39f0dce450ab334787edbc64b8a9927238e32"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:9fc1ae52a171ded7d9f1f971b9b5bb0ce4d0490a54e102f3717cea51011d0308"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5a62c691be83b1da72ff3455790b50b0f894b7932ac962a8133f3f9c04c943b3"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-win32.whl", hash = "sha256:8b5068cef07cfba5be25a9a461c010ce7a0fe2de5b0b0262c6030684f43fa7f5"}, - {file = "clickhouse_driver-0.2.8-cp39-cp39-win_amd64.whl", hash = "sha256:cd71965d00b0f3ba992652d577b1d46b87100a67b3e0dc5c191c88092e484c81"}, - {file = "clickhouse_driver-0.2.8-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:4db0812c43f67e7b1805c05e2bc08f7d670ddfd8d8c671c9b47cdb52f4f74129"}, - {file = "clickhouse_driver-0.2.8-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56622ffefe94a82d9a30747e3486819104d1310d7a94f0e37da461d7112e9864"}, - {file = "clickhouse_driver-0.2.8-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c47c8ed61b2f35bb29d991f66d6e03d5cc786def56533480331b2a584854dd5"}, - {file = "clickhouse_driver-0.2.8-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dec001a1a49b993522dd134d2fca161352f139d42edcda0e983b8ea8f5023cda"}, - {file = "clickhouse_driver-0.2.8-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:c03bd486540a6c03aa5a164b7ec6c50980df9642ab1ce22cb70327e4090bdc60"}, - {file = "clickhouse_driver-0.2.8-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:c059c3da454f0cc0a6f056b542a0c1784cd0398613d25326b11fd1c6f9f7e8d2"}, - {file = "clickhouse_driver-0.2.8-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc7f9677c637b710046ec6c6c0cab25b4c4ff21620e44f462041d7455e9e8d13"}, - {file = "clickhouse_driver-0.2.8-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba3f6b8fdd7a2e6a831ebbcaaf346f7c8c5eb5085a350c9d4d1ce7053a050b70"}, - {file = "clickhouse_driver-0.2.8-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:20c2db3ae29950c80837d270b5ab63c74597afce226b474930060cac7969287b"}, - {file = "clickhouse_driver-0.2.8-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:b7767019a301dad314e7b515046535a45eda84bd9c29590bc3e99b1c334f69e7"}, - {file = "clickhouse_driver-0.2.8-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:ba8b8b80fa8850546aa40acc952835b1f149af17182cdf3db4f2133b2a241fe8"}, - {file = "clickhouse_driver-0.2.8-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:924f11e87e3dcbbc1c9e8158af9917f182cd5e96d37385485d6268f59b564142"}, - {file = "clickhouse_driver-0.2.8-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c39e1477ad310a4d276db17c1e1cf6fb059c29eb8d21351afefd5a22de381c6"}, - {file = "clickhouse_driver-0.2.8-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e950b9a63af5fa233e3da0e57a7ebd85d4b319e65eef5f9daac84532836f4123"}, - {file = "clickhouse_driver-0.2.8-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:0698dc57373b2f42f3a95bd419d9fa07f2d02150f13a0db2909a2651208262b9"}, - {file = "clickhouse_driver-0.2.8-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:e0694ca2fb459c23e44036d975fe89544a7c9918618b5d8bda9a8aa2d24e5c37"}, - {file = "clickhouse_driver-0.2.8-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:62620348aeae5a905ccb8f7e6bff8d76aae9a95d81aa8c8f6fce0f2af7e104b8"}, - {file = "clickhouse_driver-0.2.8-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66276fd5092cccdd6f3123df4357a068fb1972b7e2622fab6f235948c50b6eed"}, - {file = "clickhouse_driver-0.2.8-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f86fe87327662b597824d0d7505cc600b0919473b22bbbd178a1a4d4e29283e1"}, - {file = "clickhouse_driver-0.2.8-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:54b9c6ff0aaabdcf7e80a6d9432459611b3413d6a66bec41cbcdad7212721cc7"}, +name = "cloudpickle" +version = "2.2.1" +description = "Extended pickling support for Python objects" +optional = false +python-versions = ">=3.6" +files = [ + {file = "cloudpickle-2.2.1-py3-none-any.whl", hash = "sha256:61f594d1f4c295fa5cd9014ceb3a1fc4a70b0de1164b94fbc2d854ccba056f9f"}, + {file = "cloudpickle-2.2.1.tar.gz", hash = "sha256:d89684b8de9e34a2a43b3460fbca07d09d6e25ce858df4d5a44240403b6178f5"}, ] -[package.dependencies] -pytz = "*" -tzlocal = "*" - -[package.extras] -lz4 = ["clickhouse-cityhash (>=1.0.2.1)", "lz4", "lz4 (<=3.0.1)"] -numpy = ["numpy (>=1.12.0)", "pandas (>=0.24.0)"] -zstd = ["clickhouse-cityhash (>=1.0.2.1)", "zstd"] - [[package]] name = "cloudscraper" version = "1.2.71" @@ -1646,66 +1703,87 @@ cron = ["capturer (>=2.4)"] [[package]] name = "contourpy" -version = "1.2.1" +version = "1.3.0" description = "Python library for calculating contours of 2D quadrilateral grids" optional = false python-versions = ">=3.9" files = [ - {file = "contourpy-1.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bd7c23df857d488f418439686d3b10ae2fbf9bc256cd045b37a8c16575ea1040"}, - {file = "contourpy-1.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5b9eb0ca724a241683c9685a484da9d35c872fd42756574a7cfbf58af26677fd"}, - {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c75507d0a55378240f781599c30e7776674dbaf883a46d1c90f37e563453480"}, - {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:11959f0ce4a6f7b76ec578576a0b61a28bdc0696194b6347ba3f1c53827178b9"}, - {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eb3315a8a236ee19b6df481fc5f997436e8ade24a9f03dfdc6bd490fea20c6da"}, - {file = "contourpy-1.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39f3ecaf76cd98e802f094e0d4fbc6dc9c45a8d0c4d185f0f6c2234e14e5f75b"}, - {file = "contourpy-1.2.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:94b34f32646ca0414237168d68a9157cb3889f06b096612afdd296003fdd32fd"}, - {file = "contourpy-1.2.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:457499c79fa84593f22454bbd27670227874cd2ff5d6c84e60575c8b50a69619"}, - {file = "contourpy-1.2.1-cp310-cp310-win32.whl", hash = "sha256:ac58bdee53cbeba2ecad824fa8159493f0bf3b8ea4e93feb06c9a465d6c87da8"}, - {file = "contourpy-1.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:9cffe0f850e89d7c0012a1fb8730f75edd4320a0a731ed0c183904fe6ecfc3a9"}, - {file = "contourpy-1.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6022cecf8f44e36af10bd9118ca71f371078b4c168b6e0fab43d4a889985dbb5"}, - {file = "contourpy-1.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ef5adb9a3b1d0c645ff694f9bca7702ec2c70f4d734f9922ea34de02294fdf72"}, - {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6150ffa5c767bc6332df27157d95442c379b7dce3a38dff89c0f39b63275696f"}, - {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4c863140fafc615c14a4bf4efd0f4425c02230eb8ef02784c9a156461e62c965"}, - {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:00e5388f71c1a0610e6fe56b5c44ab7ba14165cdd6d695429c5cd94021e390b2"}, - {file = "contourpy-1.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4492d82b3bc7fbb7e3610747b159869468079fe149ec5c4d771fa1f614a14df"}, - {file = "contourpy-1.2.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:49e70d111fee47284d9dd867c9bb9a7058a3c617274900780c43e38d90fe1205"}, - {file = "contourpy-1.2.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b59c0ffceff8d4d3996a45f2bb6f4c207f94684a96bf3d9728dbb77428dd8cb8"}, - {file = "contourpy-1.2.1-cp311-cp311-win32.whl", hash = "sha256:7b4182299f251060996af5249c286bae9361fa8c6a9cda5efc29fe8bfd6062ec"}, - {file = "contourpy-1.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2855c8b0b55958265e8b5888d6a615ba02883b225f2227461aa9127c578a4922"}, - {file = "contourpy-1.2.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:62828cada4a2b850dbef89c81f5a33741898b305db244904de418cc957ff05dc"}, - {file = "contourpy-1.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:309be79c0a354afff9ff7da4aaed7c3257e77edf6c1b448a779329431ee79d7e"}, - {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e785e0f2ef0d567099b9ff92cbfb958d71c2d5b9259981cd9bee81bd194c9a4"}, - {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1cac0a8f71a041aa587410424ad46dfa6a11f6149ceb219ce7dd48f6b02b87a7"}, - {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af3f4485884750dddd9c25cb7e3915d83c2db92488b38ccb77dd594eac84c4a0"}, - {file = "contourpy-1.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ce6889abac9a42afd07a562c2d6d4b2b7134f83f18571d859b25624a331c90b"}, - {file = "contourpy-1.2.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a1eea9aecf761c661d096d39ed9026574de8adb2ae1c5bd7b33558af884fb2ce"}, - {file = "contourpy-1.2.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:187fa1d4c6acc06adb0fae5544c59898ad781409e61a926ac7e84b8f276dcef4"}, - {file = "contourpy-1.2.1-cp312-cp312-win32.whl", hash = "sha256:c2528d60e398c7c4c799d56f907664673a807635b857df18f7ae64d3e6ce2d9f"}, - {file = "contourpy-1.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:1a07fc092a4088ee952ddae19a2b2a85757b923217b7eed584fdf25f53a6e7ce"}, - {file = "contourpy-1.2.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bb6834cbd983b19f06908b45bfc2dad6ac9479ae04abe923a275b5f48f1a186b"}, - {file = "contourpy-1.2.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1d59e739ab0e3520e62a26c60707cc3ab0365d2f8fecea74bfe4de72dc56388f"}, - {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd3db01f59fdcbce5b22afad19e390260d6d0222f35a1023d9adc5690a889364"}, - {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a12a813949e5066148712a0626895c26b2578874e4cc63160bb007e6df3436fe"}, - {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fe0ccca550bb8e5abc22f530ec0466136379c01321fd94f30a22231e8a48d985"}, - {file = "contourpy-1.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1d59258c3c67c865435d8fbeb35f8c59b8bef3d6f46c1f29f6123556af28445"}, - {file = "contourpy-1.2.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f32c38afb74bd98ce26de7cc74a67b40afb7b05aae7b42924ea990d51e4dac02"}, - {file = "contourpy-1.2.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d31a63bc6e6d87f77d71e1abbd7387ab817a66733734883d1fc0021ed9bfa083"}, - {file = "contourpy-1.2.1-cp39-cp39-win32.whl", hash = "sha256:ddcb8581510311e13421b1f544403c16e901c4e8f09083c881fab2be80ee31ba"}, - {file = "contourpy-1.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:10a37ae557aabf2509c79715cd20b62e4c7c28b8cd62dd7d99e5ed3ce28c3fd9"}, - {file = "contourpy-1.2.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a31f94983fecbac95e58388210427d68cd30fe8a36927980fab9c20062645609"}, - {file = "contourpy-1.2.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef2b055471c0eb466033760a521efb9d8a32b99ab907fc8358481a1dd29e3bd3"}, - {file = "contourpy-1.2.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:b33d2bc4f69caedcd0a275329eb2198f560b325605810895627be5d4b876bf7f"}, - {file = "contourpy-1.2.1.tar.gz", hash = "sha256:4d8908b3bee1c889e547867ca4cdc54e5ab6be6d3e078556814a22457f49423c"}, -] - -[package.dependencies] -numpy = ">=1.20" + {file = "contourpy-1.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:880ea32e5c774634f9fcd46504bf9f080a41ad855f4fef54f5380f5133d343c7"}, + {file = "contourpy-1.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:76c905ef940a4474a6289c71d53122a4f77766eef23c03cd57016ce19d0f7b42"}, + {file = "contourpy-1.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:92f8557cbb07415a4d6fa191f20fd9d2d9eb9c0b61d1b2f52a8926e43c6e9af7"}, + {file = "contourpy-1.3.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:36f965570cff02b874773c49bfe85562b47030805d7d8360748f3eca570f4cab"}, + {file = "contourpy-1.3.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cacd81e2d4b6f89c9f8a5b69b86490152ff39afc58a95af002a398273e5ce589"}, + {file = "contourpy-1.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:69375194457ad0fad3a839b9e29aa0b0ed53bb54db1bfb6c3ae43d111c31ce41"}, + {file = "contourpy-1.3.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:7a52040312b1a858b5e31ef28c2e865376a386c60c0e248370bbea2d3f3b760d"}, + {file = "contourpy-1.3.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3faeb2998e4fcb256542e8a926d08da08977f7f5e62cf733f3c211c2a5586223"}, + {file = "contourpy-1.3.0-cp310-cp310-win32.whl", hash = "sha256:36e0cff201bcb17a0a8ecc7f454fe078437fa6bda730e695a92f2d9932bd507f"}, + {file = "contourpy-1.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:87ddffef1dbe5e669b5c2440b643d3fdd8622a348fe1983fad7a0f0ccb1cd67b"}, + {file = "contourpy-1.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0fa4c02abe6c446ba70d96ece336e621efa4aecae43eaa9b030ae5fb92b309ad"}, + {file = "contourpy-1.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:834e0cfe17ba12f79963861e0f908556b2cedd52e1f75e6578801febcc6a9f49"}, + {file = "contourpy-1.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dbc4c3217eee163fa3984fd1567632b48d6dfd29216da3ded3d7b844a8014a66"}, + {file = "contourpy-1.3.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4865cd1d419e0c7a7bf6de1777b185eebdc51470800a9f42b9e9decf17762081"}, + {file = "contourpy-1.3.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:303c252947ab4b14c08afeb52375b26781ccd6a5ccd81abcdfc1fafd14cf93c1"}, + {file = "contourpy-1.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:637f674226be46f6ba372fd29d9523dd977a291f66ab2a74fbeb5530bb3f445d"}, + {file = "contourpy-1.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:76a896b2f195b57db25d6b44e7e03f221d32fe318d03ede41f8b4d9ba1bff53c"}, + {file = "contourpy-1.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e1fd23e9d01591bab45546c089ae89d926917a66dceb3abcf01f6105d927e2cb"}, + {file = "contourpy-1.3.0-cp311-cp311-win32.whl", hash = "sha256:d402880b84df3bec6eab53cd0cf802cae6a2ef9537e70cf75e91618a3801c20c"}, + {file = "contourpy-1.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:6cb6cc968059db9c62cb35fbf70248f40994dfcd7aa10444bbf8b3faeb7c2d67"}, + {file = "contourpy-1.3.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:570ef7cf892f0afbe5b2ee410c507ce12e15a5fa91017a0009f79f7d93a1268f"}, + {file = "contourpy-1.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:da84c537cb8b97d153e9fb208c221c45605f73147bd4cadd23bdae915042aad6"}, + {file = "contourpy-1.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0be4d8425bfa755e0fd76ee1e019636ccc7c29f77a7c86b4328a9eb6a26d0639"}, + {file = "contourpy-1.3.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9c0da700bf58f6e0b65312d0a5e695179a71d0163957fa381bb3c1f72972537c"}, + {file = "contourpy-1.3.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eb8b141bb00fa977d9122636b16aa67d37fd40a3d8b52dd837e536d64b9a4d06"}, + {file = "contourpy-1.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3634b5385c6716c258d0419c46d05c8aa7dc8cb70326c9a4fb66b69ad2b52e09"}, + {file = "contourpy-1.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0dce35502151b6bd35027ac39ba6e5a44be13a68f55735c3612c568cac3805fd"}, + {file = "contourpy-1.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:aea348f053c645100612b333adc5983d87be69acdc6d77d3169c090d3b01dc35"}, + {file = "contourpy-1.3.0-cp312-cp312-win32.whl", hash = "sha256:90f73a5116ad1ba7174341ef3ea5c3150ddf20b024b98fb0c3b29034752c8aeb"}, + {file = "contourpy-1.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:b11b39aea6be6764f84360fce6c82211a9db32a7c7de8fa6dd5397cf1d079c3b"}, + {file = "contourpy-1.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:3e1c7fa44aaae40a2247e2e8e0627f4bea3dd257014764aa644f319a5f8600e3"}, + {file = "contourpy-1.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:364174c2a76057feef647c802652f00953b575723062560498dc7930fc9b1cb7"}, + {file = "contourpy-1.3.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32b238b3b3b649e09ce9aaf51f0c261d38644bdfa35cbaf7b263457850957a84"}, + {file = "contourpy-1.3.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d51fca85f9f7ad0b65b4b9fe800406d0d77017d7270d31ec3fb1cc07358fdea0"}, + {file = "contourpy-1.3.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:732896af21716b29ab3e988d4ce14bc5133733b85956316fb0c56355f398099b"}, + {file = "contourpy-1.3.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d73f659398a0904e125280836ae6f88ba9b178b2fed6884f3b1f95b989d2c8da"}, + {file = "contourpy-1.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c6c7c2408b7048082932cf4e641fa3b8ca848259212f51c8c59c45aa7ac18f14"}, + {file = "contourpy-1.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f317576606de89da6b7e0861cf6061f6146ead3528acabff9236458a6ba467f8"}, + {file = "contourpy-1.3.0-cp313-cp313-win32.whl", hash = "sha256:31cd3a85dbdf1fc002280c65caa7e2b5f65e4a973fcdf70dd2fdcb9868069294"}, + {file = "contourpy-1.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:4553c421929ec95fb07b3aaca0fae668b2eb5a5203d1217ca7c34c063c53d087"}, + {file = "contourpy-1.3.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:345af746d7766821d05d72cb8f3845dfd08dd137101a2cb9b24de277d716def8"}, + {file = "contourpy-1.3.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:3bb3808858a9dc68f6f03d319acd5f1b8a337e6cdda197f02f4b8ff67ad2057b"}, + {file = "contourpy-1.3.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:420d39daa61aab1221567b42eecb01112908b2cab7f1b4106a52caaec8d36973"}, + {file = "contourpy-1.3.0-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4d63ee447261e963af02642ffcb864e5a2ee4cbfd78080657a9880b8b1868e18"}, + {file = "contourpy-1.3.0-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:167d6c890815e1dac9536dca00828b445d5d0df4d6a8c6adb4a7ec3166812fa8"}, + {file = "contourpy-1.3.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:710a26b3dc80c0e4febf04555de66f5fd17e9cf7170a7b08000601a10570bda6"}, + {file = "contourpy-1.3.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:75ee7cb1a14c617f34a51d11fa7524173e56551646828353c4af859c56b766e2"}, + {file = "contourpy-1.3.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:33c92cdae89ec5135d036e7218e69b0bb2851206077251f04a6c4e0e21f03927"}, + {file = "contourpy-1.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a11077e395f67ffc2c44ec2418cfebed032cd6da3022a94fc227b6faf8e2acb8"}, + {file = "contourpy-1.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e8134301d7e204c88ed7ab50028ba06c683000040ede1d617298611f9dc6240c"}, + {file = "contourpy-1.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e12968fdfd5bb45ffdf6192a590bd8ddd3ba9e58360b29683c6bb71a7b41edca"}, + {file = "contourpy-1.3.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fd2a0fc506eccaaa7595b7e1418951f213cf8255be2600f1ea1b61e46a60c55f"}, + {file = "contourpy-1.3.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4cfb5c62ce023dfc410d6059c936dcf96442ba40814aefbfa575425a3a7f19dc"}, + {file = "contourpy-1.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:68a32389b06b82c2fdd68276148d7b9275b5f5cf13e5417e4252f6d1a34f72a2"}, + {file = "contourpy-1.3.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:94e848a6b83da10898cbf1311a815f770acc9b6a3f2d646f330d57eb4e87592e"}, + {file = "contourpy-1.3.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:d78ab28a03c854a873787a0a42254a0ccb3cb133c672f645c9f9c8f3ae9d0800"}, + {file = "contourpy-1.3.0-cp39-cp39-win32.whl", hash = "sha256:81cb5ed4952aae6014bc9d0421dec7c5835c9c8c31cdf51910b708f548cf58e5"}, + {file = "contourpy-1.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:14e262f67bd7e6eb6880bc564dcda30b15e351a594657e55b7eec94b6ef72843"}, + {file = "contourpy-1.3.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:fe41b41505a5a33aeaed2a613dccaeaa74e0e3ead6dd6fd3a118fb471644fd6c"}, + {file = "contourpy-1.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eca7e17a65f72a5133bdbec9ecf22401c62bcf4821361ef7811faee695799779"}, + {file = "contourpy-1.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:1ec4dc6bf570f5b22ed0d7efba0dfa9c5b9e0431aeea7581aa217542d9e809a4"}, + {file = "contourpy-1.3.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:00ccd0dbaad6d804ab259820fa7cb0b8036bda0686ef844d24125d8287178ce0"}, + {file = "contourpy-1.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ca947601224119117f7c19c9cdf6b3ab54c5726ef1d906aa4a69dfb6dd58102"}, + {file = "contourpy-1.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c6ec93afeb848a0845a18989da3beca3eec2c0f852322efe21af1931147d12cb"}, + {file = "contourpy-1.3.0.tar.gz", hash = "sha256:7ffa0db17717a8ffb127efd0c95a4362d996b892c2904db72428d5b52e1938a4"}, +] + +[package.dependencies] +numpy = ">=1.23" [package.extras] bokeh = ["bokeh", "selenium"] docs = ["furo", "sphinx (>=7.2)", "sphinx-copybutton"] -mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.8.0)", "types-Pillow"] +mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.11.1)", "types-Pillow"] test = ["Pillow", "contourpy[test-no-images]", "matplotlib"] -test-no-images = ["pytest", "pytest-cov", "pytest-xdist", "wurlitzer"] +test-no-images = ["pytest", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "wurlitzer"] [[package]] name = "cos-python-sdk-v5" @@ -1724,6 +1802,46 @@ requests = ">=2.8" six = "*" xmltodict = "*" +[[package]] +name = "couchbase" +version = "4.3.3" +description = "Python Client for Couchbase" +optional = false +python-versions = ">=3.7" +files = [ + {file = "couchbase-4.3.3-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:d8069e4f01332859d56cca597874645c914699162b3979d1b432f0dfc186b124"}, + {file = "couchbase-4.3.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1caa6cfef49c785b35b1702102f718227f351df87bba2694b9334520c41e9eb5"}, + {file = "couchbase-4.3.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f4a9a65c44935249fa078fb90a3c28ea71da9d2d5889fcd514b12d0538010ae0"}, + {file = "couchbase-4.3.3-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4f144b8c482c18283d8e419b844630d41f3249b07d43d40b5e3535444e57d0fb"}, + {file = "couchbase-4.3.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:1c534fba6fdc7cf47eed9dee8a57d1e9eb867bf008574e321fa380a77cebf32f"}, + {file = "couchbase-4.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:b841be06e0e4370b69ebef6bca3409c378186f7d6e964cd645ba18e97216c022"}, + {file = "couchbase-4.3.3-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:eee7a73b3acbdc78ae314fddf7f975b3c9e05df07df255f4dcc878939a2abae0"}, + {file = "couchbase-4.3.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:53417cafcf90ff4e2fd81ebba2a08b7ad56f17160d1c5019ad3b09c758aeb363"}, + {file = "couchbase-4.3.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0cefd13bea8b0f150f1b9d27fd7614f971f77419b31817781d26ba315ed658bb"}, + {file = "couchbase-4.3.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:78fa1054d7740e2fe38fce0a2aab4e9a2d30263d894e0615ee5df297f02f59a3"}, + {file = "couchbase-4.3.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:eb093899cfad5a7472258a9b6a57775dbf23a6e0180241507ba89ce3ab241e41"}, + {file = "couchbase-4.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:f7cfbdc699af5715f49365ffbb05a6a7366a534c0d7161edf270ad3e735a6c5d"}, + {file = "couchbase-4.3.3-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:58352cae9b8affdaa2ac012e0a03c8c2632ee6297a878232888b4e0360d0d5df"}, + {file = "couchbase-4.3.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:728e7e3b5e1682706cb9d63993d289226d02a25089527b8ecb4e3889dabc38cf"}, + {file = "couchbase-4.3.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:73014bf098cf14187a39cc13453e0d859c1d54568df28f69cc308a9a5f24feb2"}, + {file = "couchbase-4.3.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a743375804068ae01b73c916bfca738764c8c12f381bb399ef04e784935856a1"}, + {file = "couchbase-4.3.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:394c122cfe02a76a99e7d5178e64129f6da49843225e78d8629abcab556c24af"}, + {file = "couchbase-4.3.3-cp312-cp312-win_amd64.whl", hash = "sha256:bf85d7a5cda548d9801614651206068b4445fa37972e62b14d7521a958198693"}, + {file = "couchbase-4.3.3-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:92d23c9cedd571631070791f2afee0e3d7d8c9ce1bf2ea6e9a4f2fdbc37a0f1e"}, + {file = "couchbase-4.3.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:38c42eb29a73cce2998ae5df45bd61b16dce9765d3bff968ec5cf6a622faa291"}, + {file = "couchbase-4.3.3-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:afed137bf0edc642d7b201b6ab7b1e7117bb4c8eac6b2f253cc6e106f334a2a1"}, + {file = "couchbase-4.3.3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:954d991377d47883aaf903934c5d0f19577680a2abf80d3ce5bb9b3c80991fc7"}, + {file = "couchbase-4.3.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d5552b9fa684630698dc98d6f3b1082540634c1b7ad5bf53b843b5da57b0169c"}, + {file = "couchbase-4.3.3-cp38-cp38-win_amd64.whl", hash = "sha256:f88f2b7e0c894f7237d9f3fb5c46abc44b8151a97b3ca8e75f57d23ebf59f9da"}, + {file = "couchbase-4.3.3-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:769e1e2367ea1d4de181fcd4b4e353e9abef97d15b581a6c5aea49ece3dc7d59"}, + {file = "couchbase-4.3.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:47f59a0b35ffce060583fd11f98f049f3b70701cf14aab9ac092594aca486aeb"}, + {file = "couchbase-4.3.3-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:440bb93d611827ba0ea2403c6f204fe931467a6cb5811f0e03bf1779204ef843"}, + {file = "couchbase-4.3.3-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cdb4dde62e1d41c0b8707121ab68fa78b7a1508541bd48fc850be396f91bc8d9"}, + {file = "couchbase-4.3.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:7f8cf45f317b39cc19db5c67b565662f08d6c90305b3aa14e04bc22707258213"}, + {file = "couchbase-4.3.3-cp39-cp39-win_amd64.whl", hash = "sha256:c97d48ad486c8f201b4482d5594258f949369cb44792ed148d5159a3d12ae21b"}, + {file = "couchbase-4.3.3.tar.gz", hash = "sha256:27808500551564b39b46943cf3daab572694889c1eb638425d363edb48b20da7"}, +] + [[package]] name = "coverage" version = "7.2.7" @@ -1808,38 +1926,38 @@ files = [ [[package]] name = "cryptography" -version = "43.0.0" +version = "43.0.3" description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." optional = false python-versions = ">=3.7" files = [ - {file = "cryptography-43.0.0-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:64c3f16e2a4fc51c0d06af28441881f98c5d91009b8caaff40cf3548089e9c74"}, - {file = "cryptography-43.0.0-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3dcdedae5c7710b9f97ac6bba7e1052b95c7083c9d0e9df96e02a1932e777895"}, - {file = "cryptography-43.0.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d9a1eca329405219b605fac09ecfc09ac09e595d6def650a437523fcd08dd22"}, - {file = "cryptography-43.0.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ea9e57f8ea880eeea38ab5abf9fbe39f923544d7884228ec67d666abd60f5a47"}, - {file = "cryptography-43.0.0-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:9a8d6802e0825767476f62aafed40532bd435e8a5f7d23bd8b4f5fd04cc80ecf"}, - {file = "cryptography-43.0.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:cc70b4b581f28d0a254d006f26949245e3657d40d8857066c2ae22a61222ef55"}, - {file = "cryptography-43.0.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:4a997df8c1c2aae1e1e5ac49c2e4f610ad037fc5a3aadc7b64e39dea42249431"}, - {file = "cryptography-43.0.0-cp37-abi3-win32.whl", hash = "sha256:6e2b11c55d260d03a8cf29ac9b5e0608d35f08077d8c087be96287f43af3ccdc"}, - {file = "cryptography-43.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:31e44a986ceccec3d0498e16f3d27b2ee5fdf69ce2ab89b52eaad1d2f33d8778"}, - {file = "cryptography-43.0.0-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:7b3f5fe74a5ca32d4d0f302ffe6680fcc5c28f8ef0dc0ae8f40c0f3a1b4fca66"}, - {file = "cryptography-43.0.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac1955ce000cb29ab40def14fd1bbfa7af2017cca696ee696925615cafd0dce5"}, - {file = "cryptography-43.0.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:299d3da8e00b7e2b54bb02ef58d73cd5f55fb31f33ebbf33bd00d9aa6807df7e"}, - {file = "cryptography-43.0.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ee0c405832ade84d4de74b9029bedb7b31200600fa524d218fc29bfa371e97f5"}, - {file = "cryptography-43.0.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:cb013933d4c127349b3948aa8aaf2f12c0353ad0eccd715ca789c8a0f671646f"}, - {file = "cryptography-43.0.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:fdcb265de28585de5b859ae13e3846a8e805268a823a12a4da2597f1f5afc9f0"}, - {file = "cryptography-43.0.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:2905ccf93a8a2a416f3ec01b1a7911c3fe4073ef35640e7ee5296754e30b762b"}, - {file = "cryptography-43.0.0-cp39-abi3-win32.whl", hash = "sha256:47ca71115e545954e6c1d207dd13461ab81f4eccfcb1345eac874828b5e3eaaf"}, - {file = "cryptography-43.0.0-cp39-abi3-win_amd64.whl", hash = "sha256:0663585d02f76929792470451a5ba64424acc3cd5227b03921dab0e2f27b1709"}, - {file = "cryptography-43.0.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:2c6d112bf61c5ef44042c253e4859b3cbbb50df2f78fa8fae6747a7814484a70"}, - {file = "cryptography-43.0.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:844b6d608374e7d08f4f6e6f9f7b951f9256db41421917dfb2d003dde4cd6b66"}, - {file = "cryptography-43.0.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:51956cf8730665e2bdf8ddb8da0056f699c1a5715648c1b0144670c1ba00b48f"}, - {file = "cryptography-43.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:aae4d918f6b180a8ab8bf6511a419473d107df4dbb4225c7b48c5c9602c38c7f"}, - {file = "cryptography-43.0.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:232ce02943a579095a339ac4b390fbbe97f5b5d5d107f8a08260ea2768be8cc2"}, - {file = "cryptography-43.0.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:5bcb8a5620008a8034d39bce21dc3e23735dfdb6a33a06974739bfa04f853947"}, - {file = "cryptography-43.0.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:08a24a7070b2b6804c1940ff0f910ff728932a9d0e80e7814234269f9d46d069"}, - {file = "cryptography-43.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:e9c5266c432a1e23738d178e51c2c7a5e2ddf790f248be939448c0ba2021f9d1"}, - {file = "cryptography-43.0.0.tar.gz", hash = "sha256:b88075ada2d51aa9f18283532c9f60e72170041bba88d7f37e49cbb10275299e"}, + {file = "cryptography-43.0.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:bf7a1932ac4176486eab36a19ed4c0492da5d97123f1406cf15e41b05e787d2e"}, + {file = "cryptography-43.0.3-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63efa177ff54aec6e1c0aefaa1a241232dcd37413835a9b674b6e3f0ae2bfd3e"}, + {file = "cryptography-43.0.3-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e1ce50266f4f70bf41a2c6dc4358afadae90e2a1e5342d3c08883df1675374f"}, + {file = "cryptography-43.0.3-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:443c4a81bb10daed9a8f334365fe52542771f25aedaf889fd323a853ce7377d6"}, + {file = "cryptography-43.0.3-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:74f57f24754fe349223792466a709f8e0c093205ff0dca557af51072ff47ab18"}, + {file = "cryptography-43.0.3-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9762ea51a8fc2a88b70cf2995e5675b38d93bf36bd67d91721c309df184f49bd"}, + {file = "cryptography-43.0.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:81ef806b1fef6b06dcebad789f988d3b37ccaee225695cf3e07648eee0fc6b73"}, + {file = "cryptography-43.0.3-cp37-abi3-win32.whl", hash = "sha256:cbeb489927bd7af4aa98d4b261af9a5bc025bd87f0e3547e11584be9e9427be2"}, + {file = "cryptography-43.0.3-cp37-abi3-win_amd64.whl", hash = "sha256:f46304d6f0c6ab8e52770addfa2fc41e6629495548862279641972b6215451cd"}, + {file = "cryptography-43.0.3-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:8ac43ae87929a5982f5948ceda07001ee5e83227fd69cf55b109144938d96984"}, + {file = "cryptography-43.0.3-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:846da004a5804145a5f441b8530b4bf35afbf7da70f82409f151695b127213d5"}, + {file = "cryptography-43.0.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f996e7268af62598f2fc1204afa98a3b5712313a55c4c9d434aef49cadc91d4"}, + {file = "cryptography-43.0.3-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:f7b178f11ed3664fd0e995a47ed2b5ff0a12d893e41dd0494f406d1cf555cab7"}, + {file = "cryptography-43.0.3-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:c2e6fc39c4ab499049df3bdf567f768a723a5e8464816e8f009f121a5a9f4405"}, + {file = "cryptography-43.0.3-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:e1be4655c7ef6e1bbe6b5d0403526601323420bcf414598955968c9ef3eb7d16"}, + {file = "cryptography-43.0.3-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:df6b6c6d742395dd77a23ea3728ab62f98379eff8fb61be2744d4679ab678f73"}, + {file = "cryptography-43.0.3-cp39-abi3-win32.whl", hash = "sha256:d56e96520b1020449bbace2b78b603442e7e378a9b3bd68de65c782db1507995"}, + {file = "cryptography-43.0.3-cp39-abi3-win_amd64.whl", hash = "sha256:0c580952eef9bf68c4747774cde7ec1d85a6e61de97281f2dba83c7d2c806362"}, + {file = "cryptography-43.0.3-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:d03b5621a135bffecad2c73e9f4deb1a0f977b9a8ffe6f8e002bf6c9d07b918c"}, + {file = "cryptography-43.0.3-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:a2a431ee15799d6db9fe80c82b055bae5a752bef645bba795e8e52687c69efe3"}, + {file = "cryptography-43.0.3-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:281c945d0e28c92ca5e5930664c1cefd85efe80e5c0d2bc58dd63383fda29f83"}, + {file = "cryptography-43.0.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:f18c716be16bc1fea8e95def49edf46b82fccaa88587a45f8dc0ff6ab5d8e0a7"}, + {file = "cryptography-43.0.3-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:4a02ded6cd4f0a5562a8887df8b3bd14e822a90f97ac5e544c162899bc467664"}, + {file = "cryptography-43.0.3-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:53a583b6637ab4c4e3591a15bc9db855b8d9dee9a669b550f311480acab6eb08"}, + {file = "cryptography-43.0.3-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1ec0bcf7e17c0c5669d881b1cd38c4972fade441b27bda1051665faaa89bdcaa"}, + {file = "cryptography-43.0.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2ce6fae5bdad59577b44e4dfed356944fbf1d925269114c28be377692643b4ff"}, + {file = "cryptography-43.0.3.tar.gz", hash = "sha256:315b9001266a492a6ff443b61238f956b214dbec9910a081ba5b6646a055a805"}, ] [package.dependencies] @@ -1852,7 +1970,7 @@ nox = ["nox"] pep8test = ["check-sdist", "click", "mypy", "ruff"] sdist = ["build"] ssh = ["bcrypt (>=3.1.5)"] -test = ["certifi", "cryptography-vectors (==43.0.0)", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"] +test = ["certifi", "cryptography-vectors (==43.0.3)", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"] test-randomorder = ["pytest-randomly"] [[package]] @@ -1901,13 +2019,13 @@ tokenizer = ["tiktoken"] [[package]] name = "dataclass-wizard" -version = "0.22.3" +version = "0.23.0" description = "Marshal dataclasses to/from JSON. Use field properties with initial values. Construct a dataclass schema with JSON input." optional = false python-versions = "*" files = [ - {file = "dataclass-wizard-0.22.3.tar.gz", hash = "sha256:4c46591782265058f1148cfd1f54a3a91221e63986fdd04c9d59f4ced61f4424"}, - {file = "dataclass_wizard-0.22.3-py2.py3-none-any.whl", hash = "sha256:63751203e54b9b9349212cc185331da73c1adc99c51312575eb73bb5c00c1962"}, + {file = "dataclass-wizard-0.23.0.tar.gz", hash = "sha256:da29ec19846d46a1eef0692ba7c59c8a86ecd3a9eaddc0511cfc7485ad6d9c50"}, + {file = "dataclass_wizard-0.23.0-py2.py3-none-any.whl", hash = "sha256:50207dec6d36494421366b49b7a9ba6a4d831e2650c0af25cb4c057103d4a97c"}, ] [package.extras] @@ -1915,17 +2033,6 @@ dev = ["Sphinx (==5.3.0)", "bump2version (==1.0.1)", "coverage (>=6.2)", "datacl timedelta = ["pytimeparse (>=1.1.7)"] yaml = ["PyYAML (>=5.3)"] -[[package]] -name = "dataclasses" -version = "0.6" -description = "A backport of the dataclasses module for Python 3.6" -optional = false -python-versions = "*" -files = [ - {file = "dataclasses-0.6-py3-none-any.whl", hash = "sha256:454a69d788c7fda44efd71e259be79577822f5e3f53f029a22d08004e951dc9f"}, - {file = "dataclasses-0.6.tar.gz", hash = "sha256:6988bd2b895eef432d562370bb707d540f32f7360ab13da45340101bc2307d84"}, -] - [[package]] name = "dataclasses-json" version = "0.6.7" @@ -1943,13 +2050,13 @@ typing-inspect = ">=0.4.0,<1" [[package]] name = "db-dtypes" -version = "1.2.0" +version = "1.3.0" description = "Pandas Data Types for SQL systems (BigQuery, Spanner)" optional = false python-versions = ">=3.7" files = [ - {file = "db-dtypes-1.2.0.tar.gz", hash = "sha256:3531bb1fb8b5fbab33121fe243ccc2ade16ab2524f4c113b05cc702a1908e6ea"}, - {file = "db_dtypes-1.2.0-py2.py3-none-any.whl", hash = "sha256:6320bddd31d096447ef749224d64aab00972ed20e4392d86f7d8b81ad79f7ff0"}, + {file = "db_dtypes-1.3.0-py2.py3-none-any.whl", hash = "sha256:7e65c59f849ccbe6f7bc4d0253edcc212a7907662906921caba3e4aadd0bc277"}, + {file = "db_dtypes-1.3.0.tar.gz", hash = "sha256:7bcbc8858b07474dc85b77bb2f3ae488978d1336f5ea73b58c39d9118bc3e91b"}, ] [package.dependencies] @@ -1958,6 +2065,17 @@ packaging = ">=17.0" pandas = ">=0.24.2" pyarrow = ">=3.0.0" +[[package]] +name = "decorator" +version = "5.1.1" +description = "Decorators for Humans" +optional = false +python-versions = ">=3.5" +files = [ + {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"}, + {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, +] + [[package]] name = "defusedxml" version = "0.7.1" @@ -1986,6 +2104,35 @@ wrapt = ">=1.10,<2" [package.extras] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] +[[package]] +name = "deprecation" +version = "2.1.0" +description = "A library to handle automated deprecations" +optional = false +python-versions = "*" +files = [ + {file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"}, + {file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"}, +] + +[package.dependencies] +packaging = "*" + +[[package]] +name = "dill" +version = "0.3.9" +description = "serialize all of Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "dill-0.3.9-py3-none-any.whl", hash = "sha256:468dff3b89520b474c0397703366b7b95eebe6303f108adf9b19da1f702be87a"}, + {file = "dill-0.3.9.tar.gz", hash = "sha256:81aa267dddf68cbfe8029c42ca9ec6a4ab3b22371d1c450abc54422577b4512c"}, +] + +[package.extras] +graph = ["objgraph (>=1.7.2)"] +profile = ["gprof2dot (>=2022.7.29)"] + [[package]] name = "distro" version = "1.9.0" @@ -1997,6 +2144,28 @@ files = [ {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, ] +[[package]] +name = "docker" +version = "7.1.0" +description = "A Python library for the Docker Engine API." +optional = false +python-versions = ">=3.8" +files = [ + {file = "docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0"}, + {file = "docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c"}, +] + +[package.dependencies] +pywin32 = {version = ">=304", markers = "sys_platform == \"win32\""} +requests = ">=2.26.0" +urllib3 = ">=1.26.0" + +[package.extras] +dev = ["coverage (==7.2.7)", "pytest (==7.4.2)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.1.0)", "ruff (==0.1.8)"] +docs = ["myst-parser (==0.18.0)", "sphinx (==5.1.1)"] +ssh = ["paramiko (>=2.4.3)"] +websockets = ["websocket-client (>=1.3.0)"] + [[package]] name = "docstring-parser" version = "0.16" @@ -2028,87 +2197,104 @@ typing_extensions = ">=4.0,<5.0" [[package]] name = "duckdb" -version = "1.0.0" +version = "1.1.2" description = "DuckDB in-process database" optional = false python-versions = ">=3.7.0" files = [ - {file = "duckdb-1.0.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4a8ce2d1f9e1c23b9bab3ae4ca7997e9822e21563ff8f646992663f66d050211"}, - {file = "duckdb-1.0.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:19797670f20f430196e48d25d082a264b66150c264c1e8eae8e22c64c2c5f3f5"}, - {file = "duckdb-1.0.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:b71c342090fe117b35d866a91ad6bffce61cd6ff3e0cff4003f93fc1506da0d8"}, - {file = "duckdb-1.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25dd69f44ad212c35ae2ea736b0e643ea2b70f204b8dff483af1491b0e2a4cec"}, - {file = "duckdb-1.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8da5f293ecb4f99daa9a9352c5fd1312a6ab02b464653a0c3a25ab7065c45d4d"}, - {file = "duckdb-1.0.0-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3207936da9967ddbb60644ec291eb934d5819b08169bc35d08b2dedbe7068c60"}, - {file = "duckdb-1.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:1128d6c9c33e883b1f5df6b57c1eb46b7ab1baf2650912d77ee769aaa05111f9"}, - {file = "duckdb-1.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:02310d263474d0ac238646677feff47190ffb82544c018b2ff732a4cb462c6ef"}, - {file = "duckdb-1.0.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:75586791ab2702719c284157b65ecefe12d0cca9041da474391896ddd9aa71a4"}, - {file = "duckdb-1.0.0-cp311-cp311-macosx_12_0_universal2.whl", hash = "sha256:83bb415fc7994e641344f3489e40430ce083b78963cb1057bf714ac3a58da3ba"}, - {file = "duckdb-1.0.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:bee2e0b415074e84c5a2cefd91f6b5ebeb4283e7196ba4ef65175a7cef298b57"}, - {file = "duckdb-1.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fa5a4110d2a499312609544ad0be61e85a5cdad90e5b6d75ad16b300bf075b90"}, - {file = "duckdb-1.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fa389e6a382d4707b5f3d1bc2087895925ebb92b77e9fe3bfb23c9b98372fdc"}, - {file = "duckdb-1.0.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7ede6f5277dd851f1a4586b0c78dc93f6c26da45e12b23ee0e88c76519cbdbe0"}, - {file = "duckdb-1.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0b88cdbc0d5c3e3d7545a341784dc6cafd90fc035f17b2f04bf1e870c68456e5"}, - {file = "duckdb-1.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:fd1693cdd15375156f7fff4745debc14e5c54928589f67b87fb8eace9880c370"}, - {file = "duckdb-1.0.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:c65a7fe8a8ce21b985356ee3ec0c3d3b3b2234e288e64b4cfb03356dbe6e5583"}, - {file = "duckdb-1.0.0-cp312-cp312-macosx_12_0_universal2.whl", hash = "sha256:e5a8eda554379b3a43b07bad00968acc14dd3e518c9fbe8f128b484cf95e3d16"}, - {file = "duckdb-1.0.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:a1b6acdd54c4a7b43bd7cb584975a1b2ff88ea1a31607a2b734b17960e7d3088"}, - {file = "duckdb-1.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a677bb1b6a8e7cab4a19874249d8144296e6e39dae38fce66a80f26d15e670df"}, - {file = "duckdb-1.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:752e9d412b0a2871bf615a2ede54be494c6dc289d076974eefbf3af28129c759"}, - {file = "duckdb-1.0.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3aadb99d098c5e32d00dc09421bc63a47134a6a0de9d7cd6abf21780b678663c"}, - {file = "duckdb-1.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83b7091d4da3e9301c4f9378833f5ffe934fb1ad2b387b439ee067b2c10c8bb0"}, - {file = "duckdb-1.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:6a8058d0148b544694cb5ea331db44f6c2a00a7b03776cc4dd1470735c3d5ff7"}, - {file = "duckdb-1.0.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e40cb20e5ee19d44bc66ec99969af791702a049079dc5f248c33b1c56af055f4"}, - {file = "duckdb-1.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7bce1bc0de9af9f47328e24e6e7e39da30093179b1c031897c042dd94a59c8e"}, - {file = "duckdb-1.0.0-cp37-cp37m-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8355507f7a04bc0a3666958f4414a58e06141d603e91c0fa5a7c50e49867fb6d"}, - {file = "duckdb-1.0.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:39f1a46f5a45ad2886dc9b02ce5b484f437f90de66c327f86606d9ba4479d475"}, - {file = "duckdb-1.0.0-cp37-cp37m-win_amd64.whl", hash = "sha256:a6d29ba477b27ae41676b62c8fae8d04ee7cbe458127a44f6049888231ca58fa"}, - {file = "duckdb-1.0.0-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:1bea713c1925918714328da76e79a1f7651b2b503511498ccf5e007a7e67d49e"}, - {file = "duckdb-1.0.0-cp38-cp38-macosx_12_0_universal2.whl", hash = "sha256:bfe67f3bcf181edbf6f918b8c963eb060e6aa26697d86590da4edc5707205450"}, - {file = "duckdb-1.0.0-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:dbc6093a75242f002be1d96a6ace3fdf1d002c813e67baff52112e899de9292f"}, - {file = "duckdb-1.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba1881a2b11c507cee18f8fd9ef10100be066fddaa2c20fba1f9a664245cd6d8"}, - {file = "duckdb-1.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:445d0bb35087c522705c724a75f9f1c13f1eb017305b694d2686218d653c8142"}, - {file = "duckdb-1.0.0-cp38-cp38-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:224553432e84432ffb9684f33206572477049b371ce68cc313a01e214f2fbdda"}, - {file = "duckdb-1.0.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:d3914032e47c4e76636ad986d466b63fdea65e37be8a6dfc484ed3f462c4fde4"}, - {file = "duckdb-1.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:af9128a2eb7e1bb50cd2c2020d825fb2946fdad0a2558920cd5411d998999334"}, - {file = "duckdb-1.0.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:dd2659a5dbc0df0de68f617a605bf12fe4da85ba24f67c08730984a0892087e8"}, - {file = "duckdb-1.0.0-cp39-cp39-macosx_12_0_universal2.whl", hash = "sha256:ac5a4afb0bc20725e734e0b2c17e99a274de4801aff0d4e765d276b99dad6d90"}, - {file = "duckdb-1.0.0-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:2c5a53bee3668d6e84c0536164589d5127b23d298e4c443d83f55e4150fafe61"}, - {file = "duckdb-1.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b980713244d7708b25ee0a73de0c65f0e5521c47a0e907f5e1b933d79d972ef6"}, - {file = "duckdb-1.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21cbd4f9fe7b7a56eff96c3f4d6778770dd370469ca2212eddbae5dd63749db5"}, - {file = "duckdb-1.0.0-cp39-cp39-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ed228167c5d49888c5ef36f6f9cbf65011c2daf9dcb53ea8aa7a041ce567b3e4"}, - {file = "duckdb-1.0.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:46d8395fbcea7231fd5032a250b673cc99352fef349b718a23dea2c0dd2b8dec"}, - {file = "duckdb-1.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:6ad1fc1a4d57e7616944166a5f9417bdbca1ea65c490797e3786e3a42e162d8a"}, - {file = "duckdb-1.0.0.tar.gz", hash = "sha256:a2a059b77bc7d5b76ae9d88e267372deff19c291048d59450c431e166233d453"}, + {file = "duckdb-1.1.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:91e7f99cf5cab1d26f92cb014429153497d805e79689baa44f4c4585a8cb243f"}, + {file = "duckdb-1.1.2-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:0107de622fe208142a1108263a03c43956048dcc99be3702d8e5d2aeaf99554c"}, + {file = "duckdb-1.1.2-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:8a09610f780857677725897856f8cdf3cafd8a991f871e6cb8ba88b2dbc8d737"}, + {file = "duckdb-1.1.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c0f0ddac0482f0f3fece54d720d13819e82ae26c01a939ffa66a87be53f7f665"}, + {file = "duckdb-1.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:84103373e818758dfa361d27781d0f096553843c5ffb9193260a0786c5248270"}, + {file = "duckdb-1.1.2-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bfdfd23e2bf58014ad0673973bd0ed88cd048dfe8e82420814a71d7d52ef2288"}, + {file = "duckdb-1.1.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:25889e6e29b87047b1dd56385ac08156e4713c59326cc6fff89657d01b2c417b"}, + {file = "duckdb-1.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:312570fa5277c3079de18388b86c2d87cbe1044838bb152b235c0227581d5d42"}, + {file = "duckdb-1.1.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:568439ea4fce8cb72ec1f767cd510686a9e7e29a011fc7c56d990059a6e94e48"}, + {file = "duckdb-1.1.2-cp311-cp311-macosx_12_0_universal2.whl", hash = "sha256:74974f2d7210623a5d61b1fb0cb589c6e5ffcbf7dbb757a04c5ba24adcfc8cac"}, + {file = "duckdb-1.1.2-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:e26422a3358c816d764639070945b73eef55d1b4df990989e3492c85ef725c21"}, + {file = "duckdb-1.1.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:87e972bd452eeeab197fe39dcaeecdb7c264b1f75a0ee67e532e235fe45b84df"}, + {file = "duckdb-1.1.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a6b73e70b73c8df85da383f6e557c03cad5c877868b9a7e41715761e8166c1e"}, + {file = "duckdb-1.1.2-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:623cb1952466aae5907af84107bcdec25a5ca021a8b6441e961f41edc724f6f2"}, + {file = "duckdb-1.1.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d9fc0b550f96901fa7e76dc70a13f6477ad3e18ef1cb21d414c3a5569de3f27e"}, + {file = "duckdb-1.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:181edb1973bd8f493bcb6ecfa035f1a592dff4667758592f300619012ba251c0"}, + {file = "duckdb-1.1.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:83372b1b411086cac01ab2071122772fa66170b1b41ddbc37527464066083668"}, + {file = "duckdb-1.1.2-cp312-cp312-macosx_12_0_universal2.whl", hash = "sha256:db37441deddfee6ac35a0c742d2f9e90e4e50b9e76d586a060d122b8fc56dada"}, + {file = "duckdb-1.1.2-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:19142a77e72874aeaa6fda30aeb13612c6de5e8c60fbcc3392cea6ef0694eeaf"}, + {file = "duckdb-1.1.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:099d99dd48d6e4682a3dd6233ceab73d977ebe1a87afaac54cf77c844e24514a"}, + {file = "duckdb-1.1.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be86e586ca7af7e807f72479a2b8d0983565360b19dbda4ef8a9d7b3909b8e2c"}, + {file = "duckdb-1.1.2-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:578e0953e4d8ba8da0cd69fb2930c45f51ce47d213b77d8a4cd461f9c0960b87"}, + {file = "duckdb-1.1.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:72b5eb5762c1a5e68849c7143f3b3747a9f15c040e34e41559f233a1569ad16f"}, + {file = "duckdb-1.1.2-cp312-cp312-win_amd64.whl", hash = "sha256:9b4c6b6a08180261d98330d97355503961a25ca31cd9ef296e0681f7895b4a2c"}, + {file = "duckdb-1.1.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:695dcbc561374b126e86659709feadf883c9969ed718e94713edd4ba15d16619"}, + {file = "duckdb-1.1.2-cp313-cp313-macosx_12_0_universal2.whl", hash = "sha256:ada29be1e889f486c6cf1f6dffd15463e748faf361f33996f2e862779edc24a9"}, + {file = "duckdb-1.1.2-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:6ca722738fa9eb6218619740631de29acfdd132de6f6a6350fee5e291c2f6117"}, + {file = "duckdb-1.1.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c796d33f1e5a0c8c570d22da0c0b1db8578687e427029e1ce2c8ce3f9fffa6a3"}, + {file = "duckdb-1.1.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5c0996988a70dd3bc8111d9b9aeab7e38ed1999a52607c5f1b528e362b4dd1c"}, + {file = "duckdb-1.1.2-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6c37b039f6d6fed14d89450f5ccf54922b3304192d7412e12d6cc8d9e757f7a2"}, + {file = "duckdb-1.1.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e8c766b87f675c76d6d17103bf6fb9fb1a9e2fcb3d9b25c28bbc634bde31223e"}, + {file = "duckdb-1.1.2-cp313-cp313-win_amd64.whl", hash = "sha256:e3e6300b7ccaf64b609f4f0780a6e1d25ab8cf34cceed46e62c35b6c4c5cb63b"}, + {file = "duckdb-1.1.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8a01fae9604a54ecbc26e7503c522311f15afbd2870e6d8f6fbef4545dfae550"}, + {file = "duckdb-1.1.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:492b1d86a696428bd3f14dc1c7c3230e2dbca8978f288be64b04a26e0e00fad5"}, + {file = "duckdb-1.1.2-cp37-cp37m-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1bba58459ad897a78c4e478a097626fc266459a40338cecc68a49a8d5dc72fb7"}, + {file = "duckdb-1.1.2-cp37-cp37m-win_amd64.whl", hash = "sha256:d395a3bf510bf24686821eec15802624797dcb33e8f14f8a7cc8e17d909474af"}, + {file = "duckdb-1.1.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:fd800f75728727fe699ed1eb22b636867cf48c9dd105ee88b977e20c89df4509"}, + {file = "duckdb-1.1.2-cp38-cp38-macosx_12_0_universal2.whl", hash = "sha256:d8caaf43909e49537e26df51d80d075ae2b25a610d28ed8bd31d6ccebeaf3c65"}, + {file = "duckdb-1.1.2-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:564166811c68d9c7f9911eb707ad32ec9c2507b98336d894fbe658b85bf1c697"}, + {file = "duckdb-1.1.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:19386aa09f0d6f97634ba2972096d1c80d880176dfb0e949eadc91c98262a663"}, + {file = "duckdb-1.1.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9e8387bcc9a591ad14011ddfec0d408d1d9b1889c6c9b495a04c7016a24b9b3"}, + {file = "duckdb-1.1.2-cp38-cp38-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f8c5ff4970403ed3ff0ac71fe0ce1e6be3199df9d542afc84c424b444ba4ffe8"}, + {file = "duckdb-1.1.2-cp38-cp38-win_amd64.whl", hash = "sha256:9283dcca87c3260eb631a99d738fa72b8545ed45b475bc72ad254f7310e14284"}, + {file = "duckdb-1.1.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:f87edaf20001530e63a4f7bda13b55dc3152d7171226915f2bf34e0813c8759e"}, + {file = "duckdb-1.1.2-cp39-cp39-macosx_12_0_universal2.whl", hash = "sha256:efec169b3fe0b821e3207ba3e445f227d42dd62b4440ff79c37fa168a4fc5a71"}, + {file = "duckdb-1.1.2-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:89164a2d29d56605a95ee5032aa415dd487028c4fd3e06d971497840e74c56e7"}, + {file = "duckdb-1.1.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6858e10c60ff7e70e61d3dd53d2545c8b2609942e45fd6de38cd0dee52932de3"}, + {file = "duckdb-1.1.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ca967c5a57b1d0cb0fd5e539ab24110e5a59dcbedd365bb2dc80533d6e44a8d"}, + {file = "duckdb-1.1.2-cp39-cp39-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4ce949f1d7999aa6a046eb64067eee41d4c5c2872ba4fa408c9947742d0c7231"}, + {file = "duckdb-1.1.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:9ba6d1f918e6ca47a368a0c32806016405cb9beb2c245806b0ca998f569d2bdf"}, + {file = "duckdb-1.1.2-cp39-cp39-win_amd64.whl", hash = "sha256:7111fd3e7b334a7be383313ce29918b7c643e4f6ef44d6d63c3ab3fa6716c114"}, + {file = "duckdb-1.1.2.tar.gz", hash = "sha256:c8232861dc8ec6daa29067056d5a0e5789919f2ab22ab792787616d7cd52f02a"}, ] [[package]] name = "duckduckgo-search" -version = "6.2.6" +version = "6.3.3" description = "Search for words, documents, images, news, maps and text translation using the DuckDuckGo.com search engine." optional = false python-versions = ">=3.8" files = [ - {file = "duckduckgo_search-6.2.6-py3-none-any.whl", hash = "sha256:c8171bcd6ff4d051f78c70ea23bd34c0d8e779d72973829d3a6b40ccc05cd7c2"}, - {file = "duckduckgo_search-6.2.6.tar.gz", hash = "sha256:96529ecfbd55afa28705b38413003cb3cfc620e55762d33184887545de27dc96"}, + {file = "duckduckgo_search-6.3.3-py3-none-any.whl", hash = "sha256:63e5d6b958bd532016bc8a53e8b18717751bf7ef51b1c83e59b9f5780c79e64c"}, + {file = "duckduckgo_search-6.3.3.tar.gz", hash = "sha256:4d49508f01f85c8675765fdd4cc25eedbb3450e129b35209897fded874f6568f"}, ] [package.dependencies] click = ">=8.1.7" -primp = ">=0.5.5" +primp = ">=0.6.5" [package.extras] -dev = ["mypy (>=1.11.0)", "pytest (>=8.3.1)", "pytest-asyncio (>=0.23.8)", "ruff (>=0.5.5)"] +dev = ["mypy (>=1.11.1)", "pytest (>=8.3.1)", "pytest-asyncio (>=0.23.8)", "ruff (>=0.6.1)"] lxml = ["lxml (>=5.2.2)"] +[[package]] +name = "durationpy" +version = "0.9" +description = "Module for converting between datetime.timedelta and Go's Duration strings." +optional = false +python-versions = "*" +files = [ + {file = "durationpy-0.9-py3-none-any.whl", hash = "sha256:e65359a7af5cedad07fb77a2dd3f390f8eb0b74cb845589fa6c057086834dd38"}, + {file = "durationpy-0.9.tar.gz", hash = "sha256:fd3feb0a69a0057d582ef643c355c40d2fa1c942191f914d12203b1a01ac722a"}, +] + [[package]] name = "elastic-transport" -version = "8.15.0" +version = "8.15.1" description = "Transport classes and utilities shared among Python Elastic client libraries" optional = false python-versions = ">=3.8" files = [ - {file = "elastic_transport-8.15.0-py3-none-any.whl", hash = "sha256:d7080d1dada2b4eee69e7574f9c17a76b42f2895eff428e562f94b0360e158c0"}, - {file = "elastic_transport-8.15.0.tar.gz", hash = "sha256:85d62558f9baafb0868c801233a59b235e61d7b4804c28c2fadaa866b6766233"}, + {file = "elastic_transport-8.15.1-py3-none-any.whl", hash = "sha256:b5e82ff1679d8c7705a03fd85c7f6ef85d6689721762d41228dd312e34f331fc"}, + {file = "elastic_transport-8.15.1.tar.gz", hash = "sha256:9cac4ab5cf9402668cf305ae0b7d93ddc0c7b61461d6d1027850db6da9cc5742"}, ] [package.dependencies] @@ -2116,7 +2302,7 @@ certifi = "*" urllib3 = ">=1.26.2,<3" [package.extras] -develop = ["aiohttp", "furo", "httpx", "opentelemetry-api", "opentelemetry-sdk", "orjson", "pytest", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "pytest-mock", "requests", "respx", "sphinx (>2)", "sphinx-autodoc-typehints", "trustme"] +develop = ["aiohttp", "furo", "httpcore (<1.0.6)", "httpx", "opentelemetry-api", "opentelemetry-sdk", "orjson", "pytest", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "pytest-mock", "requests", "respx", "sphinx (>2)", "sphinx-autodoc-typehints", "trustme"] [[package]] name = "elasticsearch" @@ -2140,18 +2326,15 @@ vectorstore-mmr = ["numpy (>=1)", "simsimd (>=3)"] [[package]] name = "emoji" -version = "2.12.1" +version = "2.14.0" description = "Emoji for Python" optional = false python-versions = ">=3.7" files = [ - {file = "emoji-2.12.1-py3-none-any.whl", hash = "sha256:a00d62173bdadc2510967a381810101624a2f0986145b8da0cffa42e29430235"}, - {file = "emoji-2.12.1.tar.gz", hash = "sha256:4aa0488817691aa58d83764b6c209f8a27c0b3ab3f89d1b8dceca1a62e4973eb"}, + {file = "emoji-2.14.0-py3-none-any.whl", hash = "sha256:fcc936bf374b1aec67dda5303ae99710ba88cc9cdce2d1a71c5f2204e6d78799"}, + {file = "emoji-2.14.0.tar.gz", hash = "sha256:f68ac28915a2221667cddb3e6c589303c3c6954c6c5af6fefaec7f9bdf72fdca"}, ] -[package.dependencies] -typing-extensions = ">=4.7.0" - [package.extras] dev = ["coverage", "pytest (>=7.4.4)"] @@ -2176,17 +2359,44 @@ django = ["dj-database-url", "dj-email-url", "django-cache-url"] lint = ["flake8 (==4.0.1)", "flake8-bugbear (==21.9.2)", "mypy (==0.910)", "pre-commit (>=2.4,<3.0)"] tests = ["dj-database-url", "dj-email-url", "django-cache-url", "pytest"] +[[package]] +name = "esdk-obs-python" +version = "3.24.6.1" +description = "OBS Python SDK" +optional = false +python-versions = "*" +files = [ + {file = "esdk-obs-python-3.24.6.1.tar.gz", hash = "sha256:c45fed143e99d9256c8560c1d78f651eae0d2e809d16e962f8b286b773c33bf0"}, +] + +[package.dependencies] +pycryptodome = ">=3.10.1" + [[package]] name = "et-xmlfile" -version = "1.1.0" +version = "2.0.0" description = "An implementation of lxml.xmlfile for the standard library" optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" +files = [ + {file = "et_xmlfile-2.0.0-py3-none-any.whl", hash = "sha256:7a91720bc756843502c3b7504c77b8fe44217c85c537d85037f0f536151b2caa"}, + {file = "et_xmlfile-2.0.0.tar.gz", hash = "sha256:dab3f4764309081ce75662649be815c4c9081e88f0837825f90fd28317d4da54"}, +] + +[[package]] +name = "eval-type-backport" +version = "0.2.0" +description = "Like `typing._eval_type`, but lets older Python versions use newer typing features." +optional = false +python-versions = ">=3.8" files = [ - {file = "et_xmlfile-1.1.0-py3-none-any.whl", hash = "sha256:a2ba85d1d6a74ef63837eed693bcb89c3f752169b0e3e7ae5b16ca5e1b3deada"}, - {file = "et_xmlfile-1.1.0.tar.gz", hash = "sha256:8eb9e2bc2f8c97e37a2dc85a09ecdcdec9d8a396530a6d5a33b30b9a92da0c5c"}, + {file = "eval_type_backport-0.2.0-py3-none-any.whl", hash = "sha256:ac2f73d30d40c5a30a80b8739a789d6bb5e49fdffa66d7912667e2015d9c9933"}, + {file = "eval_type_backport-0.2.0.tar.gz", hash = "sha256:68796cfbc7371ebf923f03bdf7bef415f3ec098aeced24e054b253a0e78f7b37"}, ] +[package.extras] +tests = ["pytest"] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -2203,62 +2413,62 @@ test = ["pytest (>=6)"] [[package]] name = "fastapi" -version = "0.112.0" +version = "0.115.4" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = false python-versions = ">=3.8" files = [ - {file = "fastapi-0.112.0-py3-none-any.whl", hash = "sha256:3487ded9778006a45834b8c816ec4a48d522e2631ca9e75ec5a774f1b052f821"}, - {file = "fastapi-0.112.0.tar.gz", hash = "sha256:d262bc56b7d101d1f4e8fc0ad2ac75bb9935fec504d2b7117686cec50710cf05"}, + {file = "fastapi-0.115.4-py3-none-any.whl", hash = "sha256:0b504a063ffb3cf96a5e27dc1bc32c80ca743a2528574f9cdc77daa2d31b4742"}, + {file = "fastapi-0.115.4.tar.gz", hash = "sha256:db653475586b091cb8b2fec2ac54a680ac6a158e07406e1abae31679e8826349"}, ] [package.dependencies] pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0" -starlette = ">=0.37.2,<0.38.0" +starlette = ">=0.40.0,<0.42.0" typing-extensions = ">=4.8.0" [package.extras] -all = ["email_validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] -standard = ["email_validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "jinja2 (>=2.11.2)", "python-multipart (>=0.0.7)", "uvicorn[standard] (>=0.12.0)"] +all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "jinja2 (>=2.11.2)", "python-multipart (>=0.0.7)", "uvicorn[standard] (>=0.12.0)"] [[package]] name = "fastavro" -version = "1.9.5" +version = "1.9.7" description = "Fast read/write of AVRO files" optional = false python-versions = ">=3.8" files = [ - {file = "fastavro-1.9.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:61253148e95dd2b6457247b441b7555074a55de17aef85f5165bfd5facf600fc"}, - {file = "fastavro-1.9.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b604935d671ad47d888efc92a106f98e9440874108b444ac10e28d643109c937"}, - {file = "fastavro-1.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0adbf4956fd53bd74c41e7855bb45ccce953e0eb0e44f5836d8d54ad843f9944"}, - {file = "fastavro-1.9.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:53d838e31457db8bf44460c244543f75ed307935d5fc1d93bc631cc7caef2082"}, - {file = "fastavro-1.9.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:07b6288e8681eede16ff077632c47395d4925c2f51545cd7a60f194454db2211"}, - {file = "fastavro-1.9.5-cp310-cp310-win_amd64.whl", hash = "sha256:ef08cf247fdfd61286ac0c41854f7194f2ad05088066a756423d7299b688d975"}, - {file = "fastavro-1.9.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c52d7bb69f617c90935a3e56feb2c34d4276819a5c477c466c6c08c224a10409"}, - {file = "fastavro-1.9.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85e05969956003df8fa4491614bc62fe40cec59e94d06e8aaa8d8256ee3aab82"}, - {file = "fastavro-1.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06e6df8527493a9f0d9a8778df82bab8b1aa6d80d1b004e5aec0a31dc4dc501c"}, - {file = "fastavro-1.9.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:27820da3b17bc01cebb6d1687c9d7254b16d149ef458871aaa207ed8950f3ae6"}, - {file = "fastavro-1.9.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:195a5b8e33eb89a1a9b63fa9dce7a77d41b3b0cd785bac6044df619f120361a2"}, - {file = "fastavro-1.9.5-cp311-cp311-win_amd64.whl", hash = "sha256:be612c109efb727bfd36d4d7ed28eb8e0506617b7dbe746463ebbf81e85eaa6b"}, - {file = "fastavro-1.9.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:b133456c8975ec7d2a99e16a7e68e896e45c821b852675eac4ee25364b999c14"}, - {file = "fastavro-1.9.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf586373c3d1748cac849395aad70c198ee39295f92e7c22c75757b5c0300fbe"}, - {file = "fastavro-1.9.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:724ef192bc9c55d5b4c7df007f56a46a21809463499856349d4580a55e2b914c"}, - {file = "fastavro-1.9.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:bfd11fe355a8f9c0416803afac298960eb4c603a23b1c74ff9c1d3e673ea7185"}, - {file = "fastavro-1.9.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:9827d1654d7bcb118ef5efd3e5b2c9ab2a48d44dac5e8c6a2327bc3ac3caa828"}, - {file = "fastavro-1.9.5-cp312-cp312-win_amd64.whl", hash = "sha256:d84b69dca296667e6137ae7c9a96d060123adbc0c00532cc47012b64d38b47e9"}, - {file = "fastavro-1.9.5-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:fb744e9de40fb1dc75354098c8db7da7636cba50a40f7bef3b3fb20f8d189d88"}, - {file = "fastavro-1.9.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:240df8bacd13ff5487f2465604c007d686a566df5cbc01d0550684eaf8ff014a"}, - {file = "fastavro-1.9.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3bb35c25bbc3904e1c02333bc1ae0173e0a44aa37a8e95d07e681601246e1f1"}, - {file = "fastavro-1.9.5-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:b47a54a9700de3eabefd36dabfb237808acae47bc873cada6be6990ef6b165aa"}, - {file = "fastavro-1.9.5-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:48c7b5e6d2f3bf7917af301c275b05c5be3dd40bb04e80979c9e7a2ab31a00d1"}, - {file = "fastavro-1.9.5-cp38-cp38-win_amd64.whl", hash = "sha256:05d13f98d4e325be40387e27da9bd60239968862fe12769258225c62ec906f04"}, - {file = "fastavro-1.9.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5b47948eb196263f6111bf34e1cd08d55529d4ed46eb50c1bc8c7c30a8d18868"}, - {file = "fastavro-1.9.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85b7a66ad521298ad9373dfe1897a6ccfc38feab54a47b97922e213ae5ad8870"}, - {file = "fastavro-1.9.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44cb154f863ad80e41aea72a709b12e1533b8728c89b9b1348af91a6154ab2f5"}, - {file = "fastavro-1.9.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:b5f7f2b1fe21231fd01f1a2a90e714ae267fe633cd7ce930c0aea33d1c9f4901"}, - {file = "fastavro-1.9.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:88fbbe16c61d90a89d78baeb5a34dc1c63a27b115adccdbd6b1fb6f787deacf2"}, - {file = "fastavro-1.9.5-cp39-cp39-win_amd64.whl", hash = "sha256:753f5eedeb5ca86004e23a9ce9b41c5f25eb64a876f95edcc33558090a7f3e4b"}, - {file = "fastavro-1.9.5.tar.gz", hash = "sha256:6419ebf45f88132a9945c51fe555d4f10bb97c236288ed01894f957c6f914553"}, + {file = "fastavro-1.9.7-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:cc811fb4f7b5ae95f969cda910241ceacf82e53014c7c7224df6f6e0ca97f52f"}, + {file = "fastavro-1.9.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb8749e419a85f251bf1ac87d463311874972554d25d4a0b19f6bdc56036d7cf"}, + {file = "fastavro-1.9.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b2f9bafa167cb4d1c3dd17565cb5bf3d8c0759e42620280d1760f1e778e07fc"}, + {file = "fastavro-1.9.7-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e87d04b235b29f7774d226b120da2ca4e60b9e6fdf6747daef7f13f218b3517a"}, + {file = "fastavro-1.9.7-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b525c363e267ed11810aaad8fbdbd1c3bd8837d05f7360977d72a65ab8c6e1fa"}, + {file = "fastavro-1.9.7-cp310-cp310-win_amd64.whl", hash = "sha256:6312fa99deecc319820216b5e1b1bd2d7ebb7d6f221373c74acfddaee64e8e60"}, + {file = "fastavro-1.9.7-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ec8499dc276c2d2ef0a68c0f1ad11782b2b956a921790a36bf4c18df2b8d4020"}, + {file = "fastavro-1.9.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76d9d96f98052615ab465c63ba8b76ed59baf2e3341b7b169058db104cbe2aa0"}, + {file = "fastavro-1.9.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:919f3549e07a8a8645a2146f23905955c35264ac809f6c2ac18142bc5b9b6022"}, + {file = "fastavro-1.9.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9de1fa832a4d9016724cd6facab8034dc90d820b71a5d57c7e9830ffe90f31e4"}, + {file = "fastavro-1.9.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1d09227d1f48f13281bd5ceac958650805aef9a4ef4f95810128c1f9be1df736"}, + {file = "fastavro-1.9.7-cp311-cp311-win_amd64.whl", hash = "sha256:2db993ae6cdc63e25eadf9f93c9e8036f9b097a3e61d19dca42536dcc5c4d8b3"}, + {file = "fastavro-1.9.7-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:4e1289b731214a7315884c74b2ec058b6e84380ce9b18b8af5d387e64b18fc44"}, + {file = "fastavro-1.9.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eac69666270a76a3a1d0444f39752061195e79e146271a568777048ffbd91a27"}, + {file = "fastavro-1.9.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9be089be8c00f68e343bbc64ca6d9a13e5e5b0ba8aa52bcb231a762484fb270e"}, + {file = "fastavro-1.9.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d576eccfd60a18ffa028259500df67d338b93562c6700e10ef68bbd88e499731"}, + {file = "fastavro-1.9.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ee9bf23c157bd7dcc91ea2c700fa3bd924d9ec198bb428ff0b47fa37fe160659"}, + {file = "fastavro-1.9.7-cp312-cp312-win_amd64.whl", hash = "sha256:b6b2ccdc78f6afc18c52e403ee68c00478da12142815c1bd8a00973138a166d0"}, + {file = "fastavro-1.9.7-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:7313def3aea3dacface0a8b83f6d66e49a311149aa925c89184a06c1ef99785d"}, + {file = "fastavro-1.9.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:536f5644737ad21d18af97d909dba099b9e7118c237be7e4bd087c7abde7e4f0"}, + {file = "fastavro-1.9.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2af559f30383b79cf7d020a6b644c42ffaed3595f775fe8f3d7f80b1c43dfdc5"}, + {file = "fastavro-1.9.7-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:edc28ab305e3c424de5ac5eb87b48d1e07eddb6aa08ef5948fcda33cc4d995ce"}, + {file = "fastavro-1.9.7-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:ec2e96bdabd58427fe683329b3d79f42c7b4f4ff6b3644664a345a655ac2c0a1"}, + {file = "fastavro-1.9.7-cp38-cp38-win_amd64.whl", hash = "sha256:3b683693c8a85ede496ebebe115be5d7870c150986e34a0442a20d88d7771224"}, + {file = "fastavro-1.9.7-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:58f76a5c9a312fbd37b84e49d08eb23094d36e10d43bc5df5187bc04af463feb"}, + {file = "fastavro-1.9.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56304401d2f4f69f5b498bdd1552c13ef9a644d522d5de0dc1d789cf82f47f73"}, + {file = "fastavro-1.9.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fcce036c6aa06269fc6a0428050fcb6255189997f5e1a728fc461e8b9d3e26b"}, + {file = "fastavro-1.9.7-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:17de68aae8c2525f5631d80f2b447a53395cdc49134f51b0329a5497277fc2d2"}, + {file = "fastavro-1.9.7-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:7c911366c625d0a997eafe0aa83ffbc6fd00d8fd4543cb39a97c6f3b8120ea87"}, + {file = "fastavro-1.9.7-cp39-cp39-win_amd64.whl", hash = "sha256:912283ed48578a103f523817fdf0c19b1755cea9b4a6387b73c79ecb8f8f84fc"}, + {file = "fastavro-1.9.7.tar.gz", hash = "sha256:13e11c6cb28626da85290933027cd419ce3f9ab8e45410ef24ce6b89d20a1f6c"}, ] [package.extras] @@ -2298,19 +2508,19 @@ sgmllib3k = "*" [[package]] name = "filelock" -version = "3.15.4" +version = "3.16.1" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.15.4-py3-none-any.whl", hash = "sha256:6ca1fffae96225dab4c6eaf1c4f4f28cd2568d3ec2a44e15a08520504de468e7"}, - {file = "filelock-3.15.4.tar.gz", hash = "sha256:2207938cbc1844345cb01a5a95524dae30f0ce089eba5b00378295a17e3e90cb"}, + {file = "filelock-3.16.1-py3-none-any.whl", hash = "sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0"}, + {file = "filelock-3.16.1.tar.gz", hash = "sha256:c249fbfcd5db47e5e2d6d62198e565475ee65e4831e2561c8e313fa7eb961435"}, ] [package.extras] -docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] -typing = ["typing-extensions (>=4.8)"] +docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4.1)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"] +typing = ["typing-extensions (>=4.12.2)"] [[package]] name = "filetype" @@ -2323,6 +2533,37 @@ files = [ {file = "filetype-1.2.0.tar.gz", hash = "sha256:66b56cd6474bf41d8c54660347d37afcc3f7d1970648de365c102ef77548aadb"}, ] +[[package]] +name = "fire" +version = "0.7.0" +description = "A library for automatically generating command line interfaces." +optional = false +python-versions = "*" +files = [ + {file = "fire-0.7.0.tar.gz", hash = "sha256:961550f07936eaf65ad1dc8360f2b2bf8408fad46abbfa4d2a3794f8d2a95cdf"}, +] + +[package.dependencies] +termcolor = "*" + +[[package]] +name = "flasgger" +version = "0.9.7.1" +description = "Extract swagger specs from your flask project" +optional = false +python-versions = "*" +files = [ + {file = "flasgger-0.9.7.1.tar.gz", hash = "sha256:ca098e10bfbb12f047acc6299cc70a33851943a746e550d86e65e60d4df245fb"}, +] + +[package.dependencies] +Flask = ">=0.10" +jsonschema = ">=3.0.1" +mistune = "*" +packaging = "*" +PyYAML = ">=3.0" +six = ">=1.10.0" + [[package]] name = "flask" version = "3.0.3" @@ -2363,13 +2604,13 @@ flask = "*" [[package]] name = "flask-cors" -version = "4.0.1" +version = "4.0.2" description = "A Flask extension adding a decorator for CORS support" optional = false python-versions = "*" files = [ - {file = "Flask_Cors-4.0.1-py2.py3-none-any.whl", hash = "sha256:f2a704e4458665580c074b714c4627dd5a306b333deb9074d0b1794dfa2fb677"}, - {file = "flask_cors-4.0.1.tar.gz", hash = "sha256:eeb69b342142fdbf4766ad99357a7f3876a2ceb77689dc10ff912aac06c389e4"}, + {file = "Flask_Cors-4.0.2-py2.py3-none-any.whl", hash = "sha256:38364faf1a7a5d0a55bd1d2e2f83ee9e359039182f5e6a029557e1f56d92c09a"}, + {file = "flask_cors-4.0.2.tar.gz", hash = "sha256:493b98e2d1e2f1a4720a7af25693ef2fe32fbafec09a2f72c59f3e475eda61d2"}, ] [package.dependencies] @@ -2470,55 +2711,74 @@ files = [ {file = "flatbuffers-24.3.25.tar.gz", hash = "sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4"}, ] +[[package]] +name = "fontmeta" +version = "1.6.1" +description = "An Utility to get ttf/otf font metadata" +optional = false +python-versions = "*" +files = [ + {file = "fontmeta-1.6.1.tar.gz", hash = "sha256:837e5bc4da879394b41bda1428a8a480eb7c4e993799a93cfb582bab771a9c24"}, +] + +[package.dependencies] +fonttools = "*" + [[package]] name = "fonttools" -version = "4.53.1" +version = "4.54.1" description = "Tools to manipulate font files" optional = false python-versions = ">=3.8" files = [ - {file = "fonttools-4.53.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0679a30b59d74b6242909945429dbddb08496935b82f91ea9bf6ad240ec23397"}, - {file = "fonttools-4.53.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e8bf06b94694251861ba7fdeea15c8ec0967f84c3d4143ae9daf42bbc7717fe3"}, - {file = "fonttools-4.53.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b96cd370a61f4d083c9c0053bf634279b094308d52fdc2dd9a22d8372fdd590d"}, - {file = "fonttools-4.53.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a1c7c5aa18dd3b17995898b4a9b5929d69ef6ae2af5b96d585ff4005033d82f0"}, - {file = "fonttools-4.53.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e013aae589c1c12505da64a7d8d023e584987e51e62006e1bb30d72f26522c41"}, - {file = "fonttools-4.53.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:9efd176f874cb6402e607e4cc9b4a9cd584d82fc34a4b0c811970b32ba62501f"}, - {file = "fonttools-4.53.1-cp310-cp310-win32.whl", hash = "sha256:c8696544c964500aa9439efb6761947393b70b17ef4e82d73277413f291260a4"}, - {file = "fonttools-4.53.1-cp310-cp310-win_amd64.whl", hash = "sha256:8959a59de5af6d2bec27489e98ef25a397cfa1774b375d5787509c06659b3671"}, - {file = "fonttools-4.53.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:da33440b1413bad53a8674393c5d29ce64d8c1a15ef8a77c642ffd900d07bfe1"}, - {file = "fonttools-4.53.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5ff7e5e9bad94e3a70c5cd2fa27f20b9bb9385e10cddab567b85ce5d306ea923"}, - {file = "fonttools-4.53.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6e7170d675d12eac12ad1a981d90f118c06cf680b42a2d74c6c931e54b50719"}, - {file = "fonttools-4.53.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bee32ea8765e859670c4447b0817514ca79054463b6b79784b08a8df3a4d78e3"}, - {file = "fonttools-4.53.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6e08f572625a1ee682115223eabebc4c6a2035a6917eac6f60350aba297ccadb"}, - {file = "fonttools-4.53.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b21952c092ffd827504de7e66b62aba26fdb5f9d1e435c52477e6486e9d128b2"}, - {file = "fonttools-4.53.1-cp311-cp311-win32.whl", hash = "sha256:9dfdae43b7996af46ff9da520998a32b105c7f098aeea06b2226b30e74fbba88"}, - {file = "fonttools-4.53.1-cp311-cp311-win_amd64.whl", hash = "sha256:d4d0096cb1ac7a77b3b41cd78c9b6bc4a400550e21dc7a92f2b5ab53ed74eb02"}, - {file = "fonttools-4.53.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:d92d3c2a1b39631a6131c2fa25b5406855f97969b068e7e08413325bc0afba58"}, - {file = "fonttools-4.53.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3b3c8ebafbee8d9002bd8f1195d09ed2bd9ff134ddec37ee8f6a6375e6a4f0e8"}, - {file = "fonttools-4.53.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32f029c095ad66c425b0ee85553d0dc326d45d7059dbc227330fc29b43e8ba60"}, - {file = "fonttools-4.53.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10f5e6c3510b79ea27bb1ebfcc67048cde9ec67afa87c7dd7efa5c700491ac7f"}, - {file = "fonttools-4.53.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f677ce218976496a587ab17140da141557beb91d2a5c1a14212c994093f2eae2"}, - {file = "fonttools-4.53.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:9e6ceba2a01b448e36754983d376064730690401da1dd104ddb543519470a15f"}, - {file = "fonttools-4.53.1-cp312-cp312-win32.whl", hash = "sha256:791b31ebbc05197d7aa096bbc7bd76d591f05905d2fd908bf103af4488e60670"}, - {file = "fonttools-4.53.1-cp312-cp312-win_amd64.whl", hash = "sha256:6ed170b5e17da0264b9f6fae86073be3db15fa1bd74061c8331022bca6d09bab"}, - {file = "fonttools-4.53.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c818c058404eb2bba05e728d38049438afd649e3c409796723dfc17cd3f08749"}, - {file = "fonttools-4.53.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:651390c3b26b0c7d1f4407cad281ee7a5a85a31a110cbac5269de72a51551ba2"}, - {file = "fonttools-4.53.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e54f1bba2f655924c1138bbc7fa91abd61f45c68bd65ab5ed985942712864bbb"}, - {file = "fonttools-4.53.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9cd19cf4fe0595ebdd1d4915882b9440c3a6d30b008f3cc7587c1da7b95be5f"}, - {file = "fonttools-4.53.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:2af40ae9cdcb204fc1d8f26b190aa16534fcd4f0df756268df674a270eab575d"}, - {file = "fonttools-4.53.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:35250099b0cfb32d799fb5d6c651220a642fe2e3c7d2560490e6f1d3f9ae9169"}, - {file = "fonttools-4.53.1-cp38-cp38-win32.whl", hash = "sha256:f08df60fbd8d289152079a65da4e66a447efc1d5d5a4d3f299cdd39e3b2e4a7d"}, - {file = "fonttools-4.53.1-cp38-cp38-win_amd64.whl", hash = "sha256:7b6b35e52ddc8fb0db562133894e6ef5b4e54e1283dff606fda3eed938c36fc8"}, - {file = "fonttools-4.53.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:75a157d8d26c06e64ace9df037ee93a4938a4606a38cb7ffaf6635e60e253b7a"}, - {file = "fonttools-4.53.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4824c198f714ab5559c5be10fd1adf876712aa7989882a4ec887bf1ef3e00e31"}, - {file = "fonttools-4.53.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:becc5d7cb89c7b7afa8321b6bb3dbee0eec2b57855c90b3e9bf5fb816671fa7c"}, - {file = "fonttools-4.53.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:84ec3fb43befb54be490147b4a922b5314e16372a643004f182babee9f9c3407"}, - {file = "fonttools-4.53.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:73379d3ffdeecb376640cd8ed03e9d2d0e568c9d1a4e9b16504a834ebadc2dfb"}, - {file = "fonttools-4.53.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:02569e9a810f9d11f4ae82c391ebc6fb5730d95a0657d24d754ed7763fb2d122"}, - {file = "fonttools-4.53.1-cp39-cp39-win32.whl", hash = "sha256:aae7bd54187e8bf7fd69f8ab87b2885253d3575163ad4d669a262fe97f0136cb"}, - {file = "fonttools-4.53.1-cp39-cp39-win_amd64.whl", hash = "sha256:e5b708073ea3d684235648786f5f6153a48dc8762cdfe5563c57e80787c29fbb"}, - {file = "fonttools-4.53.1-py3-none-any.whl", hash = "sha256:f1f8758a2ad110bd6432203a344269f445a2907dc24ef6bccfd0ac4e14e0d71d"}, - {file = "fonttools-4.53.1.tar.gz", hash = "sha256:e128778a8e9bc11159ce5447f76766cefbd876f44bd79aff030287254e4752c4"}, + {file = "fonttools-4.54.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7ed7ee041ff7b34cc62f07545e55e1468808691dddfd315d51dd82a6b37ddef2"}, + {file = "fonttools-4.54.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:41bb0b250c8132b2fcac148e2e9198e62ff06f3cc472065dff839327945c5882"}, + {file = "fonttools-4.54.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7965af9b67dd546e52afcf2e38641b5be956d68c425bef2158e95af11d229f10"}, + {file = "fonttools-4.54.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:278913a168f90d53378c20c23b80f4e599dca62fbffae4cc620c8eed476b723e"}, + {file = "fonttools-4.54.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:0e88e3018ac809b9662615072dcd6b84dca4c2d991c6d66e1970a112503bba7e"}, + {file = "fonttools-4.54.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:4aa4817f0031206e637d1e685251ac61be64d1adef111060df84fdcbc6ab6c44"}, + {file = "fonttools-4.54.1-cp310-cp310-win32.whl", hash = "sha256:7e3b7d44e18c085fd8c16dcc6f1ad6c61b71ff463636fcb13df7b1b818bd0c02"}, + {file = "fonttools-4.54.1-cp310-cp310-win_amd64.whl", hash = "sha256:dd9cc95b8d6e27d01e1e1f1fae8559ef3c02c76317da650a19047f249acd519d"}, + {file = "fonttools-4.54.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5419771b64248484299fa77689d4f3aeed643ea6630b2ea750eeab219588ba20"}, + {file = "fonttools-4.54.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:301540e89cf4ce89d462eb23a89464fef50915255ece765d10eee8b2bf9d75b2"}, + {file = "fonttools-4.54.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76ae5091547e74e7efecc3cbf8e75200bc92daaeb88e5433c5e3e95ea8ce5aa7"}, + {file = "fonttools-4.54.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:82834962b3d7c5ca98cb56001c33cf20eb110ecf442725dc5fdf36d16ed1ab07"}, + {file = "fonttools-4.54.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d26732ae002cc3d2ecab04897bb02ae3f11f06dd7575d1df46acd2f7c012a8d8"}, + {file = "fonttools-4.54.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:58974b4987b2a71ee08ade1e7f47f410c367cdfc5a94fabd599c88165f56213a"}, + {file = "fonttools-4.54.1-cp311-cp311-win32.whl", hash = "sha256:ab774fa225238986218a463f3fe151e04d8c25d7de09df7f0f5fce27b1243dbc"}, + {file = "fonttools-4.54.1-cp311-cp311-win_amd64.whl", hash = "sha256:07e005dc454eee1cc60105d6a29593459a06321c21897f769a281ff2d08939f6"}, + {file = "fonttools-4.54.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:54471032f7cb5fca694b5f1a0aaeba4af6e10ae989df408e0216f7fd6cdc405d"}, + {file = "fonttools-4.54.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8fa92cb248e573daab8d032919623cc309c005086d743afb014c836636166f08"}, + {file = "fonttools-4.54.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a911591200114969befa7f2cb74ac148bce5a91df5645443371aba6d222e263"}, + {file = "fonttools-4.54.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93d458c8a6a354dc8b48fc78d66d2a8a90b941f7fec30e94c7ad9982b1fa6bab"}, + {file = "fonttools-4.54.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5eb2474a7c5be8a5331146758debb2669bf5635c021aee00fd7c353558fc659d"}, + {file = "fonttools-4.54.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c9c563351ddc230725c4bdf7d9e1e92cbe6ae8553942bd1fb2b2ff0884e8b714"}, + {file = "fonttools-4.54.1-cp312-cp312-win32.whl", hash = "sha256:fdb062893fd6d47b527d39346e0c5578b7957dcea6d6a3b6794569370013d9ac"}, + {file = "fonttools-4.54.1-cp312-cp312-win_amd64.whl", hash = "sha256:e4564cf40cebcb53f3dc825e85910bf54835e8a8b6880d59e5159f0f325e637e"}, + {file = "fonttools-4.54.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:6e37561751b017cf5c40fce0d90fd9e8274716de327ec4ffb0df957160be3bff"}, + {file = "fonttools-4.54.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:357cacb988a18aace66e5e55fe1247f2ee706e01debc4b1a20d77400354cddeb"}, + {file = "fonttools-4.54.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8e953cc0bddc2beaf3a3c3b5dd9ab7554677da72dfaf46951e193c9653e515a"}, + {file = "fonttools-4.54.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:58d29b9a294573d8319f16f2f79e42428ba9b6480442fa1836e4eb89c4d9d61c"}, + {file = "fonttools-4.54.1-cp313-cp313-win32.whl", hash = "sha256:9ef1b167e22709b46bf8168368b7b5d3efeaaa746c6d39661c1b4405b6352e58"}, + {file = "fonttools-4.54.1-cp313-cp313-win_amd64.whl", hash = "sha256:262705b1663f18c04250bd1242b0515d3bbae177bee7752be67c979b7d47f43d"}, + {file = "fonttools-4.54.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ed2f80ca07025551636c555dec2b755dd005e2ea8fbeb99fc5cdff319b70b23b"}, + {file = "fonttools-4.54.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9dc080e5a1c3b2656caff2ac2633d009b3a9ff7b5e93d0452f40cd76d3da3b3c"}, + {file = "fonttools-4.54.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d152d1be65652fc65e695e5619e0aa0982295a95a9b29b52b85775243c06556"}, + {file = "fonttools-4.54.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8583e563df41fdecef31b793b4dd3af8a9caa03397be648945ad32717a92885b"}, + {file = "fonttools-4.54.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:0d1d353ef198c422515a3e974a1e8d5b304cd54a4c2eebcae708e37cd9eeffb1"}, + {file = "fonttools-4.54.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:fda582236fee135d4daeca056c8c88ec5f6f6d88a004a79b84a02547c8f57386"}, + {file = "fonttools-4.54.1-cp38-cp38-win32.whl", hash = "sha256:e7d82b9e56716ed32574ee106cabca80992e6bbdcf25a88d97d21f73a0aae664"}, + {file = "fonttools-4.54.1-cp38-cp38-win_amd64.whl", hash = "sha256:ada215fd079e23e060157aab12eba0d66704316547f334eee9ff26f8c0d7b8ab"}, + {file = "fonttools-4.54.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:f5b8a096e649768c2f4233f947cf9737f8dbf8728b90e2771e2497c6e3d21d13"}, + {file = "fonttools-4.54.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4e10d2e0a12e18f4e2dd031e1bf7c3d7017be5c8dbe524d07706179f355c5dac"}, + {file = "fonttools-4.54.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:31c32d7d4b0958600eac75eaf524b7b7cb68d3a8c196635252b7a2c30d80e986"}, + {file = "fonttools-4.54.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c39287f5c8f4a0c5a55daf9eaf9ccd223ea59eed3f6d467133cc727d7b943a55"}, + {file = "fonttools-4.54.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:a7a310c6e0471602fe3bf8efaf193d396ea561486aeaa7adc1f132e02d30c4b9"}, + {file = "fonttools-4.54.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:d3b659d1029946f4ff9b6183984578041b520ce0f8fb7078bb37ec7445806b33"}, + {file = "fonttools-4.54.1-cp39-cp39-win32.whl", hash = "sha256:e96bc94c8cda58f577277d4a71f51c8e2129b8b36fd05adece6320dd3d57de8a"}, + {file = "fonttools-4.54.1-cp39-cp39-win_amd64.whl", hash = "sha256:e8a4b261c1ef91e7188a30571be6ad98d1c6d9fa2427244c545e2fa0a2494dd7"}, + {file = "fonttools-4.54.1-py3-none-any.whl", hash = "sha256:37cddd62d83dc4f72f7c3f3c2bcf2697e89a30efb152079896544a93907733bd"}, + {file = "fonttools-4.54.1.tar.gz", hash = "sha256:957f669d4922f92c171ba01bef7f29410668db09f6c02111e22b2bce446f3285"}, ] [package.extras] @@ -2537,141 +2797,162 @@ woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"] [[package]] name = "frozendict" -version = "2.4.4" +version = "2.4.6" description = "A simple immutable dictionary" optional = false python-versions = ">=3.6" files = [ - {file = "frozendict-2.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4a59578d47b3949437519b5c39a016a6116b9e787bb19289e333faae81462e59"}, - {file = "frozendict-2.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12a342e439aef28ccec533f0253ea53d75fe9102bd6ea928ff530e76eac38906"}, - {file = "frozendict-2.4.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f79c26dff10ce11dad3b3627c89bb2e87b9dd5958c2b24325f16a23019b8b94"}, - {file = "frozendict-2.4.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:2bd009cf4fc47972838a91e9b83654dc9a095dc4f2bb3a37c3f3124c8a364543"}, - {file = "frozendict-2.4.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:87ebcde21565a14fe039672c25550060d6f6d88cf1f339beac094c3b10004eb0"}, - {file = "frozendict-2.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:fefeb700bc7eb8b4c2dc48704e4221860d254c8989fb53488540bc44e44a1ac2"}, - {file = "frozendict-2.4.4-cp310-cp310-win_arm64.whl", hash = "sha256:4297d694eb600efa429769125a6f910ec02b85606f22f178bafbee309e7d3ec7"}, - {file = "frozendict-2.4.4-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:812ab17522ba13637826e65454115a914c2da538356e85f43ecea069813e4b33"}, - {file = "frozendict-2.4.4-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7fee9420475bb6ff357000092aa9990c2f6182b2bab15764330f4ad7de2eae49"}, - {file = "frozendict-2.4.4-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:3148062675536724502c6344d7c485dd4667fdf7980ca9bd05e338ccc0c4471e"}, - {file = "frozendict-2.4.4-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:78c94991944dd33c5376f720228e5b252ee67faf3bac50ef381adc9e51e90d9d"}, - {file = "frozendict-2.4.4-cp36-cp36m-win_amd64.whl", hash = "sha256:1697793b5f62b416c0fc1d94638ec91ed3aa4ab277f6affa3a95216ecb3af170"}, - {file = "frozendict-2.4.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:199a4d32194f3afed6258de7e317054155bc9519252b568d9cfffde7e4d834e5"}, - {file = "frozendict-2.4.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85375ec6e979e6373bffb4f54576a68bf7497c350861d20686ccae38aab69c0a"}, - {file = "frozendict-2.4.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:2d8536e068d6bf281f23fa835ac07747fb0f8851879dd189e9709f9567408b4d"}, - {file = "frozendict-2.4.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:259528ba6b56fa051bc996f1c4d8b57e30d6dd3bc2f27441891b04babc4b5e73"}, - {file = "frozendict-2.4.4-cp37-cp37m-win_amd64.whl", hash = "sha256:07c3a5dee8bbb84cba770e273cdbf2c87c8e035903af8f781292d72583416801"}, - {file = "frozendict-2.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6874fec816b37b6eb5795b00e0574cba261bf59723e2de607a195d5edaff0786"}, - {file = "frozendict-2.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c8f92425686323a950337da4b75b4c17a3327b831df8c881df24038d560640d4"}, - {file = "frozendict-2.4.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d58d9a8d9e49662c6dafbea5e641f97decdb3d6ccd76e55e79818415362ba25"}, - {file = "frozendict-2.4.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:93a7b19afb429cbf99d56faf436b45ef2fa8fe9aca89c49eb1610c3bd85f1760"}, - {file = "frozendict-2.4.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2b70b431e3a72d410a2cdf1497b3aba2f553635e0c0f657ce311d841bf8273b6"}, - {file = "frozendict-2.4.4-cp38-cp38-win_amd64.whl", hash = "sha256:e1b941132d79ce72d562a13341d38fc217bc1ee24d8c35a20d754e79ff99e038"}, - {file = "frozendict-2.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:dc2228874eacae390e63fd4f2bb513b3144066a977dc192163c9f6c7f6de6474"}, - {file = "frozendict-2.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63aa49f1919af7d45fb8fd5dec4c0859bc09f46880bd6297c79bb2db2969b63d"}, - {file = "frozendict-2.4.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6bf9260018d653f3cab9bd147bd8592bf98a5c6e338be0491ced3c196c034a3"}, - {file = "frozendict-2.4.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:6eb716e6a6d693c03b1d53280a1947716129f5ef9bcdd061db5c17dea44b80fe"}, - {file = "frozendict-2.4.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d13b4310db337f4d2103867c5a05090b22bc4d50ca842093779ef541ea9c9eea"}, - {file = "frozendict-2.4.4-cp39-cp39-win_amd64.whl", hash = "sha256:b3b967d5065872e27b06f785a80c0ed0a45d1f7c9b85223da05358e734d858ca"}, - {file = "frozendict-2.4.4-cp39-cp39-win_arm64.whl", hash = "sha256:4ae8d05c8d0b6134bfb6bfb369d5fa0c4df21eabb5ca7f645af95fdc6689678e"}, - {file = "frozendict-2.4.4-py311-none-any.whl", hash = "sha256:705efca8d74d3facbb6ace80ab3afdd28eb8a237bfb4063ed89996b024bc443d"}, - {file = "frozendict-2.4.4-py312-none-any.whl", hash = "sha256:d9647563e76adb05b7cde2172403123380871360a114f546b4ae1704510801e5"}, - {file = "frozendict-2.4.4.tar.gz", hash = "sha256:3f7c031b26e4ee6a3f786ceb5e3abf1181c4ade92dce1f847da26ea2c96008c7"}, + {file = "frozendict-2.4.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c3a05c0a50cab96b4bb0ea25aa752efbfceed5ccb24c007612bc63e51299336f"}, + {file = "frozendict-2.4.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f5b94d5b07c00986f9e37a38dd83c13f5fe3bf3f1ccc8e88edea8fe15d6cd88c"}, + {file = "frozendict-2.4.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4c789fd70879ccb6289a603cdebdc4953e7e5dea047d30c1b180529b28257b5"}, + {file = "frozendict-2.4.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da6a10164c8a50b34b9ab508a9420df38f4edf286b9ca7b7df8a91767baecb34"}, + {file = "frozendict-2.4.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:9a8a43036754a941601635ea9c788ebd7a7efbed2becba01b54a887b41b175b9"}, + {file = "frozendict-2.4.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c9905dcf7aa659e6a11b8051114c9fa76dfde3a6e50e6dc129d5aece75b449a2"}, + {file = "frozendict-2.4.6-cp310-cp310-win_amd64.whl", hash = "sha256:323f1b674a2cc18f86ab81698e22aba8145d7a755e0ac2cccf142ee2db58620d"}, + {file = "frozendict-2.4.6-cp310-cp310-win_arm64.whl", hash = "sha256:eabd21d8e5db0c58b60d26b4bb9839cac13132e88277e1376970172a85ee04b3"}, + {file = "frozendict-2.4.6-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:eddabeb769fab1e122d3a6872982c78179b5bcc909fdc769f3cf1964f55a6d20"}, + {file = "frozendict-2.4.6-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:377a65be0a700188fc21e669c07de60f4f6d35fae8071c292b7df04776a1c27b"}, + {file = "frozendict-2.4.6-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce1e9217b85eec6ba9560d520d5089c82dbb15f977906eb345d81459723dd7e3"}, + {file = "frozendict-2.4.6-cp36-cp36m-musllinux_1_2_aarch64.whl", hash = "sha256:7291abacf51798d5ffe632771a69c14fb423ab98d63c4ccd1aa382619afe2f89"}, + {file = "frozendict-2.4.6-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:e72fb86e48811957d66ffb3e95580af7b1af1e6fbd760ad63d7bd79b2c9a07f8"}, + {file = "frozendict-2.4.6-cp36-cp36m-win_amd64.whl", hash = "sha256:622301b1c29c4f9bba633667d592a3a2b093cb408ba3ce578b8901ace3931ef3"}, + {file = "frozendict-2.4.6-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:a4e3737cb99ed03200cd303bdcd5514c9f34b29ee48f405c1184141bd68611c9"}, + {file = "frozendict-2.4.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:49ffaf09241bc1417daa19362a2241a4aa435f758fd4375c39ce9790443a39cd"}, + {file = "frozendict-2.4.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d69418479bfb834ba75b0e764f058af46ceee3d655deb6a0dd0c0c1a5e82f09"}, + {file = "frozendict-2.4.6-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:c131f10c4d3906866454c4e89b87a7e0027d533cce8f4652aa5255112c4d6677"}, + {file = "frozendict-2.4.6-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:fc67cbb3c96af7a798fab53d52589752c1673027e516b702ab355510ddf6bdff"}, + {file = "frozendict-2.4.6-cp37-cp37m-win_amd64.whl", hash = "sha256:7730f8ebe791d147a1586cbf6a42629351d4597773317002181b66a2da0d509e"}, + {file = "frozendict-2.4.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:807862e14b0e9665042458fde692c4431d660c4219b9bb240817f5b918182222"}, + {file = "frozendict-2.4.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9647c74efe3d845faa666d4853cfeabbaee403b53270cabfc635b321f770e6b8"}, + {file = "frozendict-2.4.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:665fad3f0f815aa41294e561d98dbedba4b483b3968e7e8cab7d728d64b96e33"}, + {file = "frozendict-2.4.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f42e6b75254ea2afe428ad6d095b62f95a7ae6d4f8272f0bd44a25dddd20f67"}, + {file = "frozendict-2.4.6-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:02331541611f3897f260900a1815b63389654951126e6e65545e529b63c08361"}, + {file = "frozendict-2.4.6-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:18d50a2598350b89189da9150058191f55057581e40533e470db46c942373acf"}, + {file = "frozendict-2.4.6-cp38-cp38-win_amd64.whl", hash = "sha256:1b4a3f8f6dd51bee74a50995c39b5a606b612847862203dd5483b9cd91b0d36a"}, + {file = "frozendict-2.4.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a76cee5c4be2a5d1ff063188232fffcce05dde6fd5edd6afe7b75b247526490e"}, + {file = "frozendict-2.4.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ba5ef7328706db857a2bdb2c2a17b4cd37c32a19c017cff1bb7eeebc86b0f411"}, + {file = "frozendict-2.4.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:669237c571856be575eca28a69e92a3d18f8490511eff184937283dc6093bd67"}, + {file = "frozendict-2.4.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0aaa11e7c472150efe65adbcd6c17ac0f586896096ab3963775e1c5c58ac0098"}, + {file = "frozendict-2.4.6-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:b8f2829048f29fe115da4a60409be2130e69402e29029339663fac39c90e6e2b"}, + {file = "frozendict-2.4.6-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:94321e646cc39bebc66954a31edd1847d3a2a3483cf52ff051cd0996e7db07db"}, + {file = "frozendict-2.4.6-cp39-cp39-win_amd64.whl", hash = "sha256:74b6b26c15dddfefddeb89813e455b00ebf78d0a3662b89506b4d55c6445a9f4"}, + {file = "frozendict-2.4.6-cp39-cp39-win_arm64.whl", hash = "sha256:7088102345d1606450bd1801a61139bbaa2cb0d805b9b692f8d81918ea835da6"}, + {file = "frozendict-2.4.6-py311-none-any.whl", hash = "sha256:d065db6a44db2e2375c23eac816f1a022feb2fa98cbb50df44a9e83700accbea"}, + {file = "frozendict-2.4.6-py312-none-any.whl", hash = "sha256:49344abe90fb75f0f9fdefe6d4ef6d4894e640fadab71f11009d52ad97f370b9"}, + {file = "frozendict-2.4.6-py313-none-any.whl", hash = "sha256:7134a2bb95d4a16556bb5f2b9736dceb6ea848fa5b6f3f6c2d6dba93b44b4757"}, + {file = "frozendict-2.4.6.tar.gz", hash = "sha256:df7cd16470fbd26fc4969a208efadc46319334eb97def1ddf48919b351192b8e"}, ] [[package]] name = "frozenlist" -version = "1.4.1" +version = "1.5.0" description = "A list-like structure which implements collections.abc.MutableSequence" optional = false python-versions = ">=3.8" files = [ - {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f9aa1878d1083b276b0196f2dfbe00c9b7e752475ed3b682025ff20c1c1f51ac"}, - {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:29acab3f66f0f24674b7dc4736477bcd4bc3ad4b896f5f45379a67bce8b96868"}, - {file = "frozenlist-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:74fb4bee6880b529a0c6560885fce4dc95936920f9f20f53d99a213f7bf66776"}, - {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:590344787a90ae57d62511dd7c736ed56b428f04cd8c161fcc5e7232c130c69a"}, - {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:068b63f23b17df8569b7fdca5517edef76171cf3897eb68beb01341131fbd2ad"}, - {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c849d495bf5154cd8da18a9eb15db127d4dba2968d88831aff6f0331ea9bd4c"}, - {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9750cc7fe1ae3b1611bb8cfc3f9ec11d532244235d75901fb6b8e42ce9229dfe"}, - {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9b2de4cf0cdd5bd2dee4c4f63a653c61d2408055ab77b151c1957f221cabf2a"}, - {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0633c8d5337cb5c77acbccc6357ac49a1770b8c487e5b3505c57b949b4b82e98"}, - {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:27657df69e8801be6c3638054e202a135c7f299267f1a55ed3a598934f6c0d75"}, - {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:f9a3ea26252bd92f570600098783d1371354d89d5f6b7dfd87359d669f2109b5"}, - {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4f57dab5fe3407b6c0c1cc907ac98e8a189f9e418f3b6e54d65a718aaafe3950"}, - {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e02a0e11cf6597299b9f3bbd3f93d79217cb90cfd1411aec33848b13f5c656cc"}, - {file = "frozenlist-1.4.1-cp310-cp310-win32.whl", hash = "sha256:a828c57f00f729620a442881cc60e57cfcec6842ba38e1b19fd3e47ac0ff8dc1"}, - {file = "frozenlist-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:f56e2333dda1fe0f909e7cc59f021eba0d2307bc6f012a1ccf2beca6ba362439"}, - {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a0cb6f11204443f27a1628b0e460f37fb30f624be6051d490fa7d7e26d4af3d0"}, - {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b46c8ae3a8f1f41a0d2ef350c0b6e65822d80772fe46b653ab6b6274f61d4a49"}, - {file = "frozenlist-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fde5bd59ab5357e3853313127f4d3565fc7dad314a74d7b5d43c22c6a5ed2ced"}, - {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:722e1124aec435320ae01ee3ac7bec11a5d47f25d0ed6328f2273d287bc3abb0"}, - {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2471c201b70d58a0f0c1f91261542a03d9a5e088ed3dc6c160d614c01649c106"}, - {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c757a9dd70d72b076d6f68efdbb9bc943665ae954dad2801b874c8c69e185068"}, - {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f146e0911cb2f1da549fc58fc7bcd2b836a44b79ef871980d605ec392ff6b0d2"}, - {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9c515e7914626b2a2e1e311794b4c35720a0be87af52b79ff8e1429fc25f19"}, - {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c302220494f5c1ebeb0912ea782bcd5e2f8308037b3c7553fad0e48ebad6ad82"}, - {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:442acde1e068288a4ba7acfe05f5f343e19fac87bfc96d89eb886b0363e977ec"}, - {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:1b280e6507ea8a4fa0c0a7150b4e526a8d113989e28eaaef946cc77ffd7efc0a"}, - {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:fe1a06da377e3a1062ae5fe0926e12b84eceb8a50b350ddca72dc85015873f74"}, - {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:db9e724bebd621d9beca794f2a4ff1d26eed5965b004a97f1f1685a173b869c2"}, - {file = "frozenlist-1.4.1-cp311-cp311-win32.whl", hash = "sha256:e774d53b1a477a67838a904131c4b0eef6b3d8a651f8b138b04f748fccfefe17"}, - {file = "frozenlist-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:fb3c2db03683b5767dedb5769b8a40ebb47d6f7f45b1b3e3b4b51ec8ad9d9825"}, - {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1979bc0aeb89b33b588c51c54ab0161791149f2461ea7c7c946d95d5f93b56ae"}, - {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cc7b01b3754ea68a62bd77ce6020afaffb44a590c2289089289363472d13aedb"}, - {file = "frozenlist-1.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9c92be9fd329ac801cc420e08452b70e7aeab94ea4233a4804f0915c14eba9b"}, - {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3894db91f5a489fc8fa6a9991820f368f0b3cbdb9cd8849547ccfab3392d86"}, - {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba60bb19387e13597fb059f32cd4d59445d7b18b69a745b8f8e5db0346f33480"}, - {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8aefbba5f69d42246543407ed2461db31006b0f76c4e32dfd6f42215a2c41d09"}, - {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:780d3a35680ced9ce682fbcf4cb9c2bad3136eeff760ab33707b71db84664e3a"}, - {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9acbb16f06fe7f52f441bb6f413ebae6c37baa6ef9edd49cdd567216da8600cd"}, - {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:23b701e65c7b36e4bf15546a89279bd4d8675faabc287d06bbcfac7d3c33e1e6"}, - {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3e0153a805a98f5ada7e09826255ba99fb4f7524bb81bf6b47fb702666484ae1"}, - {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:dd9b1baec094d91bf36ec729445f7769d0d0cf6b64d04d86e45baf89e2b9059b"}, - {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:1a4471094e146b6790f61b98616ab8e44f72661879cc63fa1049d13ef711e71e"}, - {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5667ed53d68d91920defdf4035d1cdaa3c3121dc0b113255124bcfada1cfa1b8"}, - {file = "frozenlist-1.4.1-cp312-cp312-win32.whl", hash = "sha256:beee944ae828747fd7cb216a70f120767fc9f4f00bacae8543c14a6831673f89"}, - {file = "frozenlist-1.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:64536573d0a2cb6e625cf309984e2d873979709f2cf22839bf2d61790b448ad5"}, - {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:20b51fa3f588ff2fe658663db52a41a4f7aa6c04f6201449c6c7c476bd255c0d"}, - {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:410478a0c562d1a5bcc2f7ea448359fcb050ed48b3c6f6f4f18c313a9bdb1826"}, - {file = "frozenlist-1.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c6321c9efe29975232da3bd0af0ad216800a47e93d763ce64f291917a381b8eb"}, - {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48f6a4533887e189dae092f1cf981f2e3885175f7a0f33c91fb5b7b682b6bab6"}, - {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6eb73fa5426ea69ee0e012fb59cdc76a15b1283d6e32e4f8dc4482ec67d1194d"}, - {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fbeb989b5cc29e8daf7f976b421c220f1b8c731cbf22b9130d8815418ea45887"}, - {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:32453c1de775c889eb4e22f1197fe3bdfe457d16476ea407472b9442e6295f7a"}, - {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:693945278a31f2086d9bf3df0fe8254bbeaef1fe71e1351c3bd730aa7d31c41b"}, - {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:1d0ce09d36d53bbbe566fe296965b23b961764c0bcf3ce2fa45f463745c04701"}, - {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3a670dc61eb0d0eb7080890c13de3066790f9049b47b0de04007090807c776b0"}, - {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:dca69045298ce5c11fd539682cff879cc1e664c245d1c64da929813e54241d11"}, - {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a06339f38e9ed3a64e4c4e43aec7f59084033647f908e4259d279a52d3757d09"}, - {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b7f2f9f912dca3934c1baec2e4585a674ef16fe00218d833856408c48d5beee7"}, - {file = "frozenlist-1.4.1-cp38-cp38-win32.whl", hash = "sha256:e7004be74cbb7d9f34553a5ce5fb08be14fb33bc86f332fb71cbe5216362a497"}, - {file = "frozenlist-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:5a7d70357e7cee13f470c7883a063aae5fe209a493c57d86eb7f5a6f910fae09"}, - {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bfa4a17e17ce9abf47a74ae02f32d014c5e9404b6d9ac7f729e01562bbee601e"}, - {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b7e3ed87d4138356775346e6845cccbe66cd9e207f3cd11d2f0b9fd13681359d"}, - {file = "frozenlist-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c99169d4ff810155ca50b4da3b075cbde79752443117d89429595c2e8e37fed8"}, - {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edb678da49d9f72c9f6c609fbe41a5dfb9a9282f9e6a2253d5a91e0fc382d7c0"}, - {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6db4667b187a6742b33afbbaf05a7bc551ffcf1ced0000a571aedbb4aa42fc7b"}, - {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55fdc093b5a3cb41d420884cdaf37a1e74c3c37a31f46e66286d9145d2063bd0"}, - {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82e8211d69a4f4bc360ea22cd6555f8e61a1bd211d1d5d39d3d228b48c83a897"}, - {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89aa2c2eeb20957be2d950b85974b30a01a762f3308cd02bb15e1ad632e22dc7"}, - {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9d3e0c25a2350080e9319724dede4f31f43a6c9779be48021a7f4ebde8b2d742"}, - {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7268252af60904bf52c26173cbadc3a071cece75f873705419c8681f24d3edea"}, - {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:0c250a29735d4f15321007fb02865f0e6b6a41a6b88f1f523ca1596ab5f50bd5"}, - {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:96ec70beabbd3b10e8bfe52616a13561e58fe84c0101dd031dc78f250d5128b9"}, - {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:23b2d7679b73fe0e5a4560b672a39f98dfc6f60df63823b0a9970525325b95f6"}, - {file = "frozenlist-1.4.1-cp39-cp39-win32.whl", hash = "sha256:a7496bfe1da7fb1a4e1cc23bb67c58fab69311cc7d32b5a99c2007b4b2a0e932"}, - {file = "frozenlist-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:e6a20a581f9ce92d389a8c7d7c3dd47c81fd5d6e655c8dddf341e14aa48659d0"}, - {file = "frozenlist-1.4.1-py3-none-any.whl", hash = "sha256:04ced3e6a46b4cfffe20f9ae482818e34eba9b5fb0ce4056e4cc9b6e212d09b7"}, - {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"}, + {file = "frozenlist-1.5.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5b6a66c18b5b9dd261ca98dffcb826a525334b2f29e7caa54e182255c5f6a65a"}, + {file = "frozenlist-1.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d1b3eb7b05ea246510b43a7e53ed1653e55c2121019a97e60cad7efb881a97bb"}, + {file = "frozenlist-1.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:15538c0cbf0e4fa11d1e3a71f823524b0c46299aed6e10ebb4c2089abd8c3bec"}, + {file = "frozenlist-1.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e79225373c317ff1e35f210dd5f1344ff31066ba8067c307ab60254cd3a78ad5"}, + {file = "frozenlist-1.5.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9272fa73ca71266702c4c3e2d4a28553ea03418e591e377a03b8e3659d94fa76"}, + {file = "frozenlist-1.5.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:498524025a5b8ba81695761d78c8dd7382ac0b052f34e66939c42df860b8ff17"}, + {file = "frozenlist-1.5.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:92b5278ed9d50fe610185ecd23c55d8b307d75ca18e94c0e7de328089ac5dcba"}, + {file = "frozenlist-1.5.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f3c8c1dacd037df16e85227bac13cca58c30da836c6f936ba1df0c05d046d8d"}, + {file = "frozenlist-1.5.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f2ac49a9bedb996086057b75bf93538240538c6d9b38e57c82d51f75a73409d2"}, + {file = "frozenlist-1.5.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e66cc454f97053b79c2ab09c17fbe3c825ea6b4de20baf1be28919460dd7877f"}, + {file = "frozenlist-1.5.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:5a3ba5f9a0dfed20337d3e966dc359784c9f96503674c2faf015f7fe8e96798c"}, + {file = "frozenlist-1.5.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:6321899477db90bdeb9299ac3627a6a53c7399c8cd58d25da094007402b039ab"}, + {file = "frozenlist-1.5.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:76e4753701248476e6286f2ef492af900ea67d9706a0155335a40ea21bf3b2f5"}, + {file = "frozenlist-1.5.0-cp310-cp310-win32.whl", hash = "sha256:977701c081c0241d0955c9586ffdd9ce44f7a7795df39b9151cd9a6fd0ce4cfb"}, + {file = "frozenlist-1.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:189f03b53e64144f90990d29a27ec4f7997d91ed3d01b51fa39d2dbe77540fd4"}, + {file = "frozenlist-1.5.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:fd74520371c3c4175142d02a976aee0b4cb4a7cc912a60586ffd8d5929979b30"}, + {file = "frozenlist-1.5.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2f3f7a0fbc219fb4455264cae4d9f01ad41ae6ee8524500f381de64ffaa077d5"}, + {file = "frozenlist-1.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f47c9c9028f55a04ac254346e92977bf0f166c483c74b4232bee19a6697e4778"}, + {file = "frozenlist-1.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0996c66760924da6e88922756d99b47512a71cfd45215f3570bf1e0b694c206a"}, + {file = "frozenlist-1.5.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a2fe128eb4edeabe11896cb6af88fca5346059f6c8d807e3b910069f39157869"}, + {file = "frozenlist-1.5.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1a8ea951bbb6cacd492e3948b8da8c502a3f814f5d20935aae74b5df2b19cf3d"}, + {file = "frozenlist-1.5.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de537c11e4aa01d37db0d403b57bd6f0546e71a82347a97c6a9f0dcc532b3a45"}, + {file = "frozenlist-1.5.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c2623347b933fcb9095841f1cc5d4ff0b278addd743e0e966cb3d460278840d"}, + {file = "frozenlist-1.5.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:cee6798eaf8b1416ef6909b06f7dc04b60755206bddc599f52232606e18179d3"}, + {file = "frozenlist-1.5.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:f5f9da7f5dbc00a604fe74aa02ae7c98bcede8a3b8b9666f9f86fc13993bc71a"}, + {file = "frozenlist-1.5.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:90646abbc7a5d5c7c19461d2e3eeb76eb0b204919e6ece342feb6032c9325ae9"}, + {file = "frozenlist-1.5.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:bdac3c7d9b705d253b2ce370fde941836a5f8b3c5c2b8fd70940a3ea3af7f4f2"}, + {file = "frozenlist-1.5.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:03d33c2ddbc1816237a67f66336616416e2bbb6beb306e5f890f2eb22b959cdf"}, + {file = "frozenlist-1.5.0-cp311-cp311-win32.whl", hash = "sha256:237f6b23ee0f44066219dae14c70ae38a63f0440ce6750f868ee08775073f942"}, + {file = "frozenlist-1.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:0cc974cc93d32c42e7b0f6cf242a6bd941c57c61b618e78b6c0a96cb72788c1d"}, + {file = "frozenlist-1.5.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:31115ba75889723431aa9a4e77d5f398f5cf976eea3bdf61749731f62d4a4a21"}, + {file = "frozenlist-1.5.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7437601c4d89d070eac8323f121fcf25f88674627505334654fd027b091db09d"}, + {file = "frozenlist-1.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7948140d9f8ece1745be806f2bfdf390127cf1a763b925c4a805c603df5e697e"}, + {file = "frozenlist-1.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:feeb64bc9bcc6b45c6311c9e9b99406660a9c05ca8a5b30d14a78555088b0b3a"}, + {file = "frozenlist-1.5.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:683173d371daad49cffb8309779e886e59c2f369430ad28fe715f66d08d4ab1a"}, + {file = "frozenlist-1.5.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7d57d8f702221405a9d9b40f9da8ac2e4a1a8b5285aac6100f3393675f0a85ee"}, + {file = "frozenlist-1.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:30c72000fbcc35b129cb09956836c7d7abf78ab5416595e4857d1cae8d6251a6"}, + {file = "frozenlist-1.5.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:000a77d6034fbad9b6bb880f7ec073027908f1b40254b5d6f26210d2dab1240e"}, + {file = "frozenlist-1.5.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5d7f5a50342475962eb18b740f3beecc685a15b52c91f7d975257e13e029eca9"}, + {file = "frozenlist-1.5.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:87f724d055eb4785d9be84e9ebf0f24e392ddfad00b3fe036e43f489fafc9039"}, + {file = "frozenlist-1.5.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:6e9080bb2fb195a046e5177f10d9d82b8a204c0736a97a153c2466127de87784"}, + {file = "frozenlist-1.5.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:9b93d7aaa36c966fa42efcaf716e6b3900438632a626fb09c049f6a2f09fc631"}, + {file = "frozenlist-1.5.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:52ef692a4bc60a6dd57f507429636c2af8b6046db8b31b18dac02cbc8f507f7f"}, + {file = "frozenlist-1.5.0-cp312-cp312-win32.whl", hash = "sha256:29d94c256679247b33a3dc96cce0f93cbc69c23bf75ff715919332fdbb6a32b8"}, + {file = "frozenlist-1.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:8969190d709e7c48ea386db202d708eb94bdb29207a1f269bab1196ce0dcca1f"}, + {file = "frozenlist-1.5.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:7a1a048f9215c90973402e26c01d1cff8a209e1f1b53f72b95c13db61b00f953"}, + {file = "frozenlist-1.5.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:dd47a5181ce5fcb463b5d9e17ecfdb02b678cca31280639255ce9d0e5aa67af0"}, + {file = "frozenlist-1.5.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1431d60b36d15cda188ea222033eec8e0eab488f39a272461f2e6d9e1a8e63c2"}, + {file = "frozenlist-1.5.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6482a5851f5d72767fbd0e507e80737f9c8646ae7fd303def99bfe813f76cf7f"}, + {file = "frozenlist-1.5.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:44c49271a937625619e862baacbd037a7ef86dd1ee215afc298a417ff3270608"}, + {file = "frozenlist-1.5.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:12f78f98c2f1c2429d42e6a485f433722b0061d5c0b0139efa64f396efb5886b"}, + {file = "frozenlist-1.5.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce3aa154c452d2467487765e3adc730a8c153af77ad84096bc19ce19a2400840"}, + {file = "frozenlist-1.5.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9b7dc0c4338e6b8b091e8faf0db3168a37101943e687f373dce00959583f7439"}, + {file = "frozenlist-1.5.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:45e0896250900b5aa25180f9aec243e84e92ac84bd4a74d9ad4138ef3f5c97de"}, + {file = "frozenlist-1.5.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:561eb1c9579d495fddb6da8959fd2a1fca2c6d060d4113f5844b433fc02f2641"}, + {file = "frozenlist-1.5.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:df6e2f325bfee1f49f81aaac97d2aa757c7646534a06f8f577ce184afe2f0a9e"}, + {file = "frozenlist-1.5.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:140228863501b44b809fb39ec56b5d4071f4d0aa6d216c19cbb08b8c5a7eadb9"}, + {file = "frozenlist-1.5.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7707a25d6a77f5d27ea7dc7d1fc608aa0a478193823f88511ef5e6b8a48f9d03"}, + {file = "frozenlist-1.5.0-cp313-cp313-win32.whl", hash = "sha256:31a9ac2b38ab9b5a8933b693db4939764ad3f299fcaa931a3e605bc3460e693c"}, + {file = "frozenlist-1.5.0-cp313-cp313-win_amd64.whl", hash = "sha256:11aabdd62b8b9c4b84081a3c246506d1cddd2dd93ff0ad53ede5defec7886b28"}, + {file = "frozenlist-1.5.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:dd94994fc91a6177bfaafd7d9fd951bc8689b0a98168aa26b5f543868548d3ca"}, + {file = "frozenlist-1.5.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2d0da8bbec082bf6bf18345b180958775363588678f64998c2b7609e34719b10"}, + {file = "frozenlist-1.5.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:73f2e31ea8dd7df61a359b731716018c2be196e5bb3b74ddba107f694fbd7604"}, + {file = "frozenlist-1.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:828afae9f17e6de596825cf4228ff28fbdf6065974e5ac1410cecc22f699d2b3"}, + {file = "frozenlist-1.5.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f1577515d35ed5649d52ab4319db757bb881ce3b2b796d7283e6634d99ace307"}, + {file = "frozenlist-1.5.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2150cc6305a2c2ab33299453e2968611dacb970d2283a14955923062c8d00b10"}, + {file = "frozenlist-1.5.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a72b7a6e3cd2725eff67cd64c8f13335ee18fc3c7befc05aed043d24c7b9ccb9"}, + {file = "frozenlist-1.5.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c16d2fa63e0800723139137d667e1056bee1a1cf7965153d2d104b62855e9b99"}, + {file = "frozenlist-1.5.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:17dcc32fc7bda7ce5875435003220a457bcfa34ab7924a49a1c19f55b6ee185c"}, + {file = "frozenlist-1.5.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:97160e245ea33d8609cd2b8fd997c850b56db147a304a262abc2b3be021a9171"}, + {file = "frozenlist-1.5.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:f1e6540b7fa044eee0bb5111ada694cf3dc15f2b0347ca125ee9ca984d5e9e6e"}, + {file = "frozenlist-1.5.0-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:91d6c171862df0a6c61479d9724f22efb6109111017c87567cfeb7b5d1449fdf"}, + {file = "frozenlist-1.5.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c1fac3e2ace2eb1052e9f7c7db480818371134410e1f5c55d65e8f3ac6d1407e"}, + {file = "frozenlist-1.5.0-cp38-cp38-win32.whl", hash = "sha256:b97f7b575ab4a8af9b7bc1d2ef7f29d3afee2226bd03ca3875c16451ad5a7723"}, + {file = "frozenlist-1.5.0-cp38-cp38-win_amd64.whl", hash = "sha256:374ca2dabdccad8e2a76d40b1d037f5bd16824933bf7bcea3e59c891fd4a0923"}, + {file = "frozenlist-1.5.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:9bbcdfaf4af7ce002694a4e10a0159d5a8d20056a12b05b45cea944a4953f972"}, + {file = "frozenlist-1.5.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1893f948bf6681733aaccf36c5232c231e3b5166d607c5fa77773611df6dc336"}, + {file = "frozenlist-1.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2b5e23253bb709ef57a8e95e6ae48daa9ac5f265637529e4ce6b003a37b2621f"}, + {file = "frozenlist-1.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f253985bb515ecd89629db13cb58d702035ecd8cfbca7d7a7e29a0e6d39af5f"}, + {file = "frozenlist-1.5.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:04a5c6babd5e8fb7d3c871dc8b321166b80e41b637c31a995ed844a6139942b6"}, + {file = "frozenlist-1.5.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a9fe0f1c29ba24ba6ff6abf688cb0b7cf1efab6b6aa6adc55441773c252f7411"}, + {file = "frozenlist-1.5.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:226d72559fa19babe2ccd920273e767c96a49b9d3d38badd7c91a0fdeda8ea08"}, + {file = "frozenlist-1.5.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15b731db116ab3aedec558573c1a5eec78822b32292fe4f2f0345b7f697745c2"}, + {file = "frozenlist-1.5.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:366d8f93e3edfe5a918c874702f78faac300209a4d5bf38352b2c1bdc07a766d"}, + {file = "frozenlist-1.5.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:1b96af8c582b94d381a1c1f51ffaedeb77c821c690ea5f01da3d70a487dd0a9b"}, + {file = "frozenlist-1.5.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:c03eff4a41bd4e38415cbed054bbaff4a075b093e2394b6915dca34a40d1e38b"}, + {file = "frozenlist-1.5.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:50cf5e7ee9b98f22bdecbabf3800ae78ddcc26e4a435515fc72d97903e8488e0"}, + {file = "frozenlist-1.5.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1e76bfbc72353269c44e0bc2cfe171900fbf7f722ad74c9a7b638052afe6a00c"}, + {file = "frozenlist-1.5.0-cp39-cp39-win32.whl", hash = "sha256:666534d15ba8f0fda3f53969117383d5dc021266b3c1a42c9ec4855e4b58b9d3"}, + {file = "frozenlist-1.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:5c28f4b5dbef8a0d8aad0d4de24d1e9e981728628afaf4ea0792f5d0939372f0"}, + {file = "frozenlist-1.5.0-py3-none-any.whl", hash = "sha256:d994863bba198a4a518b467bb971c56e1db3f180a25c6cf7bb1949c267f748c3"}, + {file = "frozenlist-1.5.0.tar.gz", hash = "sha256:81d5af29e61b9c8348e876d442253723928dce6433e0e76cd925cd83f1b4b817"}, ] [[package]] name = "fsspec" -version = "2024.6.1" +version = "2024.10.0" description = "File-system specification" optional = false python-versions = ">=3.8" files = [ - {file = "fsspec-2024.6.1-py3-none-any.whl", hash = "sha256:3cb443f8bcd2efb31295a5b9fdb02aee81d8452c80d28f97a6d0959e6cee101e"}, - {file = "fsspec-2024.6.1.tar.gz", hash = "sha256:fad7d7e209dd4c1208e3bbfda706620e0da5142bebbd9c384afb95b07e798e49"}, + {file = "fsspec-2024.10.0-py3-none-any.whl", hash = "sha256:03b9a6785766a4de40368b88906366755e2819e758b83705c88cd7cb5fe81871"}, + {file = "fsspec-2024.10.0.tar.gz", hash = "sha256:eda2d8a4116d4f2429db8550f2457da57279247dd930bb12f821b58391359493"}, ] [package.extras] @@ -2702,6 +2983,17 @@ test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask-expr", "dask[dataframe, test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"] tqdm = ["tqdm"] +[[package]] +name = "future" +version = "1.0.0" +description = "Clean single-source support for Python 3 and 2" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "future-1.0.0-py3-none-any.whl", hash = "sha256:929292d34f5872e70396626ef385ec22355a1fae8ad29e1a734c3e43f9fbc216"}, + {file = "future-1.0.0.tar.gz", hash = "sha256:bd2968309307861edae1458a4f8a4f3598c03be43b97521076aebf5d94c07b05"}, +] + [[package]] name = "gevent" version = "23.9.1" @@ -2829,22 +3121,36 @@ files = [ docs = ["sphinx (>=4)", "sphinx-rtd-theme (>=1)"] tests = ["cython", "hypothesis", "mpmath", "pytest", "setuptools"] +[[package]] +name = "google" +version = "3.0.0" +description = "Python bindings to the Google search engine." +optional = false +python-versions = "*" +files = [ + {file = "google-3.0.0-py2.py3-none-any.whl", hash = "sha256:889cf695f84e4ae2c55fbc0cfdaf4c1e729417fa52ab1db0485202ba173e4935"}, + {file = "google-3.0.0.tar.gz", hash = "sha256:143530122ee5130509ad5e989f0512f7cb218b2d4eddbafbad40fd10e8d8ccbe"}, +] + +[package.dependencies] +beautifulsoup4 = "*" + [[package]] name = "google-ai-generativelanguage" -version = "0.6.1" +version = "0.6.9" description = "Google Ai Generativelanguage API client library" optional = false python-versions = ">=3.7" files = [ - {file = "google-ai-generativelanguage-0.6.1.tar.gz", hash = "sha256:4abf37000718b20c43f4b90672b3ab8850738b02457efffd11f3184e03272ed2"}, - {file = "google_ai_generativelanguage-0.6.1-py3-none-any.whl", hash = "sha256:d2afc991c47663bdf65bd4aabcd89723550b81ad0a6d0be8bfb0160755da4cf0"}, + {file = "google_ai_generativelanguage-0.6.9-py3-none-any.whl", hash = "sha256:50360cd80015d1a8cc70952e98560f32fa06ddee2e8e9f4b4b98e431dc561e0b"}, + {file = "google_ai_generativelanguage-0.6.9.tar.gz", hash = "sha256:899f1d3a06efa9739f1cd9d2788070178db33c89d4a76f2e8f4da76f649155fa"}, ] [package.dependencies] google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev" proto-plus = ">=1.22.3,<2.0.0dev" -protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" +protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev" [[package]] name = "google-api-core" @@ -2982,30 +3288,30 @@ xai = ["tensorflow (>=2.3.0,<3.0.0dev)"] [[package]] name = "google-cloud-bigquery" -version = "3.25.0" +version = "3.26.0" description = "Google BigQuery API client library" optional = false python-versions = ">=3.7" files = [ - {file = "google-cloud-bigquery-3.25.0.tar.gz", hash = "sha256:5b2aff3205a854481117436836ae1403f11f2594e6810a98886afd57eda28509"}, - {file = "google_cloud_bigquery-3.25.0-py2.py3-none-any.whl", hash = "sha256:7f0c371bc74d2a7fb74dacbc00ac0f90c8c2bec2289b51dd6685a275873b1ce9"}, + {file = "google_cloud_bigquery-3.26.0-py2.py3-none-any.whl", hash = "sha256:e0e9ad28afa67a18696e624cbccab284bf2c0a3f6eeb9eeb0426c69b943793a8"}, + {file = "google_cloud_bigquery-3.26.0.tar.gz", hash = "sha256:edbdc788beea659e04c0af7fe4dcd6d9155344b98951a0d5055bd2f15da4ba23"}, ] [package.dependencies] -google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} +google-api-core = {version = ">=2.11.1,<3.0.0dev", extras = ["grpc"]} google-auth = ">=2.14.1,<3.0.0dev" -google-cloud-core = ">=1.6.0,<3.0.0dev" -google-resumable-media = ">=0.6.0,<3.0dev" +google-cloud-core = ">=2.4.1,<3.0.0dev" +google-resumable-media = ">=2.0.0,<3.0dev" packaging = ">=20.0.0" -python-dateutil = ">=2.7.2,<3.0dev" +python-dateutil = ">=2.7.3,<3.0dev" requests = ">=2.21.0,<3.0.0dev" [package.extras] -all = ["Shapely (>=1.8.4,<3.0.0dev)", "db-dtypes (>=0.3.0,<2.0.0dev)", "geopandas (>=0.9.0,<1.0dev)", "google-cloud-bigquery-storage (>=2.6.0,<3.0.0dev)", "grpcio (>=1.47.0,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "importlib-metadata (>=1.0.0)", "ipykernel (>=6.0.0)", "ipython (>=7.23.1,!=8.1.0)", "ipywidgets (>=7.7.0)", "opentelemetry-api (>=1.1.0)", "opentelemetry-instrumentation (>=0.20b0)", "opentelemetry-sdk (>=1.1.0)", "pandas (>=1.1.0)", "proto-plus (>=1.15.0,<2.0.0dev)", "protobuf (>=3.19.5,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev)", "pyarrow (>=3.0.0)", "tqdm (>=4.7.4,<5.0.0dev)"] -bigquery-v2 = ["proto-plus (>=1.15.0,<2.0.0dev)", "protobuf (>=3.19.5,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev)"] +all = ["Shapely (>=1.8.4,<3.0.0dev)", "bigquery-magics (>=0.1.0)", "db-dtypes (>=0.3.0,<2.0.0dev)", "geopandas (>=0.9.0,<1.0dev)", "google-cloud-bigquery-storage (>=2.6.0,<3.0.0dev)", "grpcio (>=1.47.0,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "importlib-metadata (>=1.0.0)", "ipykernel (>=6.0.0)", "ipywidgets (>=7.7.0)", "opentelemetry-api (>=1.1.0)", "opentelemetry-instrumentation (>=0.20b0)", "opentelemetry-sdk (>=1.1.0)", "pandas (>=1.1.0)", "proto-plus (>=1.22.3,<2.0.0dev)", "protobuf (>=3.20.2,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev)", "pyarrow (>=3.0.0)", "tqdm (>=4.7.4,<5.0.0dev)"] +bigquery-v2 = ["proto-plus (>=1.22.3,<2.0.0dev)", "protobuf (>=3.20.2,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev)"] bqstorage = ["google-cloud-bigquery-storage (>=2.6.0,<3.0.0dev)", "grpcio (>=1.47.0,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "pyarrow (>=3.0.0)"] geopandas = ["Shapely (>=1.8.4,<3.0.0dev)", "geopandas (>=0.9.0,<1.0dev)"] -ipython = ["ipykernel (>=6.0.0)", "ipython (>=7.23.1,!=8.1.0)"] +ipython = ["bigquery-magics (>=0.1.0)"] ipywidgets = ["ipykernel (>=6.0.0)", "ipywidgets (>=7.7.0)"] opentelemetry = ["opentelemetry-api (>=1.1.0)", "opentelemetry-instrumentation (>=0.20b0)", "opentelemetry-sdk (>=1.1.0)"] pandas = ["db-dtypes (>=0.3.0,<2.0.0dev)", "importlib-metadata (>=1.0.0)", "pandas (>=1.1.0)", "pyarrow (>=3.0.0)"] @@ -3031,13 +3337,13 @@ grpc = ["grpcio (>=1.38.0,<2.0dev)", "grpcio-status (>=1.38.0,<2.0.dev0)"] [[package]] name = "google-cloud-resource-manager" -version = "1.12.5" +version = "1.13.0" description = "Google Cloud Resource Manager API client library" optional = false python-versions = ">=3.7" files = [ - {file = "google_cloud_resource_manager-1.12.5-py2.py3-none-any.whl", hash = "sha256:2708a718b45c79464b7b21559c701b5c92e6b0b1ab2146d0a256277a623dc175"}, - {file = "google_cloud_resource_manager-1.12.5.tar.gz", hash = "sha256:b7af4254401ed4efa3aba3a929cb3ddb803fa6baf91a78485e45583597de5891"}, + {file = "google_cloud_resource_manager-1.13.0-py2.py3-none-any.whl", hash = "sha256:33beb4528c2b7aee7a97ed843710581a7b4a27f3dd1fa41a0bf3359b3d68853f"}, + {file = "google_cloud_resource_manager-1.13.0.tar.gz", hash = "sha256:ae4bf69443f14b37007d4d84150115b0942e8b01650fd7a1fc6ff4dc1760e5c4"}, ] [package.dependencies] @@ -3071,79 +3377,38 @@ protobuf = ["protobuf (<5.0.0dev)"] [[package]] name = "google-crc32c" -version = "1.5.0" +version = "1.6.0" description = "A python wrapper of the C library 'Google CRC32C'" optional = false -python-versions = ">=3.7" +python-versions = ">=3.9" files = [ - {file = "google-crc32c-1.5.0.tar.gz", hash = "sha256:89284716bc6a5a415d4eaa11b1726d2d60a0cd12aadf5439828353662ede9dd7"}, - {file = "google_crc32c-1.5.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:596d1f98fc70232fcb6590c439f43b350cb762fb5d61ce7b0e9db4539654cc13"}, - {file = "google_crc32c-1.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:be82c3c8cfb15b30f36768797a640e800513793d6ae1724aaaafe5bf86f8f346"}, - {file = "google_crc32c-1.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:461665ff58895f508e2866824a47bdee72497b091c730071f2b7575d5762ab65"}, - {file = "google_crc32c-1.5.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2096eddb4e7c7bdae4bd69ad364e55e07b8316653234a56552d9c988bd2d61b"}, - {file = "google_crc32c-1.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:116a7c3c616dd14a3de8c64a965828b197e5f2d121fedd2f8c5585c547e87b02"}, - {file = "google_crc32c-1.5.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5829b792bf5822fd0a6f6eb34c5f81dd074f01d570ed7f36aa101d6fc7a0a6e4"}, - {file = "google_crc32c-1.5.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:64e52e2b3970bd891309c113b54cf0e4384762c934d5ae56e283f9a0afcd953e"}, - {file = "google_crc32c-1.5.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:02ebb8bf46c13e36998aeaad1de9b48f4caf545e91d14041270d9dca767b780c"}, - {file = "google_crc32c-1.5.0-cp310-cp310-win32.whl", hash = "sha256:2e920d506ec85eb4ba50cd4228c2bec05642894d4c73c59b3a2fe20346bd00ee"}, - {file = "google_crc32c-1.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:07eb3c611ce363c51a933bf6bd7f8e3878a51d124acfc89452a75120bc436289"}, - {file = "google_crc32c-1.5.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:cae0274952c079886567f3f4f685bcaf5708f0a23a5f5216fdab71f81a6c0273"}, - {file = "google_crc32c-1.5.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1034d91442ead5a95b5aaef90dbfaca8633b0247d1e41621d1e9f9db88c36298"}, - {file = "google_crc32c-1.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c42c70cd1d362284289c6273adda4c6af8039a8ae12dc451dcd61cdabb8ab57"}, - {file = "google_crc32c-1.5.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8485b340a6a9e76c62a7dce3c98e5f102c9219f4cfbf896a00cf48caf078d438"}, - {file = "google_crc32c-1.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:77e2fd3057c9d78e225fa0a2160f96b64a824de17840351b26825b0848022906"}, - {file = "google_crc32c-1.5.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f583edb943cf2e09c60441b910d6a20b4d9d626c75a36c8fcac01a6c96c01183"}, - {file = "google_crc32c-1.5.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:a1fd716e7a01f8e717490fbe2e431d2905ab8aa598b9b12f8d10abebb36b04dd"}, - {file = "google_crc32c-1.5.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:72218785ce41b9cfd2fc1d6a017dc1ff7acfc4c17d01053265c41a2c0cc39b8c"}, - {file = "google_crc32c-1.5.0-cp311-cp311-win32.whl", hash = "sha256:66741ef4ee08ea0b2cc3c86916ab66b6aef03768525627fd6a1b34968b4e3709"}, - {file = "google_crc32c-1.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:ba1eb1843304b1e5537e1fca632fa894d6f6deca8d6389636ee5b4797affb968"}, - {file = "google_crc32c-1.5.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:98cb4d057f285bd80d8778ebc4fde6b4d509ac3f331758fb1528b733215443ae"}, - {file = "google_crc32c-1.5.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd8536e902db7e365f49e7d9029283403974ccf29b13fc7028b97e2295b33556"}, - {file = "google_crc32c-1.5.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:19e0a019d2c4dcc5e598cd4a4bc7b008546b0358bd322537c74ad47a5386884f"}, - {file = "google_crc32c-1.5.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02c65b9817512edc6a4ae7c7e987fea799d2e0ee40c53ec573a692bee24de876"}, - {file = "google_crc32c-1.5.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:6ac08d24c1f16bd2bf5eca8eaf8304812f44af5cfe5062006ec676e7e1d50afc"}, - {file = "google_crc32c-1.5.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:3359fc442a743e870f4588fcf5dcbc1bf929df1fad8fb9905cd94e5edb02e84c"}, - {file = "google_crc32c-1.5.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:1e986b206dae4476f41bcec1faa057851f3889503a70e1bdb2378d406223994a"}, - {file = "google_crc32c-1.5.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:de06adc872bcd8c2a4e0dc51250e9e65ef2ca91be023b9d13ebd67c2ba552e1e"}, - {file = "google_crc32c-1.5.0-cp37-cp37m-win32.whl", hash = "sha256:d3515f198eaa2f0ed49f8819d5732d70698c3fa37384146079b3799b97667a94"}, - {file = "google_crc32c-1.5.0-cp37-cp37m-win_amd64.whl", hash = "sha256:67b741654b851abafb7bc625b6d1cdd520a379074e64b6a128e3b688c3c04740"}, - {file = "google_crc32c-1.5.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c02ec1c5856179f171e032a31d6f8bf84e5a75c45c33b2e20a3de353b266ebd8"}, - {file = "google_crc32c-1.5.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:edfedb64740750e1a3b16152620220f51d58ff1b4abceb339ca92e934775c27a"}, - {file = "google_crc32c-1.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:84e6e8cd997930fc66d5bb4fde61e2b62ba19d62b7abd7a69920406f9ecca946"}, - {file = "google_crc32c-1.5.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:024894d9d3cfbc5943f8f230e23950cd4906b2fe004c72e29b209420a1e6b05a"}, - {file = "google_crc32c-1.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:998679bf62b7fb599d2878aa3ed06b9ce688b8974893e7223c60db155f26bd8d"}, - {file = "google_crc32c-1.5.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:83c681c526a3439b5cf94f7420471705bbf96262f49a6fe546a6db5f687a3d4a"}, - {file = "google_crc32c-1.5.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:4c6fdd4fccbec90cc8a01fc00773fcd5fa28db683c116ee3cb35cd5da9ef6c37"}, - {file = "google_crc32c-1.5.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5ae44e10a8e3407dbe138984f21e536583f2bba1be9491239f942c2464ac0894"}, - {file = "google_crc32c-1.5.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:37933ec6e693e51a5b07505bd05de57eee12f3e8c32b07da7e73669398e6630a"}, - {file = "google_crc32c-1.5.0-cp38-cp38-win32.whl", hash = "sha256:fe70e325aa68fa4b5edf7d1a4b6f691eb04bbccac0ace68e34820d283b5f80d4"}, - {file = "google_crc32c-1.5.0-cp38-cp38-win_amd64.whl", hash = "sha256:74dea7751d98034887dbd821b7aae3e1d36eda111d6ca36c206c44478035709c"}, - {file = "google_crc32c-1.5.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c6c777a480337ac14f38564ac88ae82d4cd238bf293f0a22295b66eb89ffced7"}, - {file = "google_crc32c-1.5.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:759ce4851a4bb15ecabae28f4d2e18983c244eddd767f560165563bf9aefbc8d"}, - {file = "google_crc32c-1.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f13cae8cc389a440def0c8c52057f37359014ccbc9dc1f0827936bcd367c6100"}, - {file = "google_crc32c-1.5.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e560628513ed34759456a416bf86b54b2476c59144a9138165c9a1575801d0d9"}, - {file = "google_crc32c-1.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1674e4307fa3024fc897ca774e9c7562c957af85df55efe2988ed9056dc4e57"}, - {file = "google_crc32c-1.5.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:278d2ed7c16cfc075c91378c4f47924c0625f5fc84b2d50d921b18b7975bd210"}, - {file = "google_crc32c-1.5.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d5280312b9af0976231f9e317c20e4a61cd2f9629b7bfea6a693d1878a264ebd"}, - {file = "google_crc32c-1.5.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:8b87e1a59c38f275c0e3676fc2ab6d59eccecfd460be267ac360cc31f7bcde96"}, - {file = "google_crc32c-1.5.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:7c074fece789b5034b9b1404a1f8208fc2d4c6ce9decdd16e8220c5a793e6f61"}, - {file = "google_crc32c-1.5.0-cp39-cp39-win32.whl", hash = "sha256:7f57f14606cd1dd0f0de396e1e53824c371e9544a822648cd76c034d209b559c"}, - {file = "google_crc32c-1.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:a2355cba1f4ad8b6988a4ca3feed5bff33f6af2d7f134852cf279c2aebfde541"}, - {file = "google_crc32c-1.5.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f314013e7dcd5cf45ab1945d92e713eec788166262ae8deb2cfacd53def27325"}, - {file = "google_crc32c-1.5.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b747a674c20a67343cb61d43fdd9207ce5da6a99f629c6e2541aa0e89215bcd"}, - {file = "google_crc32c-1.5.0-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f24ed114432de109aa9fd317278518a5af2d31ac2ea6b952b2f7782b43da091"}, - {file = "google_crc32c-1.5.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8667b48e7a7ef66afba2c81e1094ef526388d35b873966d8a9a447974ed9178"}, - {file = "google_crc32c-1.5.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:1c7abdac90433b09bad6c43a43af253e688c9cfc1c86d332aed13f9a7c7f65e2"}, - {file = "google_crc32c-1.5.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:6f998db4e71b645350b9ac28a2167e6632c239963ca9da411523bb439c5c514d"}, - {file = "google_crc32c-1.5.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c99616c853bb585301df6de07ca2cadad344fd1ada6d62bb30aec05219c45d2"}, - {file = "google_crc32c-1.5.0-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2ad40e31093a4af319dadf503b2467ccdc8f67c72e4bcba97f8c10cb078207b5"}, - {file = "google_crc32c-1.5.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd67cf24a553339d5062eff51013780a00d6f97a39ca062781d06b3a73b15462"}, - {file = "google_crc32c-1.5.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:398af5e3ba9cf768787eef45c803ff9614cc3e22a5b2f7d7ae116df8b11e3314"}, - {file = "google_crc32c-1.5.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:b1f8133c9a275df5613a451e73f36c2aea4fe13c5c8997e22cf355ebd7bd0728"}, - {file = "google_crc32c-1.5.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9ba053c5f50430a3fcfd36f75aff9caeba0440b2d076afdb79a318d6ca245f88"}, - {file = "google_crc32c-1.5.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:272d3892a1e1a2dbc39cc5cde96834c236d5327e2122d3aaa19f6614531bb6eb"}, - {file = "google_crc32c-1.5.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:635f5d4dd18758a1fbd1049a8e8d2fee4ffed124462d837d1a02a0e009c3ab31"}, - {file = "google_crc32c-1.5.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c672d99a345849301784604bfeaeba4db0c7aae50b95be04dd651fd2a7310b93"}, + {file = "google_crc32c-1.6.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:5bcc90b34df28a4b38653c36bb5ada35671ad105c99cfe915fb5bed7ad6924aa"}, + {file = "google_crc32c-1.6.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:d9e9913f7bd69e093b81da4535ce27af842e7bf371cde42d1ae9e9bd382dc0e9"}, + {file = "google_crc32c-1.6.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a184243544811e4a50d345838a883733461e67578959ac59964e43cca2c791e7"}, + {file = "google_crc32c-1.6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:236c87a46cdf06384f614e9092b82c05f81bd34b80248021f729396a78e55d7e"}, + {file = "google_crc32c-1.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ebab974b1687509e5c973b5c4b8b146683e101e102e17a86bd196ecaa4d099fc"}, + {file = "google_crc32c-1.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:50cf2a96da226dcbff8671233ecf37bf6e95de98b2a2ebadbfdf455e6d05df42"}, + {file = "google_crc32c-1.6.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:f7a1fc29803712f80879b0806cb83ab24ce62fc8daf0569f2204a0cfd7f68ed4"}, + {file = "google_crc32c-1.6.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:40b05ab32a5067525670880eb5d169529089a26fe35dce8891127aeddc1950e8"}, + {file = "google_crc32c-1.6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9e4b426c3702f3cd23b933436487eb34e01e00327fac20c9aebb68ccf34117d"}, + {file = "google_crc32c-1.6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51c4f54dd8c6dfeb58d1df5e4f7f97df8abf17a36626a217f169893d1d7f3e9f"}, + {file = "google_crc32c-1.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:bb8b3c75bd157010459b15222c3fd30577042a7060e29d42dabce449c087f2b3"}, + {file = "google_crc32c-1.6.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:ed767bf4ba90104c1216b68111613f0d5926fb3780660ea1198fc469af410e9d"}, + {file = "google_crc32c-1.6.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:62f6d4a29fea082ac4a3c9be5e415218255cf11684ac6ef5488eea0c9132689b"}, + {file = "google_crc32c-1.6.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c87d98c7c4a69066fd31701c4e10d178a648c2cac3452e62c6b24dc51f9fcc00"}, + {file = "google_crc32c-1.6.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bd5e7d2445d1a958c266bfa5d04c39932dc54093fa391736dbfdb0f1929c1fb3"}, + {file = "google_crc32c-1.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:7aec8e88a3583515f9e0957fe4f5f6d8d4997e36d0f61624e70469771584c760"}, + {file = "google_crc32c-1.6.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:e2806553238cd076f0a55bddab37a532b53580e699ed8e5606d0de1f856b5205"}, + {file = "google_crc32c-1.6.0-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:bb0966e1c50d0ef5bc743312cc730b533491d60585a9a08f897274e57c3f70e0"}, + {file = "google_crc32c-1.6.0-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:386122eeaaa76951a8196310432c5b0ef3b53590ef4c317ec7588ec554fec5d2"}, + {file = "google_crc32c-1.6.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2952396dc604544ea7476b33fe87faedc24d666fb0c2d5ac971a2b9576ab871"}, + {file = "google_crc32c-1.6.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:35834855408429cecf495cac67ccbab802de269e948e27478b1e47dfb6465e57"}, + {file = "google_crc32c-1.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:d8797406499f28b5ef791f339594b0b5fdedf54e203b5066675c406ba69d705c"}, + {file = "google_crc32c-1.6.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48abd62ca76a2cbe034542ed1b6aee851b6f28aaca4e6551b5599b6f3ef175cc"}, + {file = "google_crc32c-1.6.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18e311c64008f1f1379158158bb3f0c8d72635b9eb4f9545f8cf990c5668e59d"}, + {file = "google_crc32c-1.6.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05e2d8c9a2f853ff116db9706b4a27350587f341eda835f46db3c0a8c8ce2f24"}, + {file = "google_crc32c-1.6.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91ca8145b060679ec9176e6de4f89b07363d6805bd4760631ef254905503598d"}, + {file = "google_crc32c-1.6.0.tar.gz", hash = "sha256:6eceb6ad197656a1ff49ebfbbfa870678c75be4344feb35ac1edf694309413dc"}, ] [package.extras] @@ -3151,16 +3416,16 @@ testing = ["pytest"] [[package]] name = "google-generativeai" -version = "0.5.0" +version = "0.8.1" description = "Google Generative AI High level API client library and tools." optional = false python-versions = ">=3.9" files = [ - {file = "google_generativeai-0.5.0-py3-none-any.whl", hash = "sha256:207ed12c6a2eeab549a45abbf5373c82077f62b16030bdb502556c78f6d1b5d2"}, + {file = "google_generativeai-0.8.1-py3-none-any.whl", hash = "sha256:b031877f24d51af0945207657c085896a0a886eceec7a1cb7029327b0aa6e2f6"}, ] [package.dependencies] -google-ai-generativelanguage = "0.6.1" +google-ai-generativelanguage = "0.6.9" google-api-core = "*" google-api-python-client = "*" google-auth = ">=2.15.0" @@ -3173,18 +3438,33 @@ typing-extensions = "*" dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"] [[package]] -name = "google-resumable-media" -version = "2.7.1" -description = "Utilities for Google Media Downloads and Resumable Uploads" +name = "google-pasta" +version = "0.2.0" +description = "pasta is an AST-based Python refactoring library" optional = false -python-versions = ">=3.7" +python-versions = "*" files = [ - {file = "google-resumable-media-2.7.1.tar.gz", hash = "sha256:eae451a7b2e2cdbaaa0fd2eb00cc8a1ee5e95e16b55597359cbc3d27d7d90e33"}, - {file = "google_resumable_media-2.7.1-py2.py3-none-any.whl", hash = "sha256:103ebc4ba331ab1bfdac0250f8033627a2cd7cde09e7ccff9181e31ba4315b2c"}, + {file = "google-pasta-0.2.0.tar.gz", hash = "sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e"}, + {file = "google_pasta-0.2.0-py2-none-any.whl", hash = "sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954"}, + {file = "google_pasta-0.2.0-py3-none-any.whl", hash = "sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed"}, ] [package.dependencies] -google-crc32c = ">=1.0,<2.0dev" +six = "*" + +[[package]] +name = "google-resumable-media" +version = "2.7.2" +description = "Utilities for Google Media Downloads and Resumable Uploads" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google_resumable_media-2.7.2-py2.py3-none-any.whl", hash = "sha256:3ce7551e9fe6d99e9a126101d2536612bb73486721951e9562fee0f90c6ababa"}, + {file = "google_resumable_media-2.7.2.tar.gz", hash = "sha256:5280aed4629f2b60b847b0d42f9857fd4935c11af266744df33d8074cae92fe0"}, +] + +[package.dependencies] +google-crc32c = ">=1.0,<2.0dev" [package.extras] aiohttp = ["aiohttp (>=3.6.2,<4.0.0dev)", "google-auth (>=1.22.0,<2.0dev)"] @@ -3208,71 +3488,101 @@ protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4 [package.extras] grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] +[[package]] +name = "gotrue" +version = "2.9.3" +description = "Python Client Library for Supabase Auth" +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "gotrue-2.9.3-py3-none-any.whl", hash = "sha256:9d2e9c74405d879f4828e0a7b94daf167a6e109c10ae6e5c59a0e21446f6e423"}, + {file = "gotrue-2.9.3.tar.gz", hash = "sha256:051551d80e642bdd2ab42cac78207745d89a2a08f429a1512d82624e675d8255"}, +] + +[package.dependencies] +httpx = {version = ">=0.26,<0.28", extras = ["http2"]} +pydantic = ">=1.10,<3" + [[package]] name = "greenlet" -version = "3.0.3" +version = "3.1.1" description = "Lightweight in-process concurrent programming" optional = false python-versions = ">=3.7" files = [ - {file = "greenlet-3.0.3-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:9da2bd29ed9e4f15955dd1595ad7bc9320308a3b766ef7f837e23ad4b4aac31a"}, - {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d353cadd6083fdb056bb46ed07e4340b0869c305c8ca54ef9da3421acbdf6881"}, - {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dca1e2f3ca00b84a396bc1bce13dd21f680f035314d2379c4160c98153b2059b"}, - {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3ed7fb269f15dc662787f4119ec300ad0702fa1b19d2135a37c2c4de6fadfd4a"}, - {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd4f49ae60e10adbc94b45c0b5e6a179acc1736cf7a90160b404076ee283cf83"}, - {file = "greenlet-3.0.3-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:73a411ef564e0e097dbe7e866bb2dda0f027e072b04da387282b02c308807405"}, - {file = "greenlet-3.0.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7f362975f2d179f9e26928c5b517524e89dd48530a0202570d55ad6ca5d8a56f"}, - {file = "greenlet-3.0.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:649dde7de1a5eceb258f9cb00bdf50e978c9db1b996964cd80703614c86495eb"}, - {file = "greenlet-3.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:68834da854554926fbedd38c76e60c4a2e3198c6fbed520b106a8986445caaf9"}, - {file = "greenlet-3.0.3-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:b1b5667cced97081bf57b8fa1d6bfca67814b0afd38208d52538316e9422fc61"}, - {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:52f59dd9c96ad2fc0d5724107444f76eb20aaccb675bf825df6435acb7703559"}, - {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:afaff6cf5200befd5cec055b07d1c0a5a06c040fe5ad148abcd11ba6ab9b114e"}, - {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fe754d231288e1e64323cfad462fcee8f0288654c10bdf4f603a39ed923bef33"}, - {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2797aa5aedac23af156bbb5a6aa2cd3427ada2972c828244eb7d1b9255846379"}, - {file = "greenlet-3.0.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b7f009caad047246ed379e1c4dbcb8b020f0a390667ea74d2387be2998f58a22"}, - {file = "greenlet-3.0.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c5e1536de2aad7bf62e27baf79225d0d64360d4168cf2e6becb91baf1ed074f3"}, - {file = "greenlet-3.0.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:894393ce10ceac937e56ec00bb71c4c2f8209ad516e96033e4b3b1de270e200d"}, - {file = "greenlet-3.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:1ea188d4f49089fc6fb283845ab18a2518d279c7cd9da1065d7a84e991748728"}, - {file = "greenlet-3.0.3-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:70fb482fdf2c707765ab5f0b6655e9cfcf3780d8d87355a063547b41177599be"}, - {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4d1ac74f5c0c0524e4a24335350edad7e5f03b9532da7ea4d3c54d527784f2e"}, - {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:149e94a2dd82d19838fe4b2259f1b6b9957d5ba1b25640d2380bea9c5df37676"}, - {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:15d79dd26056573940fcb8c7413d84118086f2ec1a8acdfa854631084393efcc"}, - {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:881b7db1ebff4ba09aaaeae6aa491daeb226c8150fc20e836ad00041bcb11230"}, - {file = "greenlet-3.0.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fcd2469d6a2cf298f198f0487e0a5b1a47a42ca0fa4dfd1b6862c999f018ebbf"}, - {file = "greenlet-3.0.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:1f672519db1796ca0d8753f9e78ec02355e862d0998193038c7073045899f305"}, - {file = "greenlet-3.0.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2516a9957eed41dd8f1ec0c604f1cdc86758b587d964668b5b196a9db5bfcde6"}, - {file = "greenlet-3.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:bba5387a6975598857d86de9eac14210a49d554a77eb8261cc68b7d082f78ce2"}, - {file = "greenlet-3.0.3-cp37-cp37m-macosx_11_0_universal2.whl", hash = "sha256:5b51e85cb5ceda94e79d019ed36b35386e8c37d22f07d6a751cb659b180d5274"}, - {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:daf3cb43b7cf2ba96d614252ce1684c1bccee6b2183a01328c98d36fcd7d5cb0"}, - {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:99bf650dc5d69546e076f413a87481ee1d2d09aaaaaca058c9251b6d8c14783f"}, - {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2dd6e660effd852586b6a8478a1d244b8dc90ab5b1321751d2ea15deb49ed414"}, - {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3391d1e16e2a5a1507d83e4a8b100f4ee626e8eca43cf2cadb543de69827c4c"}, - {file = "greenlet-3.0.3-cp37-cp37m-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e1f145462f1fa6e4a4ae3c0f782e580ce44d57c8f2c7aae1b6fa88c0b2efdb41"}, - {file = "greenlet-3.0.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:1a7191e42732df52cb5f39d3527217e7ab73cae2cb3694d241e18f53d84ea9a7"}, - {file = "greenlet-3.0.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:0448abc479fab28b00cb472d278828b3ccca164531daab4e970a0458786055d6"}, - {file = "greenlet-3.0.3-cp37-cp37m-win32.whl", hash = "sha256:b542be2440edc2d48547b5923c408cbe0fc94afb9f18741faa6ae970dbcb9b6d"}, - {file = "greenlet-3.0.3-cp37-cp37m-win_amd64.whl", hash = "sha256:01bc7ea167cf943b4c802068e178bbf70ae2e8c080467070d01bfa02f337ee67"}, - {file = "greenlet-3.0.3-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:1996cb9306c8595335bb157d133daf5cf9f693ef413e7673cb07e3e5871379ca"}, - {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ddc0f794e6ad661e321caa8d2f0a55ce01213c74722587256fb6566049a8b04"}, - {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c9db1c18f0eaad2f804728c67d6c610778456e3e1cc4ab4bbd5eeb8e6053c6fc"}, - {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7170375bcc99f1a2fbd9c306f5be8764eaf3ac6b5cb968862cad4c7057756506"}, - {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b66c9c1e7ccabad3a7d037b2bcb740122a7b17a53734b7d72a344ce39882a1b"}, - {file = "greenlet-3.0.3-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:098d86f528c855ead3479afe84b49242e174ed262456c342d70fc7f972bc13c4"}, - {file = "greenlet-3.0.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:81bb9c6d52e8321f09c3d165b2a78c680506d9af285bfccbad9fb7ad5a5da3e5"}, - {file = "greenlet-3.0.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fd096eb7ffef17c456cfa587523c5f92321ae02427ff955bebe9e3c63bc9f0da"}, - {file = "greenlet-3.0.3-cp38-cp38-win32.whl", hash = "sha256:d46677c85c5ba00a9cb6f7a00b2bfa6f812192d2c9f7d9c4f6a55b60216712f3"}, - {file = "greenlet-3.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:419b386f84949bf0e7c73e6032e3457b82a787c1ab4a0e43732898a761cc9dbf"}, - {file = "greenlet-3.0.3-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:da70d4d51c8b306bb7a031d5cff6cc25ad253affe89b70352af5f1cb68e74b53"}, - {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:086152f8fbc5955df88382e8a75984e2bb1c892ad2e3c80a2508954e52295257"}, - {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d73a9fe764d77f87f8ec26a0c85144d6a951a6c438dfe50487df5595c6373eac"}, - {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b7dcbe92cc99f08c8dd11f930de4d99ef756c3591a5377d1d9cd7dd5e896da71"}, - {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1551a8195c0d4a68fac7a4325efac0d541b48def35feb49d803674ac32582f61"}, - {file = "greenlet-3.0.3-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:64d7675ad83578e3fc149b617a444fab8efdafc9385471f868eb5ff83e446b8b"}, - {file = "greenlet-3.0.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b37eef18ea55f2ffd8f00ff8fe7c8d3818abd3e25fb73fae2ca3b672e333a7a6"}, - {file = "greenlet-3.0.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:77457465d89b8263bca14759d7c1684df840b6811b2499838cc5b040a8b5b113"}, - {file = "greenlet-3.0.3-cp39-cp39-win32.whl", hash = "sha256:57e8974f23e47dac22b83436bdcf23080ade568ce77df33159e019d161ce1d1e"}, - {file = "greenlet-3.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:c5ee858cfe08f34712f548c3c363e807e7186f03ad7a5039ebadb29e8c6be067"}, - {file = "greenlet-3.0.3.tar.gz", hash = "sha256:43374442353259554ce33599da8b692d5aa96f8976d567d4badf263371fbe491"}, + {file = "greenlet-3.1.1-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:0bbae94a29c9e5c7e4a2b7f0aae5c17e8e90acbfd3bf6270eeba60c39fce3563"}, + {file = "greenlet-3.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0fde093fb93f35ca72a556cf72c92ea3ebfda3d79fc35bb19fbe685853869a83"}, + {file = "greenlet-3.1.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:36b89d13c49216cadb828db8dfa6ce86bbbc476a82d3a6c397f0efae0525bdd0"}, + {file = "greenlet-3.1.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:94b6150a85e1b33b40b1464a3f9988dcc5251d6ed06842abff82e42632fac120"}, + {file = "greenlet-3.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93147c513fac16385d1036b7e5b102c7fbbdb163d556b791f0f11eada7ba65dc"}, + {file = "greenlet-3.1.1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:da7a9bff22ce038e19bf62c4dd1ec8391062878710ded0a845bcf47cc0200617"}, + {file = "greenlet-3.1.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:b2795058c23988728eec1f36a4e5e4ebad22f8320c85f3587b539b9ac84128d7"}, + {file = "greenlet-3.1.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:ed10eac5830befbdd0c32f83e8aa6288361597550ba669b04c48f0f9a2c843c6"}, + {file = "greenlet-3.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:77c386de38a60d1dfb8e55b8c1101d68c79dfdd25c7095d51fec2dd800892b80"}, + {file = "greenlet-3.1.1-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:e4d333e558953648ca09d64f13e6d8f0523fa705f51cae3f03b5983489958c70"}, + {file = "greenlet-3.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09fc016b73c94e98e29af67ab7b9a879c307c6731a2c9da0db5a7d9b7edd1159"}, + {file = "greenlet-3.1.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d5e975ca70269d66d17dd995dafc06f1b06e8cb1ec1e9ed54c1d1e4a7c4cf26e"}, + {file = "greenlet-3.1.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3b2813dc3de8c1ee3f924e4d4227999285fd335d1bcc0d2be6dc3f1f6a318ec1"}, + {file = "greenlet-3.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e347b3bfcf985a05e8c0b7d462ba6f15b1ee1c909e2dcad795e49e91b152c383"}, + {file = "greenlet-3.1.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9e8f8c9cb53cdac7ba9793c276acd90168f416b9ce36799b9b885790f8ad6c0a"}, + {file = "greenlet-3.1.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:62ee94988d6b4722ce0028644418d93a52429e977d742ca2ccbe1c4f4a792511"}, + {file = "greenlet-3.1.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1776fd7f989fc6b8d8c8cb8da1f6b82c5814957264d1f6cf818d475ec2bf6395"}, + {file = "greenlet-3.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:48ca08c771c268a768087b408658e216133aecd835c0ded47ce955381105ba39"}, + {file = "greenlet-3.1.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:4afe7ea89de619adc868e087b4d2359282058479d7cfb94970adf4b55284574d"}, + {file = "greenlet-3.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f406b22b7c9a9b4f8aa9d2ab13d6ae0ac3e85c9a809bd590ad53fed2bf70dc79"}, + {file = "greenlet-3.1.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c3a701fe5a9695b238503ce5bbe8218e03c3bcccf7e204e455e7462d770268aa"}, + {file = "greenlet-3.1.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2846930c65b47d70b9d178e89c7e1a69c95c1f68ea5aa0a58646b7a96df12441"}, + {file = "greenlet-3.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99cfaa2110534e2cf3ba31a7abcac9d328d1d9f1b95beede58294a60348fba36"}, + {file = "greenlet-3.1.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1443279c19fca463fc33e65ef2a935a5b09bb90f978beab37729e1c3c6c25fe9"}, + {file = "greenlet-3.1.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b7cede291382a78f7bb5f04a529cb18e068dd29e0fb27376074b6d0317bf4dd0"}, + {file = "greenlet-3.1.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:23f20bb60ae298d7d8656c6ec6db134bca379ecefadb0b19ce6f19d1f232a942"}, + {file = "greenlet-3.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:7124e16b4c55d417577c2077be379514321916d5790fa287c9ed6f23bd2ffd01"}, + {file = "greenlet-3.1.1-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:05175c27cb459dcfc05d026c4232f9de8913ed006d42713cb8a5137bd49375f1"}, + {file = "greenlet-3.1.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:935e943ec47c4afab8965954bf49bfa639c05d4ccf9ef6e924188f762145c0ff"}, + {file = "greenlet-3.1.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:667a9706c970cb552ede35aee17339a18e8f2a87a51fba2ed39ceeeb1004798a"}, + {file = "greenlet-3.1.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b8a678974d1f3aa55f6cc34dc480169d58f2e6d8958895d68845fa4ab566509e"}, + {file = "greenlet-3.1.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:efc0f674aa41b92da8c49e0346318c6075d734994c3c4e4430b1c3f853e498e4"}, + {file = "greenlet-3.1.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0153404a4bb921f0ff1abeb5ce8a5131da56b953eda6e14b88dc6bbc04d2049e"}, + {file = "greenlet-3.1.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:275f72decf9932639c1c6dd1013a1bc266438eb32710016a1c742df5da6e60a1"}, + {file = "greenlet-3.1.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:c4aab7f6381f38a4b42f269057aee279ab0fc7bf2e929e3d4abfae97b682a12c"}, + {file = "greenlet-3.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:b42703b1cf69f2aa1df7d1030b9d77d3e584a70755674d60e710f0af570f3761"}, + {file = "greenlet-3.1.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1695e76146579f8c06c1509c7ce4dfe0706f49c6831a817ac04eebb2fd02011"}, + {file = "greenlet-3.1.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7876452af029456b3f3549b696bb36a06db7c90747740c5302f74a9e9fa14b13"}, + {file = "greenlet-3.1.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4ead44c85f8ab905852d3de8d86f6f8baf77109f9da589cb4fa142bd3b57b475"}, + {file = "greenlet-3.1.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8320f64b777d00dd7ccdade271eaf0cad6636343293a25074cc5566160e4de7b"}, + {file = "greenlet-3.1.1-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6510bf84a6b643dabba74d3049ead221257603a253d0a9873f55f6a59a65f822"}, + {file = "greenlet-3.1.1-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:04b013dc07c96f83134b1e99888e7a79979f1a247e2a9f59697fa14b5862ed01"}, + {file = "greenlet-3.1.1-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:411f015496fec93c1c8cd4e5238da364e1da7a124bcb293f085bf2860c32c6f6"}, + {file = "greenlet-3.1.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:47da355d8687fd65240c364c90a31569a133b7b60de111c255ef5b606f2ae291"}, + {file = "greenlet-3.1.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:98884ecf2ffb7d7fe6bd517e8eb99d31ff7855a840fa6d0d63cd07c037f6a981"}, + {file = "greenlet-3.1.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f1d4aeb8891338e60d1ab6127af1fe45def5259def8094b9c7e34690c8858803"}, + {file = "greenlet-3.1.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db32b5348615a04b82240cc67983cb315309e88d444a288934ee6ceaebcad6cc"}, + {file = "greenlet-3.1.1-cp37-cp37m-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dcc62f31eae24de7f8dce72134c8651c58000d3b1868e01392baea7c32c247de"}, + {file = "greenlet-3.1.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:1d3755bcb2e02de341c55b4fca7a745a24a9e7212ac953f6b3a48d117d7257aa"}, + {file = "greenlet-3.1.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:b8da394b34370874b4572676f36acabac172602abf054cbc4ac910219f3340af"}, + {file = "greenlet-3.1.1-cp37-cp37m-win32.whl", hash = "sha256:a0dfc6c143b519113354e780a50381508139b07d2177cb6ad6a08278ec655798"}, + {file = "greenlet-3.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:54558ea205654b50c438029505def3834e80f0869a70fb15b871c29b4575ddef"}, + {file = "greenlet-3.1.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:346bed03fe47414091be4ad44786d1bd8bef0c3fcad6ed3dee074a032ab408a9"}, + {file = "greenlet-3.1.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dfc59d69fc48664bc693842bd57acfdd490acafda1ab52c7836e3fc75c90a111"}, + {file = "greenlet-3.1.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d21e10da6ec19b457b82636209cbe2331ff4306b54d06fa04b7c138ba18c8a81"}, + {file = "greenlet-3.1.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:37b9de5a96111fc15418819ab4c4432e4f3c2ede61e660b1e33971eba26ef9ba"}, + {file = "greenlet-3.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ef9ea3f137e5711f0dbe5f9263e8c009b7069d8a1acea822bd5e9dae0ae49c8"}, + {file = "greenlet-3.1.1-cp38-cp38-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:85f3ff71e2e60bd4b4932a043fbbe0f499e263c628390b285cb599154a3b03b1"}, + {file = "greenlet-3.1.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:95ffcf719966dd7c453f908e208e14cde192e09fde6c7186c8f1896ef778d8cd"}, + {file = "greenlet-3.1.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:03a088b9de532cbfe2ba2034b2b85e82df37874681e8c470d6fb2f8c04d7e4b7"}, + {file = "greenlet-3.1.1-cp38-cp38-win32.whl", hash = "sha256:8b8b36671f10ba80e159378df9c4f15c14098c4fd73a36b9ad715f057272fbef"}, + {file = "greenlet-3.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:7017b2be767b9d43cc31416aba48aab0d2309ee31b4dbf10a1d38fb7972bdf9d"}, + {file = "greenlet-3.1.1-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:396979749bd95f018296af156201d6211240e7a23090f50a8d5d18c370084dc3"}, + {file = "greenlet-3.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca9d0ff5ad43e785350894d97e13633a66e2b50000e8a183a50a88d834752d42"}, + {file = "greenlet-3.1.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f6ff3b14f2df4c41660a7dec01045a045653998784bf8cfcb5a525bdffffbc8f"}, + {file = "greenlet-3.1.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:94ebba31df2aa506d7b14866fed00ac141a867e63143fe5bca82a8e503b36437"}, + {file = "greenlet-3.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:73aaad12ac0ff500f62cebed98d8789198ea0e6f233421059fa68a5aa7220145"}, + {file = "greenlet-3.1.1-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:63e4844797b975b9af3a3fb8f7866ff08775f5426925e1e0bbcfe7932059a12c"}, + {file = "greenlet-3.1.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7939aa3ca7d2a1593596e7ac6d59391ff30281ef280d8632fa03d81f7c5f955e"}, + {file = "greenlet-3.1.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d0028e725ee18175c6e422797c407874da24381ce0690d6b9396c204c7f7276e"}, + {file = "greenlet-3.1.1-cp39-cp39-win32.whl", hash = "sha256:5e06afd14cbaf9e00899fae69b24a32f2196c19de08fcb9f4779dd4f004e5e7c"}, + {file = "greenlet-3.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:3319aa75e0e0639bc15ff54ca327e8dc7a6fe404003496e3c6925cd3142e0e22"}, + {file = "greenlet-3.1.1.tar.gz", hash = "sha256:4ce3ac6cdb6adf7946475d7ef31777c26d94bccc377e070a7986bd2d5c515467"}, ] [package.extras] @@ -3297,61 +3607,70 @@ protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4 [[package]] name = "grpcio" -version = "1.63.0" +version = "1.67.1" description = "HTTP/2-based RPC framework" optional = false python-versions = ">=3.8" files = [ - {file = "grpcio-1.63.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:2e93aca840c29d4ab5db93f94ed0a0ca899e241f2e8aec6334ab3575dc46125c"}, - {file = "grpcio-1.63.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:91b73d3f1340fefa1e1716c8c1ec9930c676d6b10a3513ab6c26004cb02d8b3f"}, - {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:b3afbd9d6827fa6f475a4f91db55e441113f6d3eb9b7ebb8fb806e5bb6d6bd0d"}, - {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f3f6883ce54a7a5f47db43289a0a4c776487912de1a0e2cc83fdaec9685cc9f"}, - {file = "grpcio-1.63.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf8dae9cc0412cb86c8de5a8f3be395c5119a370f3ce2e69c8b7d46bb9872c8d"}, - {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:08e1559fd3b3b4468486b26b0af64a3904a8dbc78d8d936af9c1cf9636eb3e8b"}, - {file = "grpcio-1.63.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5c039ef01516039fa39da8a8a43a95b64e288f79f42a17e6c2904a02a319b357"}, - {file = "grpcio-1.63.0-cp310-cp310-win32.whl", hash = "sha256:ad2ac8903b2eae071055a927ef74121ed52d69468e91d9bcbd028bd0e554be6d"}, - {file = "grpcio-1.63.0-cp310-cp310-win_amd64.whl", hash = "sha256:b2e44f59316716532a993ca2966636df6fbe7be4ab6f099de6815570ebe4383a"}, - {file = "grpcio-1.63.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:f28f8b2db7b86c77916829d64ab21ff49a9d8289ea1564a2b2a3a8ed9ffcccd3"}, - {file = "grpcio-1.63.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:65bf975639a1f93bee63ca60d2e4951f1b543f498d581869922910a476ead2f5"}, - {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:b5194775fec7dc3dbd6a935102bb156cd2c35efe1685b0a46c67b927c74f0cfb"}, - {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4cbb2100ee46d024c45920d16e888ee5d3cf47c66e316210bc236d5bebc42b3"}, - {file = "grpcio-1.63.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ff737cf29b5b801619f10e59b581869e32f400159e8b12d7a97e7e3bdeee6a2"}, - {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cd1e68776262dd44dedd7381b1a0ad09d9930ffb405f737d64f505eb7f77d6c7"}, - {file = "grpcio-1.63.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:93f45f27f516548e23e4ec3fbab21b060416007dbe768a111fc4611464cc773f"}, - {file = "grpcio-1.63.0-cp311-cp311-win32.whl", hash = "sha256:878b1d88d0137df60e6b09b74cdb73db123f9579232c8456f53e9abc4f62eb3c"}, - {file = "grpcio-1.63.0-cp311-cp311-win_amd64.whl", hash = "sha256:756fed02dacd24e8f488f295a913f250b56b98fb793f41d5b2de6c44fb762434"}, - {file = "grpcio-1.63.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:93a46794cc96c3a674cdfb59ef9ce84d46185fe9421baf2268ccb556f8f81f57"}, - {file = "grpcio-1.63.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:a7b19dfc74d0be7032ca1eda0ed545e582ee46cd65c162f9e9fc6b26ef827dc6"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8064d986d3a64ba21e498b9a376cbc5d6ab2e8ab0e288d39f266f0fca169b90d"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:219bb1848cd2c90348c79ed0a6b0ea51866bc7e72fa6e205e459fedab5770172"}, - {file = "grpcio-1.63.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2d60cd1d58817bc5985fae6168d8b5655c4981d448d0f5b6194bbcc038090d2"}, - {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:9e350cb096e5c67832e9b6e018cf8a0d2a53b2a958f6251615173165269a91b0"}, - {file = "grpcio-1.63.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:56cdf96ff82e3cc90dbe8bac260352993f23e8e256e063c327b6cf9c88daf7a9"}, - {file = "grpcio-1.63.0-cp312-cp312-win32.whl", hash = "sha256:3a6d1f9ea965e750db7b4ee6f9fdef5fdf135abe8a249e75d84b0a3e0c668a1b"}, - {file = "grpcio-1.63.0-cp312-cp312-win_amd64.whl", hash = "sha256:d2497769895bb03efe3187fb1888fc20e98a5f18b3d14b606167dacda5789434"}, - {file = "grpcio-1.63.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:fdf348ae69c6ff484402cfdb14e18c1b0054ac2420079d575c53a60b9b2853ae"}, - {file = "grpcio-1.63.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a3abfe0b0f6798dedd2e9e92e881d9acd0fdb62ae27dcbbfa7654a57e24060c0"}, - {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:6ef0ad92873672a2a3767cb827b64741c363ebaa27e7f21659e4e31f4d750280"}, - {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b416252ac5588d9dfb8a30a191451adbf534e9ce5f56bb02cd193f12d8845b7f"}, - {file = "grpcio-1.63.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3b77eaefc74d7eb861d3ffbdf91b50a1bb1639514ebe764c47773b833fa2d91"}, - {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b005292369d9c1f80bf70c1db1c17c6c342da7576f1c689e8eee4fb0c256af85"}, - {file = "grpcio-1.63.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cdcda1156dcc41e042d1e899ba1f5c2e9f3cd7625b3d6ebfa619806a4c1aadda"}, - {file = "grpcio-1.63.0-cp38-cp38-win32.whl", hash = "sha256:01799e8649f9e94ba7db1aeb3452188048b0019dc37696b0f5ce212c87c560c3"}, - {file = "grpcio-1.63.0-cp38-cp38-win_amd64.whl", hash = "sha256:6a1a3642d76f887aa4009d92f71eb37809abceb3b7b5a1eec9c554a246f20e3a"}, - {file = "grpcio-1.63.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:75f701ff645858a2b16bc8c9fc68af215a8bb2d5a9b647448129de6e85d52bce"}, - {file = "grpcio-1.63.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cacdef0348a08e475a721967f48206a2254a1b26ee7637638d9e081761a5ba86"}, - {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:0697563d1d84d6985e40ec5ec596ff41b52abb3fd91ec240e8cb44a63b895094"}, - {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6426e1fb92d006e47476d42b8f240c1d916a6d4423c5258ccc5b105e43438f61"}, - {file = "grpcio-1.63.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e48cee31bc5f5a31fb2f3b573764bd563aaa5472342860edcc7039525b53e46a"}, - {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:50344663068041b34a992c19c600236e7abb42d6ec32567916b87b4c8b8833b3"}, - {file = "grpcio-1.63.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:259e11932230d70ef24a21b9fb5bb947eb4703f57865a404054400ee92f42f5d"}, - {file = "grpcio-1.63.0-cp39-cp39-win32.whl", hash = "sha256:a44624aad77bf8ca198c55af811fd28f2b3eaf0a50ec5b57b06c034416ef2d0a"}, - {file = "grpcio-1.63.0-cp39-cp39-win_amd64.whl", hash = "sha256:166e5c460e5d7d4656ff9e63b13e1f6029b122104c1633d5f37eaea348d7356d"}, - {file = "grpcio-1.63.0.tar.gz", hash = "sha256:f3023e14805c61bc439fb40ca545ac3d5740ce66120a678a3c6c2c55b70343d1"}, -] - -[package.extras] -protobuf = ["grpcio-tools (>=1.63.0)"] + {file = "grpcio-1.67.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:8b0341d66a57f8a3119b77ab32207072be60c9bf79760fa609c5609f2deb1f3f"}, + {file = "grpcio-1.67.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:f5a27dddefe0e2357d3e617b9079b4bfdc91341a91565111a21ed6ebbc51b22d"}, + {file = "grpcio-1.67.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:43112046864317498a33bdc4797ae6a268c36345a910de9b9c17159d8346602f"}, + {file = "grpcio-1.67.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c9b929f13677b10f63124c1a410994a401cdd85214ad83ab67cc077fc7e480f0"}, + {file = "grpcio-1.67.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7d1797a8a3845437d327145959a2c0c47c05947c9eef5ff1a4c80e499dcc6fa"}, + {file = "grpcio-1.67.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:0489063974d1452436139501bf6b180f63d4977223ee87488fe36858c5725292"}, + {file = "grpcio-1.67.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9fd042de4a82e3e7aca44008ee2fb5da01b3e5adb316348c21980f7f58adc311"}, + {file = "grpcio-1.67.1-cp310-cp310-win32.whl", hash = "sha256:638354e698fd0c6c76b04540a850bf1db27b4d2515a19fcd5cf645c48d3eb1ed"}, + {file = "grpcio-1.67.1-cp310-cp310-win_amd64.whl", hash = "sha256:608d87d1bdabf9e2868b12338cd38a79969eaf920c89d698ead08f48de9c0f9e"}, + {file = "grpcio-1.67.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:7818c0454027ae3384235a65210bbf5464bd715450e30a3d40385453a85a70cb"}, + {file = "grpcio-1.67.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ea33986b70f83844cd00814cee4451055cd8cab36f00ac64a31f5bb09b31919e"}, + {file = "grpcio-1.67.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:c7a01337407dd89005527623a4a72c5c8e2894d22bead0895306b23c6695698f"}, + {file = "grpcio-1.67.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:80b866f73224b0634f4312a4674c1be21b2b4afa73cb20953cbbb73a6b36c3cc"}, + {file = "grpcio-1.67.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9fff78ba10d4250bfc07a01bd6254a6d87dc67f9627adece85c0b2ed754fa96"}, + {file = "grpcio-1.67.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:8a23cbcc5bb11ea7dc6163078be36c065db68d915c24f5faa4f872c573bb400f"}, + {file = "grpcio-1.67.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1a65b503d008f066e994f34f456e0647e5ceb34cfcec5ad180b1b44020ad4970"}, + {file = "grpcio-1.67.1-cp311-cp311-win32.whl", hash = "sha256:e29ca27bec8e163dca0c98084040edec3bc49afd10f18b412f483cc68c712744"}, + {file = "grpcio-1.67.1-cp311-cp311-win_amd64.whl", hash = "sha256:786a5b18544622bfb1e25cc08402bd44ea83edfb04b93798d85dca4d1a0b5be5"}, + {file = "grpcio-1.67.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:267d1745894200e4c604958da5f856da6293f063327cb049a51fe67348e4f953"}, + {file = "grpcio-1.67.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:85f69fdc1d28ce7cff8de3f9c67db2b0ca9ba4449644488c1e0303c146135ddb"}, + {file = "grpcio-1.67.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:f26b0b547eb8d00e195274cdfc63ce64c8fc2d3e2d00b12bf468ece41a0423a0"}, + {file = "grpcio-1.67.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4422581cdc628f77302270ff839a44f4c24fdc57887dc2a45b7e53d8fc2376af"}, + {file = "grpcio-1.67.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d7616d2ded471231c701489190379e0c311ee0a6c756f3c03e6a62b95a7146e"}, + {file = "grpcio-1.67.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8a00efecde9d6fcc3ab00c13f816313c040a28450e5e25739c24f432fc6d3c75"}, + {file = "grpcio-1.67.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:699e964923b70f3101393710793289e42845791ea07565654ada0969522d0a38"}, + {file = "grpcio-1.67.1-cp312-cp312-win32.whl", hash = "sha256:4e7b904484a634a0fff132958dabdb10d63e0927398273917da3ee103e8d1f78"}, + {file = "grpcio-1.67.1-cp312-cp312-win_amd64.whl", hash = "sha256:5721e66a594a6c4204458004852719b38f3d5522082be9061d6510b455c90afc"}, + {file = "grpcio-1.67.1-cp313-cp313-linux_armv7l.whl", hash = "sha256:aa0162e56fd10a5547fac8774c4899fc3e18c1aa4a4759d0ce2cd00d3696ea6b"}, + {file = "grpcio-1.67.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:beee96c8c0b1a75d556fe57b92b58b4347c77a65781ee2ac749d550f2a365dc1"}, + {file = "grpcio-1.67.1-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:a93deda571a1bf94ec1f6fcda2872dad3ae538700d94dc283c672a3b508ba3af"}, + {file = "grpcio-1.67.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0e6f255980afef598a9e64a24efce87b625e3e3c80a45162d111a461a9f92955"}, + {file = "grpcio-1.67.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e838cad2176ebd5d4a8bb03955138d6589ce9e2ce5d51c3ada34396dbd2dba8"}, + {file = "grpcio-1.67.1-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:a6703916c43b1d468d0756c8077b12017a9fcb6a1ef13faf49e67d20d7ebda62"}, + {file = "grpcio-1.67.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:917e8d8994eed1d86b907ba2a61b9f0aef27a2155bca6cbb322430fc7135b7bb"}, + {file = "grpcio-1.67.1-cp313-cp313-win32.whl", hash = "sha256:e279330bef1744040db8fc432becc8a727b84f456ab62b744d3fdb83f327e121"}, + {file = "grpcio-1.67.1-cp313-cp313-win_amd64.whl", hash = "sha256:fa0c739ad8b1996bd24823950e3cb5152ae91fca1c09cc791190bf1627ffefba"}, + {file = "grpcio-1.67.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:178f5db771c4f9a9facb2ab37a434c46cb9be1a75e820f187ee3d1e7805c4f65"}, + {file = "grpcio-1.67.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:0f3e49c738396e93b7ba9016e153eb09e0778e776df6090c1b8c91877cc1c426"}, + {file = "grpcio-1.67.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:24e8a26dbfc5274d7474c27759b54486b8de23c709d76695237515bc8b5baeab"}, + {file = "grpcio-1.67.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b6c16489326d79ead41689c4b84bc40d522c9a7617219f4ad94bc7f448c5085"}, + {file = "grpcio-1.67.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:60e6a4dcf5af7bbc36fd9f81c9f372e8ae580870a9e4b6eafe948cd334b81cf3"}, + {file = "grpcio-1.67.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:95b5f2b857856ed78d72da93cd7d09b6db8ef30102e5e7fe0961fe4d9f7d48e8"}, + {file = "grpcio-1.67.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b49359977c6ec9f5d0573ea4e0071ad278ef905aa74e420acc73fd28ce39e9ce"}, + {file = "grpcio-1.67.1-cp38-cp38-win32.whl", hash = "sha256:f5b76ff64aaac53fede0cc93abf57894ab2a7362986ba22243d06218b93efe46"}, + {file = "grpcio-1.67.1-cp38-cp38-win_amd64.whl", hash = "sha256:804c6457c3cd3ec04fe6006c739579b8d35c86ae3298ffca8de57b493524b771"}, + {file = "grpcio-1.67.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:a25bdea92b13ff4d7790962190bf6bf5c4639876e01c0f3dda70fc2769616335"}, + {file = "grpcio-1.67.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cdc491ae35a13535fd9196acb5afe1af37c8237df2e54427be3eecda3653127e"}, + {file = "grpcio-1.67.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:85f862069b86a305497e74d0dc43c02de3d1d184fc2c180993aa8aa86fbd19b8"}, + {file = "grpcio-1.67.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ec74ef02010186185de82cc594058a3ccd8d86821842bbac9873fd4a2cf8be8d"}, + {file = "grpcio-1.67.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01f616a964e540638af5130469451cf580ba8c7329f45ca998ab66e0c7dcdb04"}, + {file = "grpcio-1.67.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:299b3d8c4f790c6bcca485f9963b4846dd92cf6f1b65d3697145d005c80f9fe8"}, + {file = "grpcio-1.67.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:60336bff760fbb47d7e86165408126f1dded184448e9a4c892189eb7c9d3f90f"}, + {file = "grpcio-1.67.1-cp39-cp39-win32.whl", hash = "sha256:5ed601c4c6008429e3d247ddb367fe8c7259c355757448d7c1ef7bd4a6739e8e"}, + {file = "grpcio-1.67.1-cp39-cp39-win_amd64.whl", hash = "sha256:5db70d32d6703b89912af16d6d45d78406374a8b8ef0d28140351dd0ec610e98"}, + {file = "grpcio-1.67.1.tar.gz", hash = "sha256:3dc2ed4cabea4dc14d5e708c2b426205956077cc5de419b4d4079315017e9732"}, +] + +[package.extras] +protobuf = ["grpcio-tools (>=1.67.1)"] [[package]] name = "grpcio-status" @@ -3615,13 +3934,13 @@ lxml = ["lxml"] [[package]] name = "httpcore" -version = "1.0.5" +version = "1.0.6" description = "A minimal low-level HTTP client." optional = false python-versions = ">=3.8" files = [ - {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"}, - {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"}, + {file = "httpcore-1.0.6-py3-none-any.whl", hash = "sha256:27b59625743b85577a8c0e10e55b50b5368a4f2cfe8cc7bcfa9cf00829c2682f"}, + {file = "httpcore-1.0.6.tar.gz", hash = "sha256:73f6dbd6eb8c21bbf7ef8efad555481853f5f6acdeaff1edb0694289269ee17f"}, ] [package.dependencies] @@ -3632,7 +3951,7 @@ h11 = ">=0.13,<0.15" asyncio = ["anyio (>=4.0,<5.0)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] -trio = ["trio (>=0.22.0,<0.26.0)"] +trio = ["trio (>=0.22.0,<1.0)"] [[package]] name = "httplib2" @@ -3650,61 +3969,68 @@ pyparsing = {version = ">=2.4.2,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.0.2 || >3.0 [[package]] name = "httptools" -version = "0.6.1" +version = "0.6.4" description = "A collection of framework independent HTTP protocol utils." optional = false python-versions = ">=3.8.0" files = [ - {file = "httptools-0.6.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d2f6c3c4cb1948d912538217838f6e9960bc4a521d7f9b323b3da579cd14532f"}, - {file = "httptools-0.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:00d5d4b68a717765b1fabfd9ca755bd12bf44105eeb806c03d1962acd9b8e563"}, - {file = "httptools-0.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:639dc4f381a870c9ec860ce5c45921db50205a37cc3334e756269736ff0aac58"}, - {file = "httptools-0.6.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e57997ac7fb7ee43140cc03664de5f268813a481dff6245e0075925adc6aa185"}, - {file = "httptools-0.6.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0ac5a0ae3d9f4fe004318d64b8a854edd85ab76cffbf7ef5e32920faef62f142"}, - {file = "httptools-0.6.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:3f30d3ce413088a98b9db71c60a6ada2001a08945cb42dd65a9a9fe228627658"}, - {file = "httptools-0.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:1ed99a373e327f0107cb513b61820102ee4f3675656a37a50083eda05dc9541b"}, - {file = "httptools-0.6.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:7a7ea483c1a4485c71cb5f38be9db078f8b0e8b4c4dc0210f531cdd2ddac1ef1"}, - {file = "httptools-0.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:85ed077c995e942b6f1b07583e4eb0a8d324d418954fc6af913d36db7c05a5a0"}, - {file = "httptools-0.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b0bb634338334385351a1600a73e558ce619af390c2b38386206ac6a27fecfc"}, - {file = "httptools-0.6.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d9ceb2c957320def533671fc9c715a80c47025139c8d1f3797477decbc6edd2"}, - {file = "httptools-0.6.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:4f0f8271c0a4db459f9dc807acd0eadd4839934a4b9b892f6f160e94da309837"}, - {file = "httptools-0.6.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6a4f5ccead6d18ec072ac0b84420e95d27c1cdf5c9f1bc8fbd8daf86bd94f43d"}, - {file = "httptools-0.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:5cceac09f164bcba55c0500a18fe3c47df29b62353198e4f37bbcc5d591172c3"}, - {file = "httptools-0.6.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:75c8022dca7935cba14741a42744eee13ba05db00b27a4b940f0d646bd4d56d0"}, - {file = "httptools-0.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:48ed8129cd9a0d62cf4d1575fcf90fb37e3ff7d5654d3a5814eb3d55f36478c2"}, - {file = "httptools-0.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f58e335a1402fb5a650e271e8c2d03cfa7cea46ae124649346d17bd30d59c90"}, - {file = "httptools-0.6.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93ad80d7176aa5788902f207a4e79885f0576134695dfb0fefc15b7a4648d503"}, - {file = "httptools-0.6.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9bb68d3a085c2174c2477eb3ffe84ae9fb4fde8792edb7bcd09a1d8467e30a84"}, - {file = "httptools-0.6.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:b512aa728bc02354e5ac086ce76c3ce635b62f5fbc32ab7082b5e582d27867bb"}, - {file = "httptools-0.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:97662ce7fb196c785344d00d638fc9ad69e18ee4bfb4000b35a52efe5adcc949"}, - {file = "httptools-0.6.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:8e216a038d2d52ea13fdd9b9c9c7459fb80d78302b257828285eca1c773b99b3"}, - {file = "httptools-0.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:3e802e0b2378ade99cd666b5bffb8b2a7cc8f3d28988685dc300469ea8dd86cb"}, - {file = "httptools-0.6.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4bd3e488b447046e386a30f07af05f9b38d3d368d1f7b4d8f7e10af85393db97"}, - {file = "httptools-0.6.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe467eb086d80217b7584e61313ebadc8d187a4d95bb62031b7bab4b205c3ba3"}, - {file = "httptools-0.6.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:3c3b214ce057c54675b00108ac42bacf2ab8f85c58e3f324a4e963bbc46424f4"}, - {file = "httptools-0.6.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8ae5b97f690badd2ca27cbf668494ee1b6d34cf1c464271ef7bfa9ca6b83ffaf"}, - {file = "httptools-0.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:405784577ba6540fa7d6ff49e37daf104e04f4b4ff2d1ac0469eaa6a20fde084"}, - {file = "httptools-0.6.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:95fb92dd3649f9cb139e9c56604cc2d7c7bf0fc2e7c8d7fbd58f96e35eddd2a3"}, - {file = "httptools-0.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:dcbab042cc3ef272adc11220517278519adf8f53fd3056d0e68f0a6f891ba94e"}, - {file = "httptools-0.6.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0cf2372e98406efb42e93bfe10f2948e467edfd792b015f1b4ecd897903d3e8d"}, - {file = "httptools-0.6.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:678fcbae74477a17d103b7cae78b74800d795d702083867ce160fc202104d0da"}, - {file = "httptools-0.6.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:e0b281cf5a125c35f7f6722b65d8542d2e57331be573e9e88bc8b0115c4a7a81"}, - {file = "httptools-0.6.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:95658c342529bba4e1d3d2b1a874db16c7cca435e8827422154c9da76ac4e13a"}, - {file = "httptools-0.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:7ebaec1bf683e4bf5e9fbb49b8cc36da482033596a415b3e4ebab5a4c0d7ec5e"}, - {file = "httptools-0.6.1.tar.gz", hash = "sha256:c6e26c30455600b95d94b1b836085138e82f177351454ee841c148f93a9bad5a"}, -] - -[package.extras] -test = ["Cython (>=0.29.24,<0.30.0)"] + {file = "httptools-0.6.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3c73ce323711a6ffb0d247dcd5a550b8babf0f757e86a52558fe5b86d6fefcc0"}, + {file = "httptools-0.6.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:345c288418f0944a6fe67be8e6afa9262b18c7626c3ef3c28adc5eabc06a68da"}, + {file = "httptools-0.6.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:deee0e3343f98ee8047e9f4c5bc7cedbf69f5734454a94c38ee829fb2d5fa3c1"}, + {file = "httptools-0.6.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca80b7485c76f768a3bc83ea58373f8db7b015551117375e4918e2aa77ea9b50"}, + {file = "httptools-0.6.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:90d96a385fa941283ebd231464045187a31ad932ebfa541be8edf5b3c2328959"}, + {file = "httptools-0.6.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:59e724f8b332319e2875efd360e61ac07f33b492889284a3e05e6d13746876f4"}, + {file = "httptools-0.6.4-cp310-cp310-win_amd64.whl", hash = "sha256:c26f313951f6e26147833fc923f78f95604bbec812a43e5ee37f26dc9e5a686c"}, + {file = "httptools-0.6.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f47f8ed67cc0ff862b84a1189831d1d33c963fb3ce1ee0c65d3b0cbe7b711069"}, + {file = "httptools-0.6.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0614154d5454c21b6410fdf5262b4a3ddb0f53f1e1721cfd59d55f32138c578a"}, + {file = "httptools-0.6.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f8787367fbdfccae38e35abf7641dafc5310310a5987b689f4c32cc8cc3ee975"}, + {file = "httptools-0.6.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40b0f7fe4fd38e6a507bdb751db0379df1e99120c65fbdc8ee6c1d044897a636"}, + {file = "httptools-0.6.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:40a5ec98d3f49904b9fe36827dcf1aadfef3b89e2bd05b0e35e94f97c2b14721"}, + {file = "httptools-0.6.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:dacdd3d10ea1b4ca9df97a0a303cbacafc04b5cd375fa98732678151643d4988"}, + {file = "httptools-0.6.4-cp311-cp311-win_amd64.whl", hash = "sha256:288cd628406cc53f9a541cfaf06041b4c71d751856bab45e3702191f931ccd17"}, + {file = "httptools-0.6.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:df017d6c780287d5c80601dafa31f17bddb170232d85c066604d8558683711a2"}, + {file = "httptools-0.6.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:85071a1e8c2d051b507161f6c3e26155b5c790e4e28d7f236422dbacc2a9cc44"}, + {file = "httptools-0.6.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69422b7f458c5af875922cdb5bd586cc1f1033295aa9ff63ee196a87519ac8e1"}, + {file = "httptools-0.6.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:16e603a3bff50db08cd578d54f07032ca1631450ceb972c2f834c2b860c28ea2"}, + {file = "httptools-0.6.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ec4f178901fa1834d4a060320d2f3abc5c9e39766953d038f1458cb885f47e81"}, + {file = "httptools-0.6.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:f9eb89ecf8b290f2e293325c646a211ff1c2493222798bb80a530c5e7502494f"}, + {file = "httptools-0.6.4-cp312-cp312-win_amd64.whl", hash = "sha256:db78cb9ca56b59b016e64b6031eda5653be0589dba2b1b43453f6e8b405a0970"}, + {file = "httptools-0.6.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ade273d7e767d5fae13fa637f4d53b6e961fb7fd93c7797562663f0171c26660"}, + {file = "httptools-0.6.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:856f4bc0478ae143bad54a4242fccb1f3f86a6e1be5548fecfd4102061b3a083"}, + {file = "httptools-0.6.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:322d20ea9cdd1fa98bd6a74b77e2ec5b818abdc3d36695ab402a0de8ef2865a3"}, + {file = "httptools-0.6.4-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d87b29bd4486c0093fc64dea80231f7c7f7eb4dc70ae394d70a495ab8436071"}, + {file = "httptools-0.6.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:342dd6946aa6bda4b8f18c734576106b8a31f2fe31492881a9a160ec84ff4bd5"}, + {file = "httptools-0.6.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4b36913ba52008249223042dca46e69967985fb4051951f94357ea681e1f5dc0"}, + {file = "httptools-0.6.4-cp313-cp313-win_amd64.whl", hash = "sha256:28908df1b9bb8187393d5b5db91435ccc9c8e891657f9cbb42a2541b44c82fc8"}, + {file = "httptools-0.6.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:d3f0d369e7ffbe59c4b6116a44d6a8eb4783aae027f2c0b366cf0aa964185dba"}, + {file = "httptools-0.6.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:94978a49b8f4569ad607cd4946b759d90b285e39c0d4640c6b36ca7a3ddf2efc"}, + {file = "httptools-0.6.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40dc6a8e399e15ea525305a2ddba998b0af5caa2566bcd79dcbe8948181eeaff"}, + {file = "httptools-0.6.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab9ba8dcf59de5181f6be44a77458e45a578fc99c31510b8c65b7d5acc3cf490"}, + {file = "httptools-0.6.4-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:fc411e1c0a7dcd2f902c7c48cf079947a7e65b5485dea9decb82b9105ca71a43"}, + {file = "httptools-0.6.4-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:d54efd20338ac52ba31e7da78e4a72570cf729fac82bc31ff9199bedf1dc7440"}, + {file = "httptools-0.6.4-cp38-cp38-win_amd64.whl", hash = "sha256:df959752a0c2748a65ab5387d08287abf6779ae9165916fe053e68ae1fbdc47f"}, + {file = "httptools-0.6.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:85797e37e8eeaa5439d33e556662cc370e474445d5fab24dcadc65a8ffb04003"}, + {file = "httptools-0.6.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:db353d22843cf1028f43c3651581e4bb49374d85692a85f95f7b9a130e1b2cab"}, + {file = "httptools-0.6.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d1ffd262a73d7c28424252381a5b854c19d9de5f56f075445d33919a637e3547"}, + {file = "httptools-0.6.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:703c346571fa50d2e9856a37d7cd9435a25e7fd15e236c397bf224afaa355fe9"}, + {file = "httptools-0.6.4-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:aafe0f1918ed07b67c1e838f950b1c1fabc683030477e60b335649b8020e1076"}, + {file = "httptools-0.6.4-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:0e563e54979e97b6d13f1bbc05a96109923e76b901f786a5eae36e99c01237bd"}, + {file = "httptools-0.6.4-cp39-cp39-win_amd64.whl", hash = "sha256:b799de31416ecc589ad79dd85a0b2657a8fe39327944998dea368c1d4c9e55e6"}, + {file = "httptools-0.6.4.tar.gz", hash = "sha256:4e93eee4add6493b59a5c514da98c939b244fce4a0d8879cd3f466562f4b7d5c"}, +] + +[package.extras] +test = ["Cython (>=0.29.24)"] [[package]] name = "httpx" -version = "0.27.0" +version = "0.27.2" description = "The next generation HTTP client." optional = false python-versions = ">=3.8" files = [ - {file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"}, - {file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"}, + {file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"}, + {file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"}, ] [package.dependencies] @@ -3721,6 +4047,7 @@ brotli = ["brotli", "brotlicffi"] cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] +zstd = ["zstandard (>=0.18.0)"] [[package]] name = "huggingface-hub" @@ -3781,48 +4108,55 @@ files = [ [[package]] name = "idna" -version = "3.7" +version = "3.10" description = "Internationalized Domain Names in Applications (IDNA)" optional = false -python-versions = ">=3.5" +python-versions = ">=3.6" files = [ - {file = "idna-3.7-py3-none-any.whl", hash = "sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0"}, - {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, + {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, + {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, ] +[package.extras] +all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"] + [[package]] name = "importlib-metadata" -version = "8.0.0" +version = "6.11.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_metadata-8.0.0-py3-none-any.whl", hash = "sha256:15584cf2b1bf449d98ff8a6ff1abef57bf20f3ac6454f431736cd3e660921b2f"}, - {file = "importlib_metadata-8.0.0.tar.gz", hash = "sha256:188bd24e4c346d3f0a933f275c2fec67050326a856b9a359881d7c2a697e8812"}, + {file = "importlib_metadata-6.11.0-py3-none-any.whl", hash = "sha256:f0afba6205ad8f8947c7d338b5342d5db2afbfd82f9cbef7879a9539cc12eb9b"}, + {file = "importlib_metadata-6.11.0.tar.gz", hash = "sha256:1231cf92d825c9e03cfc4da076a16de6422c863558229ea0b22b675657463443"}, ] [package.dependencies] zipp = ">=0.5" [package.extras] -doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] -test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] [[package]] name = "importlib-resources" -version = "6.4.0" +version = "6.4.5" description = "Read resources from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_resources-6.4.0-py3-none-any.whl", hash = "sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c"}, - {file = "importlib_resources-6.4.0.tar.gz", hash = "sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145"}, + {file = "importlib_resources-6.4.5-py3-none-any.whl", hash = "sha256:ac29d5f956f01d5e4bb63102a5a19957f1b9175e45649977264a1416783bb717"}, + {file = "importlib_resources-6.4.5.tar.gz", hash = "sha256:980862a1d16c9e147a59603677fa2aa5fd82b87f223b6cb870695bcfce830065"}, ] [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["jaraco.test (>=5.4)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["jaraco.test (>=5.4)", "pytest (>=6,!=8.1.*)", "zipp (>=3.17)"] +type = ["pytest-mypy"] [[package]] name = "iniconfig" @@ -3837,18 +4171,15 @@ files = [ [[package]] name = "isodate" -version = "0.6.1" +version = "0.7.2" description = "An ISO 8601 date/time/duration parser and formatter" optional = false -python-versions = "*" +python-versions = ">=3.7" files = [ - {file = "isodate-0.6.1-py2.py3-none-any.whl", hash = "sha256:0751eece944162659049d35f4f549ed815792b38793f07cf73381c1c87cbed96"}, - {file = "isodate-0.6.1.tar.gz", hash = "sha256:48c5881de7e8b0a0d648cb024c8062dc84e7b840ed81e864c7614fd3c127bde9"}, + {file = "isodate-0.7.2-py3-none-any.whl", hash = "sha256:28009937d8031054830160fce6d409ed342816b543597cece116d966c6d99e15"}, + {file = "isodate-0.7.2.tar.gz", hash = "sha256:4cd1aa0f43ca76f4a6c6c0292a85f40b35ec2e43e315b59f06e6d32171a953e6"}, ] -[package.dependencies] -six = "*" - [[package]] name = "itsdangerous" version = "2.2.0" @@ -3897,6 +4228,88 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "jiter" +version = "0.6.1" +description = "Fast iterable JSON parser." +optional = false +python-versions = ">=3.8" +files = [ + {file = "jiter-0.6.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:d08510593cb57296851080018006dfc394070178d238b767b1879dc1013b106c"}, + {file = "jiter-0.6.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:adef59d5e2394ebbad13b7ed5e0306cceb1df92e2de688824232a91588e77aa7"}, + {file = "jiter-0.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3e02f7a27f2bcc15b7d455c9df05df8ffffcc596a2a541eeda9a3110326e7a3"}, + {file = "jiter-0.6.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed69a7971d67b08f152c17c638f0e8c2aa207e9dd3a5fcd3cba294d39b5a8d2d"}, + {file = "jiter-0.6.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b2019d966e98f7c6df24b3b8363998575f47d26471bfb14aade37630fae836a1"}, + {file = "jiter-0.6.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:36c0b51a285b68311e207a76c385650322734c8717d16c2eb8af75c9d69506e7"}, + {file = "jiter-0.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:220e0963b4fb507c525c8f58cde3da6b1be0bfddb7ffd6798fb8f2531226cdb1"}, + {file = "jiter-0.6.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:aa25c7a9bf7875a141182b9c95aed487add635da01942ef7ca726e42a0c09058"}, + {file = "jiter-0.6.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e90552109ca8ccd07f47ca99c8a1509ced93920d271bb81780a973279974c5ab"}, + {file = "jiter-0.6.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:67723a011964971864e0b484b0ecfee6a14de1533cff7ffd71189e92103b38a8"}, + {file = "jiter-0.6.1-cp310-none-win32.whl", hash = "sha256:33af2b7d2bf310fdfec2da0177eab2fedab8679d1538d5b86a633ebfbbac4edd"}, + {file = "jiter-0.6.1-cp310-none-win_amd64.whl", hash = "sha256:7cea41c4c673353799906d940eee8f2d8fd1d9561d734aa921ae0f75cb9732f4"}, + {file = "jiter-0.6.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:b03c24e7da7e75b170c7b2b172d9c5e463aa4b5c95696a368d52c295b3f6847f"}, + {file = "jiter-0.6.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:47fee1be677b25d0ef79d687e238dc6ac91a8e553e1a68d0839f38c69e0ee491"}, + {file = "jiter-0.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25f0d2f6e01a8a0fb0eab6d0e469058dab2be46ff3139ed2d1543475b5a1d8e7"}, + {file = "jiter-0.6.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0b809e39e342c346df454b29bfcc7bca3d957f5d7b60e33dae42b0e5ec13e027"}, + {file = "jiter-0.6.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e9ac7c2f092f231f5620bef23ce2e530bd218fc046098747cc390b21b8738a7a"}, + {file = "jiter-0.6.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e51a2d80d5fe0ffb10ed2c82b6004458be4a3f2b9c7d09ed85baa2fbf033f54b"}, + {file = "jiter-0.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3343d4706a2b7140e8bd49b6c8b0a82abf9194b3f0f5925a78fc69359f8fc33c"}, + {file = "jiter-0.6.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:82521000d18c71e41c96960cb36e915a357bc83d63a8bed63154b89d95d05ad1"}, + {file = "jiter-0.6.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:3c843e7c1633470708a3987e8ce617ee2979ee18542d6eb25ae92861af3f1d62"}, + {file = "jiter-0.6.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a2e861658c3fe849efc39b06ebb98d042e4a4c51a8d7d1c3ddc3b1ea091d0784"}, + {file = "jiter-0.6.1-cp311-none-win32.whl", hash = "sha256:7d72fc86474862c9c6d1f87b921b70c362f2b7e8b2e3c798bb7d58e419a6bc0f"}, + {file = "jiter-0.6.1-cp311-none-win_amd64.whl", hash = "sha256:3e36a320634f33a07794bb15b8da995dccb94f944d298c8cfe2bd99b1b8a574a"}, + {file = "jiter-0.6.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:1fad93654d5a7dcce0809aff66e883c98e2618b86656aeb2129db2cd6f26f867"}, + {file = "jiter-0.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4e6e340e8cd92edab7f6a3a904dbbc8137e7f4b347c49a27da9814015cc0420c"}, + {file = "jiter-0.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:691352e5653af84ed71763c3c427cff05e4d658c508172e01e9c956dfe004aba"}, + {file = "jiter-0.6.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:defee3949313c1f5b55e18be45089970cdb936eb2a0063f5020c4185db1b63c9"}, + {file = "jiter-0.6.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:26d2bdd5da097e624081c6b5d416d3ee73e5b13f1703bcdadbb1881f0caa1933"}, + {file = "jiter-0.6.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18aa9d1626b61c0734b973ed7088f8a3d690d0b7f5384a5270cd04f4d9f26c86"}, + {file = "jiter-0.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a3567c8228afa5ddcce950631c6b17397ed178003dc9ee7e567c4c4dcae9fa0"}, + {file = "jiter-0.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e5c0507131c922defe3f04c527d6838932fcdfd69facebafd7d3574fa3395314"}, + {file = "jiter-0.6.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:540fcb224d7dc1bcf82f90f2ffb652df96f2851c031adca3c8741cb91877143b"}, + {file = "jiter-0.6.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e7b75436d4fa2032b2530ad989e4cb0ca74c655975e3ff49f91a1a3d7f4e1df2"}, + {file = "jiter-0.6.1-cp312-none-win32.whl", hash = "sha256:883d2ced7c21bf06874fdeecab15014c1c6d82216765ca6deef08e335fa719e0"}, + {file = "jiter-0.6.1-cp312-none-win_amd64.whl", hash = "sha256:91e63273563401aadc6c52cca64a7921c50b29372441adc104127b910e98a5b6"}, + {file = "jiter-0.6.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:852508a54fe3228432e56019da8b69208ea622a3069458252f725d634e955b31"}, + {file = "jiter-0.6.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f491cc69ff44e5a1e8bc6bf2b94c1f98d179e1aaf4a554493c171a5b2316b701"}, + {file = "jiter-0.6.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc56c8f0b2a28ad4d8047f3ae62d25d0e9ae01b99940ec0283263a04724de1f3"}, + {file = "jiter-0.6.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:51b58f7a0d9e084a43b28b23da2b09fc5e8df6aa2b6a27de43f991293cab85fd"}, + {file = "jiter-0.6.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5f79ce15099154c90ef900d69c6b4c686b64dfe23b0114e0971f2fecd306ec6c"}, + {file = "jiter-0.6.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:03a025b52009f47e53ea619175d17e4ded7c035c6fbd44935cb3ada11e1fd592"}, + {file = "jiter-0.6.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c74a8d93718137c021d9295248a87c2f9fdc0dcafead12d2930bc459ad40f885"}, + {file = "jiter-0.6.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:40b03b75f903975f68199fc4ec73d546150919cb7e534f3b51e727c4d6ccca5a"}, + {file = "jiter-0.6.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:825651a3f04cf92a661d22cad61fc913400e33aa89b3e3ad9a6aa9dc8a1f5a71"}, + {file = "jiter-0.6.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:928bf25eb69ddb292ab8177fe69d3fbf76c7feab5fce1c09265a7dccf25d3991"}, + {file = "jiter-0.6.1-cp313-none-win32.whl", hash = "sha256:352cd24121e80d3d053fab1cc9806258cad27c53cad99b7a3cac57cf934b12e4"}, + {file = "jiter-0.6.1-cp313-none-win_amd64.whl", hash = "sha256:be7503dd6f4bf02c2a9bacb5cc9335bc59132e7eee9d3e931b13d76fd80d7fda"}, + {file = "jiter-0.6.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:31d8e00e1fb4c277df8ab6f31a671f509ebc791a80e5c61fdc6bc8696aaa297c"}, + {file = "jiter-0.6.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:77c296d65003cd7ee5d7b0965f6acbe6cffaf9d1fa420ea751f60ef24e85fed5"}, + {file = "jiter-0.6.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aeeb0c0325ef96c12a48ea7e23e2e86fe4838e6e0a995f464cf4c79fa791ceeb"}, + {file = "jiter-0.6.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a31c6fcbe7d6c25d6f1cc6bb1cba576251d32795d09c09961174fe461a1fb5bd"}, + {file = "jiter-0.6.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:59e2b37f3b9401fc9e619f4d4badcab2e8643a721838bcf695c2318a0475ae42"}, + {file = "jiter-0.6.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bae5ae4853cb9644144e9d0755854ce5108d470d31541d83f70ca7ecdc2d1637"}, + {file = "jiter-0.6.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9df588e9c830b72d8db1dd7d0175af6706b0904f682ea9b1ca8b46028e54d6e9"}, + {file = "jiter-0.6.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:15f8395e835cf561c85c1adee72d899abf2733d9df72e9798e6d667c9b5c1f30"}, + {file = "jiter-0.6.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5a99d4e0b5fc3b05ea732d67eb2092fe894e95a90e6e413f2ea91387e228a307"}, + {file = "jiter-0.6.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a311df1fa6be0ccd64c12abcd85458383d96e542531bafbfc0a16ff6feda588f"}, + {file = "jiter-0.6.1-cp38-none-win32.whl", hash = "sha256:81116a6c272a11347b199f0e16b6bd63f4c9d9b52bc108991397dd80d3c78aba"}, + {file = "jiter-0.6.1-cp38-none-win_amd64.whl", hash = "sha256:13f9084e3e871a7c0b6e710db54444088b1dd9fbefa54d449b630d5e73bb95d0"}, + {file = "jiter-0.6.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:f1c53615fcfec3b11527c08d19cff6bc870da567ce4e57676c059a3102d3a082"}, + {file = "jiter-0.6.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f791b6a4da23238c17a81f44f5b55d08a420c5692c1fda84e301a4b036744eb1"}, + {file = "jiter-0.6.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c97e90fec2da1d5f68ef121444c2c4fa72eabf3240829ad95cf6bbeca42a301"}, + {file = "jiter-0.6.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3cbc1a66b4e41511209e97a2866898733c0110b7245791ac604117b7fb3fedb7"}, + {file = "jiter-0.6.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e4e85f9e12cd8418ab10e1fcf0e335ae5bb3da26c4d13a0fd9e6a17a674783b6"}, + {file = "jiter-0.6.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:08be33db6dcc374c9cc19d3633af5e47961a7b10d4c61710bd39e48d52a35824"}, + {file = "jiter-0.6.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:677be9550004f5e010d673d3b2a2b815a8ea07a71484a57d3f85dde7f14cf132"}, + {file = "jiter-0.6.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e8bd065be46c2eecc328e419d6557bbc37844c88bb07b7a8d2d6c91c7c4dedc9"}, + {file = "jiter-0.6.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bd95375ce3609ec079a97c5d165afdd25693302c071ca60c7ae1cf826eb32022"}, + {file = "jiter-0.6.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db459ed22d0208940d87f614e1f0ea5a946d29a3cfef71f7e1aab59b6c6b2afb"}, + {file = "jiter-0.6.1-cp39-none-win32.whl", hash = "sha256:d71c962f0971347bd552940ab96aa42ceefcd51b88c4ced8a27398182efa8d80"}, + {file = "jiter-0.6.1-cp39-none-win_amd64.whl", hash = "sha256:d465db62d2d10b489b7e7a33027c4ae3a64374425d757e963f86df5b5f2e7fc5"}, + {file = "jiter-0.6.1.tar.gz", hash = "sha256:e19cd21221fc139fb032e4112986656cb2739e9fe6d84c13956ab30ccc7d4449"}, +] + [[package]] name = "jmespath" version = "0.10.0" @@ -3919,6 +4332,20 @@ files = [ {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"}, ] +[[package]] +name = "jsonlines" +version = "4.0.0" +description = "Library with helpers for the jsonlines file format" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jsonlines-4.0.0-py3-none-any.whl", hash = "sha256:185b334ff2ca5a91362993f42e83588a360cf95ce4b71a73548502bda52a7c55"}, + {file = "jsonlines-4.0.0.tar.gz", hash = "sha256:0c6d2c09117550c089995247f605ae4cf77dd1533041d366351f6f298822ea74"}, +] + +[package.dependencies] +attrs = ">=19.2.0" + [[package]] name = "jsonpath-ng" version = "1.6.1" @@ -3933,6 +4360,52 @@ files = [ [package.dependencies] ply = "*" +[[package]] +name = "jsonpath-python" +version = "1.0.6" +description = "A more powerful JSONPath implementation in modern python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "jsonpath-python-1.0.6.tar.gz", hash = "sha256:dd5be4a72d8a2995c3f583cf82bf3cd1a9544cfdabf2d22595b67aff07349666"}, + {file = "jsonpath_python-1.0.6-py3-none-any.whl", hash = "sha256:1e3b78df579f5efc23565293612decee04214609208a2335884b3ee3f786b575"}, +] + +[[package]] +name = "jsonschema" +version = "4.23.0" +description = "An implementation of JSON Schema validation for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "jsonschema-4.23.0-py3-none-any.whl", hash = "sha256:fbadb6f8b144a8f8cf9f0b89ba94501d143e50411a1278633f56a7acf7fd5566"}, + {file = "jsonschema-4.23.0.tar.gz", hash = "sha256:d71497fef26351a33265337fa77ffeb82423f3ea21283cd9467bb03999266bc4"}, +] + +[package.dependencies] +attrs = ">=22.2.0" +jsonschema-specifications = ">=2023.03.6" +referencing = ">=0.28.4" +rpds-py = ">=0.7.1" + +[package.extras] +format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"] +format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=24.6.0)"] + +[[package]] +name = "jsonschema-specifications" +version = "2024.10.1" +description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" +optional = false +python-versions = ">=3.9" +files = [ + {file = "jsonschema_specifications-2024.10.1-py3-none-any.whl", hash = "sha256:a09a0680616357d9a0ecf05c12ad234479f549239d0f5b55f3deea67475da9bf"}, + {file = "jsonschema_specifications-2024.10.1.tar.gz", hash = "sha256:0f38b83639958ce1152d02a7f062902c41c8fd20d558b0c34344292d417ae272"}, +] + +[package.dependencies] +referencing = ">=0.31.0" + [[package]] name = "kaleido" version = "0.2.1" @@ -3950,130 +4423,141 @@ files = [ [[package]] name = "kiwisolver" -version = "1.4.5" +version = "1.4.7" description = "A fast implementation of the Cassowary constraint solver" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:05703cf211d585109fcd72207a31bb170a0f22144d68298dc5e61b3c946518af"}, - {file = "kiwisolver-1.4.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:146d14bebb7f1dc4d5fbf74f8a6cb15ac42baadee8912eb84ac0b3b2a3dc6ac3"}, - {file = "kiwisolver-1.4.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6ef7afcd2d281494c0a9101d5c571970708ad911d028137cd558f02b851c08b4"}, - {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:9eaa8b117dc8337728e834b9c6e2611f10c79e38f65157c4c38e9400286f5cb1"}, - {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ec20916e7b4cbfb1f12380e46486ec4bcbaa91a9c448b97023fde0d5bbf9e4ff"}, - {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39b42c68602539407884cf70d6a480a469b93b81b7701378ba5e2328660c847a"}, - {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa12042de0171fad672b6c59df69106d20d5596e4f87b5e8f76df757a7c399aa"}, - {file = "kiwisolver-1.4.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a40773c71d7ccdd3798f6489aaac9eee213d566850a9533f8d26332d626b82c"}, - {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:19df6e621f6d8b4b9c4d45f40a66839294ff2bb235e64d2178f7522d9170ac5b"}, - {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:83d78376d0d4fd884e2c114d0621624b73d2aba4e2788182d286309ebdeed770"}, - {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e391b1f0a8a5a10ab3b9bb6afcfd74f2175f24f8975fb87ecae700d1503cdee0"}, - {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:852542f9481f4a62dbb5dd99e8ab7aedfeb8fb6342349a181d4036877410f525"}, - {file = "kiwisolver-1.4.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:59edc41b24031bc25108e210c0def6f6c2191210492a972d585a06ff246bb79b"}, - {file = "kiwisolver-1.4.5-cp310-cp310-win32.whl", hash = "sha256:a6aa6315319a052b4ee378aa171959c898a6183f15c1e541821c5c59beaa0238"}, - {file = "kiwisolver-1.4.5-cp310-cp310-win_amd64.whl", hash = "sha256:d0ef46024e6a3d79c01ff13801cb19d0cad7fd859b15037aec74315540acc276"}, - {file = "kiwisolver-1.4.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:11863aa14a51fd6ec28688d76f1735f8f69ab1fabf388851a595d0721af042f5"}, - {file = "kiwisolver-1.4.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8ab3919a9997ab7ef2fbbed0cc99bb28d3c13e6d4b1ad36e97e482558a91be90"}, - {file = "kiwisolver-1.4.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fcc700eadbbccbf6bc1bcb9dbe0786b4b1cb91ca0dcda336eef5c2beed37b797"}, - {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dfdd7c0b105af050eb3d64997809dc21da247cf44e63dc73ff0fd20b96be55a9"}, - {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76c6a5964640638cdeaa0c359382e5703e9293030fe730018ca06bc2010c4437"}, - {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bbea0db94288e29afcc4c28afbf3a7ccaf2d7e027489c449cf7e8f83c6346eb9"}, - {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ceec1a6bc6cab1d6ff5d06592a91a692f90ec7505d6463a88a52cc0eb58545da"}, - {file = "kiwisolver-1.4.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:040c1aebeda72197ef477a906782b5ab0d387642e93bda547336b8957c61022e"}, - {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f91de7223d4c7b793867797bacd1ee53bfe7359bd70d27b7b58a04efbb9436c8"}, - {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:faae4860798c31530dd184046a900e652c95513796ef51a12bc086710c2eec4d"}, - {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:b0157420efcb803e71d1b28e2c287518b8808b7cf1ab8af36718fd0a2c453eb0"}, - {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:06f54715b7737c2fecdbf140d1afb11a33d59508a47bf11bb38ecf21dc9ab79f"}, - {file = "kiwisolver-1.4.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fdb7adb641a0d13bdcd4ef48e062363d8a9ad4a182ac7647ec88f695e719ae9f"}, - {file = "kiwisolver-1.4.5-cp311-cp311-win32.whl", hash = "sha256:bb86433b1cfe686da83ce32a9d3a8dd308e85c76b60896d58f082136f10bffac"}, - {file = "kiwisolver-1.4.5-cp311-cp311-win_amd64.whl", hash = "sha256:6c08e1312a9cf1074d17b17728d3dfce2a5125b2d791527f33ffbe805200a355"}, - {file = "kiwisolver-1.4.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:32d5cf40c4f7c7b3ca500f8985eb3fb3a7dfc023215e876f207956b5ea26632a"}, - {file = "kiwisolver-1.4.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f846c260f483d1fd217fe5ed7c173fb109efa6b1fc8381c8b7552c5781756192"}, - {file = "kiwisolver-1.4.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5ff5cf3571589b6d13bfbfd6bcd7a3f659e42f96b5fd1c4830c4cf21d4f5ef45"}, - {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7269d9e5f1084a653d575c7ec012ff57f0c042258bf5db0954bf551c158466e7"}, - {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da802a19d6e15dffe4b0c24b38b3af68e6c1a68e6e1d8f30148c83864f3881db"}, - {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3aba7311af82e335dd1e36ffff68aaca609ca6290c2cb6d821a39aa075d8e3ff"}, - {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:763773d53f07244148ccac5b084da5adb90bfaee39c197554f01b286cf869228"}, - {file = "kiwisolver-1.4.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2270953c0d8cdab5d422bee7d2007f043473f9d2999631c86a223c9db56cbd16"}, - {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d099e745a512f7e3bbe7249ca835f4d357c586d78d79ae8f1dcd4d8adeb9bda9"}, - {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:74db36e14a7d1ce0986fa104f7d5637aea5c82ca6326ed0ec5694280942d1162"}, - {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e5bab140c309cb3a6ce373a9e71eb7e4873c70c2dda01df6820474f9889d6d4"}, - {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0f114aa76dc1b8f636d077979c0ac22e7cd8f3493abbab152f20eb8d3cda71f3"}, - {file = "kiwisolver-1.4.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:88a2df29d4724b9237fc0c6eaf2a1adae0cdc0b3e9f4d8e7dc54b16812d2d81a"}, - {file = "kiwisolver-1.4.5-cp312-cp312-win32.whl", hash = "sha256:72d40b33e834371fd330fb1472ca19d9b8327acb79a5821d4008391db8e29f20"}, - {file = "kiwisolver-1.4.5-cp312-cp312-win_amd64.whl", hash = "sha256:2c5674c4e74d939b9d91dda0fae10597ac7521768fec9e399c70a1f27e2ea2d9"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:3a2b053a0ab7a3960c98725cfb0bf5b48ba82f64ec95fe06f1d06c99b552e130"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3cd32d6c13807e5c66a7cbb79f90b553642f296ae4518a60d8d76243b0ad2898"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:59ec7b7c7e1a61061850d53aaf8e93db63dce0c936db1fda2658b70e4a1be709"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:da4cfb373035def307905d05041c1d06d8936452fe89d464743ae7fb8371078b"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2400873bccc260b6ae184b2b8a4fec0e4082d30648eadb7c3d9a13405d861e89"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:1b04139c4236a0f3aff534479b58f6f849a8b351e1314826c2d230849ed48985"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:4e66e81a5779b65ac21764c295087de82235597a2293d18d943f8e9e32746265"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:7931d8f1f67c4be9ba1dd9c451fb0eeca1a25b89e4d3f89e828fe12a519b782a"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:b3f7e75f3015df442238cca659f8baa5f42ce2a8582727981cbfa15fee0ee205"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:bbf1d63eef84b2e8c89011b7f2235b1e0bf7dacc11cac9431fc6468e99ac77fb"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:4c380469bd3f970ef677bf2bcba2b6b0b4d5c75e7a020fb863ef75084efad66f"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-win32.whl", hash = "sha256:9408acf3270c4b6baad483865191e3e582b638b1654a007c62e3efe96f09a9a3"}, - {file = "kiwisolver-1.4.5-cp37-cp37m-win_amd64.whl", hash = "sha256:5b94529f9b2591b7af5f3e0e730a4e0a41ea174af35a4fd067775f9bdfeee01a"}, - {file = "kiwisolver-1.4.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:11c7de8f692fc99816e8ac50d1d1aef4f75126eefc33ac79aac02c099fd3db71"}, - {file = "kiwisolver-1.4.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:53abb58632235cd154176ced1ae8f0d29a6657aa1aa9decf50b899b755bc2b93"}, - {file = "kiwisolver-1.4.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:88b9f257ca61b838b6f8094a62418421f87ac2a1069f7e896c36a7d86b5d4c29"}, - {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3195782b26fc03aa9c6913d5bad5aeb864bdc372924c093b0f1cebad603dd712"}, - {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fc579bf0f502e54926519451b920e875f433aceb4624a3646b3252b5caa9e0b6"}, - {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5a580c91d686376f0f7c295357595c5a026e6cbc3d77b7c36e290201e7c11ecb"}, - {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cfe6ab8da05c01ba6fbea630377b5da2cd9bcbc6338510116b01c1bc939a2c18"}, - {file = "kiwisolver-1.4.5-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:d2e5a98f0ec99beb3c10e13b387f8db39106d53993f498b295f0c914328b1333"}, - {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a51a263952b1429e429ff236d2f5a21c5125437861baeed77f5e1cc2d2c7c6da"}, - {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3edd2fa14e68c9be82c5b16689e8d63d89fe927e56debd6e1dbce7a26a17f81b"}, - {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:74d1b44c6cfc897df648cc9fdaa09bc3e7679926e6f96df05775d4fb3946571c"}, - {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:76d9289ed3f7501012e05abb8358bbb129149dbd173f1f57a1bf1c22d19ab7cc"}, - {file = "kiwisolver-1.4.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:92dea1ffe3714fa8eb6a314d2b3c773208d865a0e0d35e713ec54eea08a66250"}, - {file = "kiwisolver-1.4.5-cp38-cp38-win32.whl", hash = "sha256:5c90ae8c8d32e472be041e76f9d2f2dbff4d0b0be8bd4041770eddb18cf49a4e"}, - {file = "kiwisolver-1.4.5-cp38-cp38-win_amd64.whl", hash = "sha256:c7940c1dc63eb37a67721b10d703247552416f719c4188c54e04334321351ced"}, - {file = "kiwisolver-1.4.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:9407b6a5f0d675e8a827ad8742e1d6b49d9c1a1da5d952a67d50ef5f4170b18d"}, - {file = "kiwisolver-1.4.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:15568384086b6df3c65353820a4473575dbad192e35010f622c6ce3eebd57af9"}, - {file = "kiwisolver-1.4.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0dc9db8e79f0036e8173c466d21ef18e1befc02de8bf8aa8dc0813a6dc8a7046"}, - {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:cdc8a402aaee9a798b50d8b827d7ecf75edc5fb35ea0f91f213ff927c15f4ff0"}, - {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:6c3bd3cde54cafb87d74d8db50b909705c62b17c2099b8f2e25b461882e544ff"}, - {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:955e8513d07a283056b1396e9a57ceddbd272d9252c14f154d450d227606eb54"}, - {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:346f5343b9e3f00b8db8ba359350eb124b98c99efd0b408728ac6ebf38173958"}, - {file = "kiwisolver-1.4.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b9098e0049e88c6a24ff64545cdfc50807818ba6c1b739cae221bbbcbc58aad3"}, - {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:00bd361b903dc4bbf4eb165f24d1acbee754fce22ded24c3d56eec268658a5cf"}, - {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7b8b454bac16428b22560d0a1cf0a09875339cab69df61d7805bf48919415901"}, - {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:f1d072c2eb0ad60d4c183f3fb44ac6f73fb7a8f16a2694a91f988275cbf352f9"}, - {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:31a82d498054cac9f6d0b53d02bb85811185bcb477d4b60144f915f3b3126342"}, - {file = "kiwisolver-1.4.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6512cb89e334e4700febbffaaa52761b65b4f5a3cf33f960213d5656cea36a77"}, - {file = "kiwisolver-1.4.5-cp39-cp39-win32.whl", hash = "sha256:9db8ea4c388fdb0f780fe91346fd438657ea602d58348753d9fb265ce1bca67f"}, - {file = "kiwisolver-1.4.5-cp39-cp39-win_amd64.whl", hash = "sha256:59415f46a37f7f2efeec758353dd2eae1b07640d8ca0f0c42548ec4125492635"}, - {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:5c7b3b3a728dc6faf3fc372ef24f21d1e3cee2ac3e9596691d746e5a536de920"}, - {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:620ced262a86244e2be10a676b646f29c34537d0d9cc8eb26c08f53d98013390"}, - {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:378a214a1e3bbf5ac4a8708304318b4f890da88c9e6a07699c4ae7174c09a68d"}, - {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aaf7be1207676ac608a50cd08f102f6742dbfc70e8d60c4db1c6897f62f71523"}, - {file = "kiwisolver-1.4.5-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:ba55dce0a9b8ff59495ddd050a0225d58bd0983d09f87cfe2b6aec4f2c1234e4"}, - {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:fd32ea360bcbb92d28933fc05ed09bffcb1704ba3fc7942e81db0fd4f81a7892"}, - {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5e7139af55d1688f8b960ee9ad5adafc4ac17c1c473fe07133ac092310d76544"}, - {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:dced8146011d2bc2e883f9bd68618b8247387f4bbec46d7392b3c3b032640126"}, - {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c9bf3325c47b11b2e51bca0824ea217c7cd84491d8ac4eefd1e409705ef092bd"}, - {file = "kiwisolver-1.4.5-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:5794cf59533bc3f1b1c821f7206a3617999db9fbefc345360aafe2e067514929"}, - {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:e368f200bbc2e4f905b8e71eb38b3c04333bddaa6a2464a6355487b02bb7fb09"}, - {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5d706eba36b4c4d5bc6c6377bb6568098765e990cfc21ee16d13963fab7b3e7"}, - {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85267bd1aa8880a9c88a8cb71e18d3d64d2751a790e6ca6c27b8ccc724bcd5ad"}, - {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:210ef2c3a1f03272649aff1ef992df2e724748918c4bc2d5a90352849eb40bea"}, - {file = "kiwisolver-1.4.5-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:11d011a7574eb3b82bcc9c1a1d35c1d7075677fdd15de527d91b46bd35e935ee"}, - {file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"}, + {file = "kiwisolver-1.4.7-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8a9c83f75223d5e48b0bc9cb1bf2776cf01563e00ade8775ffe13b0b6e1af3a6"}, + {file = "kiwisolver-1.4.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:58370b1ffbd35407444d57057b57da5d6549d2d854fa30249771775c63b5fe17"}, + {file = "kiwisolver-1.4.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:aa0abdf853e09aff551db11fce173e2177d00786c688203f52c87ad7fcd91ef9"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8d53103597a252fb3ab8b5845af04c7a26d5e7ea8122303dd7a021176a87e8b9"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:88f17c5ffa8e9462fb79f62746428dd57b46eb931698e42e990ad63103f35e6c"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88a9ca9c710d598fd75ee5de59d5bda2684d9db36a9f50b6125eaea3969c2599"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f4d742cb7af1c28303a51b7a27aaee540e71bb8e24f68c736f6f2ffc82f2bf05"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e28c7fea2196bf4c2f8d46a0415c77a1c480cc0724722f23d7410ffe9842c407"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e968b84db54f9d42046cf154e02911e39c0435c9801681e3fc9ce8a3c4130278"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:0c18ec74c0472de033e1bebb2911c3c310eef5649133dd0bedf2a169a1b269e5"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:8f0ea6da6d393d8b2e187e6a5e3fb81f5862010a40c3945e2c6d12ae45cfb2ad"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:f106407dda69ae456dd1227966bf445b157ccc80ba0dff3802bb63f30b74e895"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:84ec80df401cfee1457063732d90022f93951944b5b58975d34ab56bb150dfb3"}, + {file = "kiwisolver-1.4.7-cp310-cp310-win32.whl", hash = "sha256:71bb308552200fb2c195e35ef05de12f0c878c07fc91c270eb3d6e41698c3bcc"}, + {file = "kiwisolver-1.4.7-cp310-cp310-win_amd64.whl", hash = "sha256:44756f9fd339de0fb6ee4f8c1696cfd19b2422e0d70b4cefc1cc7f1f64045a8c"}, + {file = "kiwisolver-1.4.7-cp310-cp310-win_arm64.whl", hash = "sha256:78a42513018c41c2ffd262eb676442315cbfe3c44eed82385c2ed043bc63210a"}, + {file = "kiwisolver-1.4.7-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d2b0e12a42fb4e72d509fc994713d099cbb15ebf1103545e8a45f14da2dfca54"}, + {file = "kiwisolver-1.4.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2a8781ac3edc42ea4b90bc23e7d37b665d89423818e26eb6df90698aa2287c95"}, + {file = "kiwisolver-1.4.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:46707a10836894b559e04b0fd143e343945c97fd170d69a2d26d640b4e297935"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ef97b8df011141c9b0f6caf23b29379f87dd13183c978a30a3c546d2c47314cb"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ab58c12a2cd0fc769089e6d38466c46d7f76aced0a1f54c77652446733d2d02"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:803b8e1459341c1bb56d1c5c010406d5edec8a0713a0945851290a7930679b51"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f9a9e8a507420fe35992ee9ecb302dab68550dedc0da9e2880dd88071c5fb052"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18077b53dc3bb490e330669a99920c5e6a496889ae8c63b58fbc57c3d7f33a18"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6af936f79086a89b3680a280c47ea90b4df7047b5bdf3aa5c524bbedddb9e545"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:3abc5b19d24af4b77d1598a585b8a719beb8569a71568b66f4ebe1fb0449460b"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:933d4de052939d90afbe6e9d5273ae05fb836cc86c15b686edd4b3560cc0ee36"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:65e720d2ab2b53f1f72fb5da5fb477455905ce2c88aaa671ff0a447c2c80e8e3"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:3bf1ed55088f214ba6427484c59553123fdd9b218a42bbc8c6496d6754b1e523"}, + {file = "kiwisolver-1.4.7-cp311-cp311-win32.whl", hash = "sha256:4c00336b9dd5ad96d0a558fd18a8b6f711b7449acce4c157e7343ba92dd0cf3d"}, + {file = "kiwisolver-1.4.7-cp311-cp311-win_amd64.whl", hash = "sha256:929e294c1ac1e9f615c62a4e4313ca1823ba37326c164ec720a803287c4c499b"}, + {file = "kiwisolver-1.4.7-cp311-cp311-win_arm64.whl", hash = "sha256:e33e8fbd440c917106b237ef1a2f1449dfbb9b6f6e1ce17c94cd6a1e0d438376"}, + {file = "kiwisolver-1.4.7-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:5360cc32706dab3931f738d3079652d20982511f7c0ac5711483e6eab08efff2"}, + {file = "kiwisolver-1.4.7-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:942216596dc64ddb25adb215c3c783215b23626f8d84e8eff8d6d45c3f29f75a"}, + {file = "kiwisolver-1.4.7-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:48b571ecd8bae15702e4f22d3ff6a0f13e54d3d00cd25216d5e7f658242065ee"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ad42ba922c67c5f219097b28fae965e10045ddf145d2928bfac2eb2e17673640"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:612a10bdae23404a72941a0fc8fa2660c6ea1217c4ce0dbcab8a8f6543ea9e7f"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9e838bba3a3bac0fe06d849d29772eb1afb9745a59710762e4ba3f4cb8424483"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:22f499f6157236c19f4bbbd472fa55b063db77a16cd74d49afe28992dff8c258"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:693902d433cf585133699972b6d7c42a8b9f8f826ebcaf0132ff55200afc599e"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4e77f2126c3e0b0d055f44513ed349038ac180371ed9b52fe96a32aa071a5107"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:657a05857bda581c3656bfc3b20e353c232e9193eb167766ad2dc58b56504948"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4bfa75a048c056a411f9705856abfc872558e33c055d80af6a380e3658766038"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:34ea1de54beef1c104422d210c47c7d2a4999bdecf42c7b5718fbe59a4cac383"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:90da3b5f694b85231cf93586dad5e90e2d71b9428f9aad96952c99055582f520"}, + {file = "kiwisolver-1.4.7-cp312-cp312-win32.whl", hash = "sha256:18e0cca3e008e17fe9b164b55735a325140a5a35faad8de92dd80265cd5eb80b"}, + {file = "kiwisolver-1.4.7-cp312-cp312-win_amd64.whl", hash = "sha256:58cb20602b18f86f83a5c87d3ee1c766a79c0d452f8def86d925e6c60fbf7bfb"}, + {file = "kiwisolver-1.4.7-cp312-cp312-win_arm64.whl", hash = "sha256:f5a8b53bdc0b3961f8b6125e198617c40aeed638b387913bf1ce78afb1b0be2a"}, + {file = "kiwisolver-1.4.7-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:2e6039dcbe79a8e0f044f1c39db1986a1b8071051efba3ee4d74f5b365f5226e"}, + {file = "kiwisolver-1.4.7-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a1ecf0ac1c518487d9d23b1cd7139a6a65bc460cd101ab01f1be82ecf09794b6"}, + {file = "kiwisolver-1.4.7-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7ab9ccab2b5bd5702ab0803676a580fffa2aa178c2badc5557a84cc943fcf750"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f816dd2277f8d63d79f9c8473a79fe54047bc0467754962840782c575522224d"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf8bcc23ceb5a1b624572a1623b9f79d2c3b337c8c455405ef231933a10da379"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dea0bf229319828467d7fca8c7c189780aa9ff679c94539eed7532ebe33ed37c"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c06a4c7cf15ec739ce0e5971b26c93638730090add60e183530d70848ebdd34"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:913983ad2deb14e66d83c28b632fd35ba2b825031f2fa4ca29675e665dfecbe1"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5337ec7809bcd0f424c6b705ecf97941c46279cf5ed92311782c7c9c2026f07f"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:4c26ed10c4f6fa6ddb329a5120ba3b6db349ca192ae211e882970bfc9d91420b"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:c619b101e6de2222c1fcb0531e1b17bbffbe54294bfba43ea0d411d428618c27"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:073a36c8273647592ea332e816e75ef8da5c303236ec0167196793eb1e34657a"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:3ce6b2b0231bda412463e152fc18335ba32faf4e8c23a754ad50ffa70e4091ee"}, + {file = "kiwisolver-1.4.7-cp313-cp313-win32.whl", hash = "sha256:f4c9aee212bc89d4e13f58be11a56cc8036cabad119259d12ace14b34476fd07"}, + {file = "kiwisolver-1.4.7-cp313-cp313-win_amd64.whl", hash = "sha256:8a3ec5aa8e38fc4c8af308917ce12c536f1c88452ce554027e55b22cbbfbff76"}, + {file = "kiwisolver-1.4.7-cp313-cp313-win_arm64.whl", hash = "sha256:76c8094ac20ec259471ac53e774623eb62e6e1f56cd8690c67ce6ce4fcb05650"}, + {file = "kiwisolver-1.4.7-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5d5abf8f8ec1f4e22882273c423e16cae834c36856cac348cfbfa68e01c40f3a"}, + {file = "kiwisolver-1.4.7-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:aeb3531b196ef6f11776c21674dba836aeea9d5bd1cf630f869e3d90b16cfade"}, + {file = "kiwisolver-1.4.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b7d755065e4e866a8086c9bdada157133ff466476a2ad7861828e17b6026e22c"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08471d4d86cbaec61f86b217dd938a83d85e03785f51121e791a6e6689a3be95"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7bbfcb7165ce3d54a3dfbe731e470f65739c4c1f85bb1018ee912bae139e263b"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d34eb8494bea691a1a450141ebb5385e4b69d38bb8403b5146ad279f4b30fa3"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9242795d174daa40105c1d86aba618e8eab7bf96ba8c3ee614da8302a9f95503"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:a0f64a48bb81af7450e641e3fe0b0394d7381e342805479178b3d335d60ca7cf"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:8e045731a5416357638d1700927529e2b8ab304811671f665b225f8bf8d8f933"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:4322872d5772cae7369f8351da1edf255a604ea7087fe295411397d0cfd9655e"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:e1631290ee9271dffe3062d2634c3ecac02c83890ada077d225e081aca8aab89"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:edcfc407e4eb17e037bca59be0e85a2031a2ac87e4fed26d3e9df88b4165f92d"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:4d05d81ecb47d11e7f8932bd8b61b720bf0b41199358f3f5e36d38e28f0532c5"}, + {file = "kiwisolver-1.4.7-cp38-cp38-win32.whl", hash = "sha256:b38ac83d5f04b15e515fd86f312479d950d05ce2368d5413d46c088dda7de90a"}, + {file = "kiwisolver-1.4.7-cp38-cp38-win_amd64.whl", hash = "sha256:d83db7cde68459fc803052a55ace60bea2bae361fc3b7a6d5da07e11954e4b09"}, + {file = "kiwisolver-1.4.7-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:3f9362ecfca44c863569d3d3c033dbe8ba452ff8eed6f6b5806382741a1334bd"}, + {file = "kiwisolver-1.4.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e8df2eb9b2bac43ef8b082e06f750350fbbaf2887534a5be97f6cf07b19d9583"}, + {file = "kiwisolver-1.4.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f32d6edbc638cde7652bd690c3e728b25332acbadd7cad670cc4a02558d9c417"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:e2e6c39bd7b9372b0be21456caab138e8e69cc0fc1190a9dfa92bd45a1e6e904"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:dda56c24d869b1193fcc763f1284b9126550eaf84b88bbc7256e15028f19188a"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79849239c39b5e1fd906556c474d9b0439ea6792b637511f3fe3a41158d89ca8"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5e3bc157fed2a4c02ec468de4ecd12a6e22818d4f09cde2c31ee3226ffbefab2"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3da53da805b71e41053dc670f9a820d1157aae77b6b944e08024d17bcd51ef88"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:8705f17dfeb43139a692298cb6637ee2e59c0194538153e83e9ee0c75c2eddde"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:82a5c2f4b87c26bb1a0ef3d16b5c4753434633b83d365cc0ddf2770c93829e3c"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ce8be0466f4c0d585cdb6c1e2ed07232221df101a4c6f28821d2aa754ca2d9e2"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:409afdfe1e2e90e6ee7fc896f3df9a7fec8e793e58bfa0d052c8a82f99c37abb"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:5b9c3f4ee0b9a439d2415012bd1b1cc2df59e4d6a9939f4d669241d30b414327"}, + {file = "kiwisolver-1.4.7-cp39-cp39-win32.whl", hash = "sha256:a79ae34384df2b615eefca647a2873842ac3b596418032bef9a7283675962644"}, + {file = "kiwisolver-1.4.7-cp39-cp39-win_amd64.whl", hash = "sha256:cf0438b42121a66a3a667de17e779330fc0f20b0d97d59d2f2121e182b0505e4"}, + {file = "kiwisolver-1.4.7-cp39-cp39-win_arm64.whl", hash = "sha256:764202cc7e70f767dab49e8df52c7455e8de0df5d858fa801a11aa0d882ccf3f"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:94252291e3fe68001b1dd747b4c0b3be12582839b95ad4d1b641924d68fd4643"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:5b7dfa3b546da08a9f622bb6becdb14b3e24aaa30adba66749d38f3cc7ea9706"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bd3de6481f4ed8b734da5df134cd5a6a64fe32124fe83dde1e5b5f29fe30b1e6"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a91b5f9f1205845d488c928e8570dcb62b893372f63b8b6e98b863ebd2368ff2"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40fa14dbd66b8b8f470d5fc79c089a66185619d31645f9b0773b88b19f7223c4"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:eb542fe7933aa09d8d8f9d9097ef37532a7df6497819d16efe4359890a2f417a"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:bfa1acfa0c54932d5607e19a2c24646fb4c1ae2694437789129cf099789a3b00"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:eee3ea935c3d227d49b4eb85660ff631556841f6e567f0f7bda972df6c2c9935"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f3160309af4396e0ed04db259c3ccbfdc3621b5559b5453075e5de555e1f3a1b"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a17f6a29cf8935e587cc8a4dbfc8368c55edc645283db0ce9801016f83526c2d"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:10849fb2c1ecbfae45a693c070e0320a91b35dd4bcf58172c023b994283a124d"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:ac542bf38a8a4be2dc6b15248d36315ccc65f0743f7b1a76688ffb6b5129a5c2"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:8b01aac285f91ca889c800042c35ad3b239e704b150cfd3382adfc9dcc780e39"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:48be928f59a1f5c8207154f935334d374e79f2b5d212826307d072595ad76a2e"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f37cfe618a117e50d8c240555331160d73d0411422b59b5ee217843d7b693608"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:599b5c873c63a1f6ed7eead644a8a380cfbdf5db91dcb6f85707aaab213b1674"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:801fa7802e5cfabe3ab0c81a34c323a319b097dfb5004be950482d882f3d7225"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:0c6c43471bc764fad4bc99c5c2d6d16a676b1abf844ca7c8702bdae92df01ee0"}, + {file = "kiwisolver-1.4.7.tar.gz", hash = "sha256:9893ff81bd7107f7b685d3017cc6583daadb4fc26e4a888350df530e41980a60"}, ] [[package]] name = "kombu" -version = "5.4.0" +version = "5.4.2" description = "Messaging library for Python." optional = false python-versions = ">=3.8" files = [ - {file = "kombu-5.4.0-py3-none-any.whl", hash = "sha256:c8dd99820467610b4febbc7a9e8a0d3d7da2d35116b67184418b51cc520ea6b6"}, - {file = "kombu-5.4.0.tar.gz", hash = "sha256:ad200a8dbdaaa2bbc5f26d2ee7d707d9a1fded353a0f4bd751ce8c7d9f449c60"}, + {file = "kombu-5.4.2-py3-none-any.whl", hash = "sha256:14212f5ccf022fc0a70453bb025a1dcc32782a588c49ea866884047d66e14763"}, + {file = "kombu-5.4.2.tar.gz", hash = "sha256:eef572dd2fd9fc614b37580e3caeafdd5af46c1eff31e7fba89138cdb406f2cf"}, ] [package.dependencies] amqp = ">=5.1.1,<6.0.0" +tzdata = {version = "*", markers = "python_version >= \"3.9\""} vine = "5.1.0" [package.extras] @@ -4083,7 +4567,7 @@ confluentkafka = ["confluent-kafka (>=2.2.0)"] consul = ["python-consul2 (==0.1.5)"] librabbitmq = ["librabbitmq (>=2.0.0)"] mongodb = ["pymongo (>=4.1.1)"] -msgpack = ["msgpack (==1.0.8)"] +msgpack = ["msgpack (==1.1.0)"] pyro = ["pyro4 (==4.82)"] qpid = ["qpid-python (>=0.26)", "qpid-tools (>=0.26)"] redis = ["redis (>=4.5.2,!=4.5.5,!=5.0.2)"] @@ -4095,17 +4579,18 @@ zookeeper = ["kazoo (>=2.8.0)"] [[package]] name = "kubernetes" -version = "30.1.0" +version = "31.0.0" description = "Kubernetes python client" optional = false python-versions = ">=3.6" files = [ - {file = "kubernetes-30.1.0-py2.py3-none-any.whl", hash = "sha256:e212e8b7579031dd2e512168b617373bc1e03888d41ac4e04039240a292d478d"}, - {file = "kubernetes-30.1.0.tar.gz", hash = "sha256:41e4c77af9f28e7a6c314e3bd06a8c6229ddd787cad684e0ab9f69b498e98ebc"}, + {file = "kubernetes-31.0.0-py2.py3-none-any.whl", hash = "sha256:bf141e2d380c8520eada8b351f4e319ffee9636328c137aa432bc486ca1200e1"}, + {file = "kubernetes-31.0.0.tar.gz", hash = "sha256:28945de906c8c259c1ebe62703b56a03b714049372196f854105afe4e6d014c0"}, ] [package.dependencies] certifi = ">=14.05.14" +durationpy = ">=0.7" google-auth = ">=1.0.1" oauthlib = ">=3.2.2" python-dateutil = ">=2.5.3" @@ -4135,13 +4620,13 @@ six = "*" [[package]] name = "langfuse" -version = "2.42.1" +version = "2.51.5" description = "A client library for accessing langfuse" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langfuse-2.42.1-py3-none-any.whl", hash = "sha256:8895d9645aea91815db51565f90e110a76d5e157a7b12eaf1cd6959e7aaa2263"}, - {file = "langfuse-2.42.1.tar.gz", hash = "sha256:f89faf1c14308d488c90f8b7d0368fff3d259f80ffe34d169b9cfc3f0dbfab82"}, + {file = "langfuse-2.51.5-py3-none-any.whl", hash = "sha256:b95401ca710ef94b521afa6541933b6f93d7cfd4a97523c8fc75bca4d6d219fb"}, + {file = "langfuse-2.51.5.tar.gz", hash = "sha256:55bc37b5c5d3ae133c1a95db09117cfb3117add110ba02ebbf2ce45ac4395c5b"}, ] [package.dependencies] @@ -4149,7 +4634,7 @@ anyio = ">=4.4.0,<5.0.0" backoff = ">=1.10.0" httpx = ">=0.15.4,<1.0" idna = ">=3.7,<4.0" -packaging = ">=23.2,<24.0" +packaging = ">=23.2,<25.0" pydantic = ">=1.10.7,<3.0" wrapt = ">=1.14,<2.0" @@ -4160,22 +4645,24 @@ openai = ["openai (>=0.27.8)"] [[package]] name = "langsmith" -version = "0.1.98" +version = "0.1.138" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.98-py3-none-any.whl", hash = "sha256:f79e8a128652bbcee4606d10acb6236973b5cd7dde76e3741186d3b97b5698e9"}, - {file = "langsmith-0.1.98.tar.gz", hash = "sha256:e07678219a0502e8f26d35294e72127a39d25e32fafd091af5a7bb661e9a6bd1"}, + {file = "langsmith-0.1.138-py3-none-any.whl", hash = "sha256:5c2bd5c11c75f7b3d06a0f06b115186e7326ca969fd26d66ffc65a0669012aee"}, + {file = "langsmith-0.1.138.tar.gz", hash = "sha256:1ecf613bb52f6bf17f1510e24ad8b70d4b0259bc9d3dbfd69b648c66d4644f0b"}, ] [package.dependencies] +httpx = ">=0.23.0,<1" orjson = ">=3.9.14,<4.0.0" pydantic = [ {version = ">=1,<3", markers = "python_full_version < \"3.12.4\""}, {version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""}, ] requests = ">=2,<3" +requests-toolbelt = ">=1.0.0,<2.0.0" [[package]] name = "llvmlite" @@ -4207,155 +4694,169 @@ files = [ {file = "llvmlite-0.43.0.tar.gz", hash = "sha256:ae2b5b5c3ef67354824fb75517c8db5fbe93bc02cd9671f3c62271626bc041d5"}, ] +[[package]] +name = "loguru" +version = "0.7.2" +description = "Python logging made (stupidly) simple" +optional = false +python-versions = ">=3.5" +files = [ + {file = "loguru-0.7.2-py3-none-any.whl", hash = "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb"}, + {file = "loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac"}, +] + +[package.dependencies] +colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""} +win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} + +[package.extras] +dev = ["Sphinx (==7.2.5)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.4.1)", "mypy (==v1.5.1)", "pre-commit (==3.4.0)", "pytest (==6.1.2)", "pytest (==7.4.0)", "pytest-cov (==2.12.1)", "pytest-cov (==4.1.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.0.0)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.3.0)", "tox (==3.27.1)", "tox (==4.11.0)"] + [[package]] name = "lxml" -version = "5.2.2" +version = "5.3.0" description = "Powerful and Pythonic XML processing library combining libxml2/libxslt with the ElementTree API." optional = false python-versions = ">=3.6" files = [ - {file = "lxml-5.2.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:364d03207f3e603922d0d3932ef363d55bbf48e3647395765f9bfcbdf6d23632"}, - {file = "lxml-5.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:50127c186f191b8917ea2fb8b206fbebe87fd414a6084d15568c27d0a21d60db"}, - {file = "lxml-5.2.2-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:74e4f025ef3db1c6da4460dd27c118d8cd136d0391da4e387a15e48e5c975147"}, - {file = "lxml-5.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:981a06a3076997adf7c743dcd0d7a0415582661e2517c7d961493572e909aa1d"}, - {file = "lxml-5.2.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aef5474d913d3b05e613906ba4090433c515e13ea49c837aca18bde190853dff"}, - {file = "lxml-5.2.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1e275ea572389e41e8b039ac076a46cb87ee6b8542df3fff26f5baab43713bca"}, - {file = "lxml-5.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5b65529bb2f21ac7861a0e94fdbf5dc0daab41497d18223b46ee8515e5ad297"}, - {file = "lxml-5.2.2-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:bcc98f911f10278d1daf14b87d65325851a1d29153caaf146877ec37031d5f36"}, - {file = "lxml-5.2.2-cp310-cp310-manylinux_2_28_ppc64le.whl", hash = "sha256:b47633251727c8fe279f34025844b3b3a3e40cd1b198356d003aa146258d13a2"}, - {file = "lxml-5.2.2-cp310-cp310-manylinux_2_28_s390x.whl", hash = "sha256:fbc9d316552f9ef7bba39f4edfad4a734d3d6f93341232a9dddadec4f15d425f"}, - {file = "lxml-5.2.2-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:13e69be35391ce72712184f69000cda04fc89689429179bc4c0ae5f0b7a8c21b"}, - {file = "lxml-5.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:3b6a30a9ab040b3f545b697cb3adbf3696c05a3a68aad172e3fd7ca73ab3c835"}, - {file = "lxml-5.2.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:a233bb68625a85126ac9f1fc66d24337d6e8a0f9207b688eec2e7c880f012ec0"}, - {file = "lxml-5.2.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:dfa7c241073d8f2b8e8dbc7803c434f57dbb83ae2a3d7892dd068d99e96efe2c"}, - {file = "lxml-5.2.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:1a7aca7964ac4bb07680d5c9d63b9d7028cace3e2d43175cb50bba8c5ad33316"}, - {file = "lxml-5.2.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ae4073a60ab98529ab8a72ebf429f2a8cc612619a8c04e08bed27450d52103c0"}, - {file = "lxml-5.2.2-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:ffb2be176fed4457e445fe540617f0252a72a8bc56208fd65a690fdb1f57660b"}, - {file = "lxml-5.2.2-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:e290d79a4107d7d794634ce3e985b9ae4f920380a813717adf61804904dc4393"}, - {file = "lxml-5.2.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:96e85aa09274955bb6bd483eaf5b12abadade01010478154b0ec70284c1b1526"}, - {file = "lxml-5.2.2-cp310-cp310-win32.whl", hash = "sha256:f956196ef61369f1685d14dad80611488d8dc1ef00be57c0c5a03064005b0f30"}, - {file = "lxml-5.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:875a3f90d7eb5c5d77e529080d95140eacb3c6d13ad5b616ee8095447b1d22e7"}, - {file = "lxml-5.2.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:45f9494613160d0405682f9eee781c7e6d1bf45f819654eb249f8f46a2c22545"}, - {file = "lxml-5.2.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b0b3f2df149efb242cee2ffdeb6674b7f30d23c9a7af26595099afaf46ef4e88"}, - {file = "lxml-5.2.2-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d28cb356f119a437cc58a13f8135ab8a4c8ece18159eb9194b0d269ec4e28083"}, - {file = "lxml-5.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:657a972f46bbefdbba2d4f14413c0d079f9ae243bd68193cb5061b9732fa54c1"}, - {file = "lxml-5.2.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b74b9ea10063efb77a965a8d5f4182806fbf59ed068b3c3fd6f30d2ac7bee734"}, - {file = "lxml-5.2.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:07542787f86112d46d07d4f3c4e7c760282011b354d012dc4141cc12a68cef5f"}, - {file = "lxml-5.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:303f540ad2dddd35b92415b74b900c749ec2010e703ab3bfd6660979d01fd4ed"}, - {file = "lxml-5.2.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2eb2227ce1ff998faf0cd7fe85bbf086aa41dfc5af3b1d80867ecfe75fb68df3"}, - {file = "lxml-5.2.2-cp311-cp311-manylinux_2_28_ppc64le.whl", hash = "sha256:1d8a701774dfc42a2f0b8ccdfe7dbc140500d1049e0632a611985d943fcf12df"}, - {file = "lxml-5.2.2-cp311-cp311-manylinux_2_28_s390x.whl", hash = "sha256:56793b7a1a091a7c286b5f4aa1fe4ae5d1446fe742d00cdf2ffb1077865db10d"}, - {file = "lxml-5.2.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:eb00b549b13bd6d884c863554566095bf6fa9c3cecb2e7b399c4bc7904cb33b5"}, - {file = "lxml-5.2.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1a2569a1f15ae6c8c64108a2cd2b4a858fc1e13d25846be0666fc144715e32ab"}, - {file = "lxml-5.2.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:8cf85a6e40ff1f37fe0f25719aadf443686b1ac7652593dc53c7ef9b8492b115"}, - {file = "lxml-5.2.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:d237ba6664b8e60fd90b8549a149a74fcc675272e0e95539a00522e4ca688b04"}, - {file = "lxml-5.2.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0b3f5016e00ae7630a4b83d0868fca1e3d494c78a75b1c7252606a3a1c5fc2ad"}, - {file = "lxml-5.2.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:23441e2b5339bc54dc949e9e675fa35efe858108404ef9aa92f0456929ef6fe8"}, - {file = "lxml-5.2.2-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:2fb0ba3e8566548d6c8e7dd82a8229ff47bd8fb8c2da237607ac8e5a1b8312e5"}, - {file = "lxml-5.2.2-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:79d1fb9252e7e2cfe4de6e9a6610c7cbb99b9708e2c3e29057f487de5a9eaefa"}, - {file = "lxml-5.2.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:6dcc3d17eac1df7859ae01202e9bb11ffa8c98949dcbeb1069c8b9a75917e01b"}, - {file = "lxml-5.2.2-cp311-cp311-win32.whl", hash = "sha256:4c30a2f83677876465f44c018830f608fa3c6a8a466eb223535035fbc16f3438"}, - {file = "lxml-5.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:49095a38eb333aaf44c06052fd2ec3b8f23e19747ca7ec6f6c954ffea6dbf7be"}, - {file = "lxml-5.2.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:7429e7faa1a60cad26ae4227f4dd0459efde239e494c7312624ce228e04f6391"}, - {file = "lxml-5.2.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:50ccb5d355961c0f12f6cf24b7187dbabd5433f29e15147a67995474f27d1776"}, - {file = "lxml-5.2.2-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dc911208b18842a3a57266d8e51fc3cfaccee90a5351b92079beed912a7914c2"}, - {file = "lxml-5.2.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:33ce9e786753743159799fdf8e92a5da351158c4bfb6f2db0bf31e7892a1feb5"}, - {file = "lxml-5.2.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ec87c44f619380878bd49ca109669c9f221d9ae6883a5bcb3616785fa8f94c97"}, - {file = "lxml-5.2.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:08ea0f606808354eb8f2dfaac095963cb25d9d28e27edcc375d7b30ab01abbf6"}, - {file = "lxml-5.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75a9632f1d4f698b2e6e2e1ada40e71f369b15d69baddb8968dcc8e683839b18"}, - {file = "lxml-5.2.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:74da9f97daec6928567b48c90ea2c82a106b2d500f397eeb8941e47d30b1ca85"}, - {file = "lxml-5.2.2-cp312-cp312-manylinux_2_28_ppc64le.whl", hash = "sha256:0969e92af09c5687d769731e3f39ed62427cc72176cebb54b7a9d52cc4fa3b73"}, - {file = "lxml-5.2.2-cp312-cp312-manylinux_2_28_s390x.whl", hash = "sha256:9164361769b6ca7769079f4d426a41df6164879f7f3568be9086e15baca61466"}, - {file = "lxml-5.2.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:d26a618ae1766279f2660aca0081b2220aca6bd1aa06b2cf73f07383faf48927"}, - {file = "lxml-5.2.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab67ed772c584b7ef2379797bf14b82df9aa5f7438c5b9a09624dd834c1c1aaf"}, - {file = "lxml-5.2.2-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:3d1e35572a56941b32c239774d7e9ad724074d37f90c7a7d499ab98761bd80cf"}, - {file = "lxml-5.2.2-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:8268cbcd48c5375f46e000adb1390572c98879eb4f77910c6053d25cc3ac2c67"}, - {file = "lxml-5.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e282aedd63c639c07c3857097fc0e236f984ceb4089a8b284da1c526491e3f3d"}, - {file = "lxml-5.2.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6dfdc2bfe69e9adf0df4915949c22a25b39d175d599bf98e7ddf620a13678585"}, - {file = "lxml-5.2.2-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4aefd911793b5d2d7a921233a54c90329bf3d4a6817dc465f12ffdfe4fc7b8fe"}, - {file = "lxml-5.2.2-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:8b8df03a9e995b6211dafa63b32f9d405881518ff1ddd775db4e7b98fb545e1c"}, - {file = "lxml-5.2.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:f11ae142f3a322d44513de1018b50f474f8f736bc3cd91d969f464b5bfef8836"}, - {file = "lxml-5.2.2-cp312-cp312-win32.whl", hash = "sha256:16a8326e51fcdffc886294c1e70b11ddccec836516a343f9ed0f82aac043c24a"}, - {file = "lxml-5.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:bbc4b80af581e18568ff07f6395c02114d05f4865c2812a1f02f2eaecf0bfd48"}, - {file = "lxml-5.2.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:e3d9d13603410b72787579769469af730c38f2f25505573a5888a94b62b920f8"}, - {file = "lxml-5.2.2-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:38b67afb0a06b8575948641c1d6d68e41b83a3abeae2ca9eed2ac59892b36706"}, - {file = "lxml-5.2.2-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c689d0d5381f56de7bd6966a4541bff6e08bf8d3871bbd89a0c6ab18aa699573"}, - {file = "lxml-5.2.2-cp36-cp36m-manylinux_2_28_x86_64.whl", hash = "sha256:cf2a978c795b54c539f47964ec05e35c05bd045db5ca1e8366988c7f2fe6b3ce"}, - {file = "lxml-5.2.2-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:739e36ef7412b2bd940f75b278749106e6d025e40027c0b94a17ef7968d55d56"}, - {file = "lxml-5.2.2-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:d8bbcd21769594dbba9c37d3c819e2d5847656ca99c747ddb31ac1701d0c0ed9"}, - {file = "lxml-5.2.2-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:2304d3c93f2258ccf2cf7a6ba8c761d76ef84948d87bf9664e14d203da2cd264"}, - {file = "lxml-5.2.2-cp36-cp36m-win32.whl", hash = "sha256:02437fb7308386867c8b7b0e5bc4cd4b04548b1c5d089ffb8e7b31009b961dc3"}, - {file = "lxml-5.2.2-cp36-cp36m-win_amd64.whl", hash = "sha256:edcfa83e03370032a489430215c1e7783128808fd3e2e0a3225deee278585196"}, - {file = "lxml-5.2.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:28bf95177400066596cdbcfc933312493799382879da504633d16cf60bba735b"}, - {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a745cc98d504d5bd2c19b10c79c61c7c3df9222629f1b6210c0368177589fb8"}, - {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b590b39ef90c6b22ec0be925b211298e810b4856909c8ca60d27ffbca6c12e6"}, - {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b336b0416828022bfd5a2e3083e7f5ba54b96242159f83c7e3eebaec752f1716"}, - {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:c2faf60c583af0d135e853c86ac2735ce178f0e338a3c7f9ae8f622fd2eb788c"}, - {file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:4bc6cb140a7a0ad1f7bc37e018d0ed690b7b6520ade518285dc3171f7a117905"}, - {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7ff762670cada8e05b32bf1e4dc50b140790909caa8303cfddc4d702b71ea184"}, - {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:57f0a0bbc9868e10ebe874e9f129d2917750adf008fe7b9c1598c0fbbfdde6a6"}, - {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:a6d2092797b388342c1bc932077ad232f914351932353e2e8706851c870bca1f"}, - {file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:60499fe961b21264e17a471ec296dcbf4365fbea611bf9e303ab69db7159ce61"}, - {file = "lxml-5.2.2-cp37-cp37m-win32.whl", hash = "sha256:d9b342c76003c6b9336a80efcc766748a333573abf9350f4094ee46b006ec18f"}, - {file = "lxml-5.2.2-cp37-cp37m-win_amd64.whl", hash = "sha256:b16db2770517b8799c79aa80f4053cd6f8b716f21f8aca962725a9565ce3ee40"}, - {file = "lxml-5.2.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7ed07b3062b055d7a7f9d6557a251cc655eed0b3152b76de619516621c56f5d3"}, - {file = "lxml-5.2.2-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f60fdd125d85bf9c279ffb8e94c78c51b3b6a37711464e1f5f31078b45002421"}, - {file = "lxml-5.2.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8a7e24cb69ee5f32e003f50e016d5fde438010c1022c96738b04fc2423e61706"}, - {file = "lxml-5.2.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23cfafd56887eaed93d07bc4547abd5e09d837a002b791e9767765492a75883f"}, - {file = "lxml-5.2.2-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:19b4e485cd07b7d83e3fe3b72132e7df70bfac22b14fe4bf7a23822c3a35bff5"}, - {file = "lxml-5.2.2-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:7ce7ad8abebe737ad6143d9d3bf94b88b93365ea30a5b81f6877ec9c0dee0a48"}, - {file = "lxml-5.2.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:e49b052b768bb74f58c7dda4e0bdf7b79d43a9204ca584ffe1fb48a6f3c84c66"}, - {file = "lxml-5.2.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d14a0d029a4e176795cef99c056d58067c06195e0c7e2dbb293bf95c08f772a3"}, - {file = "lxml-5.2.2-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:be49ad33819d7dcc28a309b86d4ed98e1a65f3075c6acd3cd4fe32103235222b"}, - {file = "lxml-5.2.2-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:a6d17e0370d2516d5bb9062c7b4cb731cff921fc875644c3d751ad857ba9c5b1"}, - {file = "lxml-5.2.2-cp38-cp38-win32.whl", hash = "sha256:5b8c041b6265e08eac8a724b74b655404070b636a8dd6d7a13c3adc07882ef30"}, - {file = "lxml-5.2.2-cp38-cp38-win_amd64.whl", hash = "sha256:f61efaf4bed1cc0860e567d2ecb2363974d414f7f1f124b1df368bbf183453a6"}, - {file = "lxml-5.2.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:fb91819461b1b56d06fa4bcf86617fac795f6a99d12239fb0c68dbeba41a0a30"}, - {file = "lxml-5.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d4ed0c7cbecde7194cd3228c044e86bf73e30a23505af852857c09c24e77ec5d"}, - {file = "lxml-5.2.2-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54401c77a63cc7d6dc4b4e173bb484f28a5607f3df71484709fe037c92d4f0ed"}, - {file = "lxml-5.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:625e3ef310e7fa3a761d48ca7ea1f9d8718a32b1542e727d584d82f4453d5eeb"}, - {file = "lxml-5.2.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:519895c99c815a1a24a926d5b60627ce5ea48e9f639a5cd328bda0515ea0f10c"}, - {file = "lxml-5.2.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c7079d5eb1c1315a858bbf180000757db8ad904a89476653232db835c3114001"}, - {file = "lxml-5.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:343ab62e9ca78094f2306aefed67dcfad61c4683f87eee48ff2fd74902447726"}, - {file = "lxml-5.2.2-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:cd9e78285da6c9ba2d5c769628f43ef66d96ac3085e59b10ad4f3707980710d3"}, - {file = "lxml-5.2.2-cp39-cp39-manylinux_2_28_ppc64le.whl", hash = "sha256:546cf886f6242dff9ec206331209db9c8e1643ae642dea5fdbecae2453cb50fd"}, - {file = "lxml-5.2.2-cp39-cp39-manylinux_2_28_s390x.whl", hash = "sha256:02f6a8eb6512fdc2fd4ca10a49c341c4e109aa6e9448cc4859af5b949622715a"}, - {file = "lxml-5.2.2-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:339ee4a4704bc724757cd5dd9dc8cf4d00980f5d3e6e06d5847c1b594ace68ab"}, - {file = "lxml-5.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0a028b61a2e357ace98b1615fc03f76eb517cc028993964fe08ad514b1e8892d"}, - {file = "lxml-5.2.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:f90e552ecbad426eab352e7b2933091f2be77115bb16f09f78404861c8322981"}, - {file = "lxml-5.2.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:d83e2d94b69bf31ead2fa45f0acdef0757fa0458a129734f59f67f3d2eb7ef32"}, - {file = "lxml-5.2.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a02d3c48f9bb1e10c7788d92c0c7db6f2002d024ab6e74d6f45ae33e3d0288a3"}, - {file = "lxml-5.2.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:6d68ce8e7b2075390e8ac1e1d3a99e8b6372c694bbe612632606d1d546794207"}, - {file = "lxml-5.2.2-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:453d037e09a5176d92ec0fd282e934ed26d806331a8b70ab431a81e2fbabf56d"}, - {file = "lxml-5.2.2-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:3b019d4ee84b683342af793b56bb35034bd749e4cbdd3d33f7d1107790f8c472"}, - {file = "lxml-5.2.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:cb3942960f0beb9f46e2a71a3aca220d1ca32feb5a398656be934320804c0df9"}, - {file = "lxml-5.2.2-cp39-cp39-win32.whl", hash = "sha256:ac6540c9fff6e3813d29d0403ee7a81897f1d8ecc09a8ff84d2eea70ede1cdbf"}, - {file = "lxml-5.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:610b5c77428a50269f38a534057444c249976433f40f53e3b47e68349cca1425"}, - {file = "lxml-5.2.2-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:b537bd04d7ccd7c6350cdaaaad911f6312cbd61e6e6045542f781c7f8b2e99d2"}, - {file = "lxml-5.2.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4820c02195d6dfb7b8508ff276752f6b2ff8b64ae5d13ebe02e7667e035000b9"}, - {file = "lxml-5.2.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2a09f6184f17a80897172863a655467da2b11151ec98ba8d7af89f17bf63dae"}, - {file = "lxml-5.2.2-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:76acba4c66c47d27c8365e7c10b3d8016a7da83d3191d053a58382311a8bf4e1"}, - {file = "lxml-5.2.2-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:b128092c927eaf485928cec0c28f6b8bead277e28acf56800e972aa2c2abd7a2"}, - {file = "lxml-5.2.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:ae791f6bd43305aade8c0e22f816b34f3b72b6c820477aab4d18473a37e8090b"}, - {file = "lxml-5.2.2-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a2f6a1bc2460e643785a2cde17293bd7a8f990884b822f7bca47bee0a82fc66b"}, - {file = "lxml-5.2.2-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e8d351ff44c1638cb6e980623d517abd9f580d2e53bfcd18d8941c052a5a009"}, - {file = "lxml-5.2.2-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bec4bd9133420c5c52d562469c754f27c5c9e36ee06abc169612c959bd7dbb07"}, - {file = "lxml-5.2.2-pp37-pypy37_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:55ce6b6d803890bd3cc89975fca9de1dff39729b43b73cb15ddd933b8bc20484"}, - {file = "lxml-5.2.2-pp37-pypy37_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:8ab6a358d1286498d80fe67bd3d69fcbc7d1359b45b41e74c4a26964ca99c3f8"}, - {file = "lxml-5.2.2-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:06668e39e1f3c065349c51ac27ae430719d7806c026fec462e5693b08b95696b"}, - {file = "lxml-5.2.2-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:9cd5323344d8ebb9fb5e96da5de5ad4ebab993bbf51674259dbe9d7a18049525"}, - {file = "lxml-5.2.2-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89feb82ca055af0fe797a2323ec9043b26bc371365847dbe83c7fd2e2f181c34"}, - {file = "lxml-5.2.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e481bba1e11ba585fb06db666bfc23dbe181dbafc7b25776156120bf12e0d5a6"}, - {file = "lxml-5.2.2-pp38-pypy38_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:9d6c6ea6a11ca0ff9cd0390b885984ed31157c168565702959c25e2191674a14"}, - {file = "lxml-5.2.2-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:3d98de734abee23e61f6b8c2e08a88453ada7d6486dc7cdc82922a03968928db"}, - {file = "lxml-5.2.2-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:69ab77a1373f1e7563e0fb5a29a8440367dec051da6c7405333699d07444f511"}, - {file = "lxml-5.2.2-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:34e17913c431f5ae01d8658dbf792fdc457073dcdfbb31dc0cc6ab256e664a8d"}, - {file = "lxml-5.2.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05f8757b03208c3f50097761be2dea0aba02e94f0dc7023ed73a7bb14ff11eb0"}, - {file = "lxml-5.2.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6a520b4f9974b0a0a6ed73c2154de57cdfd0c8800f4f15ab2b73238ffed0b36e"}, - {file = "lxml-5.2.2-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:5e097646944b66207023bc3c634827de858aebc226d5d4d6d16f0b77566ea182"}, - {file = "lxml-5.2.2-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:b5e4ef22ff25bfd4ede5f8fb30f7b24446345f3e79d9b7455aef2836437bc38a"}, - {file = "lxml-5.2.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:ff69a9a0b4b17d78170c73abe2ab12084bdf1691550c5629ad1fe7849433f324"}, - {file = "lxml-5.2.2.tar.gz", hash = "sha256:bb2dc4898180bea79863d5487e5f9c7c34297414bad54bcd0f0852aee9cfdb87"}, + {file = "lxml-5.3.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:dd36439be765e2dde7660212b5275641edbc813e7b24668831a5c8ac91180656"}, + {file = "lxml-5.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ae5fe5c4b525aa82b8076c1a59d642c17b6e8739ecf852522c6321852178119d"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:501d0d7e26b4d261fca8132854d845e4988097611ba2531408ec91cf3fd9d20a"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb66442c2546446944437df74379e9cf9e9db353e61301d1a0e26482f43f0dd8"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9e41506fec7a7f9405b14aa2d5c8abbb4dbbd09d88f9496958b6d00cb4d45330"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f7d4a670107d75dfe5ad080bed6c341d18c4442f9378c9f58e5851e86eb79965"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41ce1f1e2c7755abfc7e759dc34d7d05fd221723ff822947132dc934d122fe22"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:44264ecae91b30e5633013fb66f6ddd05c006d3e0e884f75ce0b4755b3e3847b"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_28_ppc64le.whl", hash = "sha256:3c174dc350d3ec52deb77f2faf05c439331d6ed5e702fc247ccb4e6b62d884b7"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_28_s390x.whl", hash = "sha256:2dfab5fa6a28a0b60a20638dc48e6343c02ea9933e3279ccb132f555a62323d8"}, + {file = "lxml-5.3.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b1c8c20847b9f34e98080da785bb2336ea982e7f913eed5809e5a3c872900f32"}, + {file = "lxml-5.3.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2c86bf781b12ba417f64f3422cfc302523ac9cd1d8ae8c0f92a1c66e56ef2e86"}, + {file = "lxml-5.3.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:c162b216070f280fa7da844531169be0baf9ccb17263cf5a8bf876fcd3117fa5"}, + {file = "lxml-5.3.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:36aef61a1678cb778097b4a6eeae96a69875d51d1e8f4d4b491ab3cfb54b5a03"}, + {file = "lxml-5.3.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f65e5120863c2b266dbcc927b306c5b78e502c71edf3295dfcb9501ec96e5fc7"}, + {file = "lxml-5.3.0-cp310-cp310-win32.whl", hash = "sha256:ef0c1fe22171dd7c7c27147f2e9c3e86f8bdf473fed75f16b0c2e84a5030ce80"}, + {file = "lxml-5.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:052d99051e77a4f3e8482c65014cf6372e61b0a6f4fe9edb98503bb5364cfee3"}, + {file = "lxml-5.3.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:74bcb423462233bc5d6066e4e98b0264e7c1bed7541fff2f4e34fe6b21563c8b"}, + {file = "lxml-5.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a3d819eb6f9b8677f57f9664265d0a10dd6551d227afb4af2b9cd7bdc2ccbf18"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5b8f5db71b28b8c404956ddf79575ea77aa8b1538e8b2ef9ec877945b3f46442"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c3406b63232fc7e9b8783ab0b765d7c59e7c59ff96759d8ef9632fca27c7ee4"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2ecdd78ab768f844c7a1d4a03595038c166b609f6395e25af9b0f3f26ae1230f"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:168f2dfcfdedf611eb285efac1516c8454c8c99caf271dccda8943576b67552e"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa617107a410245b8660028a7483b68e7914304a6d4882b5ff3d2d3eb5948d8c"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:69959bd3167b993e6e710b99051265654133a98f20cec1d9b493b931942e9c16"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_28_ppc64le.whl", hash = "sha256:bd96517ef76c8654446fc3db9242d019a1bb5fe8b751ba414765d59f99210b79"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_28_s390x.whl", hash = "sha256:ab6dd83b970dc97c2d10bc71aa925b84788c7c05de30241b9e96f9b6d9ea3080"}, + {file = "lxml-5.3.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:eec1bb8cdbba2925bedc887bc0609a80e599c75b12d87ae42ac23fd199445654"}, + {file = "lxml-5.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6a7095eeec6f89111d03dabfe5883a1fd54da319c94e0fb104ee8f23616b572d"}, + {file = "lxml-5.3.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:6f651ebd0b21ec65dfca93aa629610a0dbc13dbc13554f19b0113da2e61a4763"}, + {file = "lxml-5.3.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:f422a209d2455c56849442ae42f25dbaaba1c6c3f501d58761c619c7836642ec"}, + {file = "lxml-5.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:62f7fdb0d1ed2065451f086519865b4c90aa19aed51081979ecd05a21eb4d1be"}, + {file = "lxml-5.3.0-cp311-cp311-win32.whl", hash = "sha256:c6379f35350b655fd817cd0d6cbeef7f265f3ae5fedb1caae2eb442bbeae9ab9"}, + {file = "lxml-5.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:9c52100e2c2dbb0649b90467935c4b0de5528833c76a35ea1a2691ec9f1ee7a1"}, + {file = "lxml-5.3.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:e99f5507401436fdcc85036a2e7dc2e28d962550afe1cbfc07c40e454256a859"}, + {file = "lxml-5.3.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:384aacddf2e5813a36495233b64cb96b1949da72bef933918ba5c84e06af8f0e"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:874a216bf6afaf97c263b56371434e47e2c652d215788396f60477540298218f"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:65ab5685d56914b9a2a34d67dd5488b83213d680b0c5d10b47f81da5a16b0b0e"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aac0bbd3e8dd2d9c45ceb82249e8bdd3ac99131a32b4d35c8af3cc9db1657179"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b369d3db3c22ed14c75ccd5af429086f166a19627e84a8fdade3f8f31426e52a"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c24037349665434f375645fa9d1f5304800cec574d0310f618490c871fd902b3"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:62d172f358f33a26d6b41b28c170c63886742f5b6772a42b59b4f0fa10526cb1"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_28_ppc64le.whl", hash = "sha256:c1f794c02903c2824fccce5b20c339a1a14b114e83b306ff11b597c5f71a1c8d"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_28_s390x.whl", hash = "sha256:5d6a6972b93c426ace71e0be9a6f4b2cfae9b1baed2eed2006076a746692288c"}, + {file = "lxml-5.3.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:3879cc6ce938ff4eb4900d901ed63555c778731a96365e53fadb36437a131a99"}, + {file = "lxml-5.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:74068c601baff6ff021c70f0935b0c7bc528baa8ea210c202e03757c68c5a4ff"}, + {file = "lxml-5.3.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ecd4ad8453ac17bc7ba3868371bffb46f628161ad0eefbd0a855d2c8c32dd81a"}, + {file = "lxml-5.3.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:7e2f58095acc211eb9d8b5771bf04df9ff37d6b87618d1cbf85f92399c98dae8"}, + {file = "lxml-5.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e63601ad5cd8f860aa99d109889b5ac34de571c7ee902d6812d5d9ddcc77fa7d"}, + {file = "lxml-5.3.0-cp312-cp312-win32.whl", hash = "sha256:17e8d968d04a37c50ad9c456a286b525d78c4a1c15dd53aa46c1d8e06bf6fa30"}, + {file = "lxml-5.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:c1a69e58a6bb2de65902051d57fde951febad631a20a64572677a1052690482f"}, + {file = "lxml-5.3.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8c72e9563347c7395910de6a3100a4840a75a6f60e05af5e58566868d5eb2d6a"}, + {file = "lxml-5.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e92ce66cd919d18d14b3856906a61d3f6b6a8500e0794142338da644260595cd"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1d04f064bebdfef9240478f7a779e8c5dc32b8b7b0b2fc6a62e39b928d428e51"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c2fb570d7823c2bbaf8b419ba6e5662137f8166e364a8b2b91051a1fb40ab8b"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0c120f43553ec759f8de1fee2f4794452b0946773299d44c36bfe18e83caf002"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:562e7494778a69086f0312ec9689f6b6ac1c6b65670ed7d0267e49f57ffa08c4"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:423b121f7e6fa514ba0c7918e56955a1d4470ed35faa03e3d9f0e3baa4c7e492"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:c00f323cc00576df6165cc9d21a4c21285fa6b9989c5c39830c3903dc4303ef3"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_28_ppc64le.whl", hash = "sha256:1fdc9fae8dd4c763e8a31e7630afef517eab9f5d5d31a278df087f307bf601f4"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_28_s390x.whl", hash = "sha256:658f2aa69d31e09699705949b5fc4719cbecbd4a97f9656a232e7d6c7be1a367"}, + {file = "lxml-5.3.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:1473427aff3d66a3fa2199004c3e601e6c4500ab86696edffdbc84954c72d832"}, + {file = "lxml-5.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a87de7dd873bf9a792bf1e58b1c3887b9264036629a5bf2d2e6579fe8e73edff"}, + {file = "lxml-5.3.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:0d7b36afa46c97875303a94e8f3ad932bf78bace9e18e603f2085b652422edcd"}, + {file = "lxml-5.3.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:cf120cce539453ae086eacc0130a324e7026113510efa83ab42ef3fcfccac7fb"}, + {file = "lxml-5.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:df5c7333167b9674aa8ae1d4008fa4bc17a313cc490b2cca27838bbdcc6bb15b"}, + {file = "lxml-5.3.0-cp313-cp313-win32.whl", hash = "sha256:c802e1c2ed9f0c06a65bc4ed0189d000ada8049312cfeab6ca635e39c9608957"}, + {file = "lxml-5.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:406246b96d552e0503e17a1006fd27edac678b3fcc9f1be71a2f94b4ff61528d"}, + {file = "lxml-5.3.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:8f0de2d390af441fe8b2c12626d103540b5d850d585b18fcada58d972b74a74e"}, + {file = "lxml-5.3.0-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1afe0a8c353746e610bd9031a630a95bcfb1a720684c3f2b36c4710a0a96528f"}, + {file = "lxml-5.3.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56b9861a71575f5795bde89256e7467ece3d339c9b43141dbdd54544566b3b94"}, + {file = "lxml-5.3.0-cp36-cp36m-manylinux_2_28_x86_64.whl", hash = "sha256:9fb81d2824dff4f2e297a276297e9031f46d2682cafc484f49de182aa5e5df99"}, + {file = "lxml-5.3.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:2c226a06ecb8cdef28845ae976da407917542c5e6e75dcac7cc33eb04aaeb237"}, + {file = "lxml-5.3.0-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:7d3d1ca42870cdb6d0d29939630dbe48fa511c203724820fc0fd507b2fb46577"}, + {file = "lxml-5.3.0-cp36-cp36m-win32.whl", hash = "sha256:094cb601ba9f55296774c2d57ad68730daa0b13dc260e1f941b4d13678239e70"}, + {file = "lxml-5.3.0-cp36-cp36m-win_amd64.whl", hash = "sha256:eafa2c8658f4e560b098fe9fc54539f86528651f61849b22111a9b107d18910c"}, + {file = "lxml-5.3.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:cb83f8a875b3d9b458cada4f880fa498646874ba4011dc974e071a0a84a1b033"}, + {file = "lxml-5.3.0-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:25f1b69d41656b05885aa185f5fdf822cb01a586d1b32739633679699f220391"}, + {file = "lxml-5.3.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23e0553b8055600b3bf4a00b255ec5c92e1e4aebf8c2c09334f8368e8bd174d6"}, + {file = "lxml-5.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ada35dd21dc6c039259596b358caab6b13f4db4d4a7f8665764d616daf9cc1d"}, + {file = "lxml-5.3.0-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:81b4e48da4c69313192d8c8d4311e5d818b8be1afe68ee20f6385d0e96fc9512"}, + {file = "lxml-5.3.0-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:2bc9fd5ca4729af796f9f59cd8ff160fe06a474da40aca03fcc79655ddee1a8b"}, + {file = "lxml-5.3.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:07da23d7ee08577760f0a71d67a861019103e4812c87e2fab26b039054594cc5"}, + {file = "lxml-5.3.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:ea2e2f6f801696ad7de8aec061044d6c8c0dd4037608c7cab38a9a4d316bfb11"}, + {file = "lxml-5.3.0-cp37-cp37m-win32.whl", hash = "sha256:5c54afdcbb0182d06836cc3d1be921e540be3ebdf8b8a51ee3ef987537455f84"}, + {file = "lxml-5.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:f2901429da1e645ce548bf9171784c0f74f0718c3f6150ce166be39e4dd66c3e"}, + {file = "lxml-5.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c56a1d43b2f9ee4786e4658c7903f05da35b923fb53c11025712562d5cc02753"}, + {file = "lxml-5.3.0-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ee8c39582d2652dcd516d1b879451500f8db3fe3607ce45d7c5957ab2596040"}, + {file = "lxml-5.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0fdf3a3059611f7585a78ee10399a15566356116a4288380921a4b598d807a22"}, + {file = "lxml-5.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:146173654d79eb1fc97498b4280c1d3e1e5d58c398fa530905c9ea50ea849b22"}, + {file = "lxml-5.3.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0a7056921edbdd7560746f4221dca89bb7a3fe457d3d74267995253f46343f15"}, + {file = "lxml-5.3.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:9e4b47ac0f5e749cfc618efdf4726269441014ae1d5583e047b452a32e221920"}, + {file = "lxml-5.3.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:f914c03e6a31deb632e2daa881fe198461f4d06e57ac3d0e05bbcab8eae01945"}, + {file = "lxml-5.3.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:213261f168c5e1d9b7535a67e68b1f59f92398dd17a56d934550837143f79c42"}, + {file = "lxml-5.3.0-cp38-cp38-win32.whl", hash = "sha256:218c1b2e17a710e363855594230f44060e2025b05c80d1f0661258142b2add2e"}, + {file = "lxml-5.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:315f9542011b2c4e1d280e4a20ddcca1761993dda3afc7a73b01235f8641e903"}, + {file = "lxml-5.3.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1ffc23010330c2ab67fac02781df60998ca8fe759e8efde6f8b756a20599c5de"}, + {file = "lxml-5.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2b3778cb38212f52fac9fe913017deea2fdf4eb1a4f8e4cfc6b009a13a6d3fcc"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4b0c7a688944891086ba192e21c5229dea54382f4836a209ff8d0a660fac06be"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:747a3d3e98e24597981ca0be0fd922aebd471fa99d0043a3842d00cdcad7ad6a"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:86a6b24b19eaebc448dc56b87c4865527855145d851f9fc3891673ff97950540"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b11a5d918a6216e521c715b02749240fb07ae5a1fefd4b7bf12f833bc8b4fe70"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:68b87753c784d6acb8a25b05cb526c3406913c9d988d51f80adecc2b0775d6aa"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:109fa6fede314cc50eed29e6e56c540075e63d922455346f11e4d7a036d2b8cf"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_28_ppc64le.whl", hash = "sha256:02ced472497b8362c8e902ade23e3300479f4f43e45f4105c85ef43b8db85229"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_28_s390x.whl", hash = "sha256:6b038cc86b285e4f9fea2ba5ee76e89f21ed1ea898e287dc277a25884f3a7dfe"}, + {file = "lxml-5.3.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:7437237c6a66b7ca341e868cda48be24b8701862757426852c9b3186de1da8a2"}, + {file = "lxml-5.3.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:7f41026c1d64043a36fda21d64c5026762d53a77043e73e94b71f0521939cc71"}, + {file = "lxml-5.3.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:482c2f67761868f0108b1743098640fbb2a28a8e15bf3f47ada9fa59d9fe08c3"}, + {file = "lxml-5.3.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:1483fd3358963cc5c1c9b122c80606a3a79ee0875bcac0204149fa09d6ff2727"}, + {file = "lxml-5.3.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:2dec2d1130a9cda5b904696cec33b2cfb451304ba9081eeda7f90f724097300a"}, + {file = "lxml-5.3.0-cp39-cp39-win32.whl", hash = "sha256:a0eabd0a81625049c5df745209dc7fcef6e2aea7793e5f003ba363610aa0a3ff"}, + {file = "lxml-5.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:89e043f1d9d341c52bf2af6d02e6adde62e0a46e6755d5eb60dc6e4f0b8aeca2"}, + {file = "lxml-5.3.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:7b1cd427cb0d5f7393c31b7496419da594fe600e6fdc4b105a54f82405e6626c"}, + {file = "lxml-5.3.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51806cfe0279e06ed8500ce19479d757db42a30fd509940b1701be9c86a5ff9a"}, + {file = "lxml-5.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ee70d08fd60c9565ba8190f41a46a54096afa0eeb8f76bd66f2c25d3b1b83005"}, + {file = "lxml-5.3.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:8dc2c0395bea8254d8daebc76dcf8eb3a95ec2a46fa6fae5eaccee366bfe02ce"}, + {file = "lxml-5.3.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:6ba0d3dcac281aad8a0e5b14c7ed6f9fa89c8612b47939fc94f80b16e2e9bc83"}, + {file = "lxml-5.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:6e91cf736959057f7aac7adfc83481e03615a8e8dd5758aa1d95ea69e8931dba"}, + {file = "lxml-5.3.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:94d6c3782907b5e40e21cadf94b13b0842ac421192f26b84c45f13f3c9d5dc27"}, + {file = "lxml-5.3.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c300306673aa0f3ed5ed9372b21867690a17dba38c68c44b287437c362ce486b"}, + {file = "lxml-5.3.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78d9b952e07aed35fe2e1a7ad26e929595412db48535921c5013edc8aa4a35ce"}, + {file = "lxml-5.3.0-pp37-pypy37_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:01220dca0d066d1349bd6a1726856a78f7929f3878f7e2ee83c296c69495309e"}, + {file = "lxml-5.3.0-pp37-pypy37_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:2d9b8d9177afaef80c53c0a9e30fa252ff3036fb1c6494d427c066a4ce6a282f"}, + {file = "lxml-5.3.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:20094fc3f21ea0a8669dc4c61ed7fa8263bd37d97d93b90f28fc613371e7a875"}, + {file = "lxml-5.3.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:ace2c2326a319a0bb8a8b0e5b570c764962e95818de9f259ce814ee666603f19"}, + {file = "lxml-5.3.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:92e67a0be1639c251d21e35fe74df6bcc40cba445c2cda7c4a967656733249e2"}, + {file = "lxml-5.3.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd5350b55f9fecddc51385463a4f67a5da829bc741e38cf689f38ec9023f54ab"}, + {file = "lxml-5.3.0-pp38-pypy38_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:4c1fefd7e3d00921c44dc9ca80a775af49698bbfd92ea84498e56acffd4c5469"}, + {file = "lxml-5.3.0-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:71a8dd38fbd2f2319136d4ae855a7078c69c9a38ae06e0c17c73fd70fc6caad8"}, + {file = "lxml-5.3.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:97acf1e1fd66ab53dacd2c35b319d7e548380c2e9e8c54525c6e76d21b1ae3b1"}, + {file = "lxml-5.3.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:68934b242c51eb02907c5b81d138cb977b2129a0a75a8f8b60b01cb8586c7b21"}, + {file = "lxml-5.3.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b710bc2b8292966b23a6a0121f7a6c51d45d2347edcc75f016ac123b8054d3f2"}, + {file = "lxml-5.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18feb4b93302091b1541221196a2155aa296c363fd233814fa11e181adebc52f"}, + {file = "lxml-5.3.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:3eb44520c4724c2e1a57c0af33a379eee41792595023f367ba3952a2d96c2aab"}, + {file = "lxml-5.3.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:609251a0ca4770e5a8768ff902aa02bf636339c5a93f9349b48eb1f606f7f3e9"}, + {file = "lxml-5.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:516f491c834eb320d6c843156440fe7fc0d50b33e44387fcec5b02f0bc118a4c"}, + {file = "lxml-5.3.0.tar.gz", hash = "sha256:4e109ca30d1edec1ac60cdbe341905dc3b8f55b16855e03a54aaf59e51ec8c6f"}, ] [package.extras] @@ -4363,7 +4864,7 @@ cssselect = ["cssselect (>=0.7)"] html-clean = ["lxml-html-clean"] html5 = ["html5lib"] htmlsoup = ["BeautifulSoup4"] -source = ["Cython (>=3.0.10)"] +source = ["Cython (>=3.0.11)"] [[package]] name = "lz4" @@ -4434,13 +4935,13 @@ urllib3 = ">=1.23" [[package]] name = "mako" -version = "1.3.5" +version = "1.3.6" description = "A super-fast templating language that borrows the best ideas from the existing templating languages." optional = false python-versions = ">=3.8" files = [ - {file = "Mako-1.3.5-py3-none-any.whl", hash = "sha256:260f1dbc3a519453a9c856dedfe4beb4e50bd5a26d96386cb6c80856556bb91a"}, - {file = "Mako-1.3.5.tar.gz", hash = "sha256:48dbc20568c1d276a2698b36d968fa76161bf127194907ea6fc594fa81f943bc"}, + {file = "Mako-1.3.6-py3-none-any.whl", hash = "sha256:a91198468092a2f1a0de86ca92690fb0cfc43ca90ee17e15d93662b4c04b241a"}, + {file = "mako-1.3.6.tar.gz", hash = "sha256:9ec3a1583713479fae654f83ed9fa8c9a4c16b7bb0daba0e6bbebff50c0d983d"}, ] [package.dependencies] @@ -4492,91 +4993,92 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] [[package]] name = "markupsafe" -version = "2.1.5" +version = "3.0.2" description = "Safely add untrusted strings to HTML/XML markup." optional = false -python-versions = ">=3.7" +python-versions = ">=3.9" files = [ - {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc"}, - {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5"}, - {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e61659ba32cf2cf1481e575d0462554625196a1f2fc06a1c777d3f48e8865d46"}, - {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2174c595a0d73a3080ca3257b40096db99799265e1c27cc5a610743acd86d62f"}, - {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ae2ad8ae6ebee9d2d94b17fb62763125f3f374c25618198f40cbb8b525411900"}, - {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:075202fa5b72c86ad32dc7d0b56024ebdbcf2048c0ba09f1cde31bfdd57bcfff"}, - {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:598e3276b64aff0e7b3451b72e94fa3c238d452e7ddcd893c3ab324717456bad"}, - {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fce659a462a1be54d2ffcacea5e3ba2d74daa74f30f5f143fe0c58636e355fdd"}, - {file = "MarkupSafe-2.1.5-cp310-cp310-win32.whl", hash = "sha256:d9fad5155d72433c921b782e58892377c44bd6252b5af2f67f16b194987338a4"}, - {file = "MarkupSafe-2.1.5-cp310-cp310-win_amd64.whl", hash = "sha256:bf50cd79a75d181c9181df03572cdce0fbb75cc353bc350712073108cba98de5"}, - {file = "MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:629ddd2ca402ae6dbedfceeba9c46d5f7b2a61d9749597d4307f943ef198fc1f"}, - {file = "MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5b7b716f97b52c5a14bffdf688f971b2d5ef4029127f1ad7a513973cfd818df2"}, - {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ec585f69cec0aa07d945b20805be741395e28ac1627333b1c5b0105962ffced"}, - {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b91c037585eba9095565a3556f611e3cbfaa42ca1e865f7b8015fe5c7336d5a5"}, - {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7502934a33b54030eaf1194c21c692a534196063db72176b0c4028e140f8f32c"}, - {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0e397ac966fdf721b2c528cf028494e86172b4feba51d65f81ffd65c63798f3f"}, - {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c061bb86a71b42465156a3ee7bd58c8c2ceacdbeb95d05a99893e08b8467359a"}, - {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3a57fdd7ce31c7ff06cdfbf31dafa96cc533c21e443d57f5b1ecc6cdc668ec7f"}, - {file = "MarkupSafe-2.1.5-cp311-cp311-win32.whl", hash = "sha256:397081c1a0bfb5124355710fe79478cdbeb39626492b15d399526ae53422b906"}, - {file = "MarkupSafe-2.1.5-cp311-cp311-win_amd64.whl", hash = "sha256:2b7c57a4dfc4f16f7142221afe5ba4e093e09e728ca65c51f5620c9aaeb9a617"}, - {file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:8dec4936e9c3100156f8a2dc89c4b88d5c435175ff03413b443469c7c8c5f4d1"}, - {file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3c6b973f22eb18a789b1460b4b91bf04ae3f0c4234a0a6aa6b0a92f6f7b951d4"}, - {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac07bad82163452a6884fe8fa0963fb98c2346ba78d779ec06bd7a6262132aee"}, - {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5dfb42c4604dddc8e4305050aa6deb084540643ed5804d7455b5df8fe16f5e5"}, - {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ea3d8a3d18833cf4304cd2fc9cbb1efe188ca9b5efef2bdac7adc20594a0e46b"}, - {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d050b3361367a06d752db6ead6e7edeb0009be66bc3bae0ee9d97fb326badc2a"}, - {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bec0a414d016ac1a18862a519e54b2fd0fc8bbfd6890376898a6c0891dd82e9f"}, - {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:58c98fee265677f63a4385256a6d7683ab1832f3ddd1e66fe948d5880c21a169"}, - {file = "MarkupSafe-2.1.5-cp312-cp312-win32.whl", hash = "sha256:8590b4ae07a35970728874632fed7bd57b26b0102df2d2b233b6d9d82f6c62ad"}, - {file = "MarkupSafe-2.1.5-cp312-cp312-win_amd64.whl", hash = "sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb"}, - {file = "MarkupSafe-2.1.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c8b29db45f8fe46ad280a7294f5c3ec36dbac9491f2d1c17345be8e69cc5928f"}, - {file = "MarkupSafe-2.1.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec6a563cff360b50eed26f13adc43e61bc0c04d94b8be985e6fb24b81f6dcfdf"}, - {file = "MarkupSafe-2.1.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a549b9c31bec33820e885335b451286e2969a2d9e24879f83fe904a5ce59d70a"}, - {file = "MarkupSafe-2.1.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4f11aa001c540f62c6166c7726f71f7573b52c68c31f014c25cc7901deea0b52"}, - {file = "MarkupSafe-2.1.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7b2e5a267c855eea6b4283940daa6e88a285f5f2a67f2220203786dfa59b37e9"}, - {file = "MarkupSafe-2.1.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:2d2d793e36e230fd32babe143b04cec8a8b3eb8a3122d2aceb4a371e6b09b8df"}, - {file = "MarkupSafe-2.1.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ce409136744f6521e39fd8e2a24c53fa18ad67aa5bc7c2cf83645cce5b5c4e50"}, - {file = "MarkupSafe-2.1.5-cp37-cp37m-win32.whl", hash = "sha256:4096e9de5c6fdf43fb4f04c26fb114f61ef0bf2e5604b6ee3019d51b69e8c371"}, - {file = "MarkupSafe-2.1.5-cp37-cp37m-win_amd64.whl", hash = "sha256:4275d846e41ecefa46e2015117a9f491e57a71ddd59bbead77e904dc02b1bed2"}, - {file = "MarkupSafe-2.1.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:656f7526c69fac7f600bd1f400991cc282b417d17539a1b228617081106feb4a"}, - {file = "MarkupSafe-2.1.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:97cafb1f3cbcd3fd2b6fbfb99ae11cdb14deea0736fc2b0952ee177f2b813a46"}, - {file = "MarkupSafe-2.1.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f3fbcb7ef1f16e48246f704ab79d79da8a46891e2da03f8783a5b6fa41a9532"}, - {file = "MarkupSafe-2.1.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa9db3f79de01457b03d4f01b34cf91bc0048eb2c3846ff26f66687c2f6d16ab"}, - {file = "MarkupSafe-2.1.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffee1f21e5ef0d712f9033568f8344d5da8cc2869dbd08d87c84656e6a2d2f68"}, - {file = "MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5dedb4db619ba5a2787a94d877bc8ffc0566f92a01c0ef214865e54ecc9ee5e0"}, - {file = "MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:30b600cf0a7ac9234b2638fbc0fb6158ba5bdcdf46aeb631ead21248b9affbc4"}, - {file = "MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8dd717634f5a044f860435c1d8c16a270ddf0ef8588d4887037c5028b859b0c3"}, - {file = "MarkupSafe-2.1.5-cp38-cp38-win32.whl", hash = "sha256:daa4ee5a243f0f20d528d939d06670a298dd39b1ad5f8a72a4275124a7819eff"}, - {file = "MarkupSafe-2.1.5-cp38-cp38-win_amd64.whl", hash = "sha256:619bc166c4f2de5caa5a633b8b7326fbe98e0ccbfacabd87268a2b15ff73a029"}, - {file = "MarkupSafe-2.1.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7a68b554d356a91cce1236aa7682dc01df0edba8d043fd1ce607c49dd3c1edcf"}, - {file = "MarkupSafe-2.1.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:db0b55e0f3cc0be60c1f19efdde9a637c32740486004f20d1cff53c3c0ece4d2"}, - {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e53af139f8579a6d5f7b76549125f0d94d7e630761a2111bc431fd820e163b8"}, - {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:17b950fccb810b3293638215058e432159d2b71005c74371d784862b7e4683f3"}, - {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c31f53cdae6ecfa91a77820e8b151dba54ab528ba65dfd235c80b086d68a465"}, - {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bff1b4290a66b490a2f4719358c0cdcd9bafb6b8f061e45c7a2460866bf50c2e"}, - {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bc1667f8b83f48511b94671e0e441401371dfd0f0a795c7daa4a3cd1dde55bea"}, - {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5049256f536511ee3f7e1b3f87d1d1209d327e818e6ae1365e8653d7e3abb6a6"}, - {file = "MarkupSafe-2.1.5-cp39-cp39-win32.whl", hash = "sha256:00e046b6dd71aa03a41079792f8473dc494d564611a8f89bbbd7cb93295ebdcf"}, - {file = "MarkupSafe-2.1.5-cp39-cp39-win_amd64.whl", hash = "sha256:fa173ec60341d6bb97a89f5ea19c85c5643c1e7dedebc22f5181eb73573142c5"}, - {file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7e94c425039cde14257288fd61dcfb01963e658efbc0ff54f5306b06054700f8"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9e2d922824181480953426608b81967de705c3cef4d1af983af849d7bd619158"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:38a9ef736c01fccdd6600705b09dc574584b89bea478200c5fbf112a6b0d5579"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bbcb445fa71794da8f178f0f6d66789a28d7319071af7a496d4d507ed566270d"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57cb5a3cf367aeb1d316576250f65edec5bb3be939e9247ae594b4bcbc317dfb"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:3809ede931876f5b2ec92eef964286840ed3540dadf803dd570c3b7e13141a3b"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e07c3764494e3776c602c1e78e298937c3315ccc9043ead7e685b7f2b8d47b3c"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b424c77b206d63d500bcb69fa55ed8d0e6a3774056bdc4839fc9298a7edca171"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-win32.whl", hash = "sha256:fcabf5ff6eea076f859677f5f0b6b5c1a51e70a376b0579e0eadef8db48c6b50"}, + {file = "MarkupSafe-3.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:6af100e168aa82a50e186c82875a5893c5597a0c1ccdb0d8b40240b1f28b969a"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:9025b4018f3a1314059769c7bf15441064b2207cb3f065e6ea1e7359cb46db9d"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:93335ca3812df2f366e80509ae119189886b0f3c2b81325d39efdb84a1e2ae93"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2cb8438c3cbb25e220c2ab33bb226559e7afb3baec11c4f218ffa7308603c832"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a123e330ef0853c6e822384873bef7507557d8e4a082961e1defa947aa59ba84"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e084f686b92e5b83186b07e8a17fc09e38fff551f3602b249881fec658d3eca"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d8213e09c917a951de9d09ecee036d5c7d36cb6cb7dbaece4c71a60d79fb9798"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:5b02fb34468b6aaa40dfc198d813a641e3a63b98c2b05a16b9f80b7ec314185e"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0bff5e0ae4ef2e1ae4fdf2dfd5b76c75e5c2fa4132d05fc1b0dabcd20c7e28c4"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-win32.whl", hash = "sha256:6c89876f41da747c8d3677a2b540fb32ef5715f97b66eeb0c6b66f5e3ef6f59d"}, + {file = "MarkupSafe-3.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:70a87b411535ccad5ef2f1df5136506a10775d267e197e4cf531ced10537bd6b"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:9778bd8ab0a994ebf6f84c2b949e65736d5575320a17ae8984a77fab08db94cf"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:846ade7b71e3536c4e56b386c2a47adf5741d2d8b94ec9dc3e92e5e1ee1e2225"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c99d261bd2d5f6b59325c92c73df481e05e57f19837bdca8413b9eac4bd8028"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e17c96c14e19278594aa4841ec148115f9c7615a47382ecb6b82bd8fea3ab0c8"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:88416bd1e65dcea10bc7569faacb2c20ce071dd1f87539ca2ab364bf6231393c"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2181e67807fc2fa785d0592dc2d6206c019b9502410671cc905d132a92866557"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:52305740fe773d09cffb16f8ed0427942901f00adedac82ec8b67752f58a1b22"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ad10d3ded218f1039f11a75f8091880239651b52e9bb592ca27de44eed242a48"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-win32.whl", hash = "sha256:0f4ca02bea9a23221c0182836703cbf8930c5e9454bacce27e767509fa286a30"}, + {file = "MarkupSafe-3.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:8e06879fc22a25ca47312fbe7c8264eb0b662f6db27cb2d3bbbc74b1df4b9b87"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ba9527cdd4c926ed0760bc301f6728ef34d841f405abf9d4f959c478421e4efd"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f8b3d067f2e40fe93e1ccdd6b2e1d16c43140e76f02fb1319a05cf2b79d99430"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:569511d3b58c8791ab4c2e1285575265991e6d8f8700c7be0e88f86cb0672094"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15ab75ef81add55874e7ab7055e9c397312385bd9ced94920f2802310c930396"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f3818cb119498c0678015754eba762e0d61e5b52d34c8b13d770f0719f7b1d79"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:cdb82a876c47801bb54a690c5ae105a46b392ac6099881cdfb9f6e95e4014c6a"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:cabc348d87e913db6ab4aa100f01b08f481097838bdddf7c7a84b7575b7309ca"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:444dcda765c8a838eaae23112db52f1efaf750daddb2d9ca300bcae1039adc5c"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-win32.whl", hash = "sha256:bcf3e58998965654fdaff38e58584d8937aa3096ab5354d493c77d1fdd66d7a1"}, + {file = "MarkupSafe-3.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:e6a2a455bd412959b57a172ce6328d2dd1f01cb2135efda2e4576e8a23fa3b0f"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:b5a6b3ada725cea8a5e634536b1b01c30bcdcd7f9c6fff4151548d5bf6b3a36c"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:a904af0a6162c73e3edcb969eeeb53a63ceeb5d8cf642fade7d39e7963a22ddb"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4aa4e5faecf353ed117801a068ebab7b7e09ffb6e1d5e412dc852e0da018126c"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0ef13eaeee5b615fb07c9a7dadb38eac06a0608b41570d8ade51c56539e509d"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d16a81a06776313e817c951135cf7340a3e91e8c1ff2fac444cfd75fffa04afe"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:6381026f158fdb7c72a168278597a5e3a5222e83ea18f543112b2662a9b699c5"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:3d79d162e7be8f996986c064d1c7c817f6df3a77fe3d6859f6f9e7be4b8c213a"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:131a3c7689c85f5ad20f9f6fb1b866f402c445b220c19fe4308c0b147ccd2ad9"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-win32.whl", hash = "sha256:ba8062ed2cf21c07a9e295d5b8a2a5ce678b913b45fdf68c32d95d6c1291e0b6"}, + {file = "MarkupSafe-3.0.2-cp313-cp313t-win_amd64.whl", hash = "sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:eaa0a10b7f72326f1372a713e73c3f739b524b3af41feb43e4921cb529f5929a"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:48032821bbdf20f5799ff537c7ac3d1fba0ba032cfc06194faffa8cda8b560ff"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a9d3f5f0901fdec14d8d2f66ef7d035f2157240a433441719ac9a3fba440b13"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88b49a3b9ff31e19998750c38e030fc7bb937398b1f78cfa599aaef92d693144"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cfad01eed2c2e0c01fd0ecd2ef42c492f7f93902e39a42fc9ee1692961443a29"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:1225beacc926f536dc82e45f8a4d68502949dc67eea90eab715dea3a21c1b5f0"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:3169b1eefae027567d1ce6ee7cae382c57fe26e82775f460f0b2778beaad66c0"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:eb7972a85c54febfb25b5c4b4f3af4dcc731994c7da0d8a0b4a6eb0640e1d178"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-win32.whl", hash = "sha256:8c4e8c3ce11e1f92f6536ff07154f9d49677ebaaafc32db9db4620bc11ed480f"}, + {file = "MarkupSafe-3.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6e296a513ca3d94054c2c881cc913116e90fd030ad1c656b3869762b754f5f8a"}, + {file = "markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0"}, ] [[package]] name = "marshmallow" -version = "3.21.3" +version = "3.23.0" description = "A lightweight library for converting complex datatypes to and from native Python datatypes." optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "marshmallow-3.21.3-py3-none-any.whl", hash = "sha256:86ce7fb914aa865001a4b2092c4c2872d13bc347f3d42673272cabfdbad386f1"}, - {file = "marshmallow-3.21.3.tar.gz", hash = "sha256:4f57c5e050a54d66361e826f94fba213eb10b67b2fdb02c3e0343ce207ba1662"}, + {file = "marshmallow-3.23.0-py3-none-any.whl", hash = "sha256:82f20a2397834fe6d9611b241f2f7e7b680ed89c49f84728a1ad937be6b4bdf4"}, + {file = "marshmallow-3.23.0.tar.gz", hash = "sha256:98d8827a9f10c03d44ead298d2e99c6aea8197df18ccfad360dae7f89a50da2e"}, ] [package.dependencies] packaging = ">=17.0" [package.extras] -dev = ["marshmallow[tests]", "pre-commit (>=3.5,<4.0)", "tox"] -docs = ["alabaster (==0.7.16)", "autodocsumm (==0.2.12)", "sphinx (==7.3.7)", "sphinx-issues (==4.1.0)", "sphinx-version-warning (==1.1.2)"] -tests = ["pytest", "pytz", "simplejson"] +dev = ["marshmallow[tests]", "pre-commit (>=3.5,<5.0)", "tox"] +docs = ["alabaster (==1.0.0)", "autodocsumm (==0.2.13)", "sphinx (==8.1.3)", "sphinx-issues (==5.0.0)", "sphinx-version-warning (==1.1.2)"] +tests = ["pytest", "simplejson"] [[package]] name = "matplotlib" @@ -4639,111 +5141,159 @@ files = [ [[package]] name = "milvus-lite" -version = "2.4.8" +version = "2.4.10" description = "A lightweight version of Milvus wrapped with Python." optional = false python-versions = ">=3.7" files = [ - {file = "milvus_lite-2.4.8-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:b7e90b34b214884cd44cdc112ab243d4cb197b775498355e2437b6cafea025fe"}, - {file = "milvus_lite-2.4.8-py3-none-macosx_11_0_arm64.whl", hash = "sha256:519dfc62709d8f642d98a1c5b1dcde7080d107e6e312d677fef5a3412a40ac08"}, - {file = "milvus_lite-2.4.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:b21f36d24cbb0e920b4faad607019bb28c1b2c88b4d04680ac8c7697a4ae8a4d"}, - {file = "milvus_lite-2.4.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:08332a2b9abfe7c4e1d7926068937e46f8fb81f2707928b7bc02c9dc99cebe41"}, + {file = "milvus_lite-2.4.10-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:fc4246d3ed7d1910847afce0c9ba18212e93a6e9b8406048436940578dfad5cb"}, + {file = "milvus_lite-2.4.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:74a8e07c5e3b057df17fbb46913388e84df1dc403a200f4e423799a58184c800"}, + {file = "milvus_lite-2.4.10-py3-none-manylinux2014_aarch64.whl", hash = "sha256:240c7386b747bad696ecb5bd1f58d491e86b9d4b92dccee3315ed7256256eddc"}, + {file = "milvus_lite-2.4.10-py3-none-manylinux2014_x86_64.whl", hash = "sha256:211d2e334a043f9282bdd9755f76b9b2d93b23bffa7af240919ffce6a8dfe325"}, ] [package.dependencies] tqdm = "*" +[[package]] +name = "mistune" +version = "3.0.2" +description = "A sane and fast Markdown parser with useful plugins and renderers" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mistune-3.0.2-py3-none-any.whl", hash = "sha256:71481854c30fdbc938963d3605b72501f5c10a9320ecd412c121c163a1c7d205"}, + {file = "mistune-3.0.2.tar.gz", hash = "sha256:fc7f93ded930c92394ef2cb6f04a8aabab4117a91449e72dcc8dfa646a508be8"}, +] + [[package]] name = "mmh3" -version = "4.1.0" +version = "5.0.1" description = "Python extension for MurmurHash (MurmurHash3), a set of fast and robust hash functions." optional = false -python-versions = "*" +python-versions = ">=3.8" +files = [ + {file = "mmh3-5.0.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f0a4b4bf05778ed77d820d6e7d0e9bd6beb0c01af10e1ce9233f5d2f814fcafa"}, + {file = "mmh3-5.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ac7a391039aeab95810c2d020b69a94eb6b4b37d4e2374831e92db3a0cdf71c6"}, + {file = "mmh3-5.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3a2583b5521ca49756d8d8bceba80627a9cc295f255dcab4e3df7ccc2f09679a"}, + {file = "mmh3-5.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:081a8423fe53c1ac94f87165f3e4c500125d343410c1a0c5f1703e898a3ef038"}, + {file = "mmh3-5.0.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8b4d72713799755dc8954a7d36d5c20a6c8de7b233c82404d122c7c7c1707cc"}, + {file = "mmh3-5.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:389a6fd51efc76d3182d36ec306448559c1244f11227d2bb771bdd0e6cc91321"}, + {file = "mmh3-5.0.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:39f4128edaa074bff721b1d31a72508cba4d2887ee7867f22082e1fe9d4edea0"}, + {file = "mmh3-5.0.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d5d23a94d91aabba3386b3769048d5f4210fdfef80393fece2f34ba5a7b466c"}, + {file = "mmh3-5.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:16347d038361f8b8f24fd2b7ef378c9b68ddee9f7706e46269b6e0d322814713"}, + {file = "mmh3-5.0.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:6e299408565af7d61f2d20a5ffdd77cf2ed902460fe4e6726839d59ba4b72316"}, + {file = "mmh3-5.0.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:42050af21ddfc5445ee5a66e73a8fc758c71790305e3ee9e4a85a8e69e810f94"}, + {file = "mmh3-5.0.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:2ae9b1f5ef27ec54659920f0404b7ceb39966e28867c461bfe83a05e8d18ddb0"}, + {file = "mmh3-5.0.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:50c2495a02045f3047d71d4ae9cdd7a15efc0bcbb7ff17a18346834a8e2d1d19"}, + {file = "mmh3-5.0.1-cp310-cp310-win32.whl", hash = "sha256:c028fa77cddf351ca13b4a56d43c1775652cde0764cadb39120b68f02a23ecf6"}, + {file = "mmh3-5.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:c5e741e421ec14400c4aae30890515c201f518403bdef29ae1e00d375bb4bbb5"}, + {file = "mmh3-5.0.1-cp310-cp310-win_arm64.whl", hash = "sha256:b17156d56fabc73dbf41bca677ceb6faed435cc8544f6566d72ea77d8a17e9d0"}, + {file = "mmh3-5.0.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:9a6d5a9b1b923f1643559ba1fc0bf7a5076c90cbb558878d3bf3641ce458f25d"}, + {file = "mmh3-5.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3349b968be555f7334bbcce839da98f50e1e80b1c615d8e2aa847ea4a964a012"}, + {file = "mmh3-5.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1bd3c94b110e55db02ab9b605029f48a2f7f677c6e58c09d44e42402d438b7e1"}, + {file = "mmh3-5.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d47ba84d48608f79adbb10bb09986b6dc33eeda5c2d1bd75d00820081b73bde9"}, + {file = "mmh3-5.0.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c0217987a8b8525c8d9170f66d036dec4ab45cfbd53d47e8d76125791ceb155e"}, + {file = "mmh3-5.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b2797063a34e78d1b61639a98b0edec1c856fa86ab80c7ec859f1796d10ba429"}, + {file = "mmh3-5.0.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8bba16340adcbd47853a2fbe5afdb397549e8f2e79324ff1dced69a3f8afe7c3"}, + {file = "mmh3-5.0.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:282797957c9f60b51b9d768a602c25f579420cc9af46feb77d457a27823d270a"}, + {file = "mmh3-5.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:e4fb670c29e63f954f9e7a2cdcd57b36a854c2538f579ef62681ccbaa1de2b69"}, + {file = "mmh3-5.0.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:8ee7d85438dc6aff328e19ab052086a3c29e8a9b632998a49e5c4b0034e9e8d6"}, + {file = "mmh3-5.0.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:b7fb5db231f3092444bc13901e6a8d299667126b00636ffbad4a7b45e1051e2f"}, + {file = "mmh3-5.0.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:c100dd441703da5ec136b1d9003ed4a041d8a1136234c9acd887499796df6ad8"}, + {file = "mmh3-5.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:71f3b765138260fd7a7a2dba0ea5727dabcd18c1f80323c9cfef97a7e86e01d0"}, + {file = "mmh3-5.0.1-cp311-cp311-win32.whl", hash = "sha256:9a76518336247fd17689ce3ae5b16883fd86a490947d46a0193d47fb913e26e3"}, + {file = "mmh3-5.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:336bc4df2e44271f1c302d289cc3d78bd52d3eed8d306c7e4bff8361a12bf148"}, + {file = "mmh3-5.0.1-cp311-cp311-win_arm64.whl", hash = "sha256:af6522722fbbc5999aa66f7244d0986767a46f1fb05accc5200f75b72428a508"}, + {file = "mmh3-5.0.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:f2730bb263ed9c388e8860438b057a53e3cc701134a6ea140f90443c4c11aa40"}, + {file = "mmh3-5.0.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6246927bc293f6d56724536400b85fb85f5be26101fa77d5f97dd5e2a4c69bf2"}, + {file = "mmh3-5.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fbca322519a6e6e25b6abf43e940e1667cf8ea12510e07fb4919b48a0cd1c411"}, + {file = "mmh3-5.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eae8c19903ed8a1724ad9e67e86f15d198a7a1271a4f9be83d47e38f312ed672"}, + {file = "mmh3-5.0.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a09fd6cc72c07c0c07c3357714234b646d78052487c4a3bd5f7f6e08408cff60"}, + {file = "mmh3-5.0.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2ff8551fee7ae3b11c5d986b6347ade0dccaadd4670ffdb2b944dee120ffcc84"}, + {file = "mmh3-5.0.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e39694c73a5a20c8bf36dfd8676ed351e5234d55751ba4f7562d85449b21ef3f"}, + {file = "mmh3-5.0.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eba6001989a92f72a89c7cf382fda831678bd780707a66b4f8ca90239fdf2123"}, + {file = "mmh3-5.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0771f90c9911811cc606a5c7b7b58f33501c9ee896ed68a6ac22c7d55878ecc0"}, + {file = "mmh3-5.0.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:09b31ed0c0c0920363e96641fac4efde65b1ab62b8df86293142f35a254e72b4"}, + {file = "mmh3-5.0.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:5cf4a8deda0235312db12075331cb417c4ba163770edfe789bde71d08a24b692"}, + {file = "mmh3-5.0.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:41f7090a95185ef20ac018581a99337f0cbc84a2135171ee3290a9c0d9519585"}, + {file = "mmh3-5.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b97b5b368fb7ff22194ec5854f5b12d8de9ab67a0f304728c7f16e5d12135b76"}, + {file = "mmh3-5.0.1-cp312-cp312-win32.whl", hash = "sha256:842516acf04da546f94fad52db125ee619ccbdcada179da51c326a22c4578cb9"}, + {file = "mmh3-5.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:d963be0dbfd9fca209c17172f6110787ebf78934af25e3694fe2ba40e55c1e2b"}, + {file = "mmh3-5.0.1-cp312-cp312-win_arm64.whl", hash = "sha256:a5da292ceeed8ce8e32b68847261a462d30fd7b478c3f55daae841404f433c15"}, + {file = "mmh3-5.0.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:673e3f1c8d4231d6fb0271484ee34cb7146a6499fc0df80788adb56fd76842da"}, + {file = "mmh3-5.0.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f795a306bd16a52ad578b663462cc8e95500b3925d64118ae63453485d67282b"}, + {file = "mmh3-5.0.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5ed57a5e28e502a1d60436cc25c76c3a5ba57545f250f2969af231dc1221e0a5"}, + {file = "mmh3-5.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:632c28e7612e909dbb6cbe2fe496201ada4695b7715584005689c5dc038e59ad"}, + {file = "mmh3-5.0.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:53fd6bd525a5985e391c43384672d9d6b317fcb36726447347c7fc75bfed34ec"}, + {file = "mmh3-5.0.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dceacf6b0b961a0e499836af3aa62d60633265607aef551b2a3e3c48cdaa5edd"}, + {file = "mmh3-5.0.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f0738d478fdfb5d920f6aff5452c78f2c35b0eff72caa2a97dfe38e82f93da2"}, + {file = "mmh3-5.0.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e70285e7391ab88b872e5bef632bad16b9d99a6d3ca0590656a4753d55988af"}, + {file = "mmh3-5.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:27e5fc6360aa6b828546a4318da1a7da6bf6e5474ccb053c3a6aa8ef19ff97bd"}, + {file = "mmh3-5.0.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7989530c3c1e2c17bf5a0ec2bba09fd19819078ba90beedabb1c3885f5040b0d"}, + {file = "mmh3-5.0.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:cdad7bee649950da7ecd3cbbbd12fb81f1161072ecbdb5acfa0018338c5cb9cf"}, + {file = "mmh3-5.0.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:e143b8f184c1bb58cecd85ab4a4fd6dc65a2d71aee74157392c3fddac2a4a331"}, + {file = "mmh3-5.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e5eb12e886f3646dd636f16b76eb23fc0c27e8ff3c1ae73d4391e50ef60b40f6"}, + {file = "mmh3-5.0.1-cp313-cp313-win32.whl", hash = "sha256:16e6dddfa98e1c2d021268e72c78951234186deb4df6630e984ac82df63d0a5d"}, + {file = "mmh3-5.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:d3ffb792d70b8c4a2382af3598dad6ae0c5bd9cee5b7ffcc99aa2f5fd2c1bf70"}, + {file = "mmh3-5.0.1-cp313-cp313-win_arm64.whl", hash = "sha256:122fa9ec148383f9124292962bda745f192b47bfd470b2af5fe7bb3982b17896"}, + {file = "mmh3-5.0.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:b12bad8c75e6ff5d67319794fb6a5e8c713826c818d47f850ad08b4aa06960c6"}, + {file = "mmh3-5.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e5bbb066538c1048d542246fc347bb7994bdda29a3aea61c22f9f8b57111ce69"}, + {file = "mmh3-5.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:eee6134273f64e2a106827cc8fd77e70cc7239a285006fc6ab4977d59b015af2"}, + {file = "mmh3-5.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d04d9aa19d48e4c7bbec9cabc2c4dccc6ff3b2402f856d5bf0de03e10f167b5b"}, + {file = "mmh3-5.0.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:79f37da1eed034d06567a69a7988456345c7f29e49192831c3975b464493b16e"}, + {file = "mmh3-5.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:242f77666743337aa828a2bf2da71b6ba79623ee7f93edb11e009f69237c8561"}, + {file = "mmh3-5.0.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffd943fff690463945f6441a2465555b3146deaadf6a5e88f2590d14c655d71b"}, + {file = "mmh3-5.0.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:565b15f8d7df43acb791ff5a360795c20bfa68bca8b352509e0fbabd06cc48cd"}, + {file = "mmh3-5.0.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:fc6aafb867c2030df98ac7760ff76b500359252867985f357bd387739f3d5287"}, + {file = "mmh3-5.0.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:32898170644d45aa27c974ab0d067809c066205110f5c6d09f47d9ece6978bfe"}, + {file = "mmh3-5.0.1-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:42865567838d2193eb64e0ef571f678bf361a254fcdef0c5c8e73243217829bd"}, + {file = "mmh3-5.0.1-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:5ff5c1f301c4a8b6916498969c0fcc7e3dbc56b4bfce5cfe3fe31f3f4609e5ae"}, + {file = "mmh3-5.0.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:be74c2dda8a6f44a504450aa2c3507f8067a159201586fc01dd41ab80efc350f"}, + {file = "mmh3-5.0.1-cp38-cp38-win32.whl", hash = "sha256:5610a842621ff76c04b20b29cf5f809b131f241a19d4937971ba77dc99a7f330"}, + {file = "mmh3-5.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:de15739ac50776fe8aa1ef13f1be46a6ee1fbd45f6d0651084097eb2be0a5aa4"}, + {file = "mmh3-5.0.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:48e84cf3cc7e8c41bc07de72299a73b92d9e3cde51d97851420055b1484995f7"}, + {file = "mmh3-5.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6dd9dc28c2d168c49928195c2e29b96f9582a5d07bd690a28aede4cc07b0e696"}, + {file = "mmh3-5.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2771a1c56a3d4bdad990309cff5d0a8051f29c8ec752d001f97d6392194ae880"}, + {file = "mmh3-5.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5ff2a8322ba40951a84411550352fba1073ce1c1d1213bb7530f09aed7f8caf"}, + {file = "mmh3-5.0.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a16bd3ec90682c9e0a343e6bd4c778c09947c8c5395cdb9e5d9b82b2559efbca"}, + {file = "mmh3-5.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d45733a78d68b5b05ff4a823aea51fa664df1d3bf4929b152ff4fd6dea2dd69b"}, + {file = "mmh3-5.0.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:904285e83cedebc8873b0838ed54c20f7344120be26e2ca5a907ab007a18a7a0"}, + {file = "mmh3-5.0.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac4aeb1784e43df728034d0ed72e4b2648db1a69fef48fa58e810e13230ae5ff"}, + {file = "mmh3-5.0.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:cb3d4f751a0b8b4c8d06ef1c085216c8fddcc8b8c8d72445976b5167a40c6d1e"}, + {file = "mmh3-5.0.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:8021851935600e60c42122ed1176399d7692df338d606195cd599d228a04c1c6"}, + {file = "mmh3-5.0.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:6182d5924a5efc451900f864cbb021d7e8ad5d524816ca17304a0f663bc09bb5"}, + {file = "mmh3-5.0.1-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:5f30b834552a4f79c92e3d266336fb87fd92ce1d36dc6813d3e151035890abbd"}, + {file = "mmh3-5.0.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:cd4383f35e915e06d077df27e04ffd3be7513ec6a9de2d31f430393f67e192a7"}, + {file = "mmh3-5.0.1-cp39-cp39-win32.whl", hash = "sha256:1455fb6b42665a97db8fc66e89a861e52b567bce27ed054c47877183f86ea6e3"}, + {file = "mmh3-5.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:9e26a0f4eb9855a143f5938a53592fa14c2d3b25801c2106886ab6c173982780"}, + {file = "mmh3-5.0.1-cp39-cp39-win_arm64.whl", hash = "sha256:0d0a35a69abdad7549c4030a714bb4ad07902edb3bbe61e1bbc403ded5d678be"}, + {file = "mmh3-5.0.1.tar.gz", hash = "sha256:7dab080061aeb31a6069a181f27c473a1f67933854e36a3464931f2716508896"}, +] + +[package.extras] +benchmark = ["pymmh3 (==0.0.5)", "pyperf (==2.7.0)", "xxhash (==3.5.0)"] +docs = ["myst-parser (==4.0.0)", "shibuya (==2024.8.30)", "sphinx (==8.0.2)", "sphinx-copybutton (==0.5.2)"] +lint = ["black (==24.8.0)", "clang-format (==18.1.8)", "isort (==5.13.2)", "pylint (==3.2.7)"] +plot = ["matplotlib (==3.9.2)", "pandas (==2.2.2)"] +test = ["pytest (==8.3.3)", "pytest-sugar (==1.0.0)"] +type = ["mypy (==1.11.2)"] + +[[package]] +name = "mock" +version = "4.0.3" +description = "Rolling backport of unittest.mock for all Pythons" +optional = false +python-versions = ">=3.6" files = [ - {file = "mmh3-4.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:be5ac76a8b0cd8095784e51e4c1c9c318c19edcd1709a06eb14979c8d850c31a"}, - {file = "mmh3-4.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:98a49121afdfab67cd80e912b36404139d7deceb6773a83620137aaa0da5714c"}, - {file = "mmh3-4.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5259ac0535874366e7d1a5423ef746e0d36a9e3c14509ce6511614bdc5a7ef5b"}, - {file = "mmh3-4.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5950827ca0453a2be357696da509ab39646044e3fa15cad364eb65d78797437"}, - {file = "mmh3-4.1.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1dd0f652ae99585b9dd26de458e5f08571522f0402155809fd1dc8852a613a39"}, - {file = "mmh3-4.1.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:99d25548070942fab1e4a6f04d1626d67e66d0b81ed6571ecfca511f3edf07e6"}, - {file = "mmh3-4.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:53db8d9bad3cb66c8f35cbc894f336273f63489ce4ac416634932e3cbe79eb5b"}, - {file = "mmh3-4.1.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75da0f615eb55295a437264cc0b736753f830b09d102aa4c2a7d719bc445ec05"}, - {file = "mmh3-4.1.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:b926b07fd678ea84b3a2afc1fa22ce50aeb627839c44382f3d0291e945621e1a"}, - {file = "mmh3-4.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:c5b053334f9b0af8559d6da9dc72cef0a65b325ebb3e630c680012323c950bb6"}, - {file = "mmh3-4.1.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:5bf33dc43cd6de2cb86e0aa73a1cc6530f557854bbbe5d59f41ef6de2e353d7b"}, - {file = "mmh3-4.1.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:fa7eacd2b830727ba3dd65a365bed8a5c992ecd0c8348cf39a05cc77d22f4970"}, - {file = "mmh3-4.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:42dfd6742b9e3eec599f85270617debfa0bbb913c545bb980c8a4fa7b2d047da"}, - {file = "mmh3-4.1.0-cp310-cp310-win32.whl", hash = "sha256:2974ad343f0d39dcc88e93ee6afa96cedc35a9883bc067febd7ff736e207fa47"}, - {file = "mmh3-4.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:74699a8984ded645c1a24d6078351a056f5a5f1fe5838870412a68ac5e28d865"}, - {file = "mmh3-4.1.0-cp310-cp310-win_arm64.whl", hash = "sha256:f0dc874cedc23d46fc488a987faa6ad08ffa79e44fb08e3cd4d4cf2877c00a00"}, - {file = "mmh3-4.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:3280a463855b0eae64b681cd5b9ddd9464b73f81151e87bb7c91a811d25619e6"}, - {file = "mmh3-4.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:97ac57c6c3301769e757d444fa7c973ceb002cb66534b39cbab5e38de61cd896"}, - {file = "mmh3-4.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a7b6502cdb4dbd880244818ab363c8770a48cdccecf6d729ade0241b736b5ec0"}, - {file = "mmh3-4.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:52ba2da04671a9621580ddabf72f06f0e72c1c9c3b7b608849b58b11080d8f14"}, - {file = "mmh3-4.1.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5a5fef4c4ecc782e6e43fbeab09cff1bac82c998a1773d3a5ee6a3605cde343e"}, - {file = "mmh3-4.1.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5135358a7e00991f73b88cdc8eda5203bf9de22120d10a834c5761dbeb07dd13"}, - {file = "mmh3-4.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cff9ae76a54f7c6fe0167c9c4028c12c1f6de52d68a31d11b6790bb2ae685560"}, - {file = "mmh3-4.1.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6f02576a4d106d7830ca90278868bf0983554dd69183b7bbe09f2fcd51cf54f"}, - {file = "mmh3-4.1.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:073d57425a23721730d3ff5485e2da489dd3c90b04e86243dd7211f889898106"}, - {file = "mmh3-4.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:71e32ddec7f573a1a0feb8d2cf2af474c50ec21e7a8263026e8d3b4b629805db"}, - {file = "mmh3-4.1.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7cbb20b29d57e76a58b40fd8b13a9130db495a12d678d651b459bf61c0714cea"}, - {file = "mmh3-4.1.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:a42ad267e131d7847076bb7e31050f6c4378cd38e8f1bf7a0edd32f30224d5c9"}, - {file = "mmh3-4.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4a013979fc9390abadc445ea2527426a0e7a4495c19b74589204f9b71bcaafeb"}, - {file = "mmh3-4.1.0-cp311-cp311-win32.whl", hash = "sha256:1d3b1cdad7c71b7b88966301789a478af142bddcb3a2bee563f7a7d40519a00f"}, - {file = "mmh3-4.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:0dc6dc32eb03727467da8e17deffe004fbb65e8b5ee2b502d36250d7a3f4e2ec"}, - {file = "mmh3-4.1.0-cp311-cp311-win_arm64.whl", hash = "sha256:9ae3a5c1b32dda121c7dc26f9597ef7b01b4c56a98319a7fe86c35b8bc459ae6"}, - {file = "mmh3-4.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0033d60c7939168ef65ddc396611077a7268bde024f2c23bdc283a19123f9e9c"}, - {file = "mmh3-4.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d6af3e2287644b2b08b5924ed3a88c97b87b44ad08e79ca9f93d3470a54a41c5"}, - {file = "mmh3-4.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d82eb4defa245e02bb0b0dc4f1e7ee284f8d212633389c91f7fba99ba993f0a2"}, - {file = "mmh3-4.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba245e94b8d54765e14c2d7b6214e832557e7856d5183bc522e17884cab2f45d"}, - {file = "mmh3-4.1.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bb04e2feeabaad6231e89cd43b3d01a4403579aa792c9ab6fdeef45cc58d4ec0"}, - {file = "mmh3-4.1.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1e3b1a27def545ce11e36158ba5d5390cdbc300cfe456a942cc89d649cf7e3b2"}, - {file = "mmh3-4.1.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce0ab79ff736d7044e5e9b3bfe73958a55f79a4ae672e6213e92492ad5e734d5"}, - {file = "mmh3-4.1.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b02268be6e0a8eeb8a924d7db85f28e47344f35c438c1e149878bb1c47b1cd3"}, - {file = "mmh3-4.1.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:deb887f5fcdaf57cf646b1e062d56b06ef2f23421c80885fce18b37143cba828"}, - {file = "mmh3-4.1.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:99dd564e9e2b512eb117bd0cbf0f79a50c45d961c2a02402787d581cec5448d5"}, - {file = "mmh3-4.1.0-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:08373082dfaa38fe97aa78753d1efd21a1969e51079056ff552e687764eafdfe"}, - {file = "mmh3-4.1.0-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:54b9c6a2ea571b714e4fe28d3e4e2db37abfd03c787a58074ea21ee9a8fd1740"}, - {file = "mmh3-4.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a7b1edf24c69e3513f879722b97ca85e52f9032f24a52284746877f6a7304086"}, - {file = "mmh3-4.1.0-cp312-cp312-win32.whl", hash = "sha256:411da64b951f635e1e2284b71d81a5a83580cea24994b328f8910d40bed67276"}, - {file = "mmh3-4.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:bebc3ecb6ba18292e3d40c8712482b4477abd6981c2ebf0e60869bd90f8ac3a9"}, - {file = "mmh3-4.1.0-cp312-cp312-win_arm64.whl", hash = "sha256:168473dd608ade6a8d2ba069600b35199a9af837d96177d3088ca91f2b3798e3"}, - {file = "mmh3-4.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:372f4b7e1dcde175507640679a2a8790185bb71f3640fc28a4690f73da986a3b"}, - {file = "mmh3-4.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:438584b97f6fe13e944faf590c90fc127682b57ae969f73334040d9fa1c7ffa5"}, - {file = "mmh3-4.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6e27931b232fc676675fac8641c6ec6b596daa64d82170e8597f5a5b8bdcd3b6"}, - {file = "mmh3-4.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:571a92bad859d7b0330e47cfd1850b76c39b615a8d8e7aa5853c1f971fd0c4b1"}, - {file = "mmh3-4.1.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4a69d6afe3190fa08f9e3a58e5145549f71f1f3fff27bd0800313426929c7068"}, - {file = "mmh3-4.1.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:afb127be0be946b7630220908dbea0cee0d9d3c583fa9114a07156f98566dc28"}, - {file = "mmh3-4.1.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:940d86522f36348ef1a494cbf7248ab3f4a1638b84b59e6c9e90408bd11ad729"}, - {file = "mmh3-4.1.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3dcccc4935686619a8e3d1f7b6e97e3bd89a4a796247930ee97d35ea1a39341"}, - {file = "mmh3-4.1.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:01bb9b90d61854dfc2407c5e5192bfb47222d74f29d140cb2dd2a69f2353f7cc"}, - {file = "mmh3-4.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:bcb1b8b951a2c0b0fb8a5426c62a22557e2ffc52539e0a7cc46eb667b5d606a9"}, - {file = "mmh3-4.1.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:6477a05d5e5ab3168e82e8b106e316210ac954134f46ec529356607900aea82a"}, - {file = "mmh3-4.1.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:da5892287e5bea6977364b15712a2573c16d134bc5fdcdd4cf460006cf849278"}, - {file = "mmh3-4.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:99180d7fd2327a6fffbaff270f760576839dc6ee66d045fa3a450f3490fda7f5"}, - {file = "mmh3-4.1.0-cp38-cp38-win32.whl", hash = "sha256:9b0d4f3949913a9f9a8fb1bb4cc6ecd52879730aab5ff8c5a3d8f5b593594b73"}, - {file = "mmh3-4.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:598c352da1d945108aee0c3c3cfdd0e9b3edef74108f53b49d481d3990402169"}, - {file = "mmh3-4.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:475d6d1445dd080f18f0f766277e1237fa2914e5fe3307a3b2a3044f30892103"}, - {file = "mmh3-4.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5ca07c41e6a2880991431ac717c2a049056fff497651a76e26fc22224e8b5732"}, - {file = "mmh3-4.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0ebe052fef4bbe30c0548d12ee46d09f1b69035ca5208a7075e55adfe091be44"}, - {file = "mmh3-4.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eaefd42e85afb70f2b855a011f7b4d8a3c7e19c3f2681fa13118e4d8627378c5"}, - {file = "mmh3-4.1.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ac0ae43caae5a47afe1b63a1ae3f0986dde54b5fb2d6c29786adbfb8edc9edfb"}, - {file = "mmh3-4.1.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6218666f74c8c013c221e7f5f8a693ac9cf68e5ac9a03f2373b32d77c48904de"}, - {file = "mmh3-4.1.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ac59294a536ba447b5037f62d8367d7d93b696f80671c2c45645fa9f1109413c"}, - {file = "mmh3-4.1.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:086844830fcd1e5c84fec7017ea1ee8491487cfc877847d96f86f68881569d2e"}, - {file = "mmh3-4.1.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:e42b38fad664f56f77f6fbca22d08450f2464baa68acdbf24841bf900eb98e87"}, - {file = "mmh3-4.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d08b790a63a9a1cde3b5d7d733ed97d4eb884bfbc92f075a091652d6bfd7709a"}, - {file = "mmh3-4.1.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:73ea4cc55e8aea28c86799ecacebca09e5f86500414870a8abaedfcbaf74d288"}, - {file = "mmh3-4.1.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:f90938ff137130e47bcec8dc1f4ceb02f10178c766e2ef58a9f657ff1f62d124"}, - {file = "mmh3-4.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:aa1f13e94b8631c8cd53259250556edcf1de71738936b60febba95750d9632bd"}, - {file = "mmh3-4.1.0-cp39-cp39-win32.whl", hash = "sha256:a3b680b471c181490cf82da2142029edb4298e1bdfcb67c76922dedef789868d"}, - {file = "mmh3-4.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:fefef92e9c544a8dbc08f77a8d1b6d48006a750c4375bbcd5ff8199d761e263b"}, - {file = "mmh3-4.1.0-cp39-cp39-win_arm64.whl", hash = "sha256:8e2c1f6a2b41723a4f82bd5a762a777836d29d664fc0095f17910bea0adfd4a6"}, - {file = "mmh3-4.1.0.tar.gz", hash = "sha256:a1cf25348b9acd229dda464a094d6170f47d2850a1fcb762a3b6172d2ce6ca4a"}, -] - -[package.extras] -test = ["mypy (>=1.0)", "pytest (>=7.0.0)"] + {file = "mock-4.0.3-py3-none-any.whl", hash = "sha256:122fcb64ee37cfad5b3f48d7a7d51875d7031aaf3d8be7c42e2bee25044eee62"}, + {file = "mock-4.0.3.tar.gz", hash = "sha256:7d3fbbde18228f4ff2f1f119a45cdffa458b4c0dee32eb4d2bb2f82554bac7bc"}, +] + +[package.extras] +build = ["blurb", "twine", "wheel"] +docs = ["sphinx"] +test = ["pytest (<5.4)", "pytest-cov"] [[package]] name = "monotonic" @@ -4756,6 +5306,22 @@ files = [ {file = "monotonic-1.6.tar.gz", hash = "sha256:3a55207bcfed53ddd5c5bae174524062935efed17792e9de2ad0205ce9ad63f7"}, ] +[[package]] +name = "mplfonts" +version = "0.0.8" +description = "Fonts manager for matplotlib" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mplfonts-0.0.8-py3-none-any.whl", hash = "sha256:b2182e5b0baa216cf016dec19942740e5b48956415708ad2d465e03952112ec1"}, + {file = "mplfonts-0.0.8.tar.gz", hash = "sha256:0abcb2fc0605645e1e7561c6923014d856f11676899b33b4d89757843f5e7c22"}, +] + +[package.dependencies] +fire = ">=0.4.0" +fontmeta = ">=1.6.1" +matplotlib = ">=3.4" + [[package]] name = "mpmath" version = "1.3.0" @@ -4775,22 +5341,22 @@ tests = ["pytest (>=4.6)"] [[package]] name = "msal" -version = "1.30.0" +version = "1.31.0" description = "The Microsoft Authentication Library (MSAL) for Python library enables your app to access the Microsoft Cloud by supporting authentication of users with Microsoft Azure Active Directory accounts (AAD) and Microsoft Accounts (MSA) using industry standard OAuth2 and OpenID Connect." optional = false python-versions = ">=3.7" files = [ - {file = "msal-1.30.0-py3-none-any.whl", hash = "sha256:423872177410cb61683566dc3932db7a76f661a5d2f6f52f02a047f101e1c1de"}, - {file = "msal-1.30.0.tar.gz", hash = "sha256:b4bf00850092e465157d814efa24a18f788284c9a479491024d62903085ea2fb"}, + {file = "msal-1.31.0-py3-none-any.whl", hash = "sha256:96bc37cff82ebe4b160d5fc0f1196f6ca8b50e274ecd0ec5bf69c438514086e7"}, + {file = "msal-1.31.0.tar.gz", hash = "sha256:2c4f189cf9cc8f00c80045f66d39b7c0f3ed45873fd3d1f2af9f22db2e12ff4b"}, ] [package.dependencies] -cryptography = ">=2.5,<45" +cryptography = ">=2.5,<46" PyJWT = {version = ">=1.0.0,<3", extras = ["crypto"]} requests = ">=2.0.0,<3" [package.extras] -broker = ["pymsalruntime (>=0.13.2,<0.17)"] +broker = ["pymsalruntime (>=0.14,<0.18)", "pymsalruntime (>=0.17,<0.18)"] [[package]] name = "msal-extensions" @@ -4807,23 +5373,6 @@ files = [ msal = ">=1.29,<2" portalocker = ">=1.4,<3" -[[package]] -name = "msg-parser" -version = "1.2.0" -description = "This module enables reading, parsing and converting Microsoft Outlook MSG E-Mail files." -optional = false -python-versions = ">=3.4" -files = [ - {file = "msg_parser-1.2.0-py2.py3-none-any.whl", hash = "sha256:d47a2f0b2a359cb189fad83cc991b63ea781ecc70d91410324273fbf93e95375"}, - {file = "msg_parser-1.2.0.tar.gz", hash = "sha256:0de858d4fcebb6c8f6f028da83a17a20fe01cdce67c490779cf43b3b0162aa66"}, -] - -[package.dependencies] -olefile = ">=0.46" - -[package.extras] -rtf = ["compressed-rtf (>=1.0.5)"] - [[package]] name = "msrest" version = "0.7.1" @@ -4847,103 +5396,136 @@ async = ["aiodns", "aiohttp (>=3.0)"] [[package]] name = "multidict" -version = "6.0.5" +version = "6.1.0" description = "multidict implementation" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" +files = [ + {file = "multidict-6.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3380252550e372e8511d49481bd836264c009adb826b23fefcc5dd3c69692f60"}, + {file = "multidict-6.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:99f826cbf970077383d7de805c0681799491cb939c25450b9b5b3ced03ca99f1"}, + {file = "multidict-6.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a114d03b938376557927ab23f1e950827c3b893ccb94b62fd95d430fd0e5cf53"}, + {file = "multidict-6.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1c416351ee6271b2f49b56ad7f308072f6f44b37118d69c2cad94f3fa8a40d5"}, + {file = "multidict-6.1.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6b5d83030255983181005e6cfbac1617ce9746b219bc2aad52201ad121226581"}, + {file = "multidict-6.1.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3e97b5e938051226dc025ec80980c285b053ffb1e25a3db2a3aa3bc046bf7f56"}, + {file = "multidict-6.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d618649d4e70ac6efcbba75be98b26ef5078faad23592f9b51ca492953012429"}, + {file = "multidict-6.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:10524ebd769727ac77ef2278390fb0068d83f3acb7773792a5080f2b0abf7748"}, + {file = "multidict-6.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ff3827aef427c89a25cc96ded1759271a93603aba9fb977a6d264648ebf989db"}, + {file = "multidict-6.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:06809f4f0f7ab7ea2cabf9caca7d79c22c0758b58a71f9d32943ae13c7ace056"}, + {file = "multidict-6.1.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:f179dee3b863ab1c59580ff60f9d99f632f34ccb38bf67a33ec6b3ecadd0fd76"}, + {file = "multidict-6.1.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:aaed8b0562be4a0876ee3b6946f6869b7bcdb571a5d1496683505944e268b160"}, + {file = "multidict-6.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3c8b88a2ccf5493b6c8da9076fb151ba106960a2df90c2633f342f120751a9e7"}, + {file = "multidict-6.1.0-cp310-cp310-win32.whl", hash = "sha256:4a9cb68166a34117d6646c0023c7b759bf197bee5ad4272f420a0141d7eb03a0"}, + {file = "multidict-6.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:20b9b5fbe0b88d0bdef2012ef7dee867f874b72528cf1d08f1d59b0e3850129d"}, + {file = "multidict-6.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:3efe2c2cb5763f2f1b275ad2bf7a287d3f7ebbef35648a9726e3b69284a4f3d6"}, + {file = "multidict-6.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c7053d3b0353a8b9de430a4f4b4268ac9a4fb3481af37dfe49825bf45ca24156"}, + {file = "multidict-6.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:27e5fc84ccef8dfaabb09d82b7d179c7cf1a3fbc8a966f8274fcb4ab2eb4cadb"}, + {file = "multidict-6.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0e2b90b43e696f25c62656389d32236e049568b39320e2735d51f08fd362761b"}, + {file = "multidict-6.1.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d83a047959d38a7ff552ff94be767b7fd79b831ad1cd9920662db05fec24fe72"}, + {file = "multidict-6.1.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d1a9dd711d0877a1ece3d2e4fea11a8e75741ca21954c919406b44e7cf971304"}, + {file = "multidict-6.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec2abea24d98246b94913b76a125e855eb5c434f7c46546046372fe60f666351"}, + {file = "multidict-6.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4867cafcbc6585e4b678876c489b9273b13e9fff9f6d6d66add5e15d11d926cb"}, + {file = "multidict-6.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:5b48204e8d955c47c55b72779802b219a39acc3ee3d0116d5080c388970b76e3"}, + {file = "multidict-6.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:d8fff389528cad1618fb4b26b95550327495462cd745d879a8c7c2115248e399"}, + {file = "multidict-6.1.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:a7a9541cd308eed5e30318430a9c74d2132e9a8cb46b901326272d780bf2d423"}, + {file = "multidict-6.1.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:da1758c76f50c39a2efd5e9859ce7d776317eb1dd34317c8152ac9251fc574a3"}, + {file = "multidict-6.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c943a53e9186688b45b323602298ab727d8865d8c9ee0b17f8d62d14b56f0753"}, + {file = "multidict-6.1.0-cp311-cp311-win32.whl", hash = "sha256:90f8717cb649eea3504091e640a1b8568faad18bd4b9fcd692853a04475a4b80"}, + {file = "multidict-6.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:82176036e65644a6cc5bd619f65f6f19781e8ec2e5330f51aa9ada7504cc1926"}, + {file = "multidict-6.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:b04772ed465fa3cc947db808fa306d79b43e896beb677a56fb2347ca1a49c1fa"}, + {file = "multidict-6.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6180c0ae073bddeb5a97a38c03f30c233e0a4d39cd86166251617d1bbd0af436"}, + {file = "multidict-6.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:071120490b47aa997cca00666923a83f02c7fbb44f71cf7f136df753f7fa8761"}, + {file = "multidict-6.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50b3a2710631848991d0bf7de077502e8994c804bb805aeb2925a981de58ec2e"}, + {file = "multidict-6.1.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b58c621844d55e71c1b7f7c498ce5aa6985d743a1a59034c57a905b3f153c1ef"}, + {file = "multidict-6.1.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55b6d90641869892caa9ca42ff913f7ff1c5ece06474fbd32fb2cf6834726c95"}, + {file = "multidict-6.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b820514bfc0b98a30e3d85462084779900347e4d49267f747ff54060cc33925"}, + {file = "multidict-6.1.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:10a9b09aba0c5b48c53761b7c720aaaf7cf236d5fe394cd399c7ba662d5f9966"}, + {file = "multidict-6.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1e16bf3e5fc9f44632affb159d30a437bfe286ce9e02754759be5536b169b305"}, + {file = "multidict-6.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:76f364861c3bfc98cbbcbd402d83454ed9e01a5224bb3a28bf70002a230f73e2"}, + {file = "multidict-6.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:820c661588bd01a0aa62a1283f20d2be4281b086f80dad9e955e690c75fb54a2"}, + {file = "multidict-6.1.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:0e5f362e895bc5b9e67fe6e4ded2492d8124bdf817827f33c5b46c2fe3ffaca6"}, + {file = "multidict-6.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3ec660d19bbc671e3a6443325f07263be452c453ac9e512f5eb935e7d4ac28b3"}, + {file = "multidict-6.1.0-cp312-cp312-win32.whl", hash = "sha256:58130ecf8f7b8112cdb841486404f1282b9c86ccb30d3519faf301b2e5659133"}, + {file = "multidict-6.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:188215fc0aafb8e03341995e7c4797860181562380f81ed0a87ff455b70bf1f1"}, + {file = "multidict-6.1.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:d569388c381b24671589335a3be6e1d45546c2988c2ebe30fdcada8457a31008"}, + {file = "multidict-6.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:052e10d2d37810b99cc170b785945421141bf7bb7d2f8799d431e7db229c385f"}, + {file = "multidict-6.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f90c822a402cb865e396a504f9fc8173ef34212a342d92e362ca498cad308e28"}, + {file = "multidict-6.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b225d95519a5bf73860323e633a664b0d85ad3d5bede6d30d95b35d4dfe8805b"}, + {file = "multidict-6.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:23bfd518810af7de1116313ebd9092cb9aa629beb12f6ed631ad53356ed6b86c"}, + {file = "multidict-6.1.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c09fcfdccdd0b57867577b719c69e347a436b86cd83747f179dbf0cc0d4c1f3"}, + {file = "multidict-6.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf6bea52ec97e95560af5ae576bdac3aa3aae0b6758c6efa115236d9e07dae44"}, + {file = "multidict-6.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57feec87371dbb3520da6192213c7d6fc892d5589a93db548331954de8248fd2"}, + {file = "multidict-6.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0c3f390dc53279cbc8ba976e5f8035eab997829066756d811616b652b00a23a3"}, + {file = "multidict-6.1.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:59bfeae4b25ec05b34f1956eaa1cb38032282cd4dfabc5056d0a1ec4d696d3aa"}, + {file = "multidict-6.1.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:b2f59caeaf7632cc633b5cf6fc449372b83bbdf0da4ae04d5be36118e46cc0aa"}, + {file = "multidict-6.1.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:37bb93b2178e02b7b618893990941900fd25b6b9ac0fa49931a40aecdf083fe4"}, + {file = "multidict-6.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4e9f48f58c2c523d5a06faea47866cd35b32655c46b443f163d08c6d0ddb17d6"}, + {file = "multidict-6.1.0-cp313-cp313-win32.whl", hash = "sha256:3a37ffb35399029b45c6cc33640a92bef403c9fd388acce75cdc88f58bd19a81"}, + {file = "multidict-6.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:e9aa71e15d9d9beaad2c6b9319edcdc0a49a43ef5c0a4c8265ca9ee7d6c67774"}, + {file = "multidict-6.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:db7457bac39421addd0c8449933ac32d8042aae84a14911a757ae6ca3eef1392"}, + {file = "multidict-6.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d094ddec350a2fb899fec68d8353c78233debde9b7d8b4beeafa70825f1c281a"}, + {file = "multidict-6.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5845c1fd4866bb5dd3125d89b90e57ed3138241540897de748cdf19de8a2fca2"}, + {file = "multidict-6.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9079dfc6a70abe341f521f78405b8949f96db48da98aeb43f9907f342f627cdc"}, + {file = "multidict-6.1.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3914f5aaa0f36d5d60e8ece6a308ee1c9784cd75ec8151062614657a114c4478"}, + {file = "multidict-6.1.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c08be4f460903e5a9d0f76818db3250f12e9c344e79314d1d570fc69d7f4eae4"}, + {file = "multidict-6.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d093be959277cb7dee84b801eb1af388b6ad3ca6a6b6bf1ed7585895789d027d"}, + {file = "multidict-6.1.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3702ea6872c5a2a4eeefa6ffd36b042e9773f05b1f37ae3ef7264b1163c2dcf6"}, + {file = "multidict-6.1.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:2090f6a85cafc5b2db085124d752757c9d251548cedabe9bd31afe6363e0aff2"}, + {file = "multidict-6.1.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:f67f217af4b1ff66c68a87318012de788dd95fcfeb24cc889011f4e1c7454dfd"}, + {file = "multidict-6.1.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:189f652a87e876098bbc67b4da1049afb5f5dfbaa310dd67c594b01c10388db6"}, + {file = "multidict-6.1.0-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:6bb5992037f7a9eff7991ebe4273ea7f51f1c1c511e6a2ce511d0e7bdb754492"}, + {file = "multidict-6.1.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:ac10f4c2b9e770c4e393876e35a7046879d195cd123b4f116d299d442b335bcd"}, + {file = "multidict-6.1.0-cp38-cp38-win32.whl", hash = "sha256:e27bbb6d14416713a8bd7aaa1313c0fc8d44ee48d74497a0ff4c3a1b6ccb5167"}, + {file = "multidict-6.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:22f3105d4fb15c8f57ff3959a58fcab6ce36814486500cd7485651230ad4d4ef"}, + {file = "multidict-6.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:4e18b656c5e844539d506a0a06432274d7bd52a7487e6828c63a63d69185626c"}, + {file = "multidict-6.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a185f876e69897a6f3325c3f19f26a297fa058c5e456bfcff8015e9a27e83ae1"}, + {file = "multidict-6.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ab7c4ceb38d91570a650dba194e1ca87c2b543488fe9309b4212694174fd539c"}, + {file = "multidict-6.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e617fb6b0b6953fffd762669610c1c4ffd05632c138d61ac7e14ad187870669c"}, + {file = "multidict-6.1.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:16e5f4bf4e603eb1fdd5d8180f1a25f30056f22e55ce51fb3d6ad4ab29f7d96f"}, + {file = "multidict-6.1.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f4c035da3f544b1882bac24115f3e2e8760f10a0107614fc9839fd232200b875"}, + {file = "multidict-6.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:957cf8e4b6e123a9eea554fa7ebc85674674b713551de587eb318a2df3e00255"}, + {file = "multidict-6.1.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:483a6aea59cb89904e1ceabd2b47368b5600fb7de78a6e4a2c2987b2d256cf30"}, + {file = "multidict-6.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:87701f25a2352e5bf7454caa64757642734da9f6b11384c1f9d1a8e699758057"}, + {file = "multidict-6.1.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:682b987361e5fd7a139ed565e30d81fd81e9629acc7d925a205366877d8c8657"}, + {file = "multidict-6.1.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ce2186a7df133a9c895dea3331ddc5ddad42cdd0d1ea2f0a51e5d161e4762f28"}, + {file = "multidict-6.1.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:9f636b730f7e8cb19feb87094949ba54ee5357440b9658b2a32a5ce4bce53972"}, + {file = "multidict-6.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:73eae06aa53af2ea5270cc066dcaf02cc60d2994bbb2c4ef5764949257d10f43"}, + {file = "multidict-6.1.0-cp39-cp39-win32.whl", hash = "sha256:1ca0083e80e791cffc6efce7660ad24af66c8d4079d2a750b29001b53ff59ada"}, + {file = "multidict-6.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:aa466da5b15ccea564bdab9c89175c762bc12825f4659c11227f515cee76fa4a"}, + {file = "multidict-6.1.0-py3-none-any.whl", hash = "sha256:48e171e52d1c4d33888e529b999e5900356b9ae588c2f09a52dcefb158b27506"}, + {file = "multidict-6.1.0.tar.gz", hash = "sha256:22ae2ebf9b0c69d206c003e2f6a914ea33f0a932d4aa16f236afc049d9958f4a"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.11\""} + +[[package]] +name = "multiprocess" +version = "0.70.17" +description = "better multiprocessing and multithreading in Python" +optional = false +python-versions = ">=3.8" files = [ - {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b644ae063c10e7f324ab1ab6b548bdf6f8b47f3ec234fef1093bc2735e5f9"}, - {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:896ebdcf62683551312c30e20614305f53125750803b614e9e6ce74a96232604"}, - {file = "multidict-6.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:411bf8515f3be9813d06004cac41ccf7d1cd46dfe233705933dd163b60e37600"}, - {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d147090048129ce3c453f0292e7697d333db95e52616b3793922945804a433c"}, - {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:215ed703caf15f578dca76ee6f6b21b7603791ae090fbf1ef9d865571039ade5"}, - {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c6390cf87ff6234643428991b7359b5f59cc15155695deb4eda5c777d2b880f"}, - {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21fd81c4ebdb4f214161be351eb5bcf385426bf023041da2fd9e60681f3cebae"}, - {file = "multidict-6.0.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3cc2ad10255f903656017363cd59436f2111443a76f996584d1077e43ee51182"}, - {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6939c95381e003f54cd4c5516740faba40cf5ad3eeff460c3ad1d3e0ea2549bf"}, - {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:220dd781e3f7af2c2c1053da9fa96d9cf3072ca58f057f4c5adaaa1cab8fc442"}, - {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:766c8f7511df26d9f11cd3a8be623e59cca73d44643abab3f8c8c07620524e4a"}, - {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:fe5d7785250541f7f5019ab9cba2c71169dc7d74d0f45253f8313f436458a4ef"}, - {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c1c1496e73051918fcd4f58ff2e0f2f3066d1c76a0c6aeffd9b45d53243702cc"}, - {file = "multidict-6.0.5-cp310-cp310-win32.whl", hash = "sha256:7afcdd1fc07befad18ec4523a782cde4e93e0a2bf71239894b8d61ee578c1319"}, - {file = "multidict-6.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:99f60d34c048c5c2fabc766108c103612344c46e35d4ed9ae0673d33c8fb26e8"}, - {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f285e862d2f153a70586579c15c44656f888806ed0e5b56b64489afe4a2dbfba"}, - {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:53689bb4e102200a4fafa9de9c7c3c212ab40a7ab2c8e474491914d2305f187e"}, - {file = "multidict-6.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:612d1156111ae11d14afaf3a0669ebf6c170dbb735e510a7438ffe2369a847fd"}, - {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7be7047bd08accdb7487737631d25735c9a04327911de89ff1b26b81745bd4e3"}, - {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de170c7b4fe6859beb8926e84f7d7d6c693dfe8e27372ce3b76f01c46e489fcf"}, - {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04bde7a7b3de05732a4eb39c94574db1ec99abb56162d6c520ad26f83267de29"}, - {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85f67aed7bb647f93e7520633d8f51d3cbc6ab96957c71272b286b2f30dc70ed"}, - {file = "multidict-6.0.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425bf820055005bfc8aa9a0b99ccb52cc2f4070153e34b701acc98d201693733"}, - {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d3eb1ceec286eba8220c26f3b0096cf189aea7057b6e7b7a2e60ed36b373b77f"}, - {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7901c05ead4b3fb75113fb1dd33eb1253c6d3ee37ce93305acd9d38e0b5f21a4"}, - {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e0e79d91e71b9867c73323a3444724d496c037e578a0e1755ae159ba14f4f3d1"}, - {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:29bfeb0dff5cb5fdab2023a7a9947b3b4af63e9c47cae2a10ad58394b517fddc"}, - {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e030047e85cbcedbfc073f71836d62dd5dadfbe7531cae27789ff66bc551bd5e"}, - {file = "multidict-6.0.5-cp311-cp311-win32.whl", hash = "sha256:2f4848aa3baa109e6ab81fe2006c77ed4d3cd1e0ac2c1fbddb7b1277c168788c"}, - {file = "multidict-6.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:2faa5ae9376faba05f630d7e5e6be05be22913782b927b19d12b8145968a85ea"}, - {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:51d035609b86722963404f711db441cf7134f1889107fb171a970c9701f92e1e"}, - {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cbebcd5bcaf1eaf302617c114aa67569dd3f090dd0ce8ba9e35e9985b41ac35b"}, - {file = "multidict-6.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ffc42c922dbfddb4a4c3b438eb056828719f07608af27d163191cb3e3aa6cc5"}, - {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ceb3b7e6a0135e092de86110c5a74e46bda4bd4fbfeeb3a3bcec79c0f861e450"}, - {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:79660376075cfd4b2c80f295528aa6beb2058fd289f4c9252f986751a4cd0496"}, - {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4428b29611e989719874670fd152b6625500ad6c686d464e99f5aaeeaca175a"}, - {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d84a5c3a5f7ce6db1f999fb9438f686bc2e09d38143f2d93d8406ed2dd6b9226"}, - {file = "multidict-6.0.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76c0de87358b192de7ea9649beb392f107dcad9ad27276324c24c91774ca5271"}, - {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:79a6d2ba910adb2cbafc95dad936f8b9386e77c84c35bc0add315b856d7c3abb"}, - {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:92d16a3e275e38293623ebf639c471d3e03bb20b8ebb845237e0d3664914caef"}, - {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:fb616be3538599e797a2017cccca78e354c767165e8858ab5116813146041a24"}, - {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:14c2976aa9038c2629efa2c148022ed5eb4cb939e15ec7aace7ca932f48f9ba6"}, - {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:435a0984199d81ca178b9ae2c26ec3d49692d20ee29bc4c11a2a8d4514c67eda"}, - {file = "multidict-6.0.5-cp312-cp312-win32.whl", hash = "sha256:9fe7b0653ba3d9d65cbe7698cca585bf0f8c83dbbcc710db9c90f478e175f2d5"}, - {file = "multidict-6.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:01265f5e40f5a17f8241d52656ed27192be03bfa8764d88e8220141d1e4b3556"}, - {file = "multidict-6.0.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:19fe01cea168585ba0f678cad6f58133db2aa14eccaf22f88e4a6dccadfad8b3"}, - {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf7a982604375a8d49b6cc1b781c1747f243d91b81035a9b43a2126c04766f5"}, - {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:107c0cdefe028703fb5dafe640a409cb146d44a6ae201e55b35a4af8e95457dd"}, - {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:403c0911cd5d5791605808b942c88a8155c2592e05332d2bf78f18697a5fa15e"}, - {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aeaf541ddbad8311a87dd695ed9642401131ea39ad7bc8cf3ef3967fd093b626"}, - {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4972624066095e52b569e02b5ca97dbd7a7ddd4294bf4e7247d52635630dd83"}, - {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d946b0a9eb8aaa590df1fe082cee553ceab173e6cb5b03239716338629c50c7a"}, - {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b55358304d7a73d7bdf5de62494aaf70bd33015831ffd98bc498b433dfe5b10c"}, - {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:a3145cb08d8625b2d3fee1b2d596a8766352979c9bffe5d7833e0503d0f0b5e5"}, - {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d65f25da8e248202bd47445cec78e0025c0fe7582b23ec69c3b27a640dd7a8e3"}, - {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:c9bf56195c6bbd293340ea82eafd0071cb3d450c703d2c93afb89f93b8386ccc"}, - {file = "multidict-6.0.5-cp37-cp37m-win32.whl", hash = "sha256:69db76c09796b313331bb7048229e3bee7928eb62bab5e071e9f7fcc4879caee"}, - {file = "multidict-6.0.5-cp37-cp37m-win_amd64.whl", hash = "sha256:fce28b3c8a81b6b36dfac9feb1de115bab619b3c13905b419ec71d03a3fc1423"}, - {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:76f067f5121dcecf0d63a67f29080b26c43c71a98b10c701b0677e4a065fbd54"}, - {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b82cc8ace10ab5bd93235dfaab2021c70637005e1ac787031f4d1da63d493c1d"}, - {file = "multidict-6.0.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5cb241881eefd96b46f89b1a056187ea8e9ba14ab88ba632e68d7a2ecb7aadf7"}, - {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8e94e6912639a02ce173341ff62cc1201232ab86b8a8fcc05572741a5dc7d93"}, - {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09a892e4a9fb47331da06948690ae38eaa2426de97b4ccbfafbdcbe5c8f37ff8"}, - {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55205d03e8a598cfc688c71ca8ea5f66447164efff8869517f175ea632c7cb7b"}, - {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37b15024f864916b4951adb95d3a80c9431299080341ab9544ed148091b53f50"}, - {file = "multidict-6.0.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2a1dee728b52b33eebff5072817176c172050d44d67befd681609b4746e1c2e"}, - {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:edd08e6f2f1a390bf137080507e44ccc086353c8e98c657e666c017718561b89"}, - {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:60d698e8179a42ec85172d12f50b1668254628425a6bd611aba022257cac1386"}, - {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:3d25f19500588cbc47dc19081d78131c32637c25804df8414463ec908631e453"}, - {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:4cc0ef8b962ac7a5e62b9e826bd0cd5040e7d401bc45a6835910ed699037a461"}, - {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:eca2e9d0cc5a889850e9bbd68e98314ada174ff6ccd1129500103df7a94a7a44"}, - {file = "multidict-6.0.5-cp38-cp38-win32.whl", hash = "sha256:4a6a4f196f08c58c59e0b8ef8ec441d12aee4125a7d4f4fef000ccb22f8d7241"}, - {file = "multidict-6.0.5-cp38-cp38-win_amd64.whl", hash = "sha256:0275e35209c27a3f7951e1ce7aaf93ce0d163b28948444bec61dd7badc6d3f8c"}, - {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e7be68734bd8c9a513f2b0cfd508802d6609da068f40dc57d4e3494cefc92929"}, - {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1d9ea7a7e779d7a3561aade7d596649fbecfa5c08a7674b11b423783217933f9"}, - {file = "multidict-6.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ea1456df2a27c73ce51120fa2f519f1bea2f4a03a917f4a43c8707cf4cbbae1a"}, - {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf590b134eb70629e350691ecca88eac3e3b8b3c86992042fb82e3cb1830d5e1"}, - {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5c0631926c4f58e9a5ccce555ad7747d9a9f8b10619621f22f9635f069f6233e"}, - {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dce1c6912ab9ff5f179eaf6efe7365c1f425ed690b03341911bf4939ef2f3046"}, - {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0868d64af83169e4d4152ec612637a543f7a336e4a307b119e98042e852ad9c"}, - {file = "multidict-6.0.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:141b43360bfd3bdd75f15ed811850763555a251e38b2405967f8e25fb43f7d40"}, - {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7df704ca8cf4a073334e0427ae2345323613e4df18cc224f647f251e5e75a527"}, - {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6214c5a5571802c33f80e6c84713b2c79e024995b9c5897f794b43e714daeec9"}, - {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:cd6c8fca38178e12c00418de737aef1261576bd1b6e8c6134d3e729a4e858b38"}, - {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:e02021f87a5b6932fa6ce916ca004c4d441509d33bbdbeca70d05dff5e9d2479"}, - {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ebd8d160f91a764652d3e51ce0d2956b38efe37c9231cd82cfc0bed2e40b581c"}, - {file = "multidict-6.0.5-cp39-cp39-win32.whl", hash = "sha256:04da1bb8c8dbadf2a18a452639771951c662c5ad03aefe4884775454be322c9b"}, - {file = "multidict-6.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:d6f6d4f185481c9669b9447bf9d9cf3b95a0e9df9d169bbc17e363b7d5487755"}, - {file = "multidict-6.0.5-py3-none-any.whl", hash = "sha256:0d63c74e3d7ab26de115c49bffc92cc77ed23395303d496eae515d4204a625e7"}, - {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"}, + {file = "multiprocess-0.70.17-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:7ddb24e5bcdb64e90ec5543a1f05a39463068b6d3b804aa3f2a4e16ec28562d6"}, + {file = "multiprocess-0.70.17-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d729f55198a3579f6879766a6d9b72b42d4b320c0dcb7844afb774d75b573c62"}, + {file = "multiprocess-0.70.17-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:c2c82d0375baed8d8dd0d8c38eb87c5ae9c471f8e384ad203a36f095ee860f67"}, + {file = "multiprocess-0.70.17-pp38-pypy38_pp73-macosx_10_9_arm64.whl", hash = "sha256:a22a6b1a482b80eab53078418bb0f7025e4f7d93cc8e1f36481477a023884861"}, + {file = "multiprocess-0.70.17-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:349525099a0c9ac5936f0488b5ee73199098dac3ac899d81d326d238f9fd3ccd"}, + {file = "multiprocess-0.70.17-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:27b8409c02b5dd89d336107c101dfbd1530a2cd4fd425fc27dcb7adb6e0b47bf"}, + {file = "multiprocess-0.70.17-pp39-pypy39_pp73-macosx_10_13_arm64.whl", hash = "sha256:2ea0939b0f4760a16a548942c65c76ff5afd81fbf1083c56ae75e21faf92e426"}, + {file = "multiprocess-0.70.17-pp39-pypy39_pp73-macosx_10_13_x86_64.whl", hash = "sha256:2b12e081df87ab755190e227341b2c3b17ee6587e9c82fecddcbe6aa812cd7f7"}, + {file = "multiprocess-0.70.17-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:a0f01cd9d079af7a8296f521dc03859d1a414d14c1e2b6e676ef789333421c95"}, + {file = "multiprocess-0.70.17-py310-none-any.whl", hash = "sha256:38357ca266b51a2e22841b755d9a91e4bb7b937979a54d411677111716c32744"}, + {file = "multiprocess-0.70.17-py311-none-any.whl", hash = "sha256:2884701445d0177aec5bd5f6ee0df296773e4fb65b11903b94c613fb46cfb7d1"}, + {file = "multiprocess-0.70.17-py312-none-any.whl", hash = "sha256:2818af14c52446b9617d1b0755fa70ca2f77c28b25ed97bdaa2c69a22c47b46c"}, + {file = "multiprocess-0.70.17-py313-none-any.whl", hash = "sha256:20c28ca19079a6c879258103a6d60b94d4ffe2d9da07dda93fb1c8bc6243f522"}, + {file = "multiprocess-0.70.17-py38-none-any.whl", hash = "sha256:1d52f068357acd1e5bbc670b273ef8f81d57863235d9fbf9314751886e141968"}, + {file = "multiprocess-0.70.17-py39-none-any.whl", hash = "sha256:c3feb874ba574fbccfb335980020c1ac631fbf2a3f7bee4e2042ede62558a021"}, + {file = "multiprocess-0.70.17.tar.gz", hash = "sha256:4ae2f11a3416809ebc9a48abfc8b14ecce0652a0944731a1493a3c1ba44ff57a"}, ] +[package.dependencies] +dill = ">=0.3.9" + [[package]] name = "multitasking" version = "0.0.11" @@ -4966,6 +5548,17 @@ files = [ {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, ] +[[package]] +name = "nest-asyncio" +version = "1.6.0" +description = "Patch asyncio to allow nested event loops" +optional = false +python-versions = ">=3.5" +files = [ + {file = "nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c"}, + {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"}, +] + [[package]] name = "newspaper3k" version = "0.2.8" @@ -4994,13 +5587,13 @@ tldextract = ">=2.0.1" [[package]] name = "nltk" -version = "3.8.1" +version = "3.9.1" description = "Natural Language Toolkit" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "nltk-3.8.1-py3-none-any.whl", hash = "sha256:fd5c9109f976fa86bcadba8f91e47f5e9293bd034474752e92a520f81c93dda5"}, - {file = "nltk-3.8.1.zip", hash = "sha256:1834da3d0682cba4f2cede2f9aad6b0fafb6461ba451db0efb6f9c39798d64d3"}, + {file = "nltk-3.9.1-py3-none-any.whl", hash = "sha256:4fa26829c5b00715afe3061398a8989dc643b92ce7dd93fb4585a70930d168a1"}, + {file = "nltk-3.9.1.tar.gz", hash = "sha256:87d127bd3de4bd89a4f81265e5fa59cb1b199b27440175370f7417d2bc7ae868"}, ] [package.dependencies] @@ -5017,15 +5610,45 @@ plot = ["matplotlib"] tgrep = ["pyparsing"] twitter = ["twython"] +[[package]] +name = "nomic" +version = "3.1.3" +description = "The official Nomic python client." +optional = false +python-versions = "*" +files = [ + {file = "nomic-3.1.3.tar.gz", hash = "sha256:b06744b79fbe47451874ca7b272cafa1bb272cfb82acc79c64abfc943a98e035"}, +] + +[package.dependencies] +click = "*" +jsonlines = "*" +loguru = "*" +numpy = "*" +pandas = "*" +pillow = "*" +pyarrow = "*" +pydantic = "*" +pyjwt = "*" +requests = "*" +rich = "*" +tqdm = "*" + +[package.extras] +all = ["nomic[aws,local]"] +aws = ["boto3", "sagemaker"] +dev = ["black (==24.3.0)", "cairosvg", "coverage", "isort", "mkautodoc", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]", "myst-parser", "nomic[all]", "pandas", "pillow", "pylint", "pyright (<=1.1.377)", "pytest", "pytorch-lightning", "twine"] +local = ["gpt4all (>=2.5.0,<3)"] + [[package]] name = "novita-client" -version = "0.5.6" +version = "0.5.7" description = "novita SDK for Python" optional = false python-versions = ">=3.6" files = [ - {file = "novita_client-0.5.6-py3-none-any.whl", hash = "sha256:9fa6cfd12f13a75c7da42b27f811a560b0320da24cf256480f517bde479bc57c"}, - {file = "novita_client-0.5.6.tar.gz", hash = "sha256:2e4d956903d5da39d43127a41dcb020ae40322d2a6196413071b94b3d6988b98"}, + {file = "novita_client-0.5.7-py3-none-any.whl", hash = "sha256:844a4c09c98328c8d4f72e1d3f63f76285c2963dcc37ccb2de41cbfdbe7fa51d"}, + {file = "novita_client-0.5.7.tar.gz", hash = "sha256:65baf748757aafd8ab080a64f9ab069a40c0810fc1fa9be9c26596988a0aa4b4"}, ] [package.dependencies] @@ -5169,6 +5792,25 @@ rsa = ["cryptography (>=3.0.0)"] signals = ["blinker (>=1.4.0)"] signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] +[[package]] +name = "oci" +version = "2.135.2" +description = "Oracle Cloud Infrastructure Python SDK" +optional = false +python-versions = "*" +files = [ + {file = "oci-2.135.2-py3-none-any.whl", hash = "sha256:5213319244e1c7f108bcb417322f33f01f043fd9636d4063574039f5fdf4e4f7"}, + {file = "oci-2.135.2.tar.gz", hash = "sha256:520f78983c5246eae80dd5ecfd05e3a565c8b98d02ef0c1b11ba1f61bcccb61d"}, +] + +[package.dependencies] +certifi = "*" +circuitbreaker = {version = ">=1.3.1,<3.0.0", markers = "python_version >= \"3.7\""} +cryptography = ">=3.2.1,<46.0.0" +pyOpenSSL = ">=17.5.0,<25.0.0" +python-dateutil = ">=2.5.3,<3.0.0" +pytz = ">=2016.10" + [[package]] name = "odfpy" version = "1.4.1" @@ -5198,69 +5840,129 @@ tests = ["pytest", "pytest-cov"] [[package]] name = "onnxruntime" -version = "1.18.1" +version = "1.19.2" description = "ONNX Runtime is a runtime accelerator for Machine Learning models" optional = false python-versions = "*" files = [ - {file = "onnxruntime-1.18.1-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:29ef7683312393d4ba04252f1b287d964bd67d5e6048b94d2da3643986c74d80"}, - {file = "onnxruntime-1.18.1-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fc706eb1df06ddf55776e15a30519fb15dda7697f987a2bbda4962845e3cec05"}, - {file = "onnxruntime-1.18.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b7de69f5ced2a263531923fa68bbec52a56e793b802fcd81a03487b5e292bc3a"}, - {file = "onnxruntime-1.18.1-cp310-cp310-win32.whl", hash = "sha256:221e5b16173926e6c7de2cd437764492aa12b6811f45abd37024e7cf2ae5d7e3"}, - {file = "onnxruntime-1.18.1-cp310-cp310-win_amd64.whl", hash = "sha256:75211b619275199c861ee94d317243b8a0fcde6032e5a80e1aa9ded8ab4c6060"}, - {file = "onnxruntime-1.18.1-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:f26582882f2dc581b809cfa41a125ba71ad9e715738ec6402418df356969774a"}, - {file = "onnxruntime-1.18.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ef36f3a8b768506d02be349ac303fd95d92813ba3ba70304d40c3cd5c25d6a4c"}, - {file = "onnxruntime-1.18.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:170e711393e0618efa8ed27b59b9de0ee2383bd2a1f93622a97006a5ad48e434"}, - {file = "onnxruntime-1.18.1-cp311-cp311-win32.whl", hash = "sha256:9b6a33419b6949ea34e0dc009bc4470e550155b6da644571ecace4b198b0d88f"}, - {file = "onnxruntime-1.18.1-cp311-cp311-win_amd64.whl", hash = "sha256:5c1380a9f1b7788da742c759b6a02ba771fe1ce620519b2b07309decbd1a2fe1"}, - {file = "onnxruntime-1.18.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:31bd57a55e3f983b598675dfc7e5d6f0877b70ec9864b3cc3c3e1923d0a01919"}, - {file = "onnxruntime-1.18.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b9e03c4ba9f734500691a4d7d5b381cd71ee2f3ce80a1154ac8f7aed99d1ecaa"}, - {file = "onnxruntime-1.18.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:781aa9873640f5df24524f96f6070b8c550c66cb6af35710fd9f92a20b4bfbf6"}, - {file = "onnxruntime-1.18.1-cp312-cp312-win32.whl", hash = "sha256:3a2d9ab6254ca62adbb448222e630dc6883210f718065063518c8f93a32432be"}, - {file = "onnxruntime-1.18.1-cp312-cp312-win_amd64.whl", hash = "sha256:ad93c560b1c38c27c0275ffd15cd7f45b3ad3fc96653c09ce2931179982ff204"}, - {file = "onnxruntime-1.18.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:3b55dc9d3c67626388958a3eb7ad87eb7c70f75cb0f7ff4908d27b8b42f2475c"}, - {file = "onnxruntime-1.18.1-cp38-cp38-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f80dbcfb6763cc0177a31168b29b4bd7662545b99a19e211de8c734b657e0669"}, - {file = "onnxruntime-1.18.1-cp38-cp38-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f1ff2c61a16d6c8631796c54139bafea41ee7736077a0fc64ee8ae59432f5c58"}, - {file = "onnxruntime-1.18.1-cp38-cp38-win32.whl", hash = "sha256:219855bd272fe0c667b850bf1a1a5a02499269a70d59c48e6f27f9c8bcb25d02"}, - {file = "onnxruntime-1.18.1-cp38-cp38-win_amd64.whl", hash = "sha256:afdf16aa607eb9a2c60d5ca2d5abf9f448e90c345b6b94c3ed14f4fb7e6a2d07"}, - {file = "onnxruntime-1.18.1-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:128df253ade673e60cea0955ec9d0e89617443a6d9ce47c2d79eb3f72a3be3de"}, - {file = "onnxruntime-1.18.1-cp39-cp39-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9839491e77e5c5a175cab3621e184d5a88925ee297ff4c311b68897197f4cde9"}, - {file = "onnxruntime-1.18.1-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ad3187c1faff3ac15f7f0e7373ef4788c582cafa655a80fdbb33eaec88976c66"}, - {file = "onnxruntime-1.18.1-cp39-cp39-win32.whl", hash = "sha256:34657c78aa4e0b5145f9188b550ded3af626651b15017bf43d280d7e23dbf195"}, - {file = "onnxruntime-1.18.1-cp39-cp39-win_amd64.whl", hash = "sha256:9c14fd97c3ddfa97da5feef595e2c73f14c2d0ec1d4ecbea99c8d96603c89589"}, + {file = "onnxruntime-1.19.2-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:84fa57369c06cadd3c2a538ae2a26d76d583e7c34bdecd5769d71ca5c0fc750e"}, + {file = "onnxruntime-1.19.2-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bdc471a66df0c1cdef774accef69e9f2ca168c851ab5e4f2f3341512c7ef4666"}, + {file = "onnxruntime-1.19.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e3a4ce906105d99ebbe817f536d50a91ed8a4d1592553f49b3c23c4be2560ae6"}, + {file = "onnxruntime-1.19.2-cp310-cp310-win32.whl", hash = "sha256:4b3d723cc154c8ddeb9f6d0a8c0d6243774c6b5930847cc83170bfe4678fafb3"}, + {file = "onnxruntime-1.19.2-cp310-cp310-win_amd64.whl", hash = "sha256:17ed7382d2c58d4b7354fb2b301ff30b9bf308a1c7eac9546449cd122d21cae5"}, + {file = "onnxruntime-1.19.2-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:d863e8acdc7232d705d49e41087e10b274c42f09e259016a46f32c34e06dc4fd"}, + {file = "onnxruntime-1.19.2-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c1dfe4f660a71b31caa81fc298a25f9612815215a47b286236e61d540350d7b6"}, + {file = "onnxruntime-1.19.2-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a36511dc07c5c964b916697e42e366fa43c48cdb3d3503578d78cef30417cb84"}, + {file = "onnxruntime-1.19.2-cp311-cp311-win32.whl", hash = "sha256:50cbb8dc69d6befad4746a69760e5b00cc3ff0a59c6c3fb27f8afa20e2cab7e7"}, + {file = "onnxruntime-1.19.2-cp311-cp311-win_amd64.whl", hash = "sha256:1c3e5d415b78337fa0b1b75291e9ea9fb2a4c1f148eb5811e7212fed02cfffa8"}, + {file = "onnxruntime-1.19.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:68e7051bef9cfefcbb858d2d2646536829894d72a4130c24019219442b1dd2ed"}, + {file = "onnxruntime-1.19.2-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d2d366fbcc205ce68a8a3bde2185fd15c604d9645888703785b61ef174265168"}, + {file = "onnxruntime-1.19.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:477b93df4db467e9cbf34051662a4b27c18e131fa1836e05974eae0d6e4cf29b"}, + {file = "onnxruntime-1.19.2-cp312-cp312-win32.whl", hash = "sha256:9a174073dc5608fad05f7cf7f320b52e8035e73d80b0a23c80f840e5a97c0147"}, + {file = "onnxruntime-1.19.2-cp312-cp312-win_amd64.whl", hash = "sha256:190103273ea4507638ffc31d66a980594b237874b65379e273125150eb044857"}, + {file = "onnxruntime-1.19.2-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:636bc1d4cc051d40bc52e1f9da87fbb9c57d9d47164695dfb1c41646ea51ea66"}, + {file = "onnxruntime-1.19.2-cp38-cp38-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5bd8b875757ea941cbcfe01582970cc299893d1b65bd56731e326a8333f638a3"}, + {file = "onnxruntime-1.19.2-cp38-cp38-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b2046fc9560f97947bbc1acbe4c6d48585ef0f12742744307d3364b131ac5778"}, + {file = "onnxruntime-1.19.2-cp38-cp38-win32.whl", hash = "sha256:31c12840b1cde4ac1f7d27d540c44e13e34f2345cf3642762d2a3333621abb6a"}, + {file = "onnxruntime-1.19.2-cp38-cp38-win_amd64.whl", hash = "sha256:016229660adea180e9a32ce218b95f8f84860a200f0f13b50070d7d90e92956c"}, + {file = "onnxruntime-1.19.2-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:006c8d326835c017a9e9f74c9c77ebb570a71174a1e89fe078b29a557d9c3848"}, + {file = "onnxruntime-1.19.2-cp39-cp39-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df2a94179a42d530b936f154615b54748239c2908ee44f0d722cb4df10670f68"}, + {file = "onnxruntime-1.19.2-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fae4b4de45894b9ce7ae418c5484cbf0341db6813effec01bb2216091c52f7fb"}, + {file = "onnxruntime-1.19.2-cp39-cp39-win32.whl", hash = "sha256:dc5430f473e8706fff837ae01323be9dcfddd3ea471c900a91fa7c9b807ec5d3"}, + {file = "onnxruntime-1.19.2-cp39-cp39-win_amd64.whl", hash = "sha256:38475e29a95c5f6c62c2c603d69fc7d4c6ccbf4df602bd567b86ae1138881c49"}, ] [package.dependencies] coloredlogs = "*" flatbuffers = "*" -numpy = ">=1.21.6,<2.0" +numpy = ">=1.21.6" packaging = "*" protobuf = "*" sympy = "*" [[package]] name = "openai" -version = "1.29.0" +version = "1.52.2" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.29.0-py3-none-any.whl", hash = "sha256:c61cd12376c84362d406341f9e2f9a9d6b81c082b133b44484dc0f43954496b1"}, - {file = "openai-1.29.0.tar.gz", hash = "sha256:d5a769f485610cff8bae14343fa45a8b1d346be3d541fa5b28ccd040dbc8baf8"}, + {file = "openai-1.52.2-py3-none-any.whl", hash = "sha256:57e9e37bc407f39bb6ec3a27d7e8fb9728b2779936daa1fcf95df17d3edfaccc"}, + {file = "openai-1.52.2.tar.gz", hash = "sha256:87b7d0f69d85f5641678d414b7ee3082363647a5c66a462ed7f3ccb59582da0d"}, ] [package.dependencies] anyio = ">=3.5.0,<5" distro = ">=1.7.0,<2" httpx = ">=0.23.0,<1" +jiter = ">=0.4.0,<1" pydantic = ">=1.9.0,<3" sniffio = "*" tqdm = ">4" -typing-extensions = ">=4.7,<5" +typing-extensions = ">=4.11,<5" [package.extras] datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] +[[package]] +name = "opencensus" +version = "0.11.4" +description = "A stats collection and distributed tracing framework" +optional = false +python-versions = "*" +files = [ + {file = "opencensus-0.11.4-py2.py3-none-any.whl", hash = "sha256:a18487ce68bc19900336e0ff4655c5a116daf10c1b3685ece8d971bddad6a864"}, + {file = "opencensus-0.11.4.tar.gz", hash = "sha256:cbef87d8b8773064ab60e5c2a1ced58bbaa38a6d052c41aec224958ce544eff2"}, +] + +[package.dependencies] +google-api-core = {version = ">=1.0.0,<3.0.0", markers = "python_version >= \"3.6\""} +opencensus-context = ">=0.1.3" +six = ">=1.16,<2.0" + +[[package]] +name = "opencensus-context" +version = "0.1.3" +description = "OpenCensus Runtime Context" +optional = false +python-versions = "*" +files = [ + {file = "opencensus-context-0.1.3.tar.gz", hash = "sha256:a03108c3c10d8c80bb5ddf5c8a1f033161fa61972a9917f9b9b3a18517f0088c"}, + {file = "opencensus_context-0.1.3-py2.py3-none-any.whl", hash = "sha256:073bb0590007af276853009fac7e4bab1d523c3f03baf4cb4511ca38967c6039"}, +] + +[[package]] +name = "opencensus-ext-azure" +version = "1.1.13" +description = "OpenCensus Azure Monitor Exporter" +optional = false +python-versions = "*" +files = [ + {file = "opencensus-ext-azure-1.1.13.tar.gz", hash = "sha256:aec30472177005379ba56a702a097d618c5f57558e1bb6676ec75f948130692a"}, + {file = "opencensus_ext_azure-1.1.13-py2.py3-none-any.whl", hash = "sha256:06001fac6f8588ba00726a3a7c6c7f2fc88bc8ad12a65afdca657923085393dd"}, +] + +[package.dependencies] +azure-core = ">=1.12.0,<2.0.0" +azure-identity = ">=1.5.0,<2.0.0" +opencensus = ">=0.11.4,<1.0.0" +psutil = ">=5.6.3" +requests = ">=2.19.0" + +[[package]] +name = "opencensus-ext-logging" +version = "0.1.1" +description = "OpenCensus logging Integration" +optional = false +python-versions = "*" +files = [ + {file = "opencensus-ext-logging-0.1.1.tar.gz", hash = "sha256:c203b70f034151dada529f543af330ba17aaffec27d8a5267d03c713eb1de334"}, + {file = "opencensus_ext_logging-0.1.1-py2.py3-none-any.whl", hash = "sha256:cfdaf5da5d8b195ff3d1af87a4066a6621a28046173f6be4b0b6caec4a3ca89f"}, +] + +[package.dependencies] +opencensus = ">=0.8.0,<1.0.0" + [[package]] name = "openpyxl" version = "3.1.5" @@ -5301,42 +6003,42 @@ kerberos = ["requests-kerberos"] [[package]] name = "opentelemetry-api" -version = "1.26.0" +version = "1.27.0" description = "OpenTelemetry Python API" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_api-1.26.0-py3-none-any.whl", hash = "sha256:7d7ea33adf2ceda2dd680b18b1677e4152000b37ca76e679da71ff103b943064"}, - {file = "opentelemetry_api-1.26.0.tar.gz", hash = "sha256:2bd639e4bed5b18486fef0b5a520aaffde5a18fc225e808a1ac4df363f43a1ce"}, + {file = "opentelemetry_api-1.27.0-py3-none-any.whl", hash = "sha256:953d5871815e7c30c81b56d910c707588000fff7a3ca1c73e6531911d53065e7"}, + {file = "opentelemetry_api-1.27.0.tar.gz", hash = "sha256:ed673583eaa5f81b5ce5e86ef7cdaf622f88ef65f0b9aab40b843dcae5bef342"}, ] [package.dependencies] deprecated = ">=1.2.6" -importlib-metadata = ">=6.0,<=8.0.0" +importlib-metadata = ">=6.0,<=8.4.0" [[package]] name = "opentelemetry-exporter-otlp-proto-common" -version = "1.26.0" +version = "1.27.0" description = "OpenTelemetry Protobuf encoding" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_exporter_otlp_proto_common-1.26.0-py3-none-any.whl", hash = "sha256:ee4d8f8891a1b9c372abf8d109409e5b81947cf66423fd998e56880057afbc71"}, - {file = "opentelemetry_exporter_otlp_proto_common-1.26.0.tar.gz", hash = "sha256:bdbe50e2e22a1c71acaa0c8ba6efaadd58882e5a5978737a44a4c4b10d304c92"}, + {file = "opentelemetry_exporter_otlp_proto_common-1.27.0-py3-none-any.whl", hash = "sha256:675db7fffcb60946f3a5c43e17d1168a3307a94a930ecf8d2ea1f286f3d4f79a"}, + {file = "opentelemetry_exporter_otlp_proto_common-1.27.0.tar.gz", hash = "sha256:159d27cf49f359e3798c4c3eb8da6ef4020e292571bd8c5604a2a573231dd5c8"}, ] [package.dependencies] -opentelemetry-proto = "1.26.0" +opentelemetry-proto = "1.27.0" [[package]] name = "opentelemetry-exporter-otlp-proto-grpc" -version = "1.26.0" +version = "1.27.0" description = "OpenTelemetry Collector Protobuf over gRPC Exporter" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_exporter_otlp_proto_grpc-1.26.0-py3-none-any.whl", hash = "sha256:e2be5eff72ebcb010675b818e8d7c2e7d61ec451755b8de67a140bc49b9b0280"}, - {file = "opentelemetry_exporter_otlp_proto_grpc-1.26.0.tar.gz", hash = "sha256:a65b67a9a6b06ba1ec406114568e21afe88c1cdb29c464f2507d529eb906d8ae"}, + {file = "opentelemetry_exporter_otlp_proto_grpc-1.27.0-py3-none-any.whl", hash = "sha256:56b5bbd5d61aab05e300d9d62a6b3c134827bbd28d0b12f2649c2da368006c9e"}, + {file = "opentelemetry_exporter_otlp_proto_grpc-1.27.0.tar.gz", hash = "sha256:af6f72f76bcf425dfb5ad11c1a6d6eca2863b91e63575f89bb7b4b55099d968f"}, ] [package.dependencies] @@ -5344,19 +6046,19 @@ deprecated = ">=1.2.6" googleapis-common-protos = ">=1.52,<2.0" grpcio = ">=1.0.0,<2.0.0" opentelemetry-api = ">=1.15,<2.0" -opentelemetry-exporter-otlp-proto-common = "1.26.0" -opentelemetry-proto = "1.26.0" -opentelemetry-sdk = ">=1.26.0,<1.27.0" +opentelemetry-exporter-otlp-proto-common = "1.27.0" +opentelemetry-proto = "1.27.0" +opentelemetry-sdk = ">=1.27.0,<1.28.0" [[package]] name = "opentelemetry-instrumentation" -version = "0.47b0" +version = "0.48b0" description = "Instrumentation Tools & Auto Instrumentation for OpenTelemetry Python" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_instrumentation-0.47b0-py3-none-any.whl", hash = "sha256:88974ee52b1db08fc298334b51c19d47e53099c33740e48c4f084bd1afd052d5"}, - {file = "opentelemetry_instrumentation-0.47b0.tar.gz", hash = "sha256:96f9885e450c35e3f16a4f33145f2ebf620aea910c9fd74a392bbc0f807a350f"}, + {file = "opentelemetry_instrumentation-0.48b0-py3-none-any.whl", hash = "sha256:a69750dc4ba6a5c3eb67986a337185a25b739966d80479befe37b546fc870b44"}, + {file = "opentelemetry_instrumentation-0.48b0.tar.gz", hash = "sha256:94929685d906380743a71c3970f76b5f07476eea1834abd5dd9d17abfe23cc35"}, ] [package.dependencies] @@ -5366,55 +6068,55 @@ wrapt = ">=1.0.0,<2.0.0" [[package]] name = "opentelemetry-instrumentation-asgi" -version = "0.47b0" +version = "0.48b0" description = "ASGI instrumentation for OpenTelemetry" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_instrumentation_asgi-0.47b0-py3-none-any.whl", hash = "sha256:b798dc4957b3edc9dfecb47a4c05809036a4b762234c5071212fda39ead80ade"}, - {file = "opentelemetry_instrumentation_asgi-0.47b0.tar.gz", hash = "sha256:e78b7822c1bca0511e5e9610ec484b8994a81670375e570c76f06f69af7c506a"}, + {file = "opentelemetry_instrumentation_asgi-0.48b0-py3-none-any.whl", hash = "sha256:ddb1b5fc800ae66e85a4e2eca4d9ecd66367a8c7b556169d9e7b57e10676e44d"}, + {file = "opentelemetry_instrumentation_asgi-0.48b0.tar.gz", hash = "sha256:04c32174b23c7fa72ddfe192dad874954968a6a924608079af9952964ecdf785"}, ] [package.dependencies] asgiref = ">=3.0,<4.0" opentelemetry-api = ">=1.12,<2.0" -opentelemetry-instrumentation = "0.47b0" -opentelemetry-semantic-conventions = "0.47b0" -opentelemetry-util-http = "0.47b0" +opentelemetry-instrumentation = "0.48b0" +opentelemetry-semantic-conventions = "0.48b0" +opentelemetry-util-http = "0.48b0" [package.extras] instruments = ["asgiref (>=3.0,<4.0)"] [[package]] name = "opentelemetry-instrumentation-fastapi" -version = "0.47b0" +version = "0.48b0" description = "OpenTelemetry FastAPI Instrumentation" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_instrumentation_fastapi-0.47b0-py3-none-any.whl", hash = "sha256:5ac28dd401160b02e4f544a85a9e4f61a8cbe5b077ea0379d411615376a2bd21"}, - {file = "opentelemetry_instrumentation_fastapi-0.47b0.tar.gz", hash = "sha256:0c7c10b5d971e99a420678ffd16c5b1ea4f0db3b31b62faf305fbb03b4ebee36"}, + {file = "opentelemetry_instrumentation_fastapi-0.48b0-py3-none-any.whl", hash = "sha256:afeb820a59e139d3e5d96619600f11ce0187658b8ae9e3480857dd790bc024f2"}, + {file = "opentelemetry_instrumentation_fastapi-0.48b0.tar.gz", hash = "sha256:21a72563ea412c0b535815aeed75fc580240f1f02ebc72381cfab672648637a2"}, ] [package.dependencies] opentelemetry-api = ">=1.12,<2.0" -opentelemetry-instrumentation = "0.47b0" -opentelemetry-instrumentation-asgi = "0.47b0" -opentelemetry-semantic-conventions = "0.47b0" -opentelemetry-util-http = "0.47b0" +opentelemetry-instrumentation = "0.48b0" +opentelemetry-instrumentation-asgi = "0.48b0" +opentelemetry-semantic-conventions = "0.48b0" +opentelemetry-util-http = "0.48b0" [package.extras] -instruments = ["fastapi (>=0.58,<1.0)", "fastapi-slim (>=0.111.0,<0.112.0)"] +instruments = ["fastapi (>=0.58,<1.0)"] [[package]] name = "opentelemetry-proto" -version = "1.26.0" +version = "1.27.0" description = "OpenTelemetry Python Proto" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_proto-1.26.0-py3-none-any.whl", hash = "sha256:6c4d7b4d4d9c88543bcf8c28ae3f8f0448a753dc291c18c5390444c90b76a725"}, - {file = "opentelemetry_proto-1.26.0.tar.gz", hash = "sha256:c5c18796c0cab3751fc3b98dee53855835e90c0422924b484432ac852d93dc1e"}, + {file = "opentelemetry_proto-1.27.0-py3-none-any.whl", hash = "sha256:b133873de5581a50063e1e4b29cdcf0c5e253a8c2d8dc1229add20a4c3830ace"}, + {file = "opentelemetry_proto-1.27.0.tar.gz", hash = "sha256:33c9345d91dafd8a74fc3d7576c5a38f18b7fdf8d02983ac67485386132aedd6"}, ] [package.dependencies] @@ -5422,44 +6124,44 @@ protobuf = ">=3.19,<5.0" [[package]] name = "opentelemetry-sdk" -version = "1.26.0" +version = "1.27.0" description = "OpenTelemetry Python SDK" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_sdk-1.26.0-py3-none-any.whl", hash = "sha256:feb5056a84a88670c041ea0ded9921fca559efec03905dddeb3885525e0af897"}, - {file = "opentelemetry_sdk-1.26.0.tar.gz", hash = "sha256:c90d2868f8805619535c05562d699e2f4fb1f00dbd55a86dcefca4da6fa02f85"}, + {file = "opentelemetry_sdk-1.27.0-py3-none-any.whl", hash = "sha256:365f5e32f920faf0fd9e14fdfd92c086e317eaa5f860edba9cdc17a380d9197d"}, + {file = "opentelemetry_sdk-1.27.0.tar.gz", hash = "sha256:d525017dea0ccce9ba4e0245100ec46ecdc043f2d7b8315d56b19aff0904fa6f"}, ] [package.dependencies] -opentelemetry-api = "1.26.0" -opentelemetry-semantic-conventions = "0.47b0" +opentelemetry-api = "1.27.0" +opentelemetry-semantic-conventions = "0.48b0" typing-extensions = ">=3.7.4" [[package]] name = "opentelemetry-semantic-conventions" -version = "0.47b0" +version = "0.48b0" description = "OpenTelemetry Semantic Conventions" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_semantic_conventions-0.47b0-py3-none-any.whl", hash = "sha256:4ff9d595b85a59c1c1413f02bba320ce7ea6bf9e2ead2b0913c4395c7bbc1063"}, - {file = "opentelemetry_semantic_conventions-0.47b0.tar.gz", hash = "sha256:a8d57999bbe3495ffd4d510de26a97dadc1dace53e0275001b2c1b2f67992a7e"}, + {file = "opentelemetry_semantic_conventions-0.48b0-py3-none-any.whl", hash = "sha256:a0de9f45c413a8669788a38569c7e0a11ce6ce97861a628cca785deecdc32a1f"}, + {file = "opentelemetry_semantic_conventions-0.48b0.tar.gz", hash = "sha256:12d74983783b6878162208be57c9effcb89dc88691c64992d70bb89dc00daa1a"}, ] [package.dependencies] deprecated = ">=1.2.6" -opentelemetry-api = "1.26.0" +opentelemetry-api = "1.27.0" [[package]] name = "opentelemetry-util-http" -version = "0.47b0" +version = "0.48b0" description = "Web util for OpenTelemetry" optional = false python-versions = ">=3.8" files = [ - {file = "opentelemetry_util_http-0.47b0-py3-none-any.whl", hash = "sha256:3d3215e09c4a723b12da6d0233a31395aeb2bb33a64d7b15a1500690ba250f19"}, - {file = "opentelemetry_util_http-0.47b0.tar.gz", hash = "sha256:352a07664c18eef827eb8ddcbd64c64a7284a39dd1655e2f16f577eb046ccb32"}, + {file = "opentelemetry_util_http-0.48b0-py3-none-any.whl", hash = "sha256:76f598af93aab50328d2a69c786beaedc8b6a7770f7a818cc307eb353debfffb"}, + {file = "opentelemetry_util_http-0.48b0.tar.gz", hash = "sha256:60312015153580cc20f322e5cdc3d3ecad80a71743235bdb77716e742814623c"}, ] [[package]] @@ -5507,62 +6209,69 @@ cryptography = ">=3.2.1" [[package]] name = "orjson" -version = "3.10.6" +version = "3.10.10" description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" optional = false python-versions = ">=3.8" files = [ - {file = "orjson-3.10.6-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:fb0ee33124db6eaa517d00890fc1a55c3bfe1cf78ba4a8899d71a06f2d6ff5c7"}, - {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c1c4b53b24a4c06547ce43e5fee6ec4e0d8fe2d597f4647fc033fd205707365"}, - {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eadc8fd310edb4bdbd333374f2c8fec6794bbbae99b592f448d8214a5e4050c0"}, - {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:61272a5aec2b2661f4fa2b37c907ce9701e821b2c1285d5c3ab0207ebd358d38"}, - {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57985ee7e91d6214c837936dc1608f40f330a6b88bb13f5a57ce5257807da143"}, - {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:633a3b31d9d7c9f02d49c4ab4d0a86065c4a6f6adc297d63d272e043472acab5"}, - {file = "orjson-3.10.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:1c680b269d33ec444afe2bdc647c9eb73166fa47a16d9a75ee56a374f4a45f43"}, - {file = "orjson-3.10.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f759503a97a6ace19e55461395ab0d618b5a117e8d0fbb20e70cfd68a47327f2"}, - {file = "orjson-3.10.6-cp310-none-win32.whl", hash = "sha256:95a0cce17f969fb5391762e5719575217bd10ac5a189d1979442ee54456393f3"}, - {file = "orjson-3.10.6-cp310-none-win_amd64.whl", hash = "sha256:df25d9271270ba2133cc88ee83c318372bdc0f2cd6f32e7a450809a111efc45c"}, - {file = "orjson-3.10.6-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:b1ec490e10d2a77c345def52599311849fc063ae0e67cf4f84528073152bb2ba"}, - {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55d43d3feb8f19d07e9f01e5b9be4f28801cf7c60d0fa0d279951b18fae1932b"}, - {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac3045267e98fe749408eee1593a142e02357c5c99be0802185ef2170086a863"}, - {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c27bc6a28ae95923350ab382c57113abd38f3928af3c80be6f2ba7eb8d8db0b0"}, - {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d27456491ca79532d11e507cadca37fb8c9324a3976294f68fb1eff2dc6ced5a"}, - {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05ac3d3916023745aa3b3b388e91b9166be1ca02b7c7e41045da6d12985685f0"}, - {file = "orjson-3.10.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1335d4ef59ab85cab66fe73fd7a4e881c298ee7f63ede918b7faa1b27cbe5212"}, - {file = "orjson-3.10.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4bbc6d0af24c1575edc79994c20e1b29e6fb3c6a570371306db0993ecf144dc5"}, - {file = "orjson-3.10.6-cp311-none-win32.whl", hash = "sha256:450e39ab1f7694465060a0550b3f6d328d20297bf2e06aa947b97c21e5241fbd"}, - {file = "orjson-3.10.6-cp311-none-win_amd64.whl", hash = "sha256:227df19441372610b20e05bdb906e1742ec2ad7a66ac8350dcfd29a63014a83b"}, - {file = "orjson-3.10.6-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:ea2977b21f8d5d9b758bb3f344a75e55ca78e3ff85595d248eee813ae23ecdfb"}, - {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b6f3d167d13a16ed263b52dbfedff52c962bfd3d270b46b7518365bcc2121eed"}, - {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f710f346e4c44a4e8bdf23daa974faede58f83334289df80bc9cd12fe82573c7"}, - {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7275664f84e027dcb1ad5200b8b18373e9c669b2a9ec33d410c40f5ccf4b257e"}, - {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0943e4c701196b23c240b3d10ed8ecd674f03089198cf503105b474a4f77f21f"}, - {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:446dee5a491b5bc7d8f825d80d9637e7af43f86a331207b9c9610e2f93fee22a"}, - {file = "orjson-3.10.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:64c81456d2a050d380786413786b057983892db105516639cb5d3ee3c7fd5148"}, - {file = "orjson-3.10.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:960db0e31c4e52fa0fc3ecbaea5b2d3b58f379e32a95ae6b0ebeaa25b93dfd34"}, - {file = "orjson-3.10.6-cp312-none-win32.whl", hash = "sha256:a6ea7afb5b30b2317e0bee03c8d34c8181bc5a36f2afd4d0952f378972c4efd5"}, - {file = "orjson-3.10.6-cp312-none-win_amd64.whl", hash = "sha256:874ce88264b7e655dde4aeaacdc8fd772a7962faadfb41abe63e2a4861abc3dc"}, - {file = "orjson-3.10.6-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:66680eae4c4e7fc193d91cfc1353ad6d01b4801ae9b5314f17e11ba55e934183"}, - {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:caff75b425db5ef8e8f23af93c80f072f97b4fb3afd4af44482905c9f588da28"}, - {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3722fddb821b6036fd2a3c814f6bd9b57a89dc6337b9924ecd614ebce3271394"}, - {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c2c116072a8533f2fec435fde4d134610f806bdac20188c7bd2081f3e9e0133f"}, - {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6eeb13218c8cf34c61912e9df2de2853f1d009de0e46ea09ccdf3d757896af0a"}, - {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:965a916373382674e323c957d560b953d81d7a8603fbeee26f7b8248638bd48b"}, - {file = "orjson-3.10.6-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:03c95484d53ed8e479cade8628c9cea00fd9d67f5554764a1110e0d5aa2de96e"}, - {file = "orjson-3.10.6-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:e060748a04cccf1e0a6f2358dffea9c080b849a4a68c28b1b907f272b5127e9b"}, - {file = "orjson-3.10.6-cp38-none-win32.whl", hash = "sha256:738dbe3ef909c4b019d69afc19caf6b5ed0e2f1c786b5d6215fbb7539246e4c6"}, - {file = "orjson-3.10.6-cp38-none-win_amd64.whl", hash = "sha256:d40f839dddf6a7d77114fe6b8a70218556408c71d4d6e29413bb5f150a692ff7"}, - {file = "orjson-3.10.6-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:697a35a083c4f834807a6232b3e62c8b280f7a44ad0b759fd4dce748951e70db"}, - {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd502f96bf5ea9a61cbc0b2b5900d0dd68aa0da197179042bdd2be67e51a1e4b"}, - {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f215789fb1667cdc874c1b8af6a84dc939fd802bf293a8334fce185c79cd359b"}, - {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a2debd8ddce948a8c0938c8c93ade191d2f4ba4649a54302a7da905a81f00b56"}, - {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5410111d7b6681d4b0d65e0f58a13be588d01b473822483f77f513c7f93bd3b2"}, - {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb1f28a137337fdc18384079fa5726810681055b32b92253fa15ae5656e1dddb"}, - {file = "orjson-3.10.6-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:bf2fbbce5fe7cd1aa177ea3eab2b8e6a6bc6e8592e4279ed3db2d62e57c0e1b2"}, - {file = "orjson-3.10.6-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:79b9b9e33bd4c517445a62b90ca0cc279b0f1f3970655c3df9e608bc3f91741a"}, - {file = "orjson-3.10.6-cp39-none-win32.whl", hash = "sha256:30b0a09a2014e621b1adf66a4f705f0809358350a757508ee80209b2d8dae219"}, - {file = "orjson-3.10.6-cp39-none-win_amd64.whl", hash = "sha256:49e3bc615652617d463069f91b867a4458114c5b104e13b7ae6872e5f79d0844"}, - {file = "orjson-3.10.6.tar.gz", hash = "sha256:e54b63d0a7c6c54a5f5f726bc93a2078111ef060fec4ecbf34c5db800ca3b3a7"}, + {file = "orjson-3.10.10-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:b788a579b113acf1c57e0a68e558be71d5d09aa67f62ca1f68e01117e550a998"}, + {file = "orjson-3.10.10-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:804b18e2b88022c8905bb79bd2cbe59c0cd014b9328f43da8d3b28441995cda4"}, + {file = "orjson-3.10.10-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9972572a1d042ec9ee421b6da69f7cc823da5962237563fa548ab17f152f0b9b"}, + {file = "orjson-3.10.10-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dc6993ab1c2ae7dd0711161e303f1db69062955ac2668181bfdf2dd410e65258"}, + {file = "orjson-3.10.10-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d78e4cacced5781b01d9bc0f0cd8b70b906a0e109825cb41c1b03f9c41e4ce86"}, + {file = "orjson-3.10.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e6eb2598df518281ba0cbc30d24c5b06124ccf7e19169e883c14e0831217a0bc"}, + {file = "orjson-3.10.10-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:23776265c5215ec532de6238a52707048401a568f0fa0d938008e92a147fe2c7"}, + {file = "orjson-3.10.10-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8cc2a654c08755cef90b468ff17c102e2def0edd62898b2486767204a7f5cc9c"}, + {file = "orjson-3.10.10-cp310-none-win32.whl", hash = "sha256:081b3fc6a86d72efeb67c13d0ea7c030017bd95f9868b1e329a376edc456153b"}, + {file = "orjson-3.10.10-cp310-none-win_amd64.whl", hash = "sha256:ff38c5fb749347768a603be1fb8a31856458af839f31f064c5aa74aca5be9efe"}, + {file = "orjson-3.10.10-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:879e99486c0fbb256266c7c6a67ff84f46035e4f8749ac6317cc83dacd7f993a"}, + {file = "orjson-3.10.10-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:019481fa9ea5ff13b5d5d95e6fd5ab25ded0810c80b150c2c7b1cc8660b662a7"}, + {file = "orjson-3.10.10-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0dd57eff09894938b4c86d4b871a479260f9e156fa7f12f8cad4b39ea8028bb5"}, + {file = "orjson-3.10.10-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dbde6d70cd95ab4d11ea8ac5e738e30764e510fc54d777336eec09bb93b8576c"}, + {file = "orjson-3.10.10-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3b2625cb37b8fb42e2147404e5ff7ef08712099197a9cd38895006d7053e69d6"}, + {file = "orjson-3.10.10-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dbf3c20c6a7db69df58672a0d5815647ecf78c8e62a4d9bd284e8621c1fe5ccb"}, + {file = "orjson-3.10.10-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:75c38f5647e02d423807d252ce4528bf6a95bd776af999cb1fb48867ed01d1f6"}, + {file = "orjson-3.10.10-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:23458d31fa50ec18e0ec4b0b4343730928296b11111df5f547c75913714116b2"}, + {file = "orjson-3.10.10-cp311-none-win32.whl", hash = "sha256:2787cd9dedc591c989f3facd7e3e86508eafdc9536a26ec277699c0aa63c685b"}, + {file = "orjson-3.10.10-cp311-none-win_amd64.whl", hash = "sha256:6514449d2c202a75183f807bc755167713297c69f1db57a89a1ef4a0170ee269"}, + {file = "orjson-3.10.10-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:8564f48f3620861f5ef1e080ce7cd122ee89d7d6dacf25fcae675ff63b4d6e05"}, + {file = "orjson-3.10.10-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5bf161a32b479034098c5b81f2608f09167ad2fa1c06abd4e527ea6bf4837a9"}, + {file = "orjson-3.10.10-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:68b65c93617bcafa7f04b74ae8bc2cc214bd5cb45168a953256ff83015c6747d"}, + {file = "orjson-3.10.10-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e8e28406f97fc2ea0c6150f4c1b6e8261453318930b334abc419214c82314f85"}, + {file = "orjson-3.10.10-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4d0d9fe174cc7a5bdce2e6c378bcdb4c49b2bf522a8f996aa586020e1b96cee"}, + {file = "orjson-3.10.10-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3be81c42f1242cbed03cbb3973501fcaa2675a0af638f8be494eaf37143d999"}, + {file = "orjson-3.10.10-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:65f9886d3bae65be026219c0a5f32dbbe91a9e6272f56d092ab22561ad0ea33b"}, + {file = "orjson-3.10.10-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:730ed5350147db7beb23ddaf072f490329e90a1d059711d364b49fe352ec987b"}, + {file = "orjson-3.10.10-cp312-none-win32.whl", hash = "sha256:a8f4bf5f1c85bea2170800020d53a8877812892697f9c2de73d576c9307a8a5f"}, + {file = "orjson-3.10.10-cp312-none-win_amd64.whl", hash = "sha256:384cd13579a1b4cd689d218e329f459eb9ddc504fa48c5a83ef4889db7fd7a4f"}, + {file = "orjson-3.10.10-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:44bffae68c291f94ff5a9b4149fe9d1bdd4cd0ff0fb575bcea8351d48db629a1"}, + {file = "orjson-3.10.10-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e27b4c6437315df3024f0835887127dac2a0a3ff643500ec27088d2588fa5ae1"}, + {file = "orjson-3.10.10-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bca84df16d6b49325a4084fd8b2fe2229cb415e15c46c529f868c3387bb1339d"}, + {file = "orjson-3.10.10-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c14ce70e8f39bd71f9f80423801b5d10bf93d1dceffdecd04df0f64d2c69bc01"}, + {file = "orjson-3.10.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:24ac62336da9bda1bd93c0491eff0613003b48d3cb5d01470842e7b52a40d5b4"}, + {file = "orjson-3.10.10-cp313-none-win32.whl", hash = "sha256:eb0a42831372ec2b05acc9ee45af77bcaccbd91257345f93780a8e654efc75db"}, + {file = "orjson-3.10.10-cp313-none-win_amd64.whl", hash = "sha256:f0c4f37f8bf3f1075c6cc8dd8a9f843689a4b618628f8812d0a71e6968b95ffd"}, + {file = "orjson-3.10.10-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:829700cc18503efc0cf502d630f612884258020d98a317679cd2054af0259568"}, + {file = "orjson-3.10.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e0ceb5e0e8c4f010ac787d29ae6299846935044686509e2f0f06ed441c1ca949"}, + {file = "orjson-3.10.10-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0c25908eb86968613216f3db4d3003f1c45d78eb9046b71056ca327ff92bdbd4"}, + {file = "orjson-3.10.10-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:218cb0bc03340144b6328a9ff78f0932e642199ac184dd74b01ad691f42f93ff"}, + {file = "orjson-3.10.10-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e2277ec2cea3775640dc81ab5195bb5b2ada2fe0ea6eee4677474edc75ea6785"}, + {file = "orjson-3.10.10-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:848ea3b55ab5ccc9d7bbd420d69432628b691fba3ca8ae3148c35156cbd282aa"}, + {file = "orjson-3.10.10-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:e3e67b537ac0c835b25b5f7d40d83816abd2d3f4c0b0866ee981a045287a54f3"}, + {file = "orjson-3.10.10-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:7948cfb909353fce2135dcdbe4521a5e7e1159484e0bb024c1722f272488f2b8"}, + {file = "orjson-3.10.10-cp38-none-win32.whl", hash = "sha256:78bee66a988f1a333dc0b6257503d63553b1957889c17b2c4ed72385cd1b96ae"}, + {file = "orjson-3.10.10-cp38-none-win_amd64.whl", hash = "sha256:f1d647ca8d62afeb774340a343c7fc023efacfd3a39f70c798991063f0c681dd"}, + {file = "orjson-3.10.10-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:5a059afddbaa6dd733b5a2d76a90dbc8af790b993b1b5cb97a1176ca713b5df8"}, + {file = "orjson-3.10.10-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f9b5c59f7e2a1a410f971c5ebc68f1995822837cd10905ee255f96074537ee6"}, + {file = "orjson-3.10.10-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d5ef198bafdef4aa9d49a4165ba53ffdc0a9e1c7b6f76178572ab33118afea25"}, + {file = "orjson-3.10.10-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aaf29ce0bb5d3320824ec3d1508652421000ba466abd63bdd52c64bcce9eb1fa"}, + {file = "orjson-3.10.10-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dddd5516bcc93e723d029c1633ae79c4417477b4f57dad9bfeeb6bc0315e654a"}, + {file = "orjson-3.10.10-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a12f2003695b10817f0fa8b8fca982ed7f5761dcb0d93cff4f2f9f6709903fd7"}, + {file = "orjson-3.10.10-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:672f9874a8a8fb9bb1b771331d31ba27f57702c8106cdbadad8bda5d10bc1019"}, + {file = "orjson-3.10.10-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1dcbb0ca5fafb2b378b2c74419480ab2486326974826bbf6588f4dc62137570a"}, + {file = "orjson-3.10.10-cp39-none-win32.whl", hash = "sha256:d9bbd3a4b92256875cb058c3381b782649b9a3c68a4aa9a2fff020c2f9cfc1be"}, + {file = "orjson-3.10.10-cp39-none-win_amd64.whl", hash = "sha256:766f21487a53aee8524b97ca9582d5c6541b03ab6210fbaf10142ae2f3ced2aa"}, + {file = "orjson-3.10.10.tar.gz", hash = "sha256:37949383c4df7b4337ce82ee35b6d7471e55195efa7dcb45ab8226ceadb0fe3b"}, ] [[package]] @@ -5596,51 +6305,64 @@ files = [ [[package]] name = "packaging" -version = "23.2" +version = "24.1" description = "Core utilities for Python packages" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"}, - {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, + {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, + {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, ] [[package]] name = "pandas" -version = "2.2.2" +version = "2.2.3" description = "Powerful data structures for data analysis, time series, and statistics" optional = false python-versions = ">=3.9" files = [ - {file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"}, - {file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"}, - {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"}, - {file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"}, - {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"}, - {file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8e5a0b00e1e56a842f922e7fae8ae4077aee4af0acb5ae3622bd4b4c30aedf99"}, - {file = "pandas-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:ddf818e4e6c7c6f4f7c8a12709696d193976b591cc7dc50588d3d1a6b5dc8772"}, - {file = "pandas-2.2.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:696039430f7a562b74fa45f540aca068ea85fa34c244d0deee539cb6d70aa288"}, - {file = "pandas-2.2.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8e90497254aacacbc4ea6ae5e7a8cd75629d6ad2b30025a4a8b09aa4faf55151"}, - {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58b84b91b0b9f4bafac2a0ac55002280c094dfc6402402332c0913a59654ab2b"}, - {file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2123dc9ad6a814bcdea0f099885276b31b24f7edf40f6cdbc0912672e22eee"}, - {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2925720037f06e89af896c70bca73459d7e6a4be96f9de79e2d440bd499fe0db"}, - {file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0cace394b6ea70c01ca1595f839cf193df35d1575986e484ad35c4aeae7266c1"}, - {file = "pandas-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:873d13d177501a28b2756375d59816c365e42ed8417b41665f346289adc68d24"}, - {file = "pandas-2.2.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9dfde2a0ddef507a631dc9dc4af6a9489d5e2e740e226ad426a05cabfbd7c8ef"}, - {file = "pandas-2.2.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e9b79011ff7a0f4b1d6da6a61aa1aa604fb312d6647de5bad20013682d1429ce"}, - {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cb51fe389360f3b5a4d57dbd2848a5f033350336ca3b340d1c53a1fad33bcad"}, - {file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eee3a87076c0756de40b05c5e9a6069c035ba43e8dd71c379e68cab2c20f16ad"}, - {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3e374f59e440d4ab45ca2fffde54b81ac3834cf5ae2cdfa69c90bc03bde04d76"}, - {file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"}, - {file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"}, - {file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"}, - {file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"}, - {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"}, - {file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"}, - {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"}, - {file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:92fd6b027924a7e178ac202cfbe25e53368db90d56872d20ffae94b96c7acc57"}, - {file = "pandas-2.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:640cef9aa381b60e296db324337a554aeeb883ead99dc8f6c18e81a93942f5f4"}, - {file = "pandas-2.2.2.tar.gz", hash = "sha256:9e79019aba43cb4fda9e4d983f8e88ca0373adbb697ae9c6c43093218de28b54"}, + {file = "pandas-2.2.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1948ddde24197a0f7add2bdc4ca83bf2b1ef84a1bc8ccffd95eda17fd836ecb5"}, + {file = "pandas-2.2.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:381175499d3802cde0eabbaf6324cce0c4f5d52ca6f8c377c29ad442f50f6348"}, + {file = "pandas-2.2.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d9c45366def9a3dd85a6454c0e7908f2b3b8e9c138f5dc38fed7ce720d8453ed"}, + {file = "pandas-2.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86976a1c5b25ae3f8ccae3a5306e443569ee3c3faf444dfd0f41cda24667ad57"}, + {file = "pandas-2.2.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:b8661b0238a69d7aafe156b7fa86c44b881387509653fdf857bebc5e4008ad42"}, + {file = "pandas-2.2.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:37e0aced3e8f539eccf2e099f65cdb9c8aa85109b0be6e93e2baff94264bdc6f"}, + {file = "pandas-2.2.3-cp310-cp310-win_amd64.whl", hash = "sha256:56534ce0746a58afaf7942ba4863e0ef81c9c50d3f0ae93e9497d6a41a057645"}, + {file = "pandas-2.2.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:66108071e1b935240e74525006034333f98bcdb87ea116de573a6a0dccb6c039"}, + {file = "pandas-2.2.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7c2875855b0ff77b2a64a0365e24455d9990730d6431b9e0ee18ad8acee13dbd"}, + {file = "pandas-2.2.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd8d0c3be0515c12fed0bdbae072551c8b54b7192c7b1fda0ba56059a0179698"}, + {file = "pandas-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c124333816c3a9b03fbeef3a9f230ba9a737e9e5bb4060aa2107a86cc0a497fc"}, + {file = "pandas-2.2.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:63cc132e40a2e084cf01adf0775b15ac515ba905d7dcca47e9a251819c575ef3"}, + {file = "pandas-2.2.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:29401dbfa9ad77319367d36940cd8a0b3a11aba16063e39632d98b0e931ddf32"}, + {file = "pandas-2.2.3-cp311-cp311-win_amd64.whl", hash = "sha256:3fc6873a41186404dad67245896a6e440baacc92f5b716ccd1bc9ed2995ab2c5"}, + {file = "pandas-2.2.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b1d432e8d08679a40e2a6d8b2f9770a5c21793a6f9f47fdd52c5ce1948a5a8a9"}, + {file = "pandas-2.2.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a5a1595fe639f5988ba6a8e5bc9649af3baf26df3998a0abe56c02609392e0a4"}, + {file = "pandas-2.2.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5de54125a92bb4d1c051c0659e6fcb75256bf799a732a87184e5ea503965bce3"}, + {file = "pandas-2.2.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fffb8ae78d8af97f849404f21411c95062db1496aeb3e56f146f0355c9989319"}, + {file = "pandas-2.2.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6dfcb5ee8d4d50c06a51c2fffa6cff6272098ad6540aed1a76d15fb9318194d8"}, + {file = "pandas-2.2.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:062309c1b9ea12a50e8ce661145c6aab431b1e99530d3cd60640e255778bd43a"}, + {file = "pandas-2.2.3-cp312-cp312-win_amd64.whl", hash = "sha256:59ef3764d0fe818125a5097d2ae867ca3fa64df032331b7e0917cf5d7bf66b13"}, + {file = "pandas-2.2.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f00d1345d84d8c86a63e476bb4955e46458b304b9575dcf71102b5c705320015"}, + {file = "pandas-2.2.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3508d914817e153ad359d7e069d752cdd736a247c322d932eb89e6bc84217f28"}, + {file = "pandas-2.2.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:22a9d949bfc9a502d320aa04e5d02feab689d61da4e7764b62c30b991c42c5f0"}, + {file = "pandas-2.2.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3a255b2c19987fbbe62a9dfd6cff7ff2aa9ccab3fc75218fd4b7530f01efa24"}, + {file = "pandas-2.2.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:800250ecdadb6d9c78eae4990da62743b857b470883fa27f652db8bdde7f6659"}, + {file = "pandas-2.2.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6374c452ff3ec675a8f46fd9ab25c4ad0ba590b71cf0656f8b6daa5202bca3fb"}, + {file = "pandas-2.2.3-cp313-cp313-win_amd64.whl", hash = "sha256:61c5ad4043f791b61dd4752191d9f07f0ae412515d59ba8f005832a532f8736d"}, + {file = "pandas-2.2.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:3b71f27954685ee685317063bf13c7709a7ba74fc996b84fc6821c59b0f06468"}, + {file = "pandas-2.2.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:38cf8125c40dae9d5acc10fa66af8ea6fdf760b2714ee482ca691fc66e6fcb18"}, + {file = "pandas-2.2.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ba96630bc17c875161df3818780af30e43be9b166ce51c9a18c1feae342906c2"}, + {file = "pandas-2.2.3-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1db71525a1538b30142094edb9adc10be3f3e176748cd7acc2240c2f2e5aa3a4"}, + {file = "pandas-2.2.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:15c0e1e02e93116177d29ff83e8b1619c93ddc9c49083f237d4312337a61165d"}, + {file = "pandas-2.2.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ad5b65698ab28ed8d7f18790a0dc58005c7629f227be9ecc1072aa74c0c1d43a"}, + {file = "pandas-2.2.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc6b93f9b966093cb0fd62ff1a7e4c09e6d546ad7c1de191767baffc57628f39"}, + {file = "pandas-2.2.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5dbca4c1acd72e8eeef4753eeca07de9b1db4f398669d5994086f788a5d7cc30"}, + {file = "pandas-2.2.3-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8cd6d7cc958a3910f934ea8dbdf17b2364827bb4dafc38ce6eef6bb3d65ff09c"}, + {file = "pandas-2.2.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99df71520d25fade9db7c1076ac94eb994f4d2673ef2aa2e86ee039b6746d20c"}, + {file = "pandas-2.2.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:31d0ced62d4ea3e231a9f228366919a5ea0b07440d9d4dac345376fd8e1477ea"}, + {file = "pandas-2.2.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:7eee9e7cea6adf3e3d24e304ac6b8300646e2a5d1cd3a3c2abed9101b0846761"}, + {file = "pandas-2.2.3-cp39-cp39-win_amd64.whl", hash = "sha256:4850ba03528b6dd51d6c5d273c46f183f39a9baf3f0143e566b89450965b105e"}, + {file = "pandas-2.2.3.tar.gz", hash = "sha256:4f18ba62b61d7e192368b84517265a99b4d7ee8912f8708660fb4a366cc82667"}, ] [package.dependencies] @@ -5687,25 +6409,42 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.9.2)"] +[[package]] +name = "pathos" +version = "0.3.3" +description = "parallel graph management and execution in heterogeneous computing" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pathos-0.3.3-py3-none-any.whl", hash = "sha256:e04616c6448608ad1f809360be22e3f2078d949a36a81e6991da6c2dd1f82513"}, + {file = "pathos-0.3.3.tar.gz", hash = "sha256:dcb2a5f321aa34ca541c1c1861011ea49df357bb908379c21dd5741f666e0a58"}, +] + +[package.dependencies] +dill = ">=0.3.9" +multiprocess = ">=0.70.17" +pox = ">=0.3.5" +ppft = ">=1.7.6.9" + [[package]] name = "peewee" -version = "3.17.6" +version = "3.17.7" description = "a little orm" optional = false python-versions = "*" files = [ - {file = "peewee-3.17.6.tar.gz", hash = "sha256:cea5592c6f4da1592b7cff8eaf655be6648a1f5857469e30037bf920c03fb8fb"}, + {file = "peewee-3.17.7.tar.gz", hash = "sha256:6aefc700bd530fc6ac23fa19c9c5b47041751d92985b799169c8e318e97eabaa"}, ] [[package]] name = "pgvecto-rs" -version = "0.2.1" +version = "0.2.2" description = "Python binding for pgvecto.rs" optional = false python-versions = "<3.13,>=3.8" files = [ - {file = "pgvecto_rs-0.2.1-py3-none-any.whl", hash = "sha256:b3ee2c465219469ad537b3efea2916477c6c576b3d6fd4298980d0733d12bb27"}, - {file = "pgvecto_rs-0.2.1.tar.gz", hash = "sha256:07046eaad2c4f75745f76de9ba483541909f1c595aced8d3434224a4f933daca"}, + {file = "pgvecto_rs-0.2.2-py3-none-any.whl", hash = "sha256:5f3f7f806813de408c45dc10a9eb418b986c4d7b7723e8fce9298f2f7d8fbbd5"}, + {file = "pgvecto_rs-0.2.2.tar.gz", hash = "sha256:edaa913d1747152b1407cbdf6337d51ac852547b54953ef38997433be3a75a3b"}, ] [package.dependencies] @@ -5734,95 +6473,90 @@ numpy = "*" [[package]] name = "pillow" -version = "10.4.0" +version = "11.0.0" description = "Python Imaging Library (Fork)" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "pillow-10.4.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:4d9667937cfa347525b319ae34375c37b9ee6b525440f3ef48542fcf66f2731e"}, - {file = "pillow-10.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:543f3dc61c18dafb755773efc89aae60d06b6596a63914107f75459cf984164d"}, - {file = "pillow-10.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7928ecbf1ece13956b95d9cbcfc77137652b02763ba384d9ab508099a2eca856"}, - {file = "pillow-10.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4d49b85c4348ea0b31ea63bc75a9f3857869174e2bf17e7aba02945cd218e6f"}, - {file = "pillow-10.4.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:6c762a5b0997f5659a5ef2266abc1d8851ad7749ad9a6a5506eb23d314e4f46b"}, - {file = "pillow-10.4.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a985e028fc183bf12a77a8bbf36318db4238a3ded7fa9df1b9a133f1cb79f8fc"}, - {file = "pillow-10.4.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:812f7342b0eee081eaec84d91423d1b4650bb9828eb53d8511bcef8ce5aecf1e"}, - {file = "pillow-10.4.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ac1452d2fbe4978c2eec89fb5a23b8387aba707ac72810d9490118817d9c0b46"}, - {file = "pillow-10.4.0-cp310-cp310-win32.whl", hash = "sha256:bcd5e41a859bf2e84fdc42f4edb7d9aba0a13d29a2abadccafad99de3feff984"}, - {file = "pillow-10.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:ecd85a8d3e79cd7158dec1c9e5808e821feea088e2f69a974db5edf84dc53141"}, - {file = "pillow-10.4.0-cp310-cp310-win_arm64.whl", hash = "sha256:ff337c552345e95702c5fde3158acb0625111017d0e5f24bf3acdb9cc16b90d1"}, - {file = "pillow-10.4.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:0a9ec697746f268507404647e531e92889890a087e03681a3606d9b920fbee3c"}, - {file = "pillow-10.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:dfe91cb65544a1321e631e696759491ae04a2ea11d36715eca01ce07284738be"}, - {file = "pillow-10.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5dc6761a6efc781e6a1544206f22c80c3af4c8cf461206d46a1e6006e4429ff3"}, - {file = "pillow-10.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e84b6cc6a4a3d76c153a6b19270b3526a5a8ed6b09501d3af891daa2a9de7d6"}, - {file = "pillow-10.4.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:bbc527b519bd3aa9d7f429d152fea69f9ad37c95f0b02aebddff592688998abe"}, - {file = "pillow-10.4.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:76a911dfe51a36041f2e756b00f96ed84677cdeb75d25c767f296c1c1eda1319"}, - {file = "pillow-10.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:59291fb29317122398786c2d44427bbd1a6d7ff54017075b22be9d21aa59bd8d"}, - {file = "pillow-10.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:416d3a5d0e8cfe4f27f574362435bc9bae57f679a7158e0096ad2beb427b8696"}, - {file = "pillow-10.4.0-cp311-cp311-win32.whl", hash = "sha256:7086cc1d5eebb91ad24ded9f58bec6c688e9f0ed7eb3dbbf1e4800280a896496"}, - {file = "pillow-10.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:cbed61494057c0f83b83eb3a310f0bf774b09513307c434d4366ed64f4128a91"}, - {file = "pillow-10.4.0-cp311-cp311-win_arm64.whl", hash = "sha256:f5f0c3e969c8f12dd2bb7e0b15d5c468b51e5017e01e2e867335c81903046a22"}, - {file = "pillow-10.4.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:673655af3eadf4df6b5457033f086e90299fdd7a47983a13827acf7459c15d94"}, - {file = "pillow-10.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:866b6942a92f56300012f5fbac71f2d610312ee65e22f1aa2609e491284e5597"}, - {file = "pillow-10.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29dbdc4207642ea6aad70fbde1a9338753d33fb23ed6956e706936706f52dd80"}, - {file = "pillow-10.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf2342ac639c4cf38799a44950bbc2dfcb685f052b9e262f446482afaf4bffca"}, - {file = "pillow-10.4.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:f5b92f4d70791b4a67157321c4e8225d60b119c5cc9aee8ecf153aace4aad4ef"}, - {file = "pillow-10.4.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:86dcb5a1eb778d8b25659d5e4341269e8590ad6b4e8b44d9f4b07f8d136c414a"}, - {file = "pillow-10.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:780c072c2e11c9b2c7ca37f9a2ee8ba66f44367ac3e5c7832afcfe5104fd6d1b"}, - {file = "pillow-10.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:37fb69d905be665f68f28a8bba3c6d3223c8efe1edf14cc4cfa06c241f8c81d9"}, - {file = "pillow-10.4.0-cp312-cp312-win32.whl", hash = "sha256:7dfecdbad5c301d7b5bde160150b4db4c659cee2b69589705b6f8a0c509d9f42"}, - {file = "pillow-10.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:1d846aea995ad352d4bdcc847535bd56e0fd88d36829d2c90be880ef1ee4668a"}, - {file = "pillow-10.4.0-cp312-cp312-win_arm64.whl", hash = "sha256:e553cad5179a66ba15bb18b353a19020e73a7921296a7979c4a2b7f6a5cd57f9"}, - {file = "pillow-10.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8bc1a764ed8c957a2e9cacf97c8b2b053b70307cf2996aafd70e91a082e70df3"}, - {file = "pillow-10.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6209bb41dc692ddfee4942517c19ee81b86c864b626dbfca272ec0f7cff5d9fb"}, - {file = "pillow-10.4.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bee197b30783295d2eb680b311af15a20a8b24024a19c3a26431ff83eb8d1f70"}, - {file = "pillow-10.4.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ef61f5dd14c300786318482456481463b9d6b91ebe5ef12f405afbba77ed0be"}, - {file = "pillow-10.4.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:297e388da6e248c98bc4a02e018966af0c5f92dfacf5a5ca22fa01cb3179bca0"}, - {file = "pillow-10.4.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:e4db64794ccdf6cb83a59d73405f63adbe2a1887012e308828596100a0b2f6cc"}, - {file = "pillow-10.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:bd2880a07482090a3bcb01f4265f1936a903d70bc740bfcb1fd4e8a2ffe5cf5a"}, - {file = "pillow-10.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4b35b21b819ac1dbd1233317adeecd63495f6babf21b7b2512d244ff6c6ce309"}, - {file = "pillow-10.4.0-cp313-cp313-win32.whl", hash = "sha256:551d3fd6e9dc15e4c1eb6fc4ba2b39c0c7933fa113b220057a34f4bb3268a060"}, - {file = "pillow-10.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:030abdbe43ee02e0de642aee345efa443740aa4d828bfe8e2eb11922ea6a21ea"}, - {file = "pillow-10.4.0-cp313-cp313-win_arm64.whl", hash = "sha256:5b001114dd152cfd6b23befeb28d7aee43553e2402c9f159807bf55f33af8a8d"}, - {file = "pillow-10.4.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:8d4d5063501b6dd4024b8ac2f04962d661222d120381272deea52e3fc52d3736"}, - {file = "pillow-10.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7c1ee6f42250df403c5f103cbd2768a28fe1a0ea1f0f03fe151c8741e1469c8b"}, - {file = "pillow-10.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b15e02e9bb4c21e39876698abf233c8c579127986f8207200bc8a8f6bb27acf2"}, - {file = "pillow-10.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a8d4bade9952ea9a77d0c3e49cbd8b2890a399422258a77f357b9cc9be8d680"}, - {file = "pillow-10.4.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:43efea75eb06b95d1631cb784aa40156177bf9dd5b4b03ff38979e048258bc6b"}, - {file = "pillow-10.4.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:950be4d8ba92aca4b2bb0741285a46bfae3ca699ef913ec8416c1b78eadd64cd"}, - {file = "pillow-10.4.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d7480af14364494365e89d6fddc510a13e5a2c3584cb19ef65415ca57252fb84"}, - {file = "pillow-10.4.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:73664fe514b34c8f02452ffb73b7a92c6774e39a647087f83d67f010eb9a0cf0"}, - {file = "pillow-10.4.0-cp38-cp38-win32.whl", hash = "sha256:e88d5e6ad0d026fba7bdab8c3f225a69f063f116462c49892b0149e21b6c0a0e"}, - {file = "pillow-10.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:5161eef006d335e46895297f642341111945e2c1c899eb406882a6c61a4357ab"}, - {file = "pillow-10.4.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:0ae24a547e8b711ccaaf99c9ae3cd975470e1a30caa80a6aaee9a2f19c05701d"}, - {file = "pillow-10.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:298478fe4f77a4408895605f3482b6cc6222c018b2ce565c2b6b9c354ac3229b"}, - {file = "pillow-10.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:134ace6dc392116566980ee7436477d844520a26a4b1bd4053f6f47d096997fd"}, - {file = "pillow-10.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:930044bb7679ab003b14023138b50181899da3f25de50e9dbee23b61b4de2126"}, - {file = "pillow-10.4.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:c76e5786951e72ed3686e122d14c5d7012f16c8303a674d18cdcd6d89557fc5b"}, - {file = "pillow-10.4.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:b2724fdb354a868ddf9a880cb84d102da914e99119211ef7ecbdc613b8c96b3c"}, - {file = "pillow-10.4.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:dbc6ae66518ab3c5847659e9988c3b60dc94ffb48ef9168656e0019a93dbf8a1"}, - {file = "pillow-10.4.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:06b2f7898047ae93fad74467ec3d28fe84f7831370e3c258afa533f81ef7f3df"}, - {file = "pillow-10.4.0-cp39-cp39-win32.whl", hash = "sha256:7970285ab628a3779aecc35823296a7869f889b8329c16ad5a71e4901a3dc4ef"}, - {file = "pillow-10.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:961a7293b2457b405967af9c77dcaa43cc1a8cd50d23c532e62d48ab6cdd56f5"}, - {file = "pillow-10.4.0-cp39-cp39-win_arm64.whl", hash = "sha256:32cda9e3d601a52baccb2856b8ea1fc213c90b340c542dcef77140dfa3278a9e"}, - {file = "pillow-10.4.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:5b4815f2e65b30f5fbae9dfffa8636d992d49705723fe86a3661806e069352d4"}, - {file = "pillow-10.4.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:8f0aef4ef59694b12cadee839e2ba6afeab89c0f39a3adc02ed51d109117b8da"}, - {file = "pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f4727572e2918acaa9077c919cbbeb73bd2b3ebcfe033b72f858fc9fbef0026"}, - {file = "pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff25afb18123cea58a591ea0244b92eb1e61a1fd497bf6d6384f09bc3262ec3e"}, - {file = "pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:dc3e2db6ba09ffd7d02ae9141cfa0ae23393ee7687248d46a7507b75d610f4f5"}, - {file = "pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:02a2be69f9c9b8c1e97cf2713e789d4e398c751ecfd9967c18d0ce304efbf885"}, - {file = "pillow-10.4.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:0755ffd4a0c6f267cccbae2e9903d95477ca2f77c4fcf3a3a09570001856c8a5"}, - {file = "pillow-10.4.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:a02364621fe369e06200d4a16558e056fe2805d3468350df3aef21e00d26214b"}, - {file = "pillow-10.4.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:1b5dea9831a90e9d0721ec417a80d4cbd7022093ac38a568db2dd78363b00908"}, - {file = "pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b885f89040bb8c4a1573566bbb2f44f5c505ef6e74cec7ab9068c900047f04b"}, - {file = "pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87dd88ded2e6d74d31e1e0a99a726a6765cda32d00ba72dc37f0651f306daaa8"}, - {file = "pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:2db98790afc70118bd0255c2eeb465e9767ecf1f3c25f9a1abb8ffc8cfd1fe0a"}, - {file = "pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:f7baece4ce06bade126fb84b8af1c33439a76d8a6fd818970215e0560ca28c27"}, - {file = "pillow-10.4.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:cfdd747216947628af7b259d274771d84db2268ca062dd5faf373639d00113a3"}, - {file = "pillow-10.4.0.tar.gz", hash = "sha256:166c1cd4d24309b30d61f79f4a9114b7b2313d7450912277855ff5dfd7cd4a06"}, -] - -[package.extras] -docs = ["furo", "olefile", "sphinx (>=7.3)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinxext-opengraph"] + {file = "pillow-11.0.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:6619654954dc4936fcff82db8eb6401d3159ec6be81e33c6000dfd76ae189947"}, + {file = "pillow-11.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b3c5ac4bed7519088103d9450a1107f76308ecf91d6dabc8a33a2fcfb18d0fba"}, + {file = "pillow-11.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a65149d8ada1055029fcb665452b2814fe7d7082fcb0c5bed6db851cb69b2086"}, + {file = "pillow-11.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88a58d8ac0cc0e7f3a014509f0455248a76629ca9b604eca7dc5927cc593c5e9"}, + {file = "pillow-11.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:c26845094b1af3c91852745ae78e3ea47abf3dbcd1cf962f16b9a5fbe3ee8488"}, + {file = "pillow-11.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:1a61b54f87ab5786b8479f81c4b11f4d61702830354520837f8cc791ebba0f5f"}, + {file = "pillow-11.0.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:674629ff60030d144b7bca2b8330225a9b11c482ed408813924619c6f302fdbb"}, + {file = "pillow-11.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:598b4e238f13276e0008299bd2482003f48158e2b11826862b1eb2ad7c768b97"}, + {file = "pillow-11.0.0-cp310-cp310-win32.whl", hash = "sha256:9a0f748eaa434a41fccf8e1ee7a3eed68af1b690e75328fd7a60af123c193b50"}, + {file = "pillow-11.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:a5629742881bcbc1f42e840af185fd4d83a5edeb96475a575f4da50d6ede337c"}, + {file = "pillow-11.0.0-cp310-cp310-win_arm64.whl", hash = "sha256:ee217c198f2e41f184f3869f3e485557296d505b5195c513b2bfe0062dc537f1"}, + {file = "pillow-11.0.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:1c1d72714f429a521d8d2d018badc42414c3077eb187a59579f28e4270b4b0fc"}, + {file = "pillow-11.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:499c3a1b0d6fc8213519e193796eb1a86a1be4b1877d678b30f83fd979811d1a"}, + {file = "pillow-11.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c8b2351c85d855293a299038e1f89db92a2f35e8d2f783489c6f0b2b5f3fe8a3"}, + {file = "pillow-11.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f4dba50cfa56f910241eb7f883c20f1e7b1d8f7d91c750cd0b318bad443f4d5"}, + {file = "pillow-11.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:5ddbfd761ee00c12ee1be86c9c0683ecf5bb14c9772ddbd782085779a63dd55b"}, + {file = "pillow-11.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:45c566eb10b8967d71bf1ab8e4a525e5a93519e29ea071459ce517f6b903d7fa"}, + {file = "pillow-11.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b4fd7bd29610a83a8c9b564d457cf5bd92b4e11e79a4ee4716a63c959699b306"}, + {file = "pillow-11.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:cb929ca942d0ec4fac404cbf520ee6cac37bf35be479b970c4ffadf2b6a1cad9"}, + {file = "pillow-11.0.0-cp311-cp311-win32.whl", hash = "sha256:006bcdd307cc47ba43e924099a038cbf9591062e6c50e570819743f5607404f5"}, + {file = "pillow-11.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:52a2d8323a465f84faaba5236567d212c3668f2ab53e1c74c15583cf507a0291"}, + {file = "pillow-11.0.0-cp311-cp311-win_arm64.whl", hash = "sha256:16095692a253047fe3ec028e951fa4221a1f3ed3d80c397e83541a3037ff67c9"}, + {file = "pillow-11.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d2c0a187a92a1cb5ef2c8ed5412dd8d4334272617f532d4ad4de31e0495bd923"}, + {file = "pillow-11.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:084a07ef0821cfe4858fe86652fffac8e187b6ae677e9906e192aafcc1b69903"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8069c5179902dcdce0be9bfc8235347fdbac249d23bd90514b7a47a72d9fecf4"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f02541ef64077f22bf4924f225c0fd1248c168f86e4b7abdedd87d6ebaceab0f"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:fcb4621042ac4b7865c179bb972ed0da0218a076dc1820ffc48b1d74c1e37fe9"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:00177a63030d612148e659b55ba99527803288cea7c75fb05766ab7981a8c1b7"}, + {file = "pillow-11.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8853a3bf12afddfdf15f57c4b02d7ded92c7a75a5d7331d19f4f9572a89c17e6"}, + {file = "pillow-11.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3107c66e43bda25359d5ef446f59c497de2b5ed4c7fdba0894f8d6cf3822dafc"}, + {file = "pillow-11.0.0-cp312-cp312-win32.whl", hash = "sha256:86510e3f5eca0ab87429dd77fafc04693195eec7fd6a137c389c3eeb4cfb77c6"}, + {file = "pillow-11.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:8ec4a89295cd6cd4d1058a5e6aec6bf51e0eaaf9714774e1bfac7cfc9051db47"}, + {file = "pillow-11.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25"}, + {file = "pillow-11.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:bcd1fb5bb7b07f64c15618c89efcc2cfa3e95f0e3bcdbaf4642509de1942a699"}, + {file = "pillow-11.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0e038b0745997c7dcaae350d35859c9715c71e92ffb7e0f4a8e8a16732150f38"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ae08bd8ffc41aebf578c2af2f9d8749d91f448b3bfd41d7d9ff573d74f2a6b2"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d69bfd8ec3219ae71bcde1f942b728903cad25fafe3100ba2258b973bd2bc1b2"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:61b887f9ddba63ddf62fd02a3ba7add935d053b6dd7d58998c630e6dbade8527"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:c6a660307ca9d4867caa8d9ca2c2658ab685de83792d1876274991adec7b93fa"}, + {file = "pillow-11.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:73e3a0200cdda995c7e43dd47436c1548f87a30bb27fb871f352a22ab8dcf45f"}, + {file = "pillow-11.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:fba162b8872d30fea8c52b258a542c5dfd7b235fb5cb352240c8d63b414013eb"}, + {file = "pillow-11.0.0-cp313-cp313-win32.whl", hash = "sha256:f1b82c27e89fffc6da125d5eb0ca6e68017faf5efc078128cfaa42cf5cb38798"}, + {file = "pillow-11.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:8ba470552b48e5835f1d23ecb936bb7f71d206f9dfeee64245f30c3270b994de"}, + {file = "pillow-11.0.0-cp313-cp313-win_arm64.whl", hash = "sha256:846e193e103b41e984ac921b335df59195356ce3f71dcfd155aa79c603873b84"}, + {file = "pillow-11.0.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:4ad70c4214f67d7466bea6a08061eba35c01b1b89eaa098040a35272a8efb22b"}, + {file = "pillow-11.0.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:6ec0d5af64f2e3d64a165f490d96368bb5dea8b8f9ad04487f9ab60dc4bb6003"}, + {file = "pillow-11.0.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c809a70e43c7977c4a42aefd62f0131823ebf7dd73556fa5d5950f5b354087e2"}, + {file = "pillow-11.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:4b60c9520f7207aaf2e1d94de026682fc227806c6e1f55bba7606d1c94dd623a"}, + {file = "pillow-11.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:1e2688958a840c822279fda0086fec1fdab2f95bf2b717b66871c4ad9859d7e8"}, + {file = "pillow-11.0.0-cp313-cp313t-win32.whl", hash = "sha256:607bbe123c74e272e381a8d1957083a9463401f7bd01287f50521ecb05a313f8"}, + {file = "pillow-11.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:5c39ed17edea3bc69c743a8dd3e9853b7509625c2462532e62baa0732163a904"}, + {file = "pillow-11.0.0-cp313-cp313t-win_arm64.whl", hash = "sha256:75acbbeb05b86bc53cbe7b7e6fe00fbcf82ad7c684b3ad82e3d711da9ba287d3"}, + {file = "pillow-11.0.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:2e46773dc9f35a1dd28bd6981332fd7f27bec001a918a72a79b4133cf5291dba"}, + {file = "pillow-11.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2679d2258b7f1192b378e2893a8a0a0ca472234d4c2c0e6bdd3380e8dfa21b6a"}, + {file = "pillow-11.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eda2616eb2313cbb3eebbe51f19362eb434b18e3bb599466a1ffa76a033fb916"}, + {file = "pillow-11.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20ec184af98a121fb2da42642dea8a29ec80fc3efbaefb86d8fdd2606619045d"}, + {file = "pillow-11.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:8594f42df584e5b4bb9281799698403f7af489fba84c34d53d1c4bfb71b7c4e7"}, + {file = "pillow-11.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:c12b5ae868897c7338519c03049a806af85b9b8c237b7d675b8c5e089e4a618e"}, + {file = "pillow-11.0.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:70fbbdacd1d271b77b7721fe3cdd2d537bbbd75d29e6300c672ec6bb38d9672f"}, + {file = "pillow-11.0.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:5178952973e588b3f1360868847334e9e3bf49d19e169bbbdfaf8398002419ae"}, + {file = "pillow-11.0.0-cp39-cp39-win32.whl", hash = "sha256:8c676b587da5673d3c75bd67dd2a8cdfeb282ca38a30f37950511766b26858c4"}, + {file = "pillow-11.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:94f3e1780abb45062287b4614a5bc0874519c86a777d4a7ad34978e86428b8dd"}, + {file = "pillow-11.0.0-cp39-cp39-win_arm64.whl", hash = "sha256:290f2cc809f9da7d6d622550bbf4c1e57518212da51b6a30fe8e0a270a5b78bd"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:1187739620f2b365de756ce086fdb3604573337cc28a0d3ac4a01ab6b2d2a6d2"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:fbbcb7b57dc9c794843e3d1258c0fbf0f48656d46ffe9e09b63bbd6e8cd5d0a2"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d203af30149ae339ad1b4f710d9844ed8796e97fda23ffbc4cc472968a47d0b"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21a0d3b115009ebb8ac3d2ebec5c2982cc693da935f4ab7bb5c8ebe2f47d36f2"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:73853108f56df97baf2bb8b522f3578221e56f646ba345a372c78326710d3830"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e58876c91f97b0952eb766123bfef372792ab3f4e3e1f1a2267834c2ab131734"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:224aaa38177597bb179f3ec87eeefcce8e4f85e608025e9cfac60de237ba6316"}, + {file = "pillow-11.0.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:5bd2d3bdb846d757055910f0a59792d33b555800813c3b39ada1829c372ccb06"}, + {file = "pillow-11.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:375b8dd15a1f5d2feafff536d47e22f69625c1aa92f12b339ec0b2ca40263273"}, + {file = "pillow-11.0.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:daffdf51ee5db69a82dd127eabecce20729e21f7a3680cf7cbb23f0829189790"}, + {file = "pillow-11.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7326a1787e3c7b0429659e0a944725e1b03eeaa10edd945a86dead1913383944"}, + {file = "pillow-11.0.0.tar.gz", hash = "sha256:72bacbaf24ac003fea9bff9837d1eedb6088758d41e100c1552930151f677739"}, +] + +[package.extras] +docs = ["furo", "olefile", "sphinx (>=8.1)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinxext-opengraph"] fpx = ["olefile"] mic = ["olefile"] tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] @@ -5831,29 +6565,29 @@ xmp = ["defusedxml"] [[package]] name = "platformdirs" -version = "4.2.2" +version = "4.3.6" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." optional = false python-versions = ">=3.8" files = [ - {file = "platformdirs-4.2.2-py3-none-any.whl", hash = "sha256:2d7a1657e36a80ea911db832a8a6ece5ee53d8de21edd5cc5879af6530b1bfee"}, - {file = "platformdirs-4.2.2.tar.gz", hash = "sha256:38b7b51f512eed9e84a22788b4bce1de17c0adb134d6becb09836e37d8654cd3"}, + {file = "platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb"}, + {file = "platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907"}, ] [package.extras] -docs = ["furo (>=2023.9.10)", "proselint (>=0.13)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] -test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)"] -type = ["mypy (>=1.8)"] +docs = ["furo (>=2024.8.6)", "proselint (>=0.14)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=8.3.2)", "pytest-cov (>=5)", "pytest-mock (>=3.14)"] +type = ["mypy (>=1.11.2)"] [[package]] name = "plotly" -version = "5.23.0" +version = "5.24.1" description = "An open-source, interactive data visualization library for Python" optional = false python-versions = ">=3.8" files = [ - {file = "plotly-5.23.0-py3-none-any.whl", hash = "sha256:76cbe78f75eddc10c56f5a4ee3e7ccaade7c0a57465546f02098c0caed6c2d1a"}, - {file = "plotly-5.23.0.tar.gz", hash = "sha256:89e57d003a116303a34de6700862391367dd564222ab71f8531df70279fc0193"}, + {file = "plotly-5.24.1-py3-none-any.whl", hash = "sha256:f67073a1e637eb0dc3e46324d9d51e2fe76e9727c892dde64ddf1e1b51f29089"}, + {file = "plotly-5.24.1.tar.gz", hash = "sha256:dbc8ac8339d248a4bcc36e08a5659bacfe1b079390b8953533f4eb22169b4bae"}, ] [package.dependencies] @@ -5905,15 +6639,32 @@ docs = ["sphinx (>=1.7.1)"] redis = ["redis"] tests = ["pytest (>=5.4.1)", "pytest-cov (>=2.8.1)", "pytest-mypy (>=0.8.0)", "pytest-timeout (>=2.1.0)", "redis", "sphinx (>=6.0.0)", "types-redis"] +[[package]] +name = "postgrest" +version = "0.17.2" +description = "PostgREST client for Python. This library provides an ORM interface to PostgREST." +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "postgrest-0.17.2-py3-none-any.whl", hash = "sha256:f7c4f448e5a5e2d4c1dcf192edae9d1007c4261e9a6fb5116783a0046846ece2"}, + {file = "postgrest-0.17.2.tar.gz", hash = "sha256:445cd4e4a191e279492549df0c4e827d32f9d01d0852599bb8a6efb0f07fcf78"}, +] + +[package.dependencies] +deprecation = ">=2.1.0,<3.0.0" +httpx = {version = ">=0.26,<0.28", extras = ["http2"]} +pydantic = ">=1.9,<3.0" +strenum = {version = ">=0.4.9,<0.5.0", markers = "python_version < \"3.11\""} + [[package]] name = "posthog" -version = "3.5.0" +version = "3.7.0" description = "Integrate PostHog into any python application." optional = false python-versions = "*" files = [ - {file = "posthog-3.5.0-py2.py3-none-any.whl", hash = "sha256:3c672be7ba6f95d555ea207d4486c171d06657eb34b3ce25eb043bfe7b6b5b76"}, - {file = "posthog-3.5.0.tar.gz", hash = "sha256:8f7e3b2c6e8714d0c0c542a2109b83a7549f63b7113a133ab2763a89245ef2ef"}, + {file = "posthog-3.7.0-py2.py3-none-any.whl", hash = "sha256:3555161c3a9557b5666f96d8e1f17f410ea0f07db56e399e336a1656d4e5c722"}, + {file = "posthog-3.7.0.tar.gz", hash = "sha256:b095d4354ba23f8b346ab5daed8ecfc5108772f922006982dfe8b2d29ebc6e0e"}, ] [package.dependencies] @@ -5926,37 +6677,62 @@ six = ">=1.5" [package.extras] dev = ["black", "flake8", "flake8-print", "isort", "pre-commit"] sentry = ["django", "sentry-sdk"] -test = ["coverage", "flake8", "freezegun (==0.3.15)", "mock (>=2.0.0)", "pylint", "pytest", "pytest-timeout"] +test = ["coverage", "django", "flake8", "freezegun (==0.3.15)", "mock (>=2.0.0)", "pylint", "pytest", "pytest-timeout"] + +[[package]] +name = "pox" +version = "0.3.5" +description = "utilities for filesystem exploration and automated builds" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pox-0.3.5-py3-none-any.whl", hash = "sha256:9e82bcc9e578b43e80a99cad80f0d8f44f4d424f0ee4ee8d4db27260a6aa365a"}, + {file = "pox-0.3.5.tar.gz", hash = "sha256:8120ee4c94e950e6e0483e050a4f0e56076e590ba0a9add19524c254bd23c2d1"}, +] + +[[package]] +name = "ppft" +version = "1.7.6.9" +description = "distributed and parallel Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "ppft-1.7.6.9-py3-none-any.whl", hash = "sha256:dab36548db5ca3055067fbe6b1a17db5fee29f3c366c579a9a27cebb52ed96f0"}, + {file = "ppft-1.7.6.9.tar.gz", hash = "sha256:73161c67474ea9d81d04bcdad166d399cff3f084d5d2dc21ebdd46c075bbc265"}, +] + +[package.extras] +dill = ["dill (>=0.3.9)"] [[package]] name = "primp" -version = "0.5.5" +version = "0.6.5" description = "HTTP client that can impersonate web browsers, mimicking their headers and `TLS/JA3/JA4/HTTP2` fingerprints" optional = false python-versions = ">=3.8" files = [ - {file = "primp-0.5.5-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:cff9792e8422424528c23574b5364882d68134ee2743f4a2ae6a765746fb3028"}, - {file = "primp-0.5.5-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:78e13fc5d4d90d44a005dbd5dda116981828c803c86cf85816b3bb5363b045c8"}, - {file = "primp-0.5.5-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3714abfda79d3f5c90a5363db58994afbdbacc4b94fe14e9e5f8ab97e7b82577"}, - {file = "primp-0.5.5-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e54765900ee40eceb6bde43676d7e0b2e16ca1f77c0753981fe5e40afc0c2010"}, - {file = "primp-0.5.5-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:66c7eecc5a55225c42cfb99af857df04f994f3dd0d327c016d3af5414c1a2242"}, - {file = "primp-0.5.5-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:df262271cc1a41f4bf80d68396e967a27d7d3d3de355a3d016f953130e7a20be"}, - {file = "primp-0.5.5-cp38-abi3-win_amd64.whl", hash = "sha256:8b424118d6bab6f9d4980d0f35d5ccc1213ab9f1042497c6ee11730f2f94a876"}, - {file = "primp-0.5.5.tar.gz", hash = "sha256:8623e8a25fd686785296b12175f4173250a08db1de9ee4063282e262b94bf3f2"}, + {file = "primp-0.6.5-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:b2bab0250d38c02a437c75ed94b99e3a8c03a281ba9a4c33780ccd04999c741b"}, + {file = "primp-0.6.5-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:0aedb33515d86df4c1f91b9d5772e1b74d1593dfe8978c258b136c171f8ab94c"}, + {file = "primp-0.6.5-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0e8850be30fbfefeb76c1eb5859a55c5f11c8c285a4a03ebf99c73fea964b2a"}, + {file = "primp-0.6.5-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e9b71ac07a79cbb401390e2ee5a5767d0bf202a956a533fd084957020fcb2a64"}, + {file = "primp-0.6.5-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:79c65fcb07b36bd0f8c3966a4a18c4f6a6d624a33a0b0133b0f0cc8d0050c351"}, + {file = "primp-0.6.5-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5a55e450bb52a88f4a2891db50577c8f20b134d17d37e93361ee51de1a6fe8c8"}, + {file = "primp-0.6.5-cp38-abi3-win_amd64.whl", hash = "sha256:cbe584de5c177b9f0656b77e88721296ae6151b6c4565e2e0a342b6473990f27"}, + {file = "primp-0.6.5.tar.gz", hash = "sha256:abb46c579ae682f34c1f339faac38709c85ab76c056ec3711a26823334ab8124"}, ] [package.extras] -dev = ["pytest (>=8.1.1)"] +dev = ["certifi", "pytest (>=8.1.1)"] [[package]] name = "prompt-toolkit" -version = "3.0.47" +version = "3.0.48" description = "Library for building powerful interactive command lines in Python" optional = false python-versions = ">=3.7.0" files = [ - {file = "prompt_toolkit-3.0.47-py3-none-any.whl", hash = "sha256:0d7bfa67001d5e39d02c224b663abc33687405033a8c422d0d675a5a13361d10"}, - {file = "prompt_toolkit-3.0.47.tar.gz", hash = "sha256:1e1b29cb58080b1e69f207c893a1a7bf16d127a5c30c9d17a25a5d77792e5360"}, + {file = "prompt_toolkit-3.0.48-py3-none-any.whl", hash = "sha256:f49a827f90062e411f1ce1f854f2aedb3c23353244f8108b89283587397ac10e"}, + {file = "prompt_toolkit-3.0.48.tar.gz", hash = "sha256:d6623ab0477a80df74e646bdbc93621143f5caf104206aa29294d53de1a03d90"}, ] [package.dependencies] @@ -5964,13 +6740,13 @@ wcwidth = "*" [[package]] name = "proto-plus" -version = "1.24.0" +version = "1.25.0" description = "Beautiful, Pythonic protocol buffers." optional = false python-versions = ">=3.7" files = [ - {file = "proto-plus-1.24.0.tar.gz", hash = "sha256:30b72a5ecafe4406b0d339db35b56c4059064e69227b8c3bda7462397f966445"}, - {file = "proto_plus-1.24.0-py3-none-any.whl", hash = "sha256:402576830425e5f6ce4c2a6702400ac79897dab0b4343821aa5188b0fab81a12"}, + {file = "proto_plus-1.25.0-py3-none-any.whl", hash = "sha256:c91fc4a65074ade8e458e95ef8bac34d4008daa7cce4a12d6707066fca648961"}, + {file = "proto_plus-1.25.0.tar.gz", hash = "sha256:fbb17f57f7bd05a68b7707e745e26528b0b3c34e378db91eef93912c54982d91"}, ] [package.dependencies] @@ -5981,103 +6757,139 @@ testing = ["google-api-core (>=1.31.5)"] [[package]] name = "protobuf" -version = "4.25.4" +version = "4.25.5" description = "" optional = false python-versions = ">=3.8" files = [ - {file = "protobuf-4.25.4-cp310-abi3-win32.whl", hash = "sha256:db9fd45183e1a67722cafa5c1da3e85c6492a5383f127c86c4c4aa4845867dc4"}, - {file = "protobuf-4.25.4-cp310-abi3-win_amd64.whl", hash = "sha256:ba3d8504116a921af46499471c63a85260c1a5fc23333154a427a310e015d26d"}, - {file = "protobuf-4.25.4-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:eecd41bfc0e4b1bd3fa7909ed93dd14dd5567b98c941d6c1ad08fdcab3d6884b"}, - {file = "protobuf-4.25.4-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:4c8a70fdcb995dcf6c8966cfa3a29101916f7225e9afe3ced4395359955d3835"}, - {file = "protobuf-4.25.4-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:3319e073562e2515c6ddc643eb92ce20809f5d8f10fead3332f71c63be6a7040"}, - {file = "protobuf-4.25.4-cp38-cp38-win32.whl", hash = "sha256:7e372cbbda66a63ebca18f8ffaa6948455dfecc4e9c1029312f6c2edcd86c4e1"}, - {file = "protobuf-4.25.4-cp38-cp38-win_amd64.whl", hash = "sha256:051e97ce9fa6067a4546e75cb14f90cf0232dcb3e3d508c448b8d0e4265b61c1"}, - {file = "protobuf-4.25.4-cp39-cp39-win32.whl", hash = "sha256:90bf6fd378494eb698805bbbe7afe6c5d12c8e17fca817a646cd6a1818c696ca"}, - {file = "protobuf-4.25.4-cp39-cp39-win_amd64.whl", hash = "sha256:ac79a48d6b99dfed2729ccccee547b34a1d3d63289c71cef056653a846a2240f"}, - {file = "protobuf-4.25.4-py3-none-any.whl", hash = "sha256:bfbebc1c8e4793cfd58589acfb8a1026be0003e852b9da7db5a4285bde996978"}, - {file = "protobuf-4.25.4.tar.gz", hash = "sha256:0dc4a62cc4052a036ee2204d26fe4d835c62827c855c8a03f29fe6da146b380d"}, + {file = "protobuf-4.25.5-cp310-abi3-win32.whl", hash = "sha256:5e61fd921603f58d2f5acb2806a929b4675f8874ff5f330b7d6f7e2e784bbcd8"}, + {file = "protobuf-4.25.5-cp310-abi3-win_amd64.whl", hash = "sha256:4be0571adcbe712b282a330c6e89eae24281344429ae95c6d85e79e84780f5ea"}, + {file = "protobuf-4.25.5-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:b2fde3d805354df675ea4c7c6338c1aecd254dfc9925e88c6d31a2bcb97eb173"}, + {file = "protobuf-4.25.5-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:919ad92d9b0310070f8356c24b855c98df2b8bd207ebc1c0c6fcc9ab1e007f3d"}, + {file = "protobuf-4.25.5-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:fe14e16c22be926d3abfcb500e60cab068baf10b542b8c858fa27e098123e331"}, + {file = "protobuf-4.25.5-cp38-cp38-win32.whl", hash = "sha256:98d8d8aa50de6a2747efd9cceba361c9034050ecce3e09136f90de37ddba66e1"}, + {file = "protobuf-4.25.5-cp38-cp38-win_amd64.whl", hash = "sha256:b0234dd5a03049e4ddd94b93400b67803c823cfc405689688f59b34e0742381a"}, + {file = "protobuf-4.25.5-cp39-cp39-win32.whl", hash = "sha256:abe32aad8561aa7cc94fc7ba4fdef646e576983edb94a73381b03c53728a626f"}, + {file = "protobuf-4.25.5-cp39-cp39-win_amd64.whl", hash = "sha256:7a183f592dc80aa7c8da7ad9e55091c4ffc9497b3054452d629bb85fa27c2a45"}, + {file = "protobuf-4.25.5-py3-none-any.whl", hash = "sha256:0aebecb809cae990f8129ada5ca273d9d670b76d9bfc9b1809f0a9c02b7dbf41"}, + {file = "protobuf-4.25.5.tar.gz", hash = "sha256:7f8249476b4a9473645db7f8ab42b02fe1488cbe5fb72fddd445e0665afd8584"}, +] + +[[package]] +name = "psutil" +version = "6.1.0" +description = "Cross-platform lib for process and system monitoring in Python." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +files = [ + {file = "psutil-6.1.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:ff34df86226c0227c52f38b919213157588a678d049688eded74c76c8ba4a5d0"}, + {file = "psutil-6.1.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:c0e0c00aa18ca2d3b2b991643b799a15fc8f0563d2ebb6040f64ce8dc027b942"}, + {file = "psutil-6.1.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:000d1d1ebd634b4efb383f4034437384e44a6d455260aaee2eca1e9c1b55f047"}, + {file = "psutil-6.1.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:5cd2bcdc75b452ba2e10f0e8ecc0b57b827dd5d7aaffbc6821b2a9a242823a76"}, + {file = "psutil-6.1.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:045f00a43c737f960d273a83973b2511430d61f283a44c96bf13a6e829ba8fdc"}, + {file = "psutil-6.1.0-cp27-none-win32.whl", hash = "sha256:9118f27452b70bb1d9ab3198c1f626c2499384935aaf55388211ad982611407e"}, + {file = "psutil-6.1.0-cp27-none-win_amd64.whl", hash = "sha256:a8506f6119cff7015678e2bce904a4da21025cc70ad283a53b099e7620061d85"}, + {file = "psutil-6.1.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:6e2dcd475ce8b80522e51d923d10c7871e45f20918e027ab682f94f1c6351688"}, + {file = "psutil-6.1.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:0895b8414afafc526712c498bd9de2b063deaac4021a3b3c34566283464aff8e"}, + {file = "psutil-6.1.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9dcbfce5d89f1d1f2546a2090f4fcf87c7f669d1d90aacb7d7582addece9fb38"}, + {file = "psutil-6.1.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:498c6979f9c6637ebc3a73b3f87f9eb1ec24e1ce53a7c5173b8508981614a90b"}, + {file = "psutil-6.1.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d905186d647b16755a800e7263d43df08b790d709d575105d419f8b6ef65423a"}, + {file = "psutil-6.1.0-cp36-cp36m-win32.whl", hash = "sha256:6d3fbbc8d23fcdcb500d2c9f94e07b1342df8ed71b948a2649b5cb060a7c94ca"}, + {file = "psutil-6.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:1209036fbd0421afde505a4879dee3b2fd7b1e14fee81c0069807adcbbcca747"}, + {file = "psutil-6.1.0-cp37-abi3-win32.whl", hash = "sha256:1ad45a1f5d0b608253b11508f80940985d1d0c8f6111b5cb637533a0e6ddc13e"}, + {file = "psutil-6.1.0-cp37-abi3-win_amd64.whl", hash = "sha256:a8fb3752b491d246034fa4d279ff076501588ce8cbcdbb62c32fd7a377d996be"}, + {file = "psutil-6.1.0.tar.gz", hash = "sha256:353815f59a7f64cdaca1c0307ee13558a0512f6db064e92fe833784f08539c7a"}, ] +[package.extras] +dev = ["black", "check-manifest", "coverage", "packaging", "pylint", "pyperf", "pypinfo", "pytest-cov", "requests", "rstcheck", "ruff", "sphinx", "sphinx_rtd_theme", "toml-sort", "twine", "virtualenv", "wheel"] +test = ["pytest", "pytest-xdist", "setuptools"] + [[package]] name = "psycopg2-binary" -version = "2.9.9" +version = "2.9.10" description = "psycopg2 - Python-PostgreSQL Database Adapter" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" +files = [ + {file = "psycopg2-binary-2.9.10.tar.gz", hash = "sha256:4b3df0e6990aa98acda57d983942eff13d824135fe2250e6522edaa782a06de2"}, + {file = "psycopg2_binary-2.9.10-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:0ea8e3d0ae83564f2fc554955d327fa081d065c8ca5cc6d2abb643e2c9c1200f"}, + {file = "psycopg2_binary-2.9.10-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:3e9c76f0ac6f92ecfc79516a8034a544926430f7b080ec5a0537bca389ee0906"}, + {file = "psycopg2_binary-2.9.10-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ad26b467a405c798aaa1458ba09d7e2b6e5f96b1ce0ac15d82fd9f95dc38a92"}, + {file = "psycopg2_binary-2.9.10-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:270934a475a0e4b6925b5f804e3809dd5f90f8613621d062848dd82f9cd62007"}, + {file = "psycopg2_binary-2.9.10-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:48b338f08d93e7be4ab2b5f1dbe69dc5e9ef07170fe1f86514422076d9c010d0"}, + {file = "psycopg2_binary-2.9.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f4152f8f76d2023aac16285576a9ecd2b11a9895373a1f10fd9db54b3ff06b4"}, + {file = "psycopg2_binary-2.9.10-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:32581b3020c72d7a421009ee1c6bf4a131ef5f0a968fab2e2de0c9d2bb4577f1"}, + {file = "psycopg2_binary-2.9.10-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:2ce3e21dc3437b1d960521eca599d57408a695a0d3c26797ea0f72e834c7ffe5"}, + {file = "psycopg2_binary-2.9.10-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:e984839e75e0b60cfe75e351db53d6db750b00de45644c5d1f7ee5d1f34a1ce5"}, + {file = "psycopg2_binary-2.9.10-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3c4745a90b78e51d9ba06e2088a2fe0c693ae19cc8cb051ccda44e8df8a6eb53"}, + {file = "psycopg2_binary-2.9.10-cp310-cp310-win32.whl", hash = "sha256:e5720a5d25e3b99cd0dc5c8a440570469ff82659bb09431c1439b92caf184d3b"}, + {file = "psycopg2_binary-2.9.10-cp310-cp310-win_amd64.whl", hash = "sha256:3c18f74eb4386bf35e92ab2354a12c17e5eb4d9798e4c0ad3a00783eae7cd9f1"}, + {file = "psycopg2_binary-2.9.10-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:04392983d0bb89a8717772a193cfaac58871321e3ec69514e1c4e0d4957b5aff"}, + {file = "psycopg2_binary-2.9.10-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:1a6784f0ce3fec4edc64e985865c17778514325074adf5ad8f80636cd029ef7c"}, + {file = "psycopg2_binary-2.9.10-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5f86c56eeb91dc3135b3fd8a95dc7ae14c538a2f3ad77a19645cf55bab1799c"}, + {file = "psycopg2_binary-2.9.10-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b3d2491d4d78b6b14f76881905c7a8a8abcf974aad4a8a0b065273a0ed7a2cb"}, + {file = "psycopg2_binary-2.9.10-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2286791ececda3a723d1910441c793be44625d86d1a4e79942751197f4d30341"}, + {file = "psycopg2_binary-2.9.10-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:512d29bb12608891e349af6a0cccedce51677725a921c07dba6342beaf576f9a"}, + {file = "psycopg2_binary-2.9.10-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:5a507320c58903967ef7384355a4da7ff3f28132d679aeb23572753cbf2ec10b"}, + {file = "psycopg2_binary-2.9.10-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:6d4fa1079cab9018f4d0bd2db307beaa612b0d13ba73b5c6304b9fe2fb441ff7"}, + {file = "psycopg2_binary-2.9.10-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:851485a42dbb0bdc1edcdabdb8557c09c9655dfa2ca0460ff210522e073e319e"}, + {file = "psycopg2_binary-2.9.10-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:35958ec9e46432d9076286dda67942ed6d968b9c3a6a2fd62b48939d1d78bf68"}, + {file = "psycopg2_binary-2.9.10-cp311-cp311-win32.whl", hash = "sha256:ecced182e935529727401b24d76634a357c71c9275b356efafd8a2a91ec07392"}, + {file = "psycopg2_binary-2.9.10-cp311-cp311-win_amd64.whl", hash = "sha256:ee0e8c683a7ff25d23b55b11161c2663d4b099770f6085ff0a20d4505778d6b4"}, + {file = "psycopg2_binary-2.9.10-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:880845dfe1f85d9d5f7c412efea7a08946a46894537e4e5d091732eb1d34d9a0"}, + {file = "psycopg2_binary-2.9.10-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:9440fa522a79356aaa482aa4ba500b65f28e5d0e63b801abf6aa152a29bd842a"}, + {file = "psycopg2_binary-2.9.10-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e3923c1d9870c49a2d44f795df0c889a22380d36ef92440ff618ec315757e539"}, + {file = "psycopg2_binary-2.9.10-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b2c956c028ea5de47ff3a8d6b3cc3330ab45cf0b7c3da35a2d6ff8420896526"}, + {file = "psycopg2_binary-2.9.10-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f758ed67cab30b9a8d2833609513ce4d3bd027641673d4ebc9c067e4d208eec1"}, + {file = "psycopg2_binary-2.9.10-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8cd9b4f2cfab88ed4a9106192de509464b75a906462fb846b936eabe45c2063e"}, + {file = "psycopg2_binary-2.9.10-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6dc08420625b5a20b53551c50deae6e231e6371194fa0651dbe0fb206452ae1f"}, + {file = "psycopg2_binary-2.9.10-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:d7cd730dfa7c36dbe8724426bf5612798734bff2d3c3857f36f2733f5bfc7c00"}, + {file = "psycopg2_binary-2.9.10-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:155e69561d54d02b3c3209545fb08938e27889ff5a10c19de8d23eb5a41be8a5"}, + {file = "psycopg2_binary-2.9.10-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c3cc28a6fd5a4a26224007712e79b81dbaee2ffb90ff406256158ec4d7b52b47"}, + {file = "psycopg2_binary-2.9.10-cp312-cp312-win32.whl", hash = "sha256:ec8a77f521a17506a24a5f626cb2aee7850f9b69a0afe704586f63a464f3cd64"}, + {file = "psycopg2_binary-2.9.10-cp312-cp312-win_amd64.whl", hash = "sha256:18c5ee682b9c6dd3696dad6e54cc7ff3a1a9020df6a5c0f861ef8bfd338c3ca0"}, + {file = "psycopg2_binary-2.9.10-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:26540d4a9a4e2b096f1ff9cce51253d0504dca5a85872c7f7be23be5a53eb18d"}, + {file = "psycopg2_binary-2.9.10-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e217ce4d37667df0bc1c397fdcd8de5e81018ef305aed9415c3b093faaeb10fb"}, + {file = "psycopg2_binary-2.9.10-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:245159e7ab20a71d989da00f280ca57da7641fa2cdcf71749c193cea540a74f7"}, + {file = "psycopg2_binary-2.9.10-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c4ded1a24b20021ebe677b7b08ad10bf09aac197d6943bfe6fec70ac4e4690d"}, + {file = "psycopg2_binary-2.9.10-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3abb691ff9e57d4a93355f60d4f4c1dd2d68326c968e7db17ea96df3c023ef73"}, + {file = "psycopg2_binary-2.9.10-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8608c078134f0b3cbd9f89b34bd60a943b23fd33cc5f065e8d5f840061bd0673"}, + {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:230eeae2d71594103cd5b93fd29d1ace6420d0b86f4778739cb1a5a32f607d1f"}, + {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909"}, + {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1"}, + {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567"}, + {file = "psycopg2_binary-2.9.10-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:eb09aa7f9cecb45027683bb55aebaaf45a0df8bf6de68801a6afdc7947bb09d4"}, + {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b73d6d7f0ccdad7bc43e6d34273f70d587ef62f824d7261c4ae9b8b1b6af90e8"}, + {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce5ab4bf46a211a8e924d307c1b1fcda82368586a19d0a24f8ae166f5c784864"}, + {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:056470c3dc57904bbf63d6f534988bafc4e970ffd50f6271fc4ee7daad9498a5"}, + {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:73aa0e31fa4bb82578f3a6c74a73c273367727de397a7a0f07bd83cbea696baa"}, + {file = "psycopg2_binary-2.9.10-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:8de718c0e1c4b982a54b41779667242bc630b2197948405b7bd8ce16bcecac92"}, + {file = "psycopg2_binary-2.9.10-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:5c370b1e4975df846b0277b4deba86419ca77dbc25047f535b0bb03d1a544d44"}, + {file = "psycopg2_binary-2.9.10-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:ffe8ed017e4ed70f68b7b371d84b7d4a790368db9203dfc2d222febd3a9c8863"}, + {file = "psycopg2_binary-2.9.10-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:8aecc5e80c63f7459a1a2ab2c64df952051df196294d9f739933a9f6687e86b3"}, + {file = "psycopg2_binary-2.9.10-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:7a813c8bdbaaaab1f078014b9b0b13f5de757e2b5d9be6403639b298a04d218b"}, + {file = "psycopg2_binary-2.9.10-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d00924255d7fc916ef66e4bf22f354a940c67179ad3fd7067d7a0a9c84d2fbfc"}, + {file = "psycopg2_binary-2.9.10-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7559bce4b505762d737172556a4e6ea8a9998ecac1e39b5233465093e8cee697"}, + {file = "psycopg2_binary-2.9.10-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e8b58f0a96e7a1e341fc894f62c1177a7c83febebb5ff9123b579418fdc8a481"}, + {file = "psycopg2_binary-2.9.10-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b269105e59ac96aba877c1707c600ae55711d9dcd3fc4b5012e4af68e30c648"}, + {file = "psycopg2_binary-2.9.10-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:79625966e176dc97ddabc142351e0409e28acf4660b88d1cf6adb876d20c490d"}, + {file = "psycopg2_binary-2.9.10-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:8aabf1c1a04584c168984ac678a668094d831f152859d06e055288fa515e4d30"}, + {file = "psycopg2_binary-2.9.10-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:19721ac03892001ee8fdd11507e6a2e01f4e37014def96379411ca99d78aeb2c"}, + {file = "psycopg2_binary-2.9.10-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:7f5d859928e635fa3ce3477704acee0f667b3a3d3e4bb109f2b18d4005f38287"}, + {file = "psycopg2_binary-2.9.10-cp39-cp39-win32.whl", hash = "sha256:3216ccf953b3f267691c90c6fe742e45d890d8272326b4a8b20850a03d05b7b8"}, + {file = "psycopg2_binary-2.9.10-cp39-cp39-win_amd64.whl", hash = "sha256:30e34c4e97964805f715206c7b789d54a78b70f3ff19fbe590104b71c45600e5"}, +] + +[[package]] +name = "py" +version = "1.11.0" +description = "library with cross-python path, ini-parsing, io, code, log facilities" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" files = [ - {file = "psycopg2-binary-2.9.9.tar.gz", hash = "sha256:7f01846810177d829c7692f1f5ada8096762d9172af1b1a28d4ab5b77c923c1c"}, - {file = "psycopg2_binary-2.9.9-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c2470da5418b76232f02a2fcd2229537bb2d5a7096674ce61859c3229f2eb202"}, - {file = "psycopg2_binary-2.9.9-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c6af2a6d4b7ee9615cbb162b0738f6e1fd1f5c3eda7e5da17861eacf4c717ea7"}, - {file = "psycopg2_binary-2.9.9-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:75723c3c0fbbf34350b46a3199eb50638ab22a0228f93fb472ef4d9becc2382b"}, - {file = "psycopg2_binary-2.9.9-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:83791a65b51ad6ee6cf0845634859d69a038ea9b03d7b26e703f94c7e93dbcf9"}, - {file = "psycopg2_binary-2.9.9-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0ef4854e82c09e84cc63084a9e4ccd6d9b154f1dbdd283efb92ecd0b5e2b8c84"}, - {file = "psycopg2_binary-2.9.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ed1184ab8f113e8d660ce49a56390ca181f2981066acc27cf637d5c1e10ce46e"}, - {file = "psycopg2_binary-2.9.9-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d2997c458c690ec2bc6b0b7ecbafd02b029b7b4283078d3b32a852a7ce3ddd98"}, - {file = "psycopg2_binary-2.9.9-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:b58b4710c7f4161b5e9dcbe73bb7c62d65670a87df7bcce9e1faaad43e715245"}, - {file = "psycopg2_binary-2.9.9-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:0c009475ee389757e6e34611d75f6e4f05f0cf5ebb76c6037508318e1a1e0d7e"}, - {file = "psycopg2_binary-2.9.9-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8dbf6d1bc73f1d04ec1734bae3b4fb0ee3cb2a493d35ede9badbeb901fb40f6f"}, - {file = "psycopg2_binary-2.9.9-cp310-cp310-win32.whl", hash = "sha256:3f78fd71c4f43a13d342be74ebbc0666fe1f555b8837eb113cb7416856c79682"}, - {file = "psycopg2_binary-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:876801744b0dee379e4e3c38b76fc89f88834bb15bf92ee07d94acd06ec890a0"}, - {file = "psycopg2_binary-2.9.9-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ee825e70b1a209475622f7f7b776785bd68f34af6e7a46e2e42f27b659b5bc26"}, - {file = "psycopg2_binary-2.9.9-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1ea665f8ce695bcc37a90ee52de7a7980be5161375d42a0b6c6abedbf0d81f0f"}, - {file = "psycopg2_binary-2.9.9-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:143072318f793f53819048fdfe30c321890af0c3ec7cb1dfc9cc87aa88241de2"}, - {file = "psycopg2_binary-2.9.9-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c332c8d69fb64979ebf76613c66b985414927a40f8defa16cf1bc028b7b0a7b0"}, - {file = "psycopg2_binary-2.9.9-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f7fc5a5acafb7d6ccca13bfa8c90f8c51f13d8fb87d95656d3950f0158d3ce53"}, - {file = "psycopg2_binary-2.9.9-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:977646e05232579d2e7b9c59e21dbe5261f403a88417f6a6512e70d3f8a046be"}, - {file = "psycopg2_binary-2.9.9-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:b6356793b84728d9d50ead16ab43c187673831e9d4019013f1402c41b1db9b27"}, - {file = "psycopg2_binary-2.9.9-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:bc7bb56d04601d443f24094e9e31ae6deec9ccb23581f75343feebaf30423359"}, - {file = "psycopg2_binary-2.9.9-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:77853062a2c45be16fd6b8d6de2a99278ee1d985a7bd8b103e97e41c034006d2"}, - {file = "psycopg2_binary-2.9.9-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:78151aa3ec21dccd5cdef6c74c3e73386dcdfaf19bced944169697d7ac7482fc"}, - {file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"}, - {file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e6f98446430fdf41bd36d4faa6cb409f5140c1c2cf58ce0bbdaf16af7d3f119"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c77e3d1862452565875eb31bdb45ac62502feabbd53429fdc39a1cc341d681ba"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"}, - {file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"}, - {file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"}, - {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"}, - {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"}, - {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8359bf4791968c5a78c56103702000105501adb557f3cf772b2c207284273984"}, - {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:275ff571376626195ab95a746e6a04c7df8ea34638b99fc11160de91f2fef503"}, - {file = "psycopg2_binary-2.9.9-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:f9b5571d33660d5009a8b3c25dc1db560206e2d2f89d3df1cb32d72c0d117d52"}, - {file = "psycopg2_binary-2.9.9-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:420f9bbf47a02616e8554e825208cb947969451978dceb77f95ad09c37791dae"}, - {file = "psycopg2_binary-2.9.9-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:4154ad09dac630a0f13f37b583eae260c6aa885d67dfbccb5b02c33f31a6d420"}, - {file = "psycopg2_binary-2.9.9-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:a148c5d507bb9b4f2030a2025c545fccb0e1ef317393eaba42e7eabd28eb6041"}, - {file = "psycopg2_binary-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:68fc1f1ba168724771e38bee37d940d2865cb0f562380a1fb1ffb428b75cb692"}, - {file = "psycopg2_binary-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:281309265596e388ef483250db3640e5f414168c5a67e9c665cafce9492eda2f"}, - {file = "psycopg2_binary-2.9.9-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:60989127da422b74a04345096c10d416c2b41bd7bf2a380eb541059e4e999980"}, - {file = "psycopg2_binary-2.9.9-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:246b123cc54bb5361588acc54218c8c9fb73068bf227a4a531d8ed56fa3ca7d6"}, - {file = "psycopg2_binary-2.9.9-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:34eccd14566f8fe14b2b95bb13b11572f7c7d5c36da61caf414d23b91fcc5d94"}, - {file = "psycopg2_binary-2.9.9-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18d0ef97766055fec15b5de2c06dd8e7654705ce3e5e5eed3b6651a1d2a9a152"}, - {file = "psycopg2_binary-2.9.9-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d3f82c171b4ccd83bbaf35aa05e44e690113bd4f3b7b6cc54d2219b132f3ae55"}, - {file = "psycopg2_binary-2.9.9-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ead20f7913a9c1e894aebe47cccf9dc834e1618b7aa96155d2091a626e59c972"}, - {file = "psycopg2_binary-2.9.9-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ca49a8119c6cbd77375ae303b0cfd8c11f011abbbd64601167ecca18a87e7cdd"}, - {file = "psycopg2_binary-2.9.9-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:323ba25b92454adb36fa425dc5cf6f8f19f78948cbad2e7bc6cdf7b0d7982e59"}, - {file = "psycopg2_binary-2.9.9-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:1236ed0952fbd919c100bc839eaa4a39ebc397ed1c08a97fc45fee2a595aa1b3"}, - {file = "psycopg2_binary-2.9.9-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:729177eaf0aefca0994ce4cffe96ad3c75e377c7b6f4efa59ebf003b6d398716"}, - {file = "psycopg2_binary-2.9.9-cp38-cp38-win32.whl", hash = "sha256:804d99b24ad523a1fe18cc707bf741670332f7c7412e9d49cb5eab67e886b9b5"}, - {file = "psycopg2_binary-2.9.9-cp38-cp38-win_amd64.whl", hash = "sha256:a6cdcc3ede532f4a4b96000b6362099591ab4a3e913d70bcbac2b56c872446f7"}, - {file = "psycopg2_binary-2.9.9-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:72dffbd8b4194858d0941062a9766f8297e8868e1dd07a7b36212aaa90f49472"}, - {file = "psycopg2_binary-2.9.9-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:30dcc86377618a4c8f3b72418df92e77be4254d8f89f14b8e8f57d6d43603c0f"}, - {file = "psycopg2_binary-2.9.9-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:31a34c508c003a4347d389a9e6fcc2307cc2150eb516462a7a17512130de109e"}, - {file = "psycopg2_binary-2.9.9-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:15208be1c50b99203fe88d15695f22a5bed95ab3f84354c494bcb1d08557df67"}, - {file = "psycopg2_binary-2.9.9-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1873aade94b74715be2246321c8650cabf5a0d098a95bab81145ffffa4c13876"}, - {file = "psycopg2_binary-2.9.9-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a58c98a7e9c021f357348867f537017057c2ed7f77337fd914d0bedb35dace7"}, - {file = "psycopg2_binary-2.9.9-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4686818798f9194d03c9129a4d9a702d9e113a89cb03bffe08c6cf799e053291"}, - {file = "psycopg2_binary-2.9.9-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ebdc36bea43063116f0486869652cb2ed7032dbc59fbcb4445c4862b5c1ecf7f"}, - {file = "psycopg2_binary-2.9.9-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:ca08decd2697fdea0aea364b370b1249d47336aec935f87b8bbfd7da5b2ee9c1"}, - {file = "psycopg2_binary-2.9.9-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ac05fb791acf5e1a3e39402641827780fe44d27e72567a000412c648a85ba860"}, - {file = "psycopg2_binary-2.9.9-cp39-cp39-win32.whl", hash = "sha256:9dba73be7305b399924709b91682299794887cbbd88e38226ed9f6712eabee90"}, - {file = "psycopg2_binary-2.9.9-cp39-cp39-win_amd64.whl", hash = "sha256:f7ae5d65ccfbebdfa761585228eb4d0df3a8b15cfb53bd953e713e09fbb12957"}, + {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"}, + {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"}, ] [[package]] @@ -6093,75 +6905,78 @@ files = [ [[package]] name = "pyarrow" -version = "17.0.0" +version = "18.0.0" description = "Python library for Apache Arrow" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "pyarrow-17.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a5c8b238d47e48812ee577ee20c9a2779e6a5904f1708ae240f53ecbee7c9f07"}, - {file = "pyarrow-17.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:db023dc4c6cae1015de9e198d41250688383c3f9af8f565370ab2b4cb5f62655"}, - {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da1e060b3876faa11cee287839f9cc7cdc00649f475714b8680a05fd9071d545"}, - {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75c06d4624c0ad6674364bb46ef38c3132768139ddec1c56582dbac54f2663e2"}, - {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:fa3c246cc58cb5a4a5cb407a18f193354ea47dd0648194e6265bd24177982fe8"}, - {file = "pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:f7ae2de664e0b158d1607699a16a488de3d008ba99b3a7aa5de1cbc13574d047"}, - {file = "pyarrow-17.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:5984f416552eea15fd9cee03da53542bf4cddaef5afecefb9aa8d1010c335087"}, - {file = "pyarrow-17.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:1c8856e2ef09eb87ecf937104aacfa0708f22dfeb039c363ec99735190ffb977"}, - {file = "pyarrow-17.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e19f569567efcbbd42084e87f948778eb371d308e137a0f97afe19bb860ccb3"}, - {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b244dc8e08a23b3e352899a006a26ae7b4d0da7bb636872fa8f5884e70acf15"}, - {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b72e87fe3e1db343995562f7fff8aee354b55ee83d13afba65400c178ab2597"}, - {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:dc5c31c37409dfbc5d014047817cb4ccd8c1ea25d19576acf1a001fe07f5b420"}, - {file = "pyarrow-17.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e3343cb1e88bc2ea605986d4b94948716edc7a8d14afd4e2c097232f729758b4"}, - {file = "pyarrow-17.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:a27532c38f3de9eb3e90ecab63dfda948a8ca859a66e3a47f5f42d1e403c4d03"}, - {file = "pyarrow-17.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:9b8a823cea605221e61f34859dcc03207e52e409ccf6354634143e23af7c8d22"}, - {file = "pyarrow-17.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f1e70de6cb5790a50b01d2b686d54aaf73da01266850b05e3af2a1bc89e16053"}, - {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0071ce35788c6f9077ff9ecba4858108eebe2ea5a3f7cf2cf55ebc1dbc6ee24a"}, - {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:757074882f844411fcca735e39aae74248a1531367a7c80799b4266390ae51cc"}, - {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9ba11c4f16976e89146781a83833df7f82077cdab7dc6232c897789343f7891a"}, - {file = "pyarrow-17.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b0c6ac301093b42d34410b187bba560b17c0330f64907bfa4f7f7f2444b0cf9b"}, - {file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"}, - {file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"}, - {file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"}, - {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"}, - {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"}, - {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"}, - {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"}, - {file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"}, - {file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"}, - {file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"}, - {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"}, - {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"}, - {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"}, - {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"}, - {file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"}, - {file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"}, + {file = "pyarrow-18.0.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:2333f93260674e185cfbf208d2da3007132572e56871f451ba1a556b45dae6e2"}, + {file = "pyarrow-18.0.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:4c381857754da44326f3a49b8b199f7f87a51c2faacd5114352fc78de30d3aba"}, + {file = "pyarrow-18.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:603cd8ad4976568954598ef0a6d4ed3dfb78aff3d57fa8d6271f470f0ce7d34f"}, + {file = "pyarrow-18.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58a62549a3e0bc9e03df32f350e10e1efb94ec6cf63e3920c3385b26663948ce"}, + {file = "pyarrow-18.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:bc97316840a349485fbb137eb8d0f4d7057e1b2c1272b1a20eebbbe1848f5122"}, + {file = "pyarrow-18.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:2e549a748fa8b8715e734919923f69318c953e077e9c02140ada13e59d043310"}, + {file = "pyarrow-18.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:606e9a3dcb0f52307c5040698ea962685fb1c852d72379ee9412be7de9c5f9e2"}, + {file = "pyarrow-18.0.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d5795e37c0a33baa618c5e054cd61f586cf76850a251e2b21355e4085def6280"}, + {file = "pyarrow-18.0.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:5f0510608ccd6e7f02ca8596962afb8c6cc84c453e7be0da4d85f5f4f7b0328a"}, + {file = "pyarrow-18.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:616ea2826c03c16e87f517c46296621a7c51e30400f6d0a61be645f203aa2b93"}, + {file = "pyarrow-18.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a1824f5b029ddd289919f354bc285992cb4e32da518758c136271cf66046ef22"}, + {file = "pyarrow-18.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6dd1b52d0d58dd8f685ced9971eb49f697d753aa7912f0a8f50833c7a7426319"}, + {file = "pyarrow-18.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:320ae9bd45ad7ecc12ec858b3e8e462578de060832b98fc4d671dee9f10d9954"}, + {file = "pyarrow-18.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:2c992716cffb1088414f2b478f7af0175fd0a76fea80841b1706baa8fb0ebaad"}, + {file = "pyarrow-18.0.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:e7ab04f272f98ebffd2a0661e4e126036f6936391ba2889ed2d44c5006237802"}, + {file = "pyarrow-18.0.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:03f40b65a43be159d2f97fd64dc998f769d0995a50c00f07aab58b0b3da87e1f"}, + {file = "pyarrow-18.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:be08af84808dff63a76860847c48ec0416928a7b3a17c2f49a072cac7c45efbd"}, + {file = "pyarrow-18.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c70c1965cde991b711a98448ccda3486f2a336457cf4ec4dca257a926e149c9"}, + {file = "pyarrow-18.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:00178509f379415a3fcf855af020e3340254f990a8534294ec3cf674d6e255fd"}, + {file = "pyarrow-18.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:a71ab0589a63a3e987beb2bc172e05f000a5c5be2636b4b263c44034e215b5d7"}, + {file = "pyarrow-18.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:fe92efcdbfa0bcf2fa602e466d7f2905500f33f09eb90bf0bcf2e6ca41b574c8"}, + {file = "pyarrow-18.0.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:907ee0aa8ca576f5e0cdc20b5aeb2ad4d3953a3b4769fc4b499e00ef0266f02f"}, + {file = "pyarrow-18.0.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:66dcc216ebae2eb4c37b223feaf82f15b69d502821dde2da138ec5a3716e7463"}, + {file = "pyarrow-18.0.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc1daf7c425f58527900876354390ee41b0ae962a73ad0959b9d829def583bb1"}, + {file = "pyarrow-18.0.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:871b292d4b696b09120ed5bde894f79ee2a5f109cb84470546471df264cae136"}, + {file = "pyarrow-18.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:082ba62bdcb939824ba1ce10b8acef5ab621da1f4c4805e07bfd153617ac19d4"}, + {file = "pyarrow-18.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:2c664ab88b9766413197733c1720d3dcd4190e8fa3bbdc3710384630a0a7207b"}, + {file = "pyarrow-18.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:dc892be34dbd058e8d189b47db1e33a227d965ea8805a235c8a7286f7fd17d3a"}, + {file = "pyarrow-18.0.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:28f9c39a56d2c78bf6b87dcc699d520ab850919d4a8c7418cd20eda49874a2ea"}, + {file = "pyarrow-18.0.0-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:f1a198a50c409ab2d009fbf20956ace84567d67f2c5701511d4dd561fae6f32e"}, + {file = "pyarrow-18.0.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5bd7fd32e3ace012d43925ea4fc8bd1b02cc6cc1e9813b518302950e89b5a22"}, + {file = "pyarrow-18.0.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:336addb8b6f5208be1b2398442c703a710b6b937b1a046065ee4db65e782ff5a"}, + {file = "pyarrow-18.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:45476490dd4adec5472c92b4d253e245258745d0ccaabe706f8d03288ed60a79"}, + {file = "pyarrow-18.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:b46591222c864e7da7faa3b19455196416cd8355ff6c2cc2e65726a760a3c420"}, + {file = "pyarrow-18.0.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:eb7e3abcda7e1e6b83c2dc2909c8d045881017270a119cc6ee7fdcfe71d02df8"}, + {file = "pyarrow-18.0.0-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:09f30690b99ce34e0da64d20dab372ee54431745e4efb78ac938234a282d15f9"}, + {file = "pyarrow-18.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d5ca5d707e158540312e09fd907f9f49bacbe779ab5236d9699ced14d2293b8"}, + {file = "pyarrow-18.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d6331f280c6e4521c69b201a42dd978f60f7e129511a55da9e0bfe426b4ebb8d"}, + {file = "pyarrow-18.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3ac24b2be732e78a5a3ac0b3aa870d73766dd00beba6e015ea2ea7394f8b4e55"}, + {file = "pyarrow-18.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:b30a927c6dff89ee702686596f27c25160dd6c99be5bcc1513a763ae5b1bfc03"}, + {file = "pyarrow-18.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:8f40ec677e942374e3d7f2fad6a67a4c2811a8b975e8703c6fd26d3b168a90e2"}, + {file = "pyarrow-18.0.0.tar.gz", hash = "sha256:a6aa027b1a9d2970cf328ccd6dbe4a996bc13c39fd427f502782f5bdb9ca20f5"}, ] -[package.dependencies] -numpy = ">=1.16.6" - [package.extras] test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] [[package]] name = "pyasn1" -version = "0.6.0" +version = "0.6.1" description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" optional = false python-versions = ">=3.8" files = [ - {file = "pyasn1-0.6.0-py2.py3-none-any.whl", hash = "sha256:cca4bb0f2df5504f02f6f8a775b6e416ff9b0b3b16f7ee80b5a3153d9b804473"}, - {file = "pyasn1-0.6.0.tar.gz", hash = "sha256:3a35ab2c4b5ef98e17dfdec8ab074046fbda76e281c5a706ccd82328cfc8f64c"}, + {file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"}, + {file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"}, ] [[package]] name = "pyasn1-modules" -version = "0.4.0" +version = "0.4.1" description = "A collection of ASN.1-based protocols modules" optional = false python-versions = ">=3.8" files = [ - {file = "pyasn1_modules-0.4.0-py3-none-any.whl", hash = "sha256:be04f15b66c206eed667e0bb5ab27e2b1855ea54a842e5037738099e8ca4ae0b"}, - {file = "pyasn1_modules-0.4.0.tar.gz", hash = "sha256:831dbcea1b177b28c9baddf4c6d1013c24c3accd14a1873fffaa6a2e905f17b6"}, + {file = "pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd"}, + {file = "pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c"}, ] [package.dependencies] @@ -6221,119 +7036,120 @@ files = [ [[package]] name = "pydantic" -version = "2.8.2" +version = "2.9.2" description = "Data validation using Python type hints" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic-2.8.2-py3-none-any.whl", hash = "sha256:73ee9fddd406dc318b885c7a2eab8a6472b68b8fb5ba8150949fc3db939f23c8"}, - {file = "pydantic-2.8.2.tar.gz", hash = "sha256:6f62c13d067b0755ad1c21a34bdd06c0c12625a22b0fc09c6b149816604f7c2a"}, + {file = "pydantic-2.9.2-py3-none-any.whl", hash = "sha256:f048cec7b26778210e28a0459867920654d48e5e62db0958433636cde4254f12"}, + {file = "pydantic-2.9.2.tar.gz", hash = "sha256:d155cef71265d1e9807ed1c32b4c8deec042a44a50a4188b25ac67ecd81a9c0f"}, ] [package.dependencies] -annotated-types = ">=0.4.0" -pydantic-core = "2.20.1" +annotated-types = ">=0.6.0" +pydantic-core = "2.23.4" typing-extensions = {version = ">=4.6.1", markers = "python_version < \"3.13\""} [package.extras] email = ["email-validator (>=2.0.0)"] +timezone = ["tzdata"] [[package]] name = "pydantic-core" -version = "2.20.1" +version = "2.23.4" description = "Core functionality for Pydantic validation and serialization" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic_core-2.20.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3acae97ffd19bf091c72df4d726d552c473f3576409b2a7ca36b2f535ffff4a3"}, - {file = "pydantic_core-2.20.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:41f4c96227a67a013e7de5ff8f20fb496ce573893b7f4f2707d065907bffdbd6"}, - {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f239eb799a2081495ea659d8d4a43a8f42cd1fe9ff2e7e436295c38a10c286a"}, - {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:53e431da3fc53360db73eedf6f7124d1076e1b4ee4276b36fb25514544ceb4a3"}, - {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f1f62b2413c3a0e846c3b838b2ecd6c7a19ec6793b2a522745b0869e37ab5bc1"}, - {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d41e6daee2813ecceea8eda38062d69e280b39df793f5a942fa515b8ed67953"}, - {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d482efec8b7dc6bfaedc0f166b2ce349df0011f5d2f1f25537ced4cfc34fd98"}, - {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e93e1a4b4b33daed65d781a57a522ff153dcf748dee70b40c7258c5861e1768a"}, - {file = "pydantic_core-2.20.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e7c4ea22b6739b162c9ecaaa41d718dfad48a244909fe7ef4b54c0b530effc5a"}, - {file = "pydantic_core-2.20.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4f2790949cf385d985a31984907fecb3896999329103df4e4983a4a41e13e840"}, - {file = "pydantic_core-2.20.1-cp310-none-win32.whl", hash = "sha256:5e999ba8dd90e93d57410c5e67ebb67ffcaadcea0ad973240fdfd3a135506250"}, - {file = "pydantic_core-2.20.1-cp310-none-win_amd64.whl", hash = "sha256:512ecfbefef6dac7bc5eaaf46177b2de58cdf7acac8793fe033b24ece0b9566c"}, - {file = "pydantic_core-2.20.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d2a8fa9d6d6f891f3deec72f5cc668e6f66b188ab14bb1ab52422fe8e644f312"}, - {file = "pydantic_core-2.20.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:175873691124f3d0da55aeea1d90660a6ea7a3cfea137c38afa0a5ffabe37b88"}, - {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:37eee5b638f0e0dcd18d21f59b679686bbd18917b87db0193ae36f9c23c355fc"}, - {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:25e9185e2d06c16ee438ed39bf62935ec436474a6ac4f9358524220f1b236e43"}, - {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:150906b40ff188a3260cbee25380e7494ee85048584998c1e66df0c7a11c17a6"}, - {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ad4aeb3e9a97286573c03df758fc7627aecdd02f1da04516a86dc159bf70121"}, - {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3f3ed29cd9f978c604708511a1f9c2fdcb6c38b9aae36a51905b8811ee5cbf1"}, - {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b0dae11d8f5ded51699c74d9548dcc5938e0804cc8298ec0aa0da95c21fff57b"}, - {file = "pydantic_core-2.20.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:faa6b09ee09433b87992fb5a2859efd1c264ddc37280d2dd5db502126d0e7f27"}, - {file = "pydantic_core-2.20.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9dc1b507c12eb0481d071f3c1808f0529ad41dc415d0ca11f7ebfc666e66a18b"}, - {file = "pydantic_core-2.20.1-cp311-none-win32.whl", hash = "sha256:fa2fddcb7107e0d1808086ca306dcade7df60a13a6c347a7acf1ec139aa6789a"}, - {file = "pydantic_core-2.20.1-cp311-none-win_amd64.whl", hash = "sha256:40a783fb7ee353c50bd3853e626f15677ea527ae556429453685ae32280c19c2"}, - {file = "pydantic_core-2.20.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:595ba5be69b35777474fa07f80fc260ea71255656191adb22a8c53aba4479231"}, - {file = "pydantic_core-2.20.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a4f55095ad087474999ee28d3398bae183a66be4823f753cd7d67dd0153427c9"}, - {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f9aa05d09ecf4c75157197f27cdc9cfaeb7c5f15021c6373932bf3e124af029f"}, - {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e97fdf088d4b31ff4ba35db26d9cc472ac7ef4a2ff2badeabf8d727b3377fc52"}, - {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bc633a9fe1eb87e250b5c57d389cf28998e4292336926b0b6cdaee353f89a237"}, - {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d573faf8eb7e6b1cbbcb4f5b247c60ca8be39fe2c674495df0eb4318303137fe"}, - {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26dc97754b57d2fd00ac2b24dfa341abffc380b823211994c4efac7f13b9e90e"}, - {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:33499e85e739a4b60c9dac710c20a08dc73cb3240c9a0e22325e671b27b70d24"}, - {file = "pydantic_core-2.20.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:bebb4d6715c814597f85297c332297c6ce81e29436125ca59d1159b07f423eb1"}, - {file = "pydantic_core-2.20.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:516d9227919612425c8ef1c9b869bbbee249bc91912c8aaffb66116c0b447ebd"}, - {file = "pydantic_core-2.20.1-cp312-none-win32.whl", hash = "sha256:469f29f9093c9d834432034d33f5fe45699e664f12a13bf38c04967ce233d688"}, - {file = "pydantic_core-2.20.1-cp312-none-win_amd64.whl", hash = "sha256:035ede2e16da7281041f0e626459bcae33ed998cca6a0a007a5ebb73414ac72d"}, - {file = "pydantic_core-2.20.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:0827505a5c87e8aa285dc31e9ec7f4a17c81a813d45f70b1d9164e03a813a686"}, - {file = "pydantic_core-2.20.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:19c0fa39fa154e7e0b7f82f88ef85faa2a4c23cc65aae2f5aea625e3c13c735a"}, - {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4aa223cd1e36b642092c326d694d8bf59b71ddddc94cdb752bbbb1c5c91d833b"}, - {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c336a6d235522a62fef872c6295a42ecb0c4e1d0f1a3e500fe949415761b8a19"}, - {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7eb6a0587eded33aeefea9f916899d42b1799b7b14b8f8ff2753c0ac1741edac"}, - {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:70c8daf4faca8da5a6d655f9af86faf6ec2e1768f4b8b9d0226c02f3d6209703"}, - {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9fa4c9bf273ca41f940bceb86922a7667cd5bf90e95dbb157cbb8441008482c"}, - {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:11b71d67b4725e7e2a9f6e9c0ac1239bbc0c48cce3dc59f98635efc57d6dac83"}, - {file = "pydantic_core-2.20.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:270755f15174fb983890c49881e93f8f1b80f0b5e3a3cc1394a255706cabd203"}, - {file = "pydantic_core-2.20.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:c81131869240e3e568916ef4c307f8b99583efaa60a8112ef27a366eefba8ef0"}, - {file = "pydantic_core-2.20.1-cp313-none-win32.whl", hash = "sha256:b91ced227c41aa29c672814f50dbb05ec93536abf8f43cd14ec9521ea09afe4e"}, - {file = "pydantic_core-2.20.1-cp313-none-win_amd64.whl", hash = "sha256:65db0f2eefcaad1a3950f498aabb4875c8890438bc80b19362cf633b87a8ab20"}, - {file = "pydantic_core-2.20.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:4745f4ac52cc6686390c40eaa01d48b18997cb130833154801a442323cc78f91"}, - {file = "pydantic_core-2.20.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a8ad4c766d3f33ba8fd692f9aa297c9058970530a32c728a2c4bfd2616d3358b"}, - {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41e81317dd6a0127cabce83c0c9c3fbecceae981c8391e6f1dec88a77c8a569a"}, - {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:04024d270cf63f586ad41fff13fde4311c4fc13ea74676962c876d9577bcc78f"}, - {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eaad4ff2de1c3823fddf82f41121bdf453d922e9a238642b1dedb33c4e4f98ad"}, - {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:26ab812fa0c845df815e506be30337e2df27e88399b985d0bb4e3ecfe72df31c"}, - {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c5ebac750d9d5f2706654c638c041635c385596caf68f81342011ddfa1e5598"}, - {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2aafc5a503855ea5885559eae883978c9b6d8c8993d67766ee73d82e841300dd"}, - {file = "pydantic_core-2.20.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:4868f6bd7c9d98904b748a2653031fc9c2f85b6237009d475b1008bfaeb0a5aa"}, - {file = "pydantic_core-2.20.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aa2f457b4af386254372dfa78a2eda2563680d982422641a85f271c859df1987"}, - {file = "pydantic_core-2.20.1-cp38-none-win32.whl", hash = "sha256:225b67a1f6d602de0ce7f6c1c3ae89a4aa25d3de9be857999e9124f15dab486a"}, - {file = "pydantic_core-2.20.1-cp38-none-win_amd64.whl", hash = "sha256:6b507132dcfc0dea440cce23ee2182c0ce7aba7054576efc65634f080dbe9434"}, - {file = "pydantic_core-2.20.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:b03f7941783b4c4a26051846dea594628b38f6940a2fdc0df00b221aed39314c"}, - {file = "pydantic_core-2.20.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1eedfeb6089ed3fad42e81a67755846ad4dcc14d73698c120a82e4ccf0f1f9f6"}, - {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:635fee4e041ab9c479e31edda27fcf966ea9614fff1317e280d99eb3e5ab6fe2"}, - {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:77bf3ac639c1ff567ae3b47f8d4cc3dc20f9966a2a6dd2311dcc055d3d04fb8a"}, - {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7ed1b0132f24beeec5a78b67d9388656d03e6a7c837394f99257e2d55b461611"}, - {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c6514f963b023aeee506678a1cf821fe31159b925c4b76fe2afa94cc70b3222b"}, - {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10d4204d8ca33146e761c79f83cc861df20e7ae9f6487ca290a97702daf56006"}, - {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2d036c7187b9422ae5b262badb87a20a49eb6c5238b2004e96d4da1231badef1"}, - {file = "pydantic_core-2.20.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9ebfef07dbe1d93efb94b4700f2d278494e9162565a54f124c404a5656d7ff09"}, - {file = "pydantic_core-2.20.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6b9d9bb600328a1ce523ab4f454859e9d439150abb0906c5a1983c146580ebab"}, - {file = "pydantic_core-2.20.1-cp39-none-win32.whl", hash = "sha256:784c1214cb6dd1e3b15dd8b91b9a53852aed16671cc3fbe4786f4f1db07089e2"}, - {file = "pydantic_core-2.20.1-cp39-none-win_amd64.whl", hash = "sha256:d2fe69c5434391727efa54b47a1e7986bb0186e72a41b203df8f5b0a19a4f669"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a45f84b09ac9c3d35dfcf6a27fd0634d30d183205230a0ebe8373a0e8cfa0906"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d02a72df14dfdbaf228424573a07af10637bd490f0901cee872c4f434a735b94"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2b27e6af28f07e2f195552b37d7d66b150adbaa39a6d327766ffd695799780f"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:084659fac3c83fd674596612aeff6041a18402f1e1bc19ca39e417d554468482"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:242b8feb3c493ab78be289c034a1f659e8826e2233786e36f2893a950a719bb6"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:38cf1c40a921d05c5edc61a785c0ddb4bed67827069f535d794ce6bcded919fc"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:e0bbdd76ce9aa5d4209d65f2b27fc6e5ef1312ae6c5333c26db3f5ade53a1e99"}, - {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:254ec27fdb5b1ee60684f91683be95e5133c994cc54e86a0b0963afa25c8f8a6"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:407653af5617f0757261ae249d3fba09504d7a71ab36ac057c938572d1bc9331"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:c693e916709c2465b02ca0ad7b387c4f8423d1db7b4649c551f27a529181c5ad"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b5ff4911aea936a47d9376fd3ab17e970cc543d1b68921886e7f64bd28308d1"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:177f55a886d74f1808763976ac4efd29b7ed15c69f4d838bbd74d9d09cf6fa86"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:964faa8a861d2664f0c7ab0c181af0bea66098b1919439815ca8803ef136fc4e"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:4dd484681c15e6b9a977c785a345d3e378d72678fd5f1f3c0509608da24f2ac0"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f6d6cff3538391e8486a431569b77921adfcdef14eb18fbf19b7c0a5294d4e6a"}, - {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a6d511cc297ff0883bc3708b465ff82d7560193169a8b93260f74ecb0a5e08a7"}, - {file = "pydantic_core-2.20.1.tar.gz", hash = "sha256:26ca695eeee5f9f1aeeb211ffc12f10bcb6f71e2989988fda61dabd65db878d4"}, + {file = "pydantic_core-2.23.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b10bd51f823d891193d4717448fab065733958bdb6a6b351967bd349d48d5c9b"}, + {file = "pydantic_core-2.23.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4fc714bdbfb534f94034efaa6eadd74e5b93c8fa6315565a222f7b6f42ca1166"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63e46b3169866bd62849936de036f901a9356e36376079b05efa83caeaa02ceb"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed1a53de42fbe34853ba90513cea21673481cd81ed1be739f7f2efb931b24916"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cfdd16ab5e59fc31b5e906d1a3f666571abc367598e3e02c83403acabc092e07"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:255a8ef062cbf6674450e668482456abac99a5583bbafb73f9ad469540a3a232"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a7cd62e831afe623fbb7aabbb4fe583212115b3ef38a9f6b71869ba644624a2"}, + {file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f09e2ff1f17c2b51f2bc76d1cc33da96298f0a036a137f5440ab3ec5360b624f"}, + {file = "pydantic_core-2.23.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e38e63e6f3d1cec5a27e0afe90a085af8b6806ee208b33030e65b6516353f1a3"}, + {file = "pydantic_core-2.23.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0dbd8dbed2085ed23b5c04afa29d8fd2771674223135dc9bc937f3c09284d071"}, + {file = "pydantic_core-2.23.4-cp310-none-win32.whl", hash = "sha256:6531b7ca5f951d663c339002e91aaebda765ec7d61b7d1e3991051906ddde119"}, + {file = "pydantic_core-2.23.4-cp310-none-win_amd64.whl", hash = "sha256:7c9129eb40958b3d4500fa2467e6a83356b3b61bfff1b414c7361d9220f9ae8f"}, + {file = "pydantic_core-2.23.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:77733e3892bb0a7fa797826361ce8a9184d25c8dffaec60b7ffe928153680ba8"}, + {file = "pydantic_core-2.23.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1b84d168f6c48fabd1f2027a3d1bdfe62f92cade1fb273a5d68e621da0e44e6d"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df49e7a0861a8c36d089c1ed57d308623d60416dab2647a4a17fe050ba85de0e"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ff02b6d461a6de369f07ec15e465a88895f3223eb75073ffea56b84d9331f607"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:996a38a83508c54c78a5f41456b0103c30508fed9abcad0a59b876d7398f25fd"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d97683ddee4723ae8c95d1eddac7c192e8c552da0c73a925a89fa8649bf13eea"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:216f9b2d7713eb98cb83c80b9c794de1f6b7e3145eef40400c62e86cee5f4e1e"}, + {file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6f783e0ec4803c787bcea93e13e9932edab72068f68ecffdf86a99fd5918878b"}, + {file = "pydantic_core-2.23.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d0776dea117cf5272382634bd2a5c1b6eb16767c223c6a5317cd3e2a757c61a0"}, + {file = "pydantic_core-2.23.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d5f7a395a8cf1621939692dba2a6b6a830efa6b3cee787d82c7de1ad2930de64"}, + {file = "pydantic_core-2.23.4-cp311-none-win32.whl", hash = "sha256:74b9127ffea03643e998e0c5ad9bd3811d3dac8c676e47db17b0ee7c3c3bf35f"}, + {file = "pydantic_core-2.23.4-cp311-none-win_amd64.whl", hash = "sha256:98d134c954828488b153d88ba1f34e14259284f256180ce659e8d83e9c05eaa3"}, + {file = "pydantic_core-2.23.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f3e0da4ebaef65158d4dfd7d3678aad692f7666877df0002b8a522cdf088f231"}, + {file = "pydantic_core-2.23.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f69a8e0b033b747bb3e36a44e7732f0c99f7edd5cea723d45bc0d6e95377ffee"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:723314c1d51722ab28bfcd5240d858512ffd3116449c557a1336cbe3919beb87"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bb2802e667b7051a1bebbfe93684841cc9351004e2badbd6411bf357ab8d5ac8"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d18ca8148bebe1b0a382a27a8ee60350091a6ddaf475fa05ef50dc35b5df6327"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:33e3d65a85a2a4a0dc3b092b938a4062b1a05f3a9abde65ea93b233bca0e03f2"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:128585782e5bfa515c590ccee4b727fb76925dd04a98864182b22e89a4e6ed36"}, + {file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:68665f4c17edcceecc112dfed5dbe6f92261fb9d6054b47d01bf6371a6196126"}, + {file = "pydantic_core-2.23.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:20152074317d9bed6b7a95ade3b7d6054845d70584216160860425f4fbd5ee9e"}, + {file = "pydantic_core-2.23.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:9261d3ce84fa1d38ed649c3638feefeae23d32ba9182963e465d58d62203bd24"}, + {file = "pydantic_core-2.23.4-cp312-none-win32.whl", hash = "sha256:4ba762ed58e8d68657fc1281e9bb72e1c3e79cc5d464be146e260c541ec12d84"}, + {file = "pydantic_core-2.23.4-cp312-none-win_amd64.whl", hash = "sha256:97df63000f4fea395b2824da80e169731088656d1818a11b95f3b173747b6cd9"}, + {file = "pydantic_core-2.23.4-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:7530e201d10d7d14abce4fb54cfe5b94a0aefc87da539d0346a484ead376c3cc"}, + {file = "pydantic_core-2.23.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:df933278128ea1cd77772673c73954e53a1c95a4fdf41eef97c2b779271bd0bd"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0cb3da3fd1b6a5d0279a01877713dbda118a2a4fc6f0d821a57da2e464793f05"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:42c6dcb030aefb668a2b7009c85b27f90e51e6a3b4d5c9bc4c57631292015b0d"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:696dd8d674d6ce621ab9d45b205df149399e4bb9aa34102c970b721554828510"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2971bb5ffe72cc0f555c13e19b23c85b654dd2a8f7ab493c262071377bfce9f6"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8394d940e5d400d04cad4f75c0598665cbb81aecefaca82ca85bd28264af7f9b"}, + {file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0dff76e0602ca7d4cdaacc1ac4c005e0ce0dcfe095d5b5259163a80d3a10d327"}, + {file = "pydantic_core-2.23.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7d32706badfe136888bdea71c0def994644e09fff0bfe47441deaed8e96fdbc6"}, + {file = "pydantic_core-2.23.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ed541d70698978a20eb63d8c5d72f2cc6d7079d9d90f6b50bad07826f1320f5f"}, + {file = "pydantic_core-2.23.4-cp313-none-win32.whl", hash = "sha256:3d5639516376dce1940ea36edf408c554475369f5da2abd45d44621cb616f769"}, + {file = "pydantic_core-2.23.4-cp313-none-win_amd64.whl", hash = "sha256:5a1504ad17ba4210df3a045132a7baeeba5a200e930f57512ee02909fc5c4cb5"}, + {file = "pydantic_core-2.23.4-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:d4488a93b071c04dc20f5cecc3631fc78b9789dd72483ba15d423b5b3689b555"}, + {file = "pydantic_core-2.23.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:81965a16b675b35e1d09dd14df53f190f9129c0202356ed44ab2728b1c905658"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ffa2ebd4c8530079140dd2d7f794a9d9a73cbb8e9d59ffe24c63436efa8f271"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:61817945f2fe7d166e75fbfb28004034b48e44878177fc54d81688e7b85a3665"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:29d2c342c4bc01b88402d60189f3df065fb0dda3654744d5a165a5288a657368"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5e11661ce0fd30a6790e8bcdf263b9ec5988e95e63cf901972107efc49218b13"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d18368b137c6295db49ce7218b1a9ba15c5bc254c96d7c9f9e924a9bc7825ad"}, + {file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ec4e55f79b1c4ffb2eecd8a0cfba9955a2588497d96851f4c8f99aa4a1d39b12"}, + {file = "pydantic_core-2.23.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:374a5e5049eda9e0a44c696c7ade3ff355f06b1fe0bb945ea3cac2bc336478a2"}, + {file = "pydantic_core-2.23.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5c364564d17da23db1106787675fc7af45f2f7b58b4173bfdd105564e132e6fb"}, + {file = "pydantic_core-2.23.4-cp38-none-win32.whl", hash = "sha256:d7a80d21d613eec45e3d41eb22f8f94ddc758a6c4720842dc74c0581f54993d6"}, + {file = "pydantic_core-2.23.4-cp38-none-win_amd64.whl", hash = "sha256:5f5ff8d839f4566a474a969508fe1c5e59c31c80d9e140566f9a37bba7b8d556"}, + {file = "pydantic_core-2.23.4-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:a4fa4fc04dff799089689f4fd502ce7d59de529fc2f40a2c8836886c03e0175a"}, + {file = "pydantic_core-2.23.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0a7df63886be5e270da67e0966cf4afbae86069501d35c8c1b3b6c168f42cb36"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dcedcd19a557e182628afa1d553c3895a9f825b936415d0dbd3cd0bbcfd29b4b"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5f54b118ce5de9ac21c363d9b3caa6c800341e8c47a508787e5868c6b79c9323"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:86d2f57d3e1379a9525c5ab067b27dbb8a0642fb5d454e17a9ac434f9ce523e3"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:de6d1d1b9e5101508cb37ab0d972357cac5235f5c6533d1071964c47139257df"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1278e0d324f6908e872730c9102b0112477a7f7cf88b308e4fc36ce1bdb6d58c"}, + {file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9a6b5099eeec78827553827f4c6b8615978bb4b6a88e5d9b93eddf8bb6790f55"}, + {file = "pydantic_core-2.23.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:e55541f756f9b3ee346b840103f32779c695a19826a4c442b7954550a0972040"}, + {file = "pydantic_core-2.23.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a5c7ba8ffb6d6f8f2ab08743be203654bb1aaa8c9dcb09f82ddd34eadb695605"}, + {file = "pydantic_core-2.23.4-cp39-none-win32.whl", hash = "sha256:37b0fe330e4a58d3c58b24d91d1eb102aeec675a3db4c292ec3928ecd892a9a6"}, + {file = "pydantic_core-2.23.4-cp39-none-win_amd64.whl", hash = "sha256:1498bec4c05c9c787bde9125cfdcc63a41004ff167f495063191b863399b1a29"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f455ee30a9d61d3e1a15abd5068827773d6e4dc513e795f380cdd59932c782d5"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1e90d2e3bd2c3863d48525d297cd143fe541be8bbf6f579504b9712cb6b643ec"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e203fdf807ac7e12ab59ca2bfcabb38c7cf0b33c41efeb00f8e5da1d86af480"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e08277a400de01bc72436a0ccd02bdf596631411f592ad985dcee21445bd0068"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f220b0eea5965dec25480b6333c788fb72ce5f9129e8759ef876a1d805d00801"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d06b0c8da4f16d1d1e352134427cb194a0a6e19ad5db9161bf32b2113409e728"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:ba1a0996f6c2773bd83e63f18914c1de3c9dd26d55f4ac302a7efe93fb8e7433"}, + {file = "pydantic_core-2.23.4-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:9a5bce9d23aac8f0cf0836ecfc033896aa8443b501c58d0602dbfd5bd5b37753"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:78ddaaa81421a29574a682b3179d4cf9e6d405a09b99d93ddcf7e5239c742e21"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:883a91b5dd7d26492ff2f04f40fbb652de40fcc0afe07e8129e8ae779c2110eb"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88ad334a15b32a791ea935af224b9de1bf99bcd62fabf745d5f3442199d86d59"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:233710f069d251feb12a56da21e14cca67994eab08362207785cf8c598e74577"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:19442362866a753485ba5e4be408964644dd6a09123d9416c54cd49171f50744"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:624e278a7d29b6445e4e813af92af37820fafb6dcc55c012c834f9e26f9aaaef"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f5ef8f42bec47f21d07668a043f077d507e5bf4e668d5c6dfe6aaba89de1a5b8"}, + {file = "pydantic_core-2.23.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:aea443fffa9fbe3af1a9ba721a87f926fe548d32cab71d188a6ede77d0ff244e"}, + {file = "pydantic_core-2.23.4.tar.gz", hash = "sha256:2584f7cf844ac4d970fba483a717dbe10c1c1c96a969bf65d61ffe94df1b2863"}, ] [package.dependencies] @@ -6363,13 +7179,13 @@ semver = ["semver (>=3.0.2)"] [[package]] name = "pydantic-settings" -version = "2.3.4" +version = "2.6.0" description = "Settings management using Pydantic" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic_settings-2.3.4-py3-none-any.whl", hash = "sha256:11ad8bacb68a045f00e4f862c7a718c8a9ec766aa8fd4c32e39a0594b207b53a"}, - {file = "pydantic_settings-2.3.4.tar.gz", hash = "sha256:c5802e3d62b78e82522319bbc9b8f8ffb28ad1c988a99311d04f2a6051fca0a7"}, + {file = "pydantic_settings-2.6.0-py3-none-any.whl", hash = "sha256:4a819166f119b74d7f8c765196b165f95cc7487ce58ea27dec8a5a26be0970e0"}, + {file = "pydantic_settings-2.6.0.tar.gz", hash = "sha256:44a1804abffac9e6a30372bb45f6cafab945ef5af25e66b1c634c01dd39e0188"}, ] [package.dependencies] @@ -6377,9 +7193,38 @@ pydantic = ">=2.7.0" python-dotenv = ">=0.21.0" [package.extras] +azure-key-vault = ["azure-identity (>=1.16.0)", "azure-keyvault-secrets (>=4.8.0)"] toml = ["tomli (>=2.0.1)"] yaml = ["pyyaml (>=6.0.1)"] +[[package]] +name = "pydash" +version = "8.0.3" +description = "The kitchen sink of Python utility libraries for doing \"stuff\" in a functional way. Based on the Lo-Dash Javascript library." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydash-8.0.3-py3-none-any.whl", hash = "sha256:c16871476822ee6b59b87e206dd27888240eff50a7b4cd72a4b80b43b6b994d7"}, + {file = "pydash-8.0.3.tar.gz", hash = "sha256:1b27cd3da05b72f0e5ff786c523afd82af796936462e631ffd1b228d91f8b9aa"}, +] + +[package.dependencies] +typing-extensions = ">3.10,<4.6.0 || >4.6.0" + +[package.extras] +dev = ["build", "coverage", "furo", "invoke", "mypy", "pytest", "pytest-cov", "pytest-mypy-testing", "ruff", "sphinx", "sphinx-autodoc-typehints", "tox", "twine", "wheel"] + +[[package]] +name = "pydub" +version = "0.25.1" +description = "Manipulate audio with an simple and easy high level interface" +optional = false +python-versions = "*" +files = [ + {file = "pydub-0.25.1-py2.py3-none-any.whl", hash = "sha256:65617e33033874b59d87db603aa1ed450633288aefead953b30bded59cb599a6"}, + {file = "pydub-0.25.1.tar.gz", hash = "sha256:980a33ce9949cab2a569606b65674d748ecbca4f0796887fd6f46173a7b0d30f"}, +] + [[package]] name = "pygments" version = "2.18.0" @@ -6416,22 +7261,22 @@ tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] [[package]] name = "pymilvus" -version = "2.4.4" +version = "2.4.9" description = "Python Sdk for Milvus" optional = false python-versions = ">=3.8" files = [ - {file = "pymilvus-2.4.4-py3-none-any.whl", hash = "sha256:073b76bc36f6f4e70f0f0a0023a53324f0ba8ef9a60883f87cd30a44b6c6f2b5"}, - {file = "pymilvus-2.4.4.tar.gz", hash = "sha256:50c53eb103e034fbffe936fe942751ea3dbd2452e18cf79acc52360ed4987fb7"}, + {file = "pymilvus-2.4.9-py3-none-any.whl", hash = "sha256:45313607d2c164064bdc44e0f933cb6d6afa92e9efcc7f357c5240c57db58fbe"}, + {file = "pymilvus-2.4.9.tar.gz", hash = "sha256:0937663700007c23a84cfc0656160b301f6ff9247aaec4c96d599a6b43572136"}, ] [package.dependencies] environs = "<=9.5.0" -grpcio = ">=1.49.1,<=1.63.0" +grpcio = ">=1.49.1" milvus-lite = {version = ">=2.4.0,<2.5.0", markers = "sys_platform != \"win32\""} pandas = ">=1.2.4" protobuf = ">=3.20.0" -setuptools = ">=67" +setuptools = ">69" ujson = ">=2.0.0" [package.extras] @@ -6439,6 +7284,22 @@ bulk-writer = ["azure-storage-blob", "minio (>=7.0.0)", "pyarrow (>=12.0.0)", "r dev = ["black", "grpcio (==1.62.2)", "grpcio-testing (==1.62.2)", "grpcio-tools (==1.62.2)", "pytest (>=5.3.4)", "pytest-cov (>=2.8.1)", "pytest-timeout (>=1.3.4)", "ruff (>0.4.0)"] model = ["milvus-model (>=0.1.0)"] +[[package]] +name = "pymochow" +version = "1.3.1" +description = "Python SDK for mochow" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pymochow-1.3.1-py3-none-any.whl", hash = "sha256:a7f3b34fd6ea5d1d8413650bb6678365aa148fc396ae945e4ccb4f2365a52327"}, + {file = "pymochow-1.3.1.tar.gz", hash = "sha256:1693d10cd0bb7bce45327890a90adafb503155922ccc029acb257699a73a20ba"}, +] + +[package.dependencies] +future = "*" +orjson = "*" +requests = "*" + [[package]] name = "pymysql" version = "1.1.1" @@ -6454,30 +7315,86 @@ files = [ ed25519 = ["PyNaCl (>=1.4.0)"] rsa = ["cryptography"] +[[package]] +name = "pyobvector" +version = "0.1.6" +description = "A python SDK for OceanBase Vector Store, based on SQLAlchemy, compatible with Milvus API." +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "pyobvector-0.1.6-py3-none-any.whl", hash = "sha256:0d700e865a85b4716b9a03384189e49288cd9d5f3cef88aed4740bc82d5fd136"}, + {file = "pyobvector-0.1.6.tar.gz", hash = "sha256:05551addcac8c596992d5e38b480c83ca3481c6cfc6f56a1a1bddfb2e6ae037e"}, +] + +[package.dependencies] +numpy = ">=1.26.0,<2.0.0" +pymysql = ">=1.1.1,<2.0.0" +sqlalchemy = ">=2.0.32,<3.0.0" + +[[package]] +name = "pyopenssl" +version = "24.2.1" +description = "Python wrapper module around the OpenSSL library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pyOpenSSL-24.2.1-py3-none-any.whl", hash = "sha256:967d5719b12b243588573f39b0c677637145c7a1ffedcd495a487e58177fbb8d"}, + {file = "pyopenssl-24.2.1.tar.gz", hash = "sha256:4247f0dbe3748d560dcbb2ff3ea01af0f9a1a001ef5f7c4c647956ed8cbf0e95"}, +] + +[package.dependencies] +cryptography = ">=41.0.5,<44" + +[package.extras] +docs = ["sphinx (!=5.2.0,!=5.2.0.post0,!=7.2.5)", "sphinx-rtd-theme"] +test = ["pretend", "pytest (>=3.0.1)", "pytest-rerunfailures"] + [[package]] name = "pypandoc" -version = "1.13" +version = "1.14" description = "Thin wrapper for pandoc." optional = false python-versions = ">=3.6" files = [ - {file = "pypandoc-1.13-py3-none-any.whl", hash = "sha256:4c7d71bf2f1ed122aac287113b5c4d537a33bbc3c1df5aed11a7d4a7ac074681"}, - {file = "pypandoc-1.13.tar.gz", hash = "sha256:31652073c7960c2b03570bd1e94f602ca9bc3e70099df5ead4cea98ff5151c1e"}, + {file = "pypandoc-1.14-py3-none-any.whl", hash = "sha256:1315c7ad7fac7236dacf69a05b521ed2c3f1d0177f70e9b92bfffce6c023df22"}, + {file = "pypandoc-1.14.tar.gz", hash = "sha256:6b4c45f5f1b9fb5bb562079164806bdbbc3e837b5402bcf3f1139edc5730a197"}, +] + +[[package]] +name = "pyparsing" +version = "3.2.0" +description = "pyparsing module - Classes and methods to define and execute parsing grammars" +optional = false +python-versions = ">=3.9" +files = [ + {file = "pyparsing-3.2.0-py3-none-any.whl", hash = "sha256:93d9577b88da0bbea8cc8334ee8b918ed014968fd2ec383e868fb8afb1ccef84"}, + {file = "pyparsing-3.2.0.tar.gz", hash = "sha256:cbf74e27246d595d9a74b186b810f6fbb86726dbf3b9532efb343f6d7294fe9c"}, ] +[package.extras] +diagrams = ["jinja2", "railroad-diagrams"] + [[package]] -name = "pyparsing" -version = "3.1.2" -description = "pyparsing module - Classes and methods to define and execute parsing grammars" +name = "pypdf" +version = "5.1.0" +description = "A pure-python PDF library capable of splitting, merging, cropping, and transforming PDF files" optional = false -python-versions = ">=3.6.8" +python-versions = ">=3.8" files = [ - {file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"}, - {file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"}, + {file = "pypdf-5.1.0-py3-none-any.whl", hash = "sha256:3bd4f503f4ebc58bae40d81e81a9176c400cbbac2ba2d877367595fb524dfdfc"}, + {file = "pypdf-5.1.0.tar.gz", hash = "sha256:425a129abb1614183fd1aca6982f650b47f8026867c0ce7c4b9f281c443d2740"}, ] +[package.dependencies] +typing_extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} + [package.extras] -diagrams = ["jinja2", "railroad-diagrams"] +crypto = ["cryptography"] +cryptodome = ["PyCryptodome"] +dev = ["black", "flit", "pip-tools", "pre-commit (<2.18.0)", "pytest-cov", "pytest-socket", "pytest-timeout", "pytest-xdist", "wheel"] +docs = ["myst_parser", "sphinx", "sphinx_rtd_theme"] +full = ["Pillow (>=8.0.0)", "cryptography"] +image = ["Pillow (>=8.0.0)"] [[package]] name = "pypdfium2" @@ -6523,35 +7440,38 @@ files = [ [[package]] name = "pyproject-hooks" -version = "1.1.0" +version = "1.2.0" description = "Wrappers to call pyproject.toml-based build backend hooks." optional = false python-versions = ">=3.7" files = [ - {file = "pyproject_hooks-1.1.0-py3-none-any.whl", hash = "sha256:7ceeefe9aec63a1064c18d939bdc3adf2d8aa1988a510afec15151578b232aa2"}, - {file = "pyproject_hooks-1.1.0.tar.gz", hash = "sha256:4b37730834edbd6bd37f26ece6b44802fb1c1ee2ece0e54ddff8bfc06db86965"}, + {file = "pyproject_hooks-1.2.0-py3-none-any.whl", hash = "sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913"}, + {file = "pyproject_hooks-1.2.0.tar.gz", hash = "sha256:1e859bd5c40fae9448642dd871adf459e5e2084186e8d2c2a79a824c970da1f8"}, ] [[package]] name = "pyreadline3" -version = "3.4.1" +version = "3.5.4" description = "A python implementation of GNU readline." optional = false -python-versions = "*" +python-versions = ">=3.8" files = [ - {file = "pyreadline3-3.4.1-py3-none-any.whl", hash = "sha256:b0efb6516fd4fb07b45949053826a62fa4cb353db5be2bbb4a7aa1fdd1e345fb"}, - {file = "pyreadline3-3.4.1.tar.gz", hash = "sha256:6f3d1f7b8a31ba32b73917cefc1f28cc660562f39aea8646d30bd6eff21f7bae"}, + {file = "pyreadline3-3.5.4-py3-none-any.whl", hash = "sha256:eaf8e6cc3c49bcccf145fc6067ba8643d1df34d604a1ec0eccbf7a18e6d3fae6"}, + {file = "pyreadline3-3.5.4.tar.gz", hash = "sha256:8d57d53039a1c75adba8e50dd3d992b28143480816187ea5efbd5c78e6c885b7"}, ] +[package.extras] +dev = ["build", "flake8", "mypy", "pytest", "twine"] + [[package]] name = "pytest" -version = "8.1.2" +version = "8.3.3" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.8" files = [ - {file = "pytest-8.1.2-py3-none-any.whl", hash = "sha256:6c06dc309ff46a05721e6fd48e492a775ed8165d2ecdf57f156a80c7e95bb142"}, - {file = "pytest-8.1.2.tar.gz", hash = "sha256:f3c45d1d5eed96b01a2aea70dee6a4a366d51d38f9957768083e4fecfc77f3ef"}, + {file = "pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2"}, + {file = "pytest-8.3.3.tar.gz", hash = "sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181"}, ] [package.dependencies] @@ -6559,11 +7479,11 @@ colorama = {version = "*", markers = "sys_platform == \"win32\""} exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} iniconfig = "*" packaging = "*" -pluggy = ">=1.4,<2.0" +pluggy = ">=1.5,<2" tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] -testing = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] [[package]] name = "pytest-benchmark" @@ -6587,21 +7507,21 @@ histogram = ["pygal", "pygaljs"] [[package]] name = "pytest-env" -version = "1.1.3" +version = "1.1.5" description = "pytest plugin that allows you to add environment variables." optional = false python-versions = ">=3.8" files = [ - {file = "pytest_env-1.1.3-py3-none-any.whl", hash = "sha256:aada77e6d09fcfb04540a6e462c58533c37df35fa853da78707b17ec04d17dfc"}, - {file = "pytest_env-1.1.3.tar.gz", hash = "sha256:fcd7dc23bb71efd3d35632bde1bbe5ee8c8dc4489d6617fb010674880d96216b"}, + {file = "pytest_env-1.1.5-py3-none-any.whl", hash = "sha256:ce90cf8772878515c24b31cd97c7fa1f4481cd68d588419fd45f10ecaee6bc30"}, + {file = "pytest_env-1.1.5.tar.gz", hash = "sha256:91209840aa0e43385073ac464a554ad2947cc2fd663a9debf88d03b01e0cc1cf"}, ] [package.dependencies] -pytest = ">=7.4.3" +pytest = ">=8.3.3" tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""} [package.extras] -test = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "pytest-mock (>=3.12)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "pytest-mock (>=3.14)"] [[package]] name = "pytest-mock" @@ -6731,13 +7651,13 @@ files = [ [[package]] name = "python-dateutil" -version = "2.9.0.post0" +version = "2.8.2" description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ - {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, - {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, + {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, + {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, ] [package.dependencies] @@ -6774,17 +7694,17 @@ cli = ["click (>=5.0)"] [[package]] name = "python-iso639" -version = "2024.4.27" +version = "2024.10.22" description = "ISO 639 language codes, names, and other associated information" optional = false python-versions = ">=3.8" files = [ - {file = "python_iso639-2024.4.27-py3-none-any.whl", hash = "sha256:27526a84cebc4c4d53fea9d1ebbc7209c8d279bebaa343e6765a1fc8780565ab"}, - {file = "python_iso639-2024.4.27.tar.gz", hash = "sha256:97e63b5603e085c6a56a12a95740010e75d9134e0aab767e0978b53fd8824f13"}, + {file = "python_iso639-2024.10.22-py3-none-any.whl", hash = "sha256:02d3ce2e01c6896b30b9cbbd3e1c8ee0d7221250b5d63ea9803e0d2a81fd1047"}, + {file = "python_iso639-2024.10.22.tar.gz", hash = "sha256:750f21b6a0bc6baa24253a3d8aae92b582bf93aa40988361cd96852c2c6d9a52"}, ] [package.extras] -dev = ["black (==24.4.2)", "build (==1.2.1)", "flake8 (==7.0.0)", "pytest (==8.1.2)", "requests (==2.31.0)", "twine (==5.0.0)"] +dev = ["black (==24.10.0)", "build (==1.2.1)", "flake8 (==7.1.1)", "pytest (==8.3.3)", "requests (==2.32.3)", "twine (==5.1.1)"] [[package]] name = "python-magic" @@ -6797,54 +7717,75 @@ files = [ {file = "python_magic-0.4.27-py2.py3-none-any.whl", hash = "sha256:c212960ad306f700aa0d01e5d7a325d20548ff97eb9920dcd29513174f0294d3"}, ] +[[package]] +name = "python-oxmsg" +version = "0.0.1" +description = "Extract attachments from Outlook .msg files." +optional = false +python-versions = ">=3.9" +files = [ + {file = "python_oxmsg-0.0.1-py3-none-any.whl", hash = "sha256:8ea7d5dda1bc161a413213da9e18ed152927c1fda2feaf5d1f02192d8ad45eea"}, + {file = "python_oxmsg-0.0.1.tar.gz", hash = "sha256:b65c1f93d688b85a9410afa824192a1ddc39da359b04a0bd2cbd3874e84d4994"}, +] + +[package.dependencies] +click = "*" +olefile = "*" +typing-extensions = ">=4.9.0" + [[package]] name = "python-pptx" -version = "0.6.23" -description = "Generate and manipulate Open XML PowerPoint (.pptx) files" +version = "1.0.2" +description = "Create, read, and update PowerPoint 2007+ (.pptx) files." optional = false -python-versions = "*" +python-versions = ">=3.8" files = [ - {file = "python-pptx-0.6.23.tar.gz", hash = "sha256:587497ff28e779ab18dbb074f6d4052893c85dedc95ed75df319364f331fedee"}, - {file = "python_pptx-0.6.23-py3-none-any.whl", hash = "sha256:dd0527194627a2b7cc05f3ba23ecaa2d9a0d5ac9b6193a28ed1b7a716f4217d4"}, + {file = "python_pptx-1.0.2-py3-none-any.whl", hash = "sha256:160838e0b8565a8b1f67947675886e9fea18aa5e795db7ae531606d68e785cba"}, + {file = "python_pptx-1.0.2.tar.gz", hash = "sha256:479a8af0eaf0f0d76b6f00b0887732874ad2e3188230315290cd1f9dd9cc7095"}, ] [package.dependencies] lxml = ">=3.1.0" Pillow = ">=3.3.2" +typing-extensions = ">=4.9.0" XlsxWriter = ">=0.5.7" [[package]] name = "pytz" -version = "2024.1" +version = "2024.2" description = "World timezone definitions, modern and historical" optional = false python-versions = "*" files = [ - {file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"}, - {file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"}, + {file = "pytz-2024.2-py2.py3-none-any.whl", hash = "sha256:31c7c1817eb7fae7ca4b8c7ee50c72f93aa2dd863de768e1ef4245d426aa0725"}, + {file = "pytz-2024.2.tar.gz", hash = "sha256:2aa355083c50a0f93fa581709deac0c9ad65cca8a9e9beac660adcbd493c798a"}, ] [[package]] name = "pywin32" -version = "306" +version = "308" description = "Python for Window Extensions" optional = false python-versions = "*" files = [ - {file = "pywin32-306-cp310-cp310-win32.whl", hash = "sha256:06d3420a5155ba65f0b72f2699b5bacf3109f36acbe8923765c22938a69dfc8d"}, - {file = "pywin32-306-cp310-cp310-win_amd64.whl", hash = "sha256:84f4471dbca1887ea3803d8848a1616429ac94a4a8d05f4bc9c5dcfd42ca99c8"}, - {file = "pywin32-306-cp311-cp311-win32.whl", hash = "sha256:e65028133d15b64d2ed8f06dd9fbc268352478d4f9289e69c190ecd6818b6407"}, - {file = "pywin32-306-cp311-cp311-win_amd64.whl", hash = "sha256:a7639f51c184c0272e93f244eb24dafca9b1855707d94c192d4a0b4c01e1100e"}, - {file = "pywin32-306-cp311-cp311-win_arm64.whl", hash = "sha256:70dba0c913d19f942a2db25217d9a1b726c278f483a919f1abfed79c9cf64d3a"}, - {file = "pywin32-306-cp312-cp312-win32.whl", hash = "sha256:383229d515657f4e3ed1343da8be101000562bf514591ff383ae940cad65458b"}, - {file = "pywin32-306-cp312-cp312-win_amd64.whl", hash = "sha256:37257794c1ad39ee9be652da0462dc2e394c8159dfd913a8a4e8eb6fd346da0e"}, - {file = "pywin32-306-cp312-cp312-win_arm64.whl", hash = "sha256:5821ec52f6d321aa59e2db7e0a35b997de60c201943557d108af9d4ae1ec7040"}, - {file = "pywin32-306-cp37-cp37m-win32.whl", hash = "sha256:1c73ea9a0d2283d889001998059f5eaaba3b6238f767c9cf2833b13e6a685f65"}, - {file = "pywin32-306-cp37-cp37m-win_amd64.whl", hash = "sha256:72c5f621542d7bdd4fdb716227be0dd3f8565c11b280be6315b06ace35487d36"}, - {file = "pywin32-306-cp38-cp38-win32.whl", hash = "sha256:e4c092e2589b5cf0d365849e73e02c391c1349958c5ac3e9d5ccb9a28e017b3a"}, - {file = "pywin32-306-cp38-cp38-win_amd64.whl", hash = "sha256:e8ac1ae3601bee6ca9f7cb4b5363bf1c0badb935ef243c4733ff9a393b1690c0"}, - {file = "pywin32-306-cp39-cp39-win32.whl", hash = "sha256:e25fd5b485b55ac9c057f67d94bc203f3f6595078d1fb3b458c9c28b7153a802"}, - {file = "pywin32-306-cp39-cp39-win_amd64.whl", hash = "sha256:39b61c15272833b5c329a2989999dcae836b1eed650252ab1b7bfbe1d59f30f4"}, + {file = "pywin32-308-cp310-cp310-win32.whl", hash = "sha256:796ff4426437896550d2981b9c2ac0ffd75238ad9ea2d3bfa67a1abd546d262e"}, + {file = "pywin32-308-cp310-cp310-win_amd64.whl", hash = "sha256:4fc888c59b3c0bef905ce7eb7e2106a07712015ea1c8234b703a088d46110e8e"}, + {file = "pywin32-308-cp310-cp310-win_arm64.whl", hash = "sha256:a5ab5381813b40f264fa3495b98af850098f814a25a63589a8e9eb12560f450c"}, + {file = "pywin32-308-cp311-cp311-win32.whl", hash = "sha256:5d8c8015b24a7d6855b1550d8e660d8daa09983c80e5daf89a273e5c6fb5095a"}, + {file = "pywin32-308-cp311-cp311-win_amd64.whl", hash = "sha256:575621b90f0dc2695fec346b2d6302faebd4f0f45c05ea29404cefe35d89442b"}, + {file = "pywin32-308-cp311-cp311-win_arm64.whl", hash = "sha256:100a5442b7332070983c4cd03f2e906a5648a5104b8a7f50175f7906efd16bb6"}, + {file = "pywin32-308-cp312-cp312-win32.whl", hash = "sha256:587f3e19696f4bf96fde9d8a57cec74a57021ad5f204c9e627e15c33ff568897"}, + {file = "pywin32-308-cp312-cp312-win_amd64.whl", hash = "sha256:00b3e11ef09ede56c6a43c71f2d31857cf7c54b0ab6e78ac659497abd2834f47"}, + {file = "pywin32-308-cp312-cp312-win_arm64.whl", hash = "sha256:9b4de86c8d909aed15b7011182c8cab38c8850de36e6afb1f0db22b8959e3091"}, + {file = "pywin32-308-cp313-cp313-win32.whl", hash = "sha256:1c44539a37a5b7b21d02ab34e6a4d314e0788f1690d65b48e9b0b89f31abbbed"}, + {file = "pywin32-308-cp313-cp313-win_amd64.whl", hash = "sha256:fd380990e792eaf6827fcb7e187b2b4b1cede0585e3d0c9e84201ec27b9905e4"}, + {file = "pywin32-308-cp313-cp313-win_arm64.whl", hash = "sha256:ef313c46d4c18dfb82a2431e3051ac8f112ccee1a34f29c263c583c568db63cd"}, + {file = "pywin32-308-cp37-cp37m-win32.whl", hash = "sha256:1f696ab352a2ddd63bd07430080dd598e6369152ea13a25ebcdd2f503a38f1ff"}, + {file = "pywin32-308-cp37-cp37m-win_amd64.whl", hash = "sha256:13dcb914ed4347019fbec6697a01a0aec61019c1046c2b905410d197856326a6"}, + {file = "pywin32-308-cp38-cp38-win32.whl", hash = "sha256:5794e764ebcabf4ff08c555b31bd348c9025929371763b2183172ff4708152f0"}, + {file = "pywin32-308-cp38-cp38-win_amd64.whl", hash = "sha256:3b92622e29d651c6b783e368ba7d6722b1634b8e70bd376fd7610fe1992e19de"}, + {file = "pywin32-308-cp39-cp39-win32.whl", hash = "sha256:7873ca4dc60ab3287919881a7d4f88baee4a6e639aa6962de25a98ba6b193341"}, + {file = "pywin32-308-cp39-cp39-win_amd64.whl", hash = "sha256:71b3322d949b4cc20776436a9c9ba0eeedcbc9c650daa536df63f0ff111bb920"}, ] [[package]] @@ -6988,123 +7929,103 @@ dev = ["pytest"] [[package]] name = "rapidfuzz" -version = "3.9.6" +version = "3.10.1" description = "rapid fuzzy string matching" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "rapidfuzz-3.9.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a7ed0d0b9c85720f0ae33ac5efc8dc3f60c1489dad5c29d735fbdf2f66f0431f"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f3deff6ab7017ed21b9aec5874a07ad13e6b2a688af055837f88b743c7bfd947"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3f9fc060160507b2704f7d1491bd58453d69689b580cbc85289335b14fe8ca"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c4e86c2b3827fa6169ad6e7d4b790ce02a20acefb8b78d92fa4249589bbc7a2c"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f982e1aafb4bd8207a5e073b1efef9e68a984e91330e1bbf364f9ed157ed83f0"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9196a51d0ec5eaaaf5bca54a85b7b1e666fc944c332f68e6427503af9fb8c49e"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb5a514064e02585b1cc09da2fe406a6dc1a7e5f3e92dd4f27c53e5f1465ec81"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e3a4244f65dbc3580b1275480118c3763f9dc29fc3dd96610560cb5e140a4d4a"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:f6ebb910a702e41641e1e1dada3843bc11ba9107a33c98daef6945a885a40a07"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:624fbe96115fb39addafa288d583b5493bc76dab1d34d0ebba9987d6871afdf9"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:1c59f1c1507b7a557cf3c410c76e91f097460da7d97e51c985343798e9df7a3c"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f6f0256cb27b6a0fb2e1918477d1b56473cd04acfa245376a342e7c15806a396"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-win32.whl", hash = "sha256:24d473d00d23a30a85802b502b417a7f5126019c3beec91a6739fe7b95388b24"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-win_amd64.whl", hash = "sha256:248f6d2612e661e2b5f9a22bbd5862a1600e720da7bb6ad8a55bb1548cdfa423"}, - {file = "rapidfuzz-3.9.6-cp310-cp310-win_arm64.whl", hash = "sha256:e03fdf0e74f346ed7e798135df5f2a0fb8d6b96582b00ebef202dcf2171e1d1d"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:52e4675f642fbc85632f691b67115a243cd4d2a47bdcc4a3d9a79e784518ff97"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1f93a2f13038700bd245b927c46a2017db3dcd4d4ff94687d74b5123689b873b"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42b70500bca460264b8141d8040caee22e9cf0418c5388104ff0c73fb69ee28f"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a1e037fb89f714a220f68f902fc6300ab7a33349f3ce8ffae668c3b3a40b0b06"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6792f66d59b86ccfad5e247f2912e255c85c575789acdbad8e7f561412ffed8a"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:68d9cffe710b67f1969cf996983608cee4490521d96ea91d16bd7ea5dc80ea98"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63daaeeea76da17fa0bbe7fb05cba8ed8064bb1a0edf8360636557f8b6511961"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d214e063bffa13e3b771520b74f674b22d309b5720d4df9918ff3e0c0f037720"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:ed443a2062460f44c0346cb9d269b586496b808c2419bbd6057f54061c9b9c75"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:5b0c9b227ee0076fb2d58301c505bb837a290ae99ee628beacdb719f0626d749"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:82c9722b7dfaa71e8b61f8c89fed0482567fb69178e139fe4151fc71ed7df782"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c18897c95c0a288347e29537b63608a8f63a5c3cb6da258ac46fcf89155e723e"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-win32.whl", hash = "sha256:3e910cf08944da381159587709daaad9e59d8ff7bca1f788d15928f3c3d49c2a"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-win_amd64.whl", hash = "sha256:59c4a61fab676d37329fc3a671618a461bfeef53a4d0b8b12e3bc24a14e166f8"}, - {file = "rapidfuzz-3.9.6-cp311-cp311-win_arm64.whl", hash = "sha256:8b4afea244102332973377fddbe54ce844d0916e1c67a5123432291717f32ffa"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:70591b28b218fff351b88cdd7f2359a01a71f9f7f5a2e465ce3715ed4b3c422b"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ee2d8355c7343c631a03e57540ea06e8717c19ecf5ff64ea07e0498f7f161457"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:708fb675de0f47b9635d1cc6fbbf80d52cb710d0a1abbfae5c84c46e3abbddc3"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1d66c247c2d3bb7a9b60567c395a15a929d0ebcc5f4ceedb55bfa202c38c6e0c"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:15146301b32e6e3d2b7e8146db1a26747919d8b13690c7f83a4cb5dc111b3a08"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7a03da59b6c7c97e657dd5cd4bcaab5fe4a2affd8193958d6f4d938bee36679"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d2c2fe19e392dbc22695b6c3b2510527e2b774647e79936bbde49db7742d6f1"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:91aaee4c94cb45930684f583ffc4e7c01a52b46610971cede33586cf8a04a12e"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3f5702828c10768f9281180a7ff8597da1e5002803e1304e9519dd0f06d79a85"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ccd1763b608fb4629a0b08f00b3c099d6395e67c14e619f6341b2c8429c2f310"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:cc7a0d4b2cb166bc46d02c8c9f7551cde8e2f3c9789df3827309433ee9771163"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7496f53d40560a58964207b52586783633f371683834a8f719d6d965d223a2eb"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-win32.whl", hash = "sha256:5eb1a9272ca71bc72be5415c2fa8448a6302ea4578e181bb7da9db855b367df0"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-win_amd64.whl", hash = "sha256:0d21fc3c0ca507a1180152a6dbd129ebaef48facde3f943db5c1055b6e6be56a"}, - {file = "rapidfuzz-3.9.6-cp312-cp312-win_arm64.whl", hash = "sha256:43bb27a57c29dc5fa754496ba6a1a508480d21ae99ac0d19597646c16407e9f3"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:83a5ac6547a9d6eedaa212975cb8f2ce2aa07e6e30833b40e54a52b9f9999aa4"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:10f06139142ecde67078ebc9a745965446132b998f9feebffd71acdf218acfcc"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:74720c3f24597f76c7c3e2c4abdff55f1664f4766ff5b28aeaa689f8ffba5fab"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce2bce52b5c150878e558a0418c2b637fb3dbb6eb38e4eb27d24aa839920483e"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1611199f178793ca9a060c99b284e11f6d7d124998191f1cace9a0245334d219"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0308b2ad161daf502908a6e21a57c78ded0258eba9a8f5e2545e2dafca312507"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3eda91832201b86e3b70835f91522587725bec329ec68f2f7faf5124091e5ca7"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ece873c093aedd87fc07c2a7e333d52e458dc177016afa1edaf157e82b6914d8"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:d97d3c9d209d5c30172baea5966f2129e8a198fec4a1aeb2f92abb6e82a2edb1"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:6c4550d0db4931f5ebe9f0678916d1b06f06f5a99ba0b8a48b9457fd8959a7d4"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:b6b8dd4af6324fc325d9483bec75ecf9be33e590928c9202d408e4eafff6a0a6"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:16122ae448bc89e2bea9d81ce6cb0f751e4e07da39bd1e70b95cae2493857853"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-win32.whl", hash = "sha256:71cc168c305a4445109cd0d4925406f6e66bcb48fde99a1835387c58af4ecfe9"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-win_amd64.whl", hash = "sha256:59ee78f2ecd53fef8454909cda7400fe2cfcd820f62b8a5d4dfe930102268054"}, - {file = "rapidfuzz-3.9.6-cp313-cp313-win_arm64.whl", hash = "sha256:58b4ce83f223605c358ae37e7a2d19a41b96aa65b1fede99cc664c9053af89ac"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9f469dbc9c4aeaac7dd005992af74b7dff94aa56a3ea063ce64e4b3e6736dd2f"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a9ed7ad9adb68d0fe63a156fe752bbf5f1403ed66961551e749641af2874da92"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39ffe48ffbeedf78d120ddfb9d583f2ca906712159a4e9c3c743c9f33e7b1775"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8502ccdea9084d54b6f737d96a3b60a84e3afed9d016686dc979b49cdac71613"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6a4bec4956e06b170ca896ba055d08d4c457dac745548172443982956a80e118"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2c0488b1c273be39e109ff885ccac0448b2fa74dea4c4dc676bcf756c15f16d6"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0542c036cb6acf24edd2c9e0411a67d7ba71e29e4d3001a082466b86fc34ff30"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:0a96b52c9f26857bf009e270dcd829381e7a634f7ddd585fa29b87d4c82146d9"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:6edd3cd7c4aa8c68c716d349f531bd5011f2ca49ddade216bb4429460151559f"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:50b2fb55d7ed58c66d49c9f954acd8fc4a3f0e9fd0ff708299bd8abb68238d0e"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:32848dfe54391636b84cda1823fd23e5a6b1dbb8be0e9a1d80e4ee9903820994"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:29146cb7a1bf69c87e928b31bffa54f066cb65639d073b36e1425f98cccdebc6"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-win32.whl", hash = "sha256:aed13e5edacb0ecadcc304cc66e93e7e77ff24f059c9792ee602c0381808e10c"}, - {file = "rapidfuzz-3.9.6-cp38-cp38-win_amd64.whl", hash = "sha256:af440e36b828922256d0b4d79443bf2cbe5515fc4b0e9e96017ec789b36bb9fc"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:efa674b407424553024522159296690d99d6e6b1192cafe99ca84592faff16b4"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0b40ff76ee19b03ebf10a0a87938f86814996a822786c41c3312d251b7927849"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16a6c7997cb5927ced6f617122eb116ba514ec6b6f60f4803e7925ef55158891"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f3f42504bdc8d770987fc3d99964766d42b2a03e4d5b0f891decdd256236bae0"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ad9462aa2be9f60b540c19a083471fdf28e7cf6434f068b631525b5e6251b35e"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1629698e68f47609a73bf9e73a6da3a4cac20bc710529215cbdf111ab603665b"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:68bc7621843d8e9a7fd1b1a32729465bf94b47b6fb307d906da168413331f8d6"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:c6254c50f15bc2fcc33cb93a95a81b702d9e6590f432a7f7822b8c7aba9ae288"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:7e535a114fa575bc143e175e4ca386a467ec8c42909eff500f5f0f13dc84e3e0"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:d50acc0e9d67e4ba7a004a14c42d1b1e8b6ca1c515692746f4f8e7948c673167"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:fa742ec60bec53c5a211632cf1d31b9eb5a3c80f1371a46a23ac25a1fa2ab209"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:c256fa95d29cbe5aa717db790b231a9a5b49e5983d50dc9df29d364a1db5e35b"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-win32.whl", hash = "sha256:89acbf728b764421036c173a10ada436ecca22999851cdc01d0aa904c70d362d"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-win_amd64.whl", hash = "sha256:c608fcba8b14d86c04cb56b203fed31a96e8a1ebb4ce99e7b70313c5bf8cf497"}, - {file = "rapidfuzz-3.9.6-cp39-cp39-win_arm64.whl", hash = "sha256:d41c00ded0e22e9dba88ff23ebe0dc9d2a5f21ba2f88e185ea7374461e61daa9"}, - {file = "rapidfuzz-3.9.6-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:a65c2f63218ea2dedd56fc56361035e189ca123bd9c9ce63a9bef6f99540d681"}, - {file = "rapidfuzz-3.9.6-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:680dc78a5f889d3b89f74824b89fe357f49f88ad10d2c121e9c3ad37bac1e4eb"}, - {file = "rapidfuzz-3.9.6-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b8ca862927a0b05bd825e46ddf82d0724ea44b07d898ef639386530bf9b40f15"}, - {file = "rapidfuzz-3.9.6-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2116fa1fbff21fa52cd46f3cfcb1e193ba1d65d81f8b6e123193451cd3d6c15e"}, - {file = "rapidfuzz-3.9.6-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4dcb7d9afd740370a897c15da61d3d57a8d54738d7c764a99cedb5f746d6a003"}, - {file = "rapidfuzz-3.9.6-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:1a5bd6401bb489e14cbb5981c378d53ede850b7cc84b2464cad606149cc4e17d"}, - {file = "rapidfuzz-3.9.6-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:29fda70b9d03e29df6fc45cc27cbcc235534b1b0b2900e0a3ae0b43022aaeef5"}, - {file = "rapidfuzz-3.9.6-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:88144f5f52ae977df9352029488326afadd7a7f42c6779d486d1f82d43b2b1f2"}, - {file = "rapidfuzz-3.9.6-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:715aeaabafba2709b9dd91acb2a44bad59d60b4616ef90c08f4d4402a3bbca60"}, - {file = "rapidfuzz-3.9.6-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:af26ebd3714224fbf9bebbc27bdbac14f334c15f5d7043699cd694635050d6ca"}, - {file = "rapidfuzz-3.9.6-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:101bd2df438861a005ed47c032631b7857dfcdb17b82beeeb410307983aac61d"}, - {file = "rapidfuzz-3.9.6-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:2185e8e29809b97ad22a7f99281d1669a89bdf5fa1ef4ef1feca36924e675367"}, - {file = "rapidfuzz-3.9.6-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:9e53c72d08f0e9c6e4a369e52df5971f311305b4487690c62e8dd0846770260c"}, - {file = "rapidfuzz-3.9.6-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:a0cb157162f0cdd62e538c7bd298ff669847fc43a96422811d5ab933f4c16c3a"}, - {file = "rapidfuzz-3.9.6-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4bb5ff2bd48132ed5e7fbb8f619885facb2e023759f2519a448b2c18afe07e5d"}, - {file = "rapidfuzz-3.9.6-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6dc37f601865e8407e3a8037ffbc3afe0b0f837b2146f7632bd29d087385babe"}, - {file = "rapidfuzz-3.9.6-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a657eee4b94668faf1fa2703bdd803654303f7e468eb9ba10a664d867ed9e779"}, - {file = "rapidfuzz-3.9.6-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:51be6ab5b1d5bb32abd39718f2a5e3835502e026a8272d139ead295c224a6f5e"}, - {file = "rapidfuzz-3.9.6.tar.gz", hash = "sha256:5cf2a7d621e4515fee84722e93563bf77ff2cbe832a77a48b81f88f9e23b9e8d"}, -] - -[package.extras] -full = ["numpy"] + {file = "rapidfuzz-3.10.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f17d9f21bf2f2f785d74f7b0d407805468b4c173fa3e52c86ec94436b338e74a"}, + {file = "rapidfuzz-3.10.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b31f358a70efc143909fb3d75ac6cd3c139cd41339aa8f2a3a0ead8315731f2b"}, + {file = "rapidfuzz-3.10.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f4f43f2204b56a61448ec2dd061e26fd344c404da99fb19f3458200c5874ba2"}, + {file = "rapidfuzz-3.10.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9d81bf186a453a2757472133b24915768abc7c3964194406ed93e170e16c21cb"}, + {file = "rapidfuzz-3.10.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3611c8f45379a12063d70075c75134f2a8bd2e4e9b8a7995112ddae95ca1c982"}, + {file = "rapidfuzz-3.10.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3c3b537b97ac30da4b73930fa8a4fe2f79c6d1c10ad535c5c09726612cd6bed9"}, + {file = "rapidfuzz-3.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:231ef1ec9cf7b59809ce3301006500b9d564ddb324635f4ea8f16b3e2a1780da"}, + {file = "rapidfuzz-3.10.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ed4f3adc1294834955b7e74edd3c6bd1aad5831c007f2d91ea839e76461a5879"}, + {file = "rapidfuzz-3.10.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:7b6015da2e707bf632a71772a2dbf0703cff6525732c005ad24987fe86e8ec32"}, + {file = "rapidfuzz-3.10.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1b35a118d61d6f008e8e3fb3a77674d10806a8972c7b8be433d6598df4d60b01"}, + {file = "rapidfuzz-3.10.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:bc308d79a7e877226f36bdf4e149e3ed398d8277c140be5c1fd892ec41739e6d"}, + {file = "rapidfuzz-3.10.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f017dbfecc172e2d0c37cf9e3d519179d71a7f16094b57430dffc496a098aa17"}, + {file = "rapidfuzz-3.10.1-cp310-cp310-win32.whl", hash = "sha256:36c0e1483e21f918d0f2f26799fe5ac91c7b0c34220b73007301c4f831a9c4c7"}, + {file = "rapidfuzz-3.10.1-cp310-cp310-win_amd64.whl", hash = "sha256:10746c1d4c8cd8881c28a87fd7ba0c9c102346dfe7ff1b0d021cdf093e9adbff"}, + {file = "rapidfuzz-3.10.1-cp310-cp310-win_arm64.whl", hash = "sha256:dfa64b89dcb906835e275187569e51aa9d546a444489e97aaf2cc84011565fbe"}, + {file = "rapidfuzz-3.10.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:92958ae075c87fef393f835ed02d4fe8d5ee2059a0934c6c447ea3417dfbf0e8"}, + {file = "rapidfuzz-3.10.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ba7521e072c53e33c384e78615d0718e645cab3c366ecd3cc8cb732befd94967"}, + {file = "rapidfuzz-3.10.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00d02cbd75d283c287471b5b3738b3e05c9096150f93f2d2dfa10b3d700f2db9"}, + {file = "rapidfuzz-3.10.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:efa1582a397da038e2f2576c9cd49b842f56fde37d84a6b0200ffebc08d82350"}, + {file = "rapidfuzz-3.10.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f12912acee1f506f974f58de9fdc2e62eea5667377a7e9156de53241c05fdba8"}, + {file = "rapidfuzz-3.10.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:666d5d8b17becc3f53447bcb2b6b33ce6c2df78792495d1fa82b2924cd48701a"}, + {file = "rapidfuzz-3.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26f71582c0d62445067ee338ddad99b655a8f4e4ed517a90dcbfbb7d19310474"}, + {file = "rapidfuzz-3.10.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8a2ef08b27167bcff230ffbfeedd4c4fa6353563d6aaa015d725dd3632fc3de7"}, + {file = "rapidfuzz-3.10.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:365e4fc1a2b95082c890f5e98489b894e6bf8c338c6ac89bb6523c2ca6e9f086"}, + {file = "rapidfuzz-3.10.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:1996feb7a61609fa842e6b5e0c549983222ffdedaf29644cc67e479902846dfe"}, + {file = "rapidfuzz-3.10.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:cf654702f144beaa093103841a2ea6910d617d0bb3fccb1d1fd63c54dde2cd49"}, + {file = "rapidfuzz-3.10.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ec108bf25de674781d0a9a935030ba090c78d49def3d60f8724f3fc1e8e75024"}, + {file = "rapidfuzz-3.10.1-cp311-cp311-win32.whl", hash = "sha256:031f8b367e5d92f7a1e27f7322012f3c321c3110137b43cc3bf678505583ef48"}, + {file = "rapidfuzz-3.10.1-cp311-cp311-win_amd64.whl", hash = "sha256:f98f36c6a1bb9a6c8bbec99ad87c8c0e364f34761739b5ea9adf7b48129ae8cf"}, + {file = "rapidfuzz-3.10.1-cp311-cp311-win_arm64.whl", hash = "sha256:f1da2028cb4e41be55ee797a82d6c1cf589442504244249dfeb32efc608edee7"}, + {file = "rapidfuzz-3.10.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:1340b56340896bede246f612b6ecf685f661a56aabef3d2512481bfe23ac5835"}, + {file = "rapidfuzz-3.10.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2316515169b7b5a453f0ce3adbc46c42aa332cae9f2edb668e24d1fc92b2f2bb"}, + {file = "rapidfuzz-3.10.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e06fe6a12241ec1b72c0566c6b28cda714d61965d86569595ad24793d1ab259"}, + {file = "rapidfuzz-3.10.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d99c1cd9443b19164ec185a7d752f4b4db19c066c136f028991a480720472e23"}, + {file = "rapidfuzz-3.10.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a1d9aa156ed52d3446388ba4c2f335e312191d1ca9d1f5762ee983cf23e4ecf6"}, + {file = "rapidfuzz-3.10.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:54bcf4efaaee8e015822be0c2c28214815f4f6b4f70d8362cfecbd58a71188ac"}, + {file = "rapidfuzz-3.10.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0c955e32afdbfdf6e9ee663d24afb25210152d98c26d22d399712d29a9b976b"}, + {file = "rapidfuzz-3.10.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:191633722203f5b7717efcb73a14f76f3b124877d0608c070b827c5226d0b972"}, + {file = "rapidfuzz-3.10.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:195baad28057ec9609e40385991004e470af9ef87401e24ebe72c064431524ab"}, + {file = "rapidfuzz-3.10.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:0fff4a6b87c07366662b62ae994ffbeadc472e72f725923f94b72a3db49f4671"}, + {file = "rapidfuzz-3.10.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:4ffed25f9fdc0b287f30a98467493d1e1ce5b583f6317f70ec0263b3c97dbba6"}, + {file = "rapidfuzz-3.10.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d02cf8e5af89a9ac8f53c438ddff6d773f62c25c6619b29db96f4aae248177c0"}, + {file = "rapidfuzz-3.10.1-cp312-cp312-win32.whl", hash = "sha256:f3bb81d4fe6a5d20650f8c0afcc8f6e1941f6fecdb434f11b874c42467baded0"}, + {file = "rapidfuzz-3.10.1-cp312-cp312-win_amd64.whl", hash = "sha256:aaf83e9170cb1338922ae42d320699dccbbdca8ffed07faeb0b9257822c26e24"}, + {file = "rapidfuzz-3.10.1-cp312-cp312-win_arm64.whl", hash = "sha256:c5da802a0d085ad81b0f62828fb55557996c497b2d0b551bbdfeafd6d447892f"}, + {file = "rapidfuzz-3.10.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:fc22d69a1c9cccd560a5c434c0371b2df0f47c309c635a01a913e03bbf183710"}, + {file = "rapidfuzz-3.10.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:38b0dac2c8e057562b8f0d8ae5b663d2d6a28c5ab624de5b73cef9abb6129a24"}, + {file = "rapidfuzz-3.10.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6fde3bbb14e92ce8fcb5c2edfff72e474d0080cadda1c97785bf4822f037a309"}, + {file = "rapidfuzz-3.10.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9141fb0592e55f98fe9ac0f3ce883199b9c13e262e0bf40c5b18cdf926109d16"}, + {file = "rapidfuzz-3.10.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:237bec5dd1bfc9b40bbd786cd27949ef0c0eb5fab5eb491904c6b5df59d39d3c"}, + {file = "rapidfuzz-3.10.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18123168cba156ab5794ea6de66db50f21bb3c66ae748d03316e71b27d907b95"}, + {file = "rapidfuzz-3.10.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b75fe506c8e02769cc47f5ab21ce3e09b6211d3edaa8f8f27331cb6988779be"}, + {file = "rapidfuzz-3.10.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9da82aa4b46973aaf9e03bb4c3d6977004648c8638febfc0f9d237e865761270"}, + {file = "rapidfuzz-3.10.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:c34c022d5ad564f1a5a57a4a89793bd70d7bad428150fb8ff2760b223407cdcf"}, + {file = "rapidfuzz-3.10.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:1e96c84d6c2a0ca94e15acb5399118fff669f4306beb98a6d8ec6f5dccab4412"}, + {file = "rapidfuzz-3.10.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:e8e154b84a311263e1aca86818c962e1fa9eefdd643d1d5d197fcd2738f88cb9"}, + {file = "rapidfuzz-3.10.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:335fee93188f8cd585552bb8057228ce0111bd227fa81bfd40b7df6b75def8ab"}, + {file = "rapidfuzz-3.10.1-cp313-cp313-win32.whl", hash = "sha256:6729b856166a9e95c278410f73683957ea6100c8a9d0a8dbe434c49663689255"}, + {file = "rapidfuzz-3.10.1-cp313-cp313-win_amd64.whl", hash = "sha256:0e06d99ad1ad97cb2ef7f51ec6b1fedd74a3a700e4949353871cf331d07b382a"}, + {file = "rapidfuzz-3.10.1-cp313-cp313-win_arm64.whl", hash = "sha256:8d1b7082104d596a3eb012e0549b2634ed15015b569f48879701e9d8db959dbb"}, + {file = "rapidfuzz-3.10.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:779027d3307e1a2b1dc0c03c34df87a470a368a1a0840a9d2908baf2d4067956"}, + {file = "rapidfuzz-3.10.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:440b5608ab12650d0390128d6858bc839ae77ffe5edf0b33a1551f2fa9860651"}, + {file = "rapidfuzz-3.10.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82cac41a411e07a6f3dc80dfbd33f6be70ea0abd72e99c59310819d09f07d945"}, + {file = "rapidfuzz-3.10.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:958473c9f0bca250590200fd520b75be0dbdbc4a7327dc87a55b6d7dc8d68552"}, + {file = "rapidfuzz-3.10.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ef60dfa73749ef91cb6073be1a3e135f4846ec809cc115f3cbfc6fe283a5584"}, + {file = "rapidfuzz-3.10.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7fbac18f2c19fc983838a60611e67e3262e36859994c26f2ee85bb268de2355"}, + {file = "rapidfuzz-3.10.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a0d519ff39db887cd73f4e297922786d548f5c05d6b51f4e6754f452a7f4296"}, + {file = "rapidfuzz-3.10.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:bebb7bc6aeb91cc57e4881b222484c26759ca865794187217c9dcea6c33adae6"}, + {file = "rapidfuzz-3.10.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:fe07f8b9c3bb5c5ad1d2c66884253e03800f4189a60eb6acd6119ebaf3eb9894"}, + {file = "rapidfuzz-3.10.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:bfa48a4a2d45a41457f0840c48e579db157a927f4e97acf6e20df8fc521c79de"}, + {file = "rapidfuzz-3.10.1-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:2cf44d01bfe8ee605b7eaeecbc2b9ca64fc55765f17b304b40ed8995f69d7716"}, + {file = "rapidfuzz-3.10.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1e6bbca9246d9eedaa1c84e04a7f555493ba324d52ae4d9f3d9ddd1b740dcd87"}, + {file = "rapidfuzz-3.10.1-cp39-cp39-win32.whl", hash = "sha256:567f88180f2c1423b4fe3f3ad6e6310fc97b85bdba574801548597287fc07028"}, + {file = "rapidfuzz-3.10.1-cp39-cp39-win_amd64.whl", hash = "sha256:6b2cd7c29d6ecdf0b780deb587198f13213ac01c430ada6913452fd0c40190fc"}, + {file = "rapidfuzz-3.10.1-cp39-cp39-win_arm64.whl", hash = "sha256:9f912d459e46607ce276128f52bea21ebc3e9a5ccf4cccfef30dd5bddcf47be8"}, + {file = "rapidfuzz-3.10.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:ac4452f182243cfab30ba4668ef2de101effaedc30f9faabb06a095a8c90fd16"}, + {file = "rapidfuzz-3.10.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:565c2bd4f7d23c32834652b27b51dd711814ab614b4e12add8476be4e20d1cf5"}, + {file = "rapidfuzz-3.10.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:187d9747149321607be4ccd6f9f366730078bed806178ec3eeb31d05545e9e8f"}, + {file = "rapidfuzz-3.10.1-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:616290fb9a8fa87e48cb0326d26f98d4e29f17c3b762c2d586f2b35c1fd2034b"}, + {file = "rapidfuzz-3.10.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:073a5b107e17ebd264198b78614c0206fa438cce749692af5bc5f8f484883f50"}, + {file = "rapidfuzz-3.10.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:39c4983e2e2ccb9732f3ac7d81617088822f4a12291d416b09b8a1eadebb3e29"}, + {file = "rapidfuzz-3.10.1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:ac7adee6bcf0c6fee495d877edad1540a7e0f5fc208da03ccb64734b43522d7a"}, + {file = "rapidfuzz-3.10.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:425f4ac80b22153d391ee3f94bc854668a0c6c129f05cf2eaf5ee74474ddb69e"}, + {file = "rapidfuzz-3.10.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:65a2fa13e8a219f9b5dcb9e74abe3ced5838a7327e629f426d333dfc8c5a6e66"}, + {file = "rapidfuzz-3.10.1-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:75561f3df9a906aaa23787e9992b228b1ab69007932dc42070f747103e177ba8"}, + {file = "rapidfuzz-3.10.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:edd062490537e97ca125bc6c7f2b7331c2b73d21dc304615afe61ad1691e15d5"}, + {file = "rapidfuzz-3.10.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:cfcc8feccf63245a22dfdd16e222f1a39771a44b870beb748117a0e09cbb4a62"}, + {file = "rapidfuzz-3.10.1.tar.gz", hash = "sha256:5a15546d847a915b3f42dc79ef9b0c78b998b4e2c53b252e7166284066585979"}, +] + +[package.extras] +all = ["numpy"] [[package]] name = "readabilipy" @@ -7128,6 +8049,23 @@ dev = ["coveralls", "m2r", "pycodestyle", "pyflakes", "pylint", "pytest", "pytes docs = ["m2r", "sphinx"] test = ["coveralls", "pycodestyle", "pyflakes", "pylint", "pytest", "pytest-benchmark", "pytest-cov"] +[[package]] +name = "realtime" +version = "2.0.2" +description = "" +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "realtime-2.0.2-py3-none-any.whl", hash = "sha256:2634c915bc38807f2013f21e8bcc4d2f79870dfd81460ddb9393883d0489928a"}, + {file = "realtime-2.0.2.tar.gz", hash = "sha256:519da9325b3b8102139d51785013d592f6b2403d81fa21d838a0b0234723ed7d"}, +] + +[package.dependencies] +aiohttp = ">=3.10.2,<4.0.0" +python-dateutil = ">=2.8.1,<3.0.0" +typing-extensions = ">=4.12.2,<5.0.0" +websockets = ">=11,<13" + [[package]] name = "redis" version = "5.0.8" @@ -7147,92 +8085,122 @@ hiredis = {version = ">1.0.0", optional = true, markers = "extra == \"hiredis\"" hiredis = ["hiredis (>1.0.0)"] ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] +[[package]] +name = "referencing" +version = "0.35.1" +description = "JSON Referencing + Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "referencing-0.35.1-py3-none-any.whl", hash = "sha256:eda6d3234d62814d1c64e305c1331c9a3a6132da475ab6382eaa997b21ee75de"}, + {file = "referencing-0.35.1.tar.gz", hash = "sha256:25b42124a6c8b632a425174f24087783efb348a6f1e0008e63cd4466fedf703c"}, +] + +[package.dependencies] +attrs = ">=22.2.0" +rpds-py = ">=0.7.0" + [[package]] name = "regex" -version = "2024.7.24" +version = "2024.9.11" description = "Alternative regular expression module, to replace re." optional = false python-versions = ">=3.8" files = [ - {file = "regex-2024.7.24-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b0d3f567fafa0633aee87f08b9276c7062da9616931382993c03808bb68ce"}, - {file = "regex-2024.7.24-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3426de3b91d1bc73249042742f45c2148803c111d1175b283270177fdf669024"}, - {file = "regex-2024.7.24-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f273674b445bcb6e4409bf8d1be67bc4b58e8b46fd0d560055d515b8830063cd"}, - {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23acc72f0f4e1a9e6e9843d6328177ae3074b4182167e34119ec7233dfeccf53"}, - {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65fd3d2e228cae024c411c5ccdffae4c315271eee4a8b839291f84f796b34eca"}, - {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c414cbda77dbf13c3bc88b073a1a9f375c7b0cb5e115e15d4b73ec3a2fbc6f59"}, - {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf7a89eef64b5455835f5ed30254ec19bf41f7541cd94f266ab7cbd463f00c41"}, - {file = "regex-2024.7.24-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:19c65b00d42804e3fbea9708f0937d157e53429a39b7c61253ff15670ff62cb5"}, - {file = "regex-2024.7.24-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7a5486ca56c8869070a966321d5ab416ff0f83f30e0e2da1ab48815c8d165d46"}, - {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6f51f9556785e5a203713f5efd9c085b4a45aecd2a42573e2b5041881b588d1f"}, - {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:a4997716674d36a82eab3e86f8fa77080a5d8d96a389a61ea1d0e3a94a582cf7"}, - {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:c0abb5e4e8ce71a61d9446040c1e86d4e6d23f9097275c5bd49ed978755ff0fe"}, - {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:18300a1d78cf1290fa583cd8b7cde26ecb73e9f5916690cf9d42de569c89b1ce"}, - {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:416c0e4f56308f34cdb18c3f59849479dde5b19febdcd6e6fa4d04b6c31c9faa"}, - {file = "regex-2024.7.24-cp310-cp310-win32.whl", hash = "sha256:fb168b5924bef397b5ba13aabd8cf5df7d3d93f10218d7b925e360d436863f66"}, - {file = "regex-2024.7.24-cp310-cp310-win_amd64.whl", hash = "sha256:6b9fc7e9cc983e75e2518496ba1afc524227c163e43d706688a6bb9eca41617e"}, - {file = "regex-2024.7.24-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:382281306e3adaaa7b8b9ebbb3ffb43358a7bbf585fa93821300a418bb975281"}, - {file = "regex-2024.7.24-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4fdd1384619f406ad9037fe6b6eaa3de2749e2e12084abc80169e8e075377d3b"}, - {file = "regex-2024.7.24-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3d974d24edb231446f708c455fd08f94c41c1ff4f04bcf06e5f36df5ef50b95a"}, - {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a2ec4419a3fe6cf8a4795752596dfe0adb4aea40d3683a132bae9c30b81e8d73"}, - {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eb563dd3aea54c797adf513eeec819c4213d7dbfc311874eb4fd28d10f2ff0f2"}, - {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:45104baae8b9f67569f0f1dca5e1f1ed77a54ae1cd8b0b07aba89272710db61e"}, - {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:994448ee01864501912abf2bad9203bffc34158e80fe8bfb5b031f4f8e16da51"}, - {file = "regex-2024.7.24-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3fac296f99283ac232d8125be932c5cd7644084a30748fda013028c815ba3364"}, - {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7e37e809b9303ec3a179085415cb5f418ecf65ec98cdfe34f6a078b46ef823ee"}, - {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:01b689e887f612610c869421241e075c02f2e3d1ae93a037cb14f88ab6a8934c"}, - {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:f6442f0f0ff81775eaa5b05af8a0ffa1dda36e9cf6ec1e0d3d245e8564b684ce"}, - {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:871e3ab2838fbcb4e0865a6e01233975df3a15e6fce93b6f99d75cacbd9862d1"}, - {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c918b7a1e26b4ab40409820ddccc5d49871a82329640f5005f73572d5eaa9b5e"}, - {file = "regex-2024.7.24-cp311-cp311-win32.whl", hash = "sha256:2dfbb8baf8ba2c2b9aa2807f44ed272f0913eeeba002478c4577b8d29cde215c"}, - {file = "regex-2024.7.24-cp311-cp311-win_amd64.whl", hash = "sha256:538d30cd96ed7d1416d3956f94d54e426a8daf7c14527f6e0d6d425fcb4cca52"}, - {file = "regex-2024.7.24-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:fe4ebef608553aff8deb845c7f4f1d0740ff76fa672c011cc0bacb2a00fbde86"}, - {file = "regex-2024.7.24-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:74007a5b25b7a678459f06559504f1eec2f0f17bca218c9d56f6a0a12bfffdad"}, - {file = "regex-2024.7.24-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7df9ea48641da022c2a3c9c641650cd09f0cd15e8908bf931ad538f5ca7919c9"}, - {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a1141a1dcc32904c47f6846b040275c6e5de0bf73f17d7a409035d55b76f289"}, - {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:80c811cfcb5c331237d9bad3bea2c391114588cf4131707e84d9493064d267f9"}, - {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7214477bf9bd195894cf24005b1e7b496f46833337b5dedb7b2a6e33f66d962c"}, - {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d55588cba7553f0b6ec33130bc3e114b355570b45785cebdc9daed8c637dd440"}, - {file = "regex-2024.7.24-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:558a57cfc32adcf19d3f791f62b5ff564922942e389e3cfdb538a23d65a6b610"}, - {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a512eed9dfd4117110b1881ba9a59b31433caed0c4101b361f768e7bcbaf93c5"}, - {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:86b17ba823ea76256b1885652e3a141a99a5c4422f4a869189db328321b73799"}, - {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:5eefee9bfe23f6df09ffb6dfb23809f4d74a78acef004aa904dc7c88b9944b05"}, - {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:731fcd76bbdbf225e2eb85b7c38da9633ad3073822f5ab32379381e8c3c12e94"}, - {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:eaef80eac3b4cfbdd6de53c6e108b4c534c21ae055d1dbea2de6b3b8ff3def38"}, - {file = "regex-2024.7.24-cp312-cp312-win32.whl", hash = "sha256:185e029368d6f89f36e526764cf12bf8d6f0e3a2a7737da625a76f594bdfcbfc"}, - {file = "regex-2024.7.24-cp312-cp312-win_amd64.whl", hash = "sha256:2f1baff13cc2521bea83ab2528e7a80cbe0ebb2c6f0bfad15be7da3aed443908"}, - {file = "regex-2024.7.24-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:66b4c0731a5c81921e938dcf1a88e978264e26e6ac4ec96a4d21ae0354581ae0"}, - {file = "regex-2024.7.24-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:88ecc3afd7e776967fa16c80f974cb79399ee8dc6c96423321d6f7d4b881c92b"}, - {file = "regex-2024.7.24-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:64bd50cf16bcc54b274e20235bf8edbb64184a30e1e53873ff8d444e7ac656b2"}, - {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb462f0e346fcf41a901a126b50f8781e9a474d3927930f3490f38a6e73b6950"}, - {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a82465ebbc9b1c5c50738536fdfa7cab639a261a99b469c9d4c7dcbb2b3f1e57"}, - {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:68a8f8c046c6466ac61a36b65bb2395c74451df2ffb8458492ef49900efed293"}, - {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dac8e84fff5d27420f3c1e879ce9929108e873667ec87e0c8eeb413a5311adfe"}, - {file = "regex-2024.7.24-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba2537ef2163db9e6ccdbeb6f6424282ae4dea43177402152c67ef869cf3978b"}, - {file = "regex-2024.7.24-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:43affe33137fcd679bdae93fb25924979517e011f9dea99163f80b82eadc7e53"}, - {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:c9bb87fdf2ab2370f21e4d5636e5317775e5d51ff32ebff2cf389f71b9b13750"}, - {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:945352286a541406f99b2655c973852da7911b3f4264e010218bbc1cc73168f2"}, - {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:8bc593dcce679206b60a538c302d03c29b18e3d862609317cb560e18b66d10cf"}, - {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:3f3b6ca8eae6d6c75a6cff525c8530c60e909a71a15e1b731723233331de4169"}, - {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c51edc3541e11fbe83f0c4d9412ef6c79f664a3745fab261457e84465ec9d5a8"}, - {file = "regex-2024.7.24-cp38-cp38-win32.whl", hash = "sha256:d0a07763776188b4db4c9c7fb1b8c494049f84659bb387b71c73bbc07f189e96"}, - {file = "regex-2024.7.24-cp38-cp38-win_amd64.whl", hash = "sha256:8fd5afd101dcf86a270d254364e0e8dddedebe6bd1ab9d5f732f274fa00499a5"}, - {file = "regex-2024.7.24-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0ffe3f9d430cd37d8fa5632ff6fb36d5b24818c5c986893063b4e5bdb84cdf24"}, - {file = "regex-2024.7.24-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:25419b70ba00a16abc90ee5fce061228206173231f004437730b67ac77323f0d"}, - {file = "regex-2024.7.24-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:33e2614a7ce627f0cdf2ad104797d1f68342d967de3695678c0cb84f530709f8"}, - {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d33a0021893ede5969876052796165bab6006559ab845fd7b515a30abdd990dc"}, - {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:04ce29e2c5fedf296b1a1b0acc1724ba93a36fb14031f3abfb7abda2806c1535"}, - {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b16582783f44fbca6fcf46f61347340c787d7530d88b4d590a397a47583f31dd"}, - {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:836d3cc225b3e8a943d0b02633fb2f28a66e281290302a79df0e1eaa984ff7c1"}, - {file = "regex-2024.7.24-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:438d9f0f4bc64e8dea78274caa5af971ceff0f8771e1a2333620969936ba10be"}, - {file = "regex-2024.7.24-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:973335b1624859cb0e52f96062a28aa18f3a5fc77a96e4a3d6d76e29811a0e6e"}, - {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:c5e69fd3eb0b409432b537fe3c6f44ac089c458ab6b78dcec14478422879ec5f"}, - {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:fbf8c2f00904eaf63ff37718eb13acf8e178cb940520e47b2f05027f5bb34ce3"}, - {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ae2757ace61bc4061b69af19e4689fa4416e1a04840f33b441034202b5cd02d4"}, - {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:44fc61b99035fd9b3b9453f1713234e5a7c92a04f3577252b45feefe1b327759"}, - {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:84c312cdf839e8b579f504afcd7b65f35d60b6285d892b19adea16355e8343c9"}, - {file = "regex-2024.7.24-cp39-cp39-win32.whl", hash = "sha256:ca5b2028c2f7af4e13fb9fc29b28d0ce767c38c7facdf64f6c2cd040413055f1"}, - {file = "regex-2024.7.24-cp39-cp39-win_amd64.whl", hash = "sha256:7c479f5ae937ec9985ecaf42e2e10631551d909f203e31308c12d703922742f9"}, - {file = "regex-2024.7.24.tar.gz", hash = "sha256:9cfd009eed1a46b27c14039ad5bbc5e71b6367c5b2e6d5f5da0ea91600817506"}, + {file = "regex-2024.9.11-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:1494fa8725c285a81d01dc8c06b55287a1ee5e0e382d8413adc0a9197aac6408"}, + {file = "regex-2024.9.11-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0e12c481ad92d129c78f13a2a3662317e46ee7ef96c94fd332e1c29131875b7d"}, + {file = "regex-2024.9.11-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:16e13a7929791ac1216afde26f712802e3df7bf0360b32e4914dca3ab8baeea5"}, + {file = "regex-2024.9.11-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:46989629904bad940bbec2106528140a218b4a36bb3042d8406980be1941429c"}, + {file = "regex-2024.9.11-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a906ed5e47a0ce5f04b2c981af1c9acf9e8696066900bf03b9d7879a6f679fc8"}, + {file = "regex-2024.9.11-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e9a091b0550b3b0207784a7d6d0f1a00d1d1c8a11699c1a4d93db3fbefc3ad35"}, + {file = "regex-2024.9.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ddcd9a179c0a6fa8add279a4444015acddcd7f232a49071ae57fa6e278f1f71"}, + {file = "regex-2024.9.11-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6b41e1adc61fa347662b09398e31ad446afadff932a24807d3ceb955ed865cc8"}, + {file = "regex-2024.9.11-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ced479f601cd2f8ca1fd7b23925a7e0ad512a56d6e9476f79b8f381d9d37090a"}, + {file = "regex-2024.9.11-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:635a1d96665f84b292e401c3d62775851aedc31d4f8784117b3c68c4fcd4118d"}, + {file = "regex-2024.9.11-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:c0256beda696edcf7d97ef16b2a33a8e5a875affd6fa6567b54f7c577b30a137"}, + {file = "regex-2024.9.11-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:3ce4f1185db3fbde8ed8aa223fc9620f276c58de8b0d4f8cc86fd1360829edb6"}, + {file = "regex-2024.9.11-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:09d77559e80dcc9d24570da3745ab859a9cf91953062e4ab126ba9d5993688ca"}, + {file = "regex-2024.9.11-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7a22ccefd4db3f12b526eccb129390942fe874a3a9fdbdd24cf55773a1faab1a"}, + {file = "regex-2024.9.11-cp310-cp310-win32.whl", hash = "sha256:f745ec09bc1b0bd15cfc73df6fa4f726dcc26bb16c23a03f9e3367d357eeedd0"}, + {file = "regex-2024.9.11-cp310-cp310-win_amd64.whl", hash = "sha256:01c2acb51f8a7d6494c8c5eafe3d8e06d76563d8a8a4643b37e9b2dd8a2ff623"}, + {file = "regex-2024.9.11-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:2cce2449e5927a0bf084d346da6cd5eb016b2beca10d0013ab50e3c226ffc0df"}, + {file = "regex-2024.9.11-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3b37fa423beefa44919e009745ccbf353d8c981516e807995b2bd11c2c77d268"}, + {file = "regex-2024.9.11-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:64ce2799bd75039b480cc0360907c4fb2f50022f030bf9e7a8705b636e408fad"}, + {file = "regex-2024.9.11-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a4cc92bb6db56ab0c1cbd17294e14f5e9224f0cc6521167ef388332604e92679"}, + {file = "regex-2024.9.11-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d05ac6fa06959c4172eccd99a222e1fbf17b5670c4d596cb1e5cde99600674c4"}, + {file = "regex-2024.9.11-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:040562757795eeea356394a7fb13076ad4f99d3c62ab0f8bdfb21f99a1f85664"}, + {file = "regex-2024.9.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6113c008a7780792efc80f9dfe10ba0cd043cbf8dc9a76ef757850f51b4edc50"}, + {file = "regex-2024.9.11-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8e5fb5f77c8745a60105403a774fe2c1759b71d3e7b4ca237a5e67ad066c7199"}, + {file = "regex-2024.9.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:54d9ff35d4515debf14bc27f1e3b38bfc453eff3220f5bce159642fa762fe5d4"}, + {file = "regex-2024.9.11-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:df5cbb1fbc74a8305b6065d4ade43b993be03dbe0f8b30032cced0d7740994bd"}, + {file = "regex-2024.9.11-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:7fb89ee5d106e4a7a51bce305ac4efb981536301895f7bdcf93ec92ae0d91c7f"}, + {file = "regex-2024.9.11-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:a738b937d512b30bf75995c0159c0ddf9eec0775c9d72ac0202076c72f24aa96"}, + {file = "regex-2024.9.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e28f9faeb14b6f23ac55bfbbfd3643f5c7c18ede093977f1df249f73fd22c7b1"}, + {file = "regex-2024.9.11-cp311-cp311-win32.whl", hash = "sha256:18e707ce6c92d7282dfce370cd205098384b8ee21544e7cb29b8aab955b66fa9"}, + {file = "regex-2024.9.11-cp311-cp311-win_amd64.whl", hash = "sha256:313ea15e5ff2a8cbbad96ccef6be638393041b0a7863183c2d31e0c6116688cf"}, + {file = "regex-2024.9.11-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:b0d0a6c64fcc4ef9c69bd5b3b3626cc3776520a1637d8abaa62b9edc147a58f7"}, + {file = "regex-2024.9.11-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:49b0e06786ea663f933f3710a51e9385ce0cba0ea56b67107fd841a55d56a231"}, + {file = "regex-2024.9.11-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5b513b6997a0b2f10e4fd3a1313568e373926e8c252bd76c960f96fd039cd28d"}, + {file = "regex-2024.9.11-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee439691d8c23e76f9802c42a95cfeebf9d47cf4ffd06f18489122dbb0a7ad64"}, + {file = "regex-2024.9.11-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a8f877c89719d759e52783f7fe6e1c67121076b87b40542966c02de5503ace42"}, + {file = "regex-2024.9.11-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:23b30c62d0f16827f2ae9f2bb87619bc4fba2044911e2e6c2eb1af0161cdb766"}, + {file = "regex-2024.9.11-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85ab7824093d8f10d44330fe1e6493f756f252d145323dd17ab6b48733ff6c0a"}, + {file = "regex-2024.9.11-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8dee5b4810a89447151999428fe096977346cf2f29f4d5e29609d2e19e0199c9"}, + {file = "regex-2024.9.11-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:98eeee2f2e63edae2181c886d7911ce502e1292794f4c5ee71e60e23e8d26b5d"}, + {file = "regex-2024.9.11-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:57fdd2e0b2694ce6fc2e5ccf189789c3e2962916fb38779d3e3521ff8fe7a822"}, + {file = "regex-2024.9.11-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:d552c78411f60b1fdaafd117a1fca2f02e562e309223b9d44b7de8be451ec5e0"}, + {file = "regex-2024.9.11-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:a0b2b80321c2ed3fcf0385ec9e51a12253c50f146fddb2abbb10f033fe3d049a"}, + {file = "regex-2024.9.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:18406efb2f5a0e57e3a5881cd9354c1512d3bb4f5c45d96d110a66114d84d23a"}, + {file = "regex-2024.9.11-cp312-cp312-win32.whl", hash = "sha256:e464b467f1588e2c42d26814231edecbcfe77f5ac414d92cbf4e7b55b2c2a776"}, + {file = "regex-2024.9.11-cp312-cp312-win_amd64.whl", hash = "sha256:9e8719792ca63c6b8340380352c24dcb8cd7ec49dae36e963742a275dfae6009"}, + {file = "regex-2024.9.11-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:c157bb447303070f256e084668b702073db99bbb61d44f85d811025fcf38f784"}, + {file = "regex-2024.9.11-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4db21ece84dfeefc5d8a3863f101995de646c6cb0536952c321a2650aa202c36"}, + {file = "regex-2024.9.11-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:220e92a30b426daf23bb67a7962900ed4613589bab80382be09b48896d211e92"}, + {file = "regex-2024.9.11-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb1ae19e64c14c7ec1995f40bd932448713d3c73509e82d8cd7744dc00e29e86"}, + {file = "regex-2024.9.11-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f47cd43a5bfa48f86925fe26fbdd0a488ff15b62468abb5d2a1e092a4fb10e85"}, + {file = "regex-2024.9.11-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9d4a76b96f398697fe01117093613166e6aa8195d63f1b4ec3f21ab637632963"}, + {file = "regex-2024.9.11-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ea51dcc0835eea2ea31d66456210a4e01a076d820e9039b04ae8d17ac11dee6"}, + {file = "regex-2024.9.11-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b7aaa315101c6567a9a45d2839322c51c8d6e81f67683d529512f5bcfb99c802"}, + {file = "regex-2024.9.11-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c57d08ad67aba97af57a7263c2d9006d5c404d721c5f7542f077f109ec2a4a29"}, + {file = "regex-2024.9.11-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:f8404bf61298bb6f8224bb9176c1424548ee1181130818fcd2cbffddc768bed8"}, + {file = "regex-2024.9.11-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:dd4490a33eb909ef5078ab20f5f000087afa2a4daa27b4c072ccb3cb3050ad84"}, + {file = "regex-2024.9.11-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:eee9130eaad130649fd73e5cd92f60e55708952260ede70da64de420cdcad554"}, + {file = "regex-2024.9.11-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6a2644a93da36c784e546de579ec1806bfd2763ef47babc1b03d765fe560c9f8"}, + {file = "regex-2024.9.11-cp313-cp313-win32.whl", hash = "sha256:e997fd30430c57138adc06bba4c7c2968fb13d101e57dd5bb9355bf8ce3fa7e8"}, + {file = "regex-2024.9.11-cp313-cp313-win_amd64.whl", hash = "sha256:042c55879cfeb21a8adacc84ea347721d3d83a159da6acdf1116859e2427c43f"}, + {file = "regex-2024.9.11-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:35f4a6f96aa6cb3f2f7247027b07b15a374f0d5b912c0001418d1d55024d5cb4"}, + {file = "regex-2024.9.11-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:55b96e7ce3a69a8449a66984c268062fbaa0d8ae437b285428e12797baefce7e"}, + {file = "regex-2024.9.11-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cb130fccd1a37ed894824b8c046321540263013da72745d755f2d35114b81a60"}, + {file = "regex-2024.9.11-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:323c1f04be6b2968944d730e5c2091c8c89767903ecaa135203eec4565ed2b2b"}, + {file = "regex-2024.9.11-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:be1c8ed48c4c4065ecb19d882a0ce1afe0745dfad8ce48c49586b90a55f02366"}, + {file = "regex-2024.9.11-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b5b029322e6e7b94fff16cd120ab35a253236a5f99a79fb04fda7ae71ca20ae8"}, + {file = "regex-2024.9.11-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6fff13ef6b5f29221d6904aa816c34701462956aa72a77f1f151a8ec4f56aeb"}, + {file = "regex-2024.9.11-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:587d4af3979376652010e400accc30404e6c16b7df574048ab1f581af82065e4"}, + {file = "regex-2024.9.11-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:079400a8269544b955ffa9e31f186f01d96829110a3bf79dc338e9910f794fca"}, + {file = "regex-2024.9.11-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:f9268774428ec173654985ce55fc6caf4c6d11ade0f6f914d48ef4719eb05ebb"}, + {file = "regex-2024.9.11-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:23f9985c8784e544d53fc2930fc1ac1a7319f5d5332d228437acc9f418f2f168"}, + {file = "regex-2024.9.11-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:ae2941333154baff9838e88aa71c1d84f4438189ecc6021a12c7573728b5838e"}, + {file = "regex-2024.9.11-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:e93f1c331ca8e86fe877a48ad64e77882c0c4da0097f2212873a69bbfea95d0c"}, + {file = "regex-2024.9.11-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:846bc79ee753acf93aef4184c040d709940c9d001029ceb7b7a52747b80ed2dd"}, + {file = "regex-2024.9.11-cp38-cp38-win32.whl", hash = "sha256:c94bb0a9f1db10a1d16c00880bdebd5f9faf267273b8f5bd1878126e0fbde771"}, + {file = "regex-2024.9.11-cp38-cp38-win_amd64.whl", hash = "sha256:2b08fce89fbd45664d3df6ad93e554b6c16933ffa9d55cb7e01182baaf971508"}, + {file = "regex-2024.9.11-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:07f45f287469039ffc2c53caf6803cd506eb5f5f637f1d4acb37a738f71dd066"}, + {file = "regex-2024.9.11-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4838e24ee015101d9f901988001038f7f0d90dc0c3b115541a1365fb439add62"}, + {file = "regex-2024.9.11-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6edd623bae6a737f10ce853ea076f56f507fd7726bee96a41ee3d68d347e4d16"}, + {file = "regex-2024.9.11-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c69ada171c2d0e97a4b5aa78fbb835e0ffbb6b13fc5da968c09811346564f0d3"}, + {file = "regex-2024.9.11-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:02087ea0a03b4af1ed6ebab2c54d7118127fee8d71b26398e8e4b05b78963199"}, + {file = "regex-2024.9.11-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:69dee6a020693d12a3cf892aba4808fe168d2a4cef368eb9bf74f5398bfd4ee8"}, + {file = "regex-2024.9.11-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:297f54910247508e6e5cae669f2bc308985c60540a4edd1c77203ef19bfa63ca"}, + {file = "regex-2024.9.11-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ecea58b43a67b1b79805f1a0255730edaf5191ecef84dbc4cc85eb30bc8b63b9"}, + {file = "regex-2024.9.11-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:eab4bb380f15e189d1313195b062a6aa908f5bd687a0ceccd47c8211e9cf0d4a"}, + {file = "regex-2024.9.11-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0cbff728659ce4bbf4c30b2a1be040faafaa9eca6ecde40aaff86f7889f4ab39"}, + {file = "regex-2024.9.11-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:54c4a097b8bc5bb0dfc83ae498061d53ad7b5762e00f4adaa23bee22b012e6ba"}, + {file = "regex-2024.9.11-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:73d6d2f64f4d894c96626a75578b0bf7d9e56dcda8c3d037a2118fdfe9b1c664"}, + {file = "regex-2024.9.11-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:e53b5fbab5d675aec9f0c501274c467c0f9a5d23696cfc94247e1fb56501ed89"}, + {file = "regex-2024.9.11-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:0ffbcf9221e04502fc35e54d1ce9567541979c3fdfb93d2c554f0ca583a19b35"}, + {file = "regex-2024.9.11-cp39-cp39-win32.whl", hash = "sha256:e4c22e1ac1f1ec1e09f72e6c44d8f2244173db7eb9629cc3a346a8d7ccc31142"}, + {file = "regex-2024.9.11-cp39-cp39-win_amd64.whl", hash = "sha256:faa3c142464efec496967359ca99696c896c591c56c53506bac1ad465f66e919"}, + {file = "regex-2024.9.11.tar.gz", hash = "sha256:6c188c307e8433bcb63dc1915022deb553b4203a70722fc542c363bf120a01fd"}, ] [[package]] @@ -7336,24 +8304,152 @@ files = [ [package.dependencies] requests = "2.31.0" +[[package]] +name = "retry" +version = "0.9.2" +description = "Easy to use retry decorator." +optional = false +python-versions = "*" +files = [ + {file = "retry-0.9.2-py2.py3-none-any.whl", hash = "sha256:ccddf89761fa2c726ab29391837d4327f819ea14d244c232a1d24c67a2f98606"}, + {file = "retry-0.9.2.tar.gz", hash = "sha256:f8bfa8b99b69c4506d6f5bd3b0aabf77f98cdb17f3c9fc3f5ca820033336fba4"}, +] + +[package.dependencies] +decorator = ">=3.4.2" +py = ">=1.4.26,<2.0.0" + [[package]] name = "rich" -version = "13.7.1" +version = "13.9.3" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" optional = false -python-versions = ">=3.7.0" +python-versions = ">=3.8.0" files = [ - {file = "rich-13.7.1-py3-none-any.whl", hash = "sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222"}, - {file = "rich-13.7.1.tar.gz", hash = "sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432"}, + {file = "rich-13.9.3-py3-none-any.whl", hash = "sha256:9836f5096eb2172c9e77df411c1b009bace4193d6a481d534fea75ebba758283"}, + {file = "rich-13.9.3.tar.gz", hash = "sha256:bc1e01b899537598cf02579d2b9f4a415104d3fc439313a7a2c165d76557a08e"}, ] [package.dependencies] markdown-it-py = ">=2.2.0" pygments = ">=2.13.0,<3.0.0" +typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.11\""} [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] +[[package]] +name = "rpds-py" +version = "0.20.0" +description = "Python bindings to Rust's persistent data structures (rpds)" +optional = false +python-versions = ">=3.8" +files = [ + {file = "rpds_py-0.20.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3ad0fda1635f8439cde85c700f964b23ed5fc2d28016b32b9ee5fe30da5c84e2"}, + {file = "rpds_py-0.20.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9bb4a0d90fdb03437c109a17eade42dfbf6190408f29b2744114d11586611d6f"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6377e647bbfd0a0b159fe557f2c6c602c159fc752fa316572f012fc0bf67150"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb851b7df9dda52dc1415ebee12362047ce771fc36914586b2e9fcbd7d293b3e"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1e0f80b739e5a8f54837be5d5c924483996b603d5502bfff79bf33da06164ee2"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5a8c94dad2e45324fc74dce25e1645d4d14df9a4e54a30fa0ae8bad9a63928e3"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8e604fe73ba048c06085beaf51147eaec7df856824bfe7b98657cf436623daf"}, + {file = "rpds_py-0.20.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:df3de6b7726b52966edf29663e57306b23ef775faf0ac01a3e9f4012a24a4140"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:cf258ede5bc22a45c8e726b29835b9303c285ab46fc7c3a4cc770736b5304c9f"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:55fea87029cded5df854ca7e192ec7bdb7ecd1d9a3f63d5c4eb09148acf4a7ce"}, + {file = "rpds_py-0.20.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ae94bd0b2f02c28e199e9bc51485d0c5601f58780636185660f86bf80c89af94"}, + {file = "rpds_py-0.20.0-cp310-none-win32.whl", hash = "sha256:28527c685f237c05445efec62426d285e47a58fb05ba0090a4340b73ecda6dee"}, + {file = "rpds_py-0.20.0-cp310-none-win_amd64.whl", hash = "sha256:238a2d5b1cad28cdc6ed15faf93a998336eb041c4e440dd7f902528b8891b399"}, + {file = "rpds_py-0.20.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:ac2f4f7a98934c2ed6505aead07b979e6f999389f16b714448fb39bbaa86a489"}, + {file = "rpds_py-0.20.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:220002c1b846db9afd83371d08d239fdc865e8f8c5795bbaec20916a76db3318"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d7919548df3f25374a1f5d01fbcd38dacab338ef5f33e044744b5c36729c8db"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:758406267907b3781beee0f0edfe4a179fbd97c0be2e9b1154d7f0a1279cf8e5"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3d61339e9f84a3f0767b1995adfb171a0d00a1185192718a17af6e124728e0f5"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1259c7b3705ac0a0bd38197565a5d603218591d3f6cee6e614e380b6ba61c6f6"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c1dc0f53856b9cc9a0ccca0a7cc61d3d20a7088201c0937f3f4048c1718a209"}, + {file = "rpds_py-0.20.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7e60cb630f674a31f0368ed32b2a6b4331b8350d67de53c0359992444b116dd3"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:dbe982f38565bb50cb7fb061ebf762c2f254ca3d8c20d4006878766e84266272"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:514b3293b64187172bc77c8fb0cdae26981618021053b30d8371c3a902d4d5ad"}, + {file = "rpds_py-0.20.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d0a26ffe9d4dd35e4dfdd1e71f46401cff0181c75ac174711ccff0459135fa58"}, + {file = "rpds_py-0.20.0-cp311-none-win32.whl", hash = "sha256:89c19a494bf3ad08c1da49445cc5d13d8fefc265f48ee7e7556839acdacf69d0"}, + {file = "rpds_py-0.20.0-cp311-none-win_amd64.whl", hash = "sha256:c638144ce971df84650d3ed0096e2ae7af8e62ecbbb7b201c8935c370df00a2c"}, + {file = "rpds_py-0.20.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a84ab91cbe7aab97f7446652d0ed37d35b68a465aeef8fc41932a9d7eee2c1a6"}, + {file = "rpds_py-0.20.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:56e27147a5a4c2c21633ff8475d185734c0e4befd1c989b5b95a5d0db699b21b"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2580b0c34583b85efec8c5c5ec9edf2dfe817330cc882ee972ae650e7b5ef739"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b80d4a7900cf6b66bb9cee5c352b2d708e29e5a37fe9bf784fa97fc11504bf6c"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:50eccbf054e62a7b2209b28dc7a22d6254860209d6753e6b78cfaeb0075d7bee"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:49a8063ea4296b3a7e81a5dfb8f7b2d73f0b1c20c2af401fb0cdf22e14711a96"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea438162a9fcbee3ecf36c23e6c68237479f89f962f82dae83dc15feeceb37e4"}, + {file = "rpds_py-0.20.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:18d7585c463087bddcfa74c2ba267339f14f2515158ac4db30b1f9cbdb62c8ef"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d4c7d1a051eeb39f5c9547e82ea27cbcc28338482242e3e0b7768033cb083821"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:e4df1e3b3bec320790f699890d41c59d250f6beda159ea3c44c3f5bac1976940"}, + {file = "rpds_py-0.20.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2cf126d33a91ee6eedc7f3197b53e87a2acdac63602c0f03a02dd69e4b138174"}, + {file = "rpds_py-0.20.0-cp312-none-win32.whl", hash = "sha256:8bc7690f7caee50b04a79bf017a8d020c1f48c2a1077ffe172abec59870f1139"}, + {file = "rpds_py-0.20.0-cp312-none-win_amd64.whl", hash = "sha256:0e13e6952ef264c40587d510ad676a988df19adea20444c2b295e536457bc585"}, + {file = "rpds_py-0.20.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:aa9a0521aeca7d4941499a73ad7d4f8ffa3d1affc50b9ea11d992cd7eff18a29"}, + {file = "rpds_py-0.20.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4a1f1d51eccb7e6c32ae89243cb352389228ea62f89cd80823ea7dd1b98e0b91"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8a86a9b96070674fc88b6f9f71a97d2c1d3e5165574615d1f9168ecba4cecb24"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6c8ef2ebf76df43f5750b46851ed1cdf8f109d7787ca40035fe19fbdc1acc5a7"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b74b25f024b421d5859d156750ea9a65651793d51b76a2e9238c05c9d5f203a9"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57eb94a8c16ab08fef6404301c38318e2c5a32216bf5de453e2714c964c125c8"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1940dae14e715e2e02dfd5b0f64a52e8374a517a1e531ad9412319dc3ac7879"}, + {file = "rpds_py-0.20.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d20277fd62e1b992a50c43f13fbe13277a31f8c9f70d59759c88f644d66c619f"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:06db23d43f26478303e954c34c75182356ca9aa7797d22c5345b16871ab9c45c"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b2a5db5397d82fa847e4c624b0c98fe59d2d9b7cf0ce6de09e4d2e80f8f5b3f2"}, + {file = "rpds_py-0.20.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5a35df9f5548fd79cb2f52d27182108c3e6641a4feb0f39067911bf2adaa3e57"}, + {file = "rpds_py-0.20.0-cp313-none-win32.whl", hash = "sha256:fd2d84f40633bc475ef2d5490b9c19543fbf18596dcb1b291e3a12ea5d722f7a"}, + {file = "rpds_py-0.20.0-cp313-none-win_amd64.whl", hash = "sha256:9bc2d153989e3216b0559251b0c260cfd168ec78b1fac33dd485750a228db5a2"}, + {file = "rpds_py-0.20.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:f2fbf7db2012d4876fb0d66b5b9ba6591197b0f165db8d99371d976546472a24"}, + {file = "rpds_py-0.20.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1e5f3cd7397c8f86c8cc72d5a791071431c108edd79872cdd96e00abd8497d29"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce9845054c13696f7af7f2b353e6b4f676dab1b4b215d7fe5e05c6f8bb06f965"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c3e130fd0ec56cb76eb49ef52faead8ff09d13f4527e9b0c400307ff72b408e1"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4b16aa0107ecb512b568244ef461f27697164d9a68d8b35090e9b0c1c8b27752"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aa7f429242aae2947246587d2964fad750b79e8c233a2367f71b554e9447949c"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af0fc424a5842a11e28956e69395fbbeab2c97c42253169d87e90aac2886d751"}, + {file = "rpds_py-0.20.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b8c00a3b1e70c1d3891f0db1b05292747f0dbcfb49c43f9244d04c70fbc40eb8"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:40ce74fc86ee4645d0a225498d091d8bc61f39b709ebef8204cb8b5a464d3c0e"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:4fe84294c7019456e56d93e8ababdad5a329cd25975be749c3f5f558abb48253"}, + {file = "rpds_py-0.20.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:338ca4539aad4ce70a656e5187a3a31c5204f261aef9f6ab50e50bcdffaf050a"}, + {file = "rpds_py-0.20.0-cp38-none-win32.whl", hash = "sha256:54b43a2b07db18314669092bb2de584524d1ef414588780261e31e85846c26a5"}, + {file = "rpds_py-0.20.0-cp38-none-win_amd64.whl", hash = "sha256:a1862d2d7ce1674cffa6d186d53ca95c6e17ed2b06b3f4c476173565c862d232"}, + {file = "rpds_py-0.20.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:3fde368e9140312b6e8b6c09fb9f8c8c2f00999d1823403ae90cc00480221b22"}, + {file = "rpds_py-0.20.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9824fb430c9cf9af743cf7aaf6707bf14323fb51ee74425c380f4c846ea70789"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11ef6ce74616342888b69878d45e9f779b95d4bd48b382a229fe624a409b72c5"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c52d3f2f82b763a24ef52f5d24358553e8403ce05f893b5347098014f2d9eff2"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9d35cef91e59ebbeaa45214861874bc6f19eb35de96db73e467a8358d701a96c"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d72278a30111e5b5525c1dd96120d9e958464316f55adb030433ea905866f4de"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4c29cbbba378759ac5786730d1c3cb4ec6f8ababf5c42a9ce303dc4b3d08cda"}, + {file = "rpds_py-0.20.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6632f2d04f15d1bd6fe0eedd3b86d9061b836ddca4c03d5cf5c7e9e6b7c14580"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:d0b67d87bb45ed1cd020e8fbf2307d449b68abc45402fe1a4ac9e46c3c8b192b"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:ec31a99ca63bf3cd7f1a5ac9fe95c5e2d060d3c768a09bc1d16e235840861420"}, + {file = "rpds_py-0.20.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:22e6c9976e38f4d8c4a63bd8a8edac5307dffd3ee7e6026d97f3cc3a2dc02a0b"}, + {file = "rpds_py-0.20.0-cp39-none-win32.whl", hash = "sha256:569b3ea770c2717b730b61998b6c54996adee3cef69fc28d444f3e7920313cf7"}, + {file = "rpds_py-0.20.0-cp39-none-win_amd64.whl", hash = "sha256:e6900ecdd50ce0facf703f7a00df12374b74bbc8ad9fe0f6559947fb20f82364"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:617c7357272c67696fd052811e352ac54ed1d9b49ab370261a80d3b6ce385045"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9426133526f69fcaba6e42146b4e12d6bc6c839b8b555097020e2b78ce908dcc"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:deb62214c42a261cb3eb04d474f7155279c1a8a8c30ac89b7dcb1721d92c3c02"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fcaeb7b57f1a1e071ebd748984359fef83ecb026325b9d4ca847c95bc7311c92"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d454b8749b4bd70dd0a79f428731ee263fa6995f83ccb8bada706e8d1d3ff89d"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d807dc2051abe041b6649681dce568f8e10668e3c1c6543ebae58f2d7e617855"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3c20f0ddeb6e29126d45f89206b8291352b8c5b44384e78a6499d68b52ae511"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b7f19250ceef892adf27f0399b9e5afad019288e9be756d6919cb58892129f51"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:4f1ed4749a08379555cebf4650453f14452eaa9c43d0a95c49db50c18b7da075"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:dcedf0b42bcb4cfff4101d7771a10532415a6106062f005ab97d1d0ab5681c60"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:39ed0d010457a78f54090fafb5d108501b5aa5604cc22408fc1c0c77eac14344"}, + {file = "rpds_py-0.20.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:bb273176be34a746bdac0b0d7e4e2c467323d13640b736c4c477881a3220a989"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f918a1a130a6dfe1d7fe0f105064141342e7dd1611f2e6a21cd2f5c8cb1cfb3e"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:f60012a73aa396be721558caa3a6fd49b3dd0033d1675c6d59c4502e870fcf0c"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d2b1ad682a3dfda2a4e8ad8572f3100f95fad98cb99faf37ff0ddfe9cbf9d03"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:614fdafe9f5f19c63ea02817fa4861c606a59a604a77c8cdef5aa01d28b97921"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fa518bcd7600c584bf42e6617ee8132869e877db2f76bcdc281ec6a4113a53ab"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f0475242f447cc6cb8a9dd486d68b2ef7fbee84427124c232bff5f63b1fe11e5"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f90a4cd061914a60bd51c68bcb4357086991bd0bb93d8aa66a6da7701370708f"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:def7400461c3a3f26e49078302e1c1b38f6752342c77e3cf72ce91ca69fb1bc1"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:65794e4048ee837494aea3c21a28ad5fc080994dfba5b036cf84de37f7ad5074"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_i686.whl", hash = "sha256:faefcc78f53a88f3076b7f8be0a8f8d35133a3ecf7f3770895c25f8813460f08"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:5b4f105deeffa28bbcdff6c49b34e74903139afa690e35d2d9e3c2c2fba18cec"}, + {file = "rpds_py-0.20.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:fdfc3a892927458d98f3d55428ae46b921d1f7543b89382fdb483f5640daaec8"}, + {file = "rpds_py-0.20.0.tar.gz", hash = "sha256:d72a210824facfdaf8768cf2d7ca25a042c30320b3020de2fa04640920d4e121"}, +] + [[package]] name = "rsa" version = "4.9" @@ -7370,40 +8466,40 @@ pyasn1 = ">=0.1.3" [[package]] name = "ruff" -version = "0.5.7" +version = "0.6.9" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.5.7-py3-none-linux_armv6l.whl", hash = "sha256:548992d342fc404ee2e15a242cdbea4f8e39a52f2e7752d0e4cbe88d2d2f416a"}, - {file = "ruff-0.5.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:00cc8872331055ee017c4f1071a8a31ca0809ccc0657da1d154a1d2abac5c0be"}, - {file = "ruff-0.5.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:eaf3d86a1fdac1aec8a3417a63587d93f906c678bb9ed0b796da7b59c1114a1e"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a01c34400097b06cf8a6e61b35d6d456d5bd1ae6961542de18ec81eaf33b4cb8"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fcc8054f1a717e2213500edaddcf1dbb0abad40d98e1bd9d0ad364f75c763eea"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7f70284e73f36558ef51602254451e50dd6cc479f8b6f8413a95fcb5db4a55fc"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:a78ad870ae3c460394fc95437d43deb5c04b5c29297815a2a1de028903f19692"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ccd078c66a8e419475174bfe60a69adb36ce04f8d4e91b006f1329d5cd44bcf"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7e31c9bad4ebf8fdb77b59cae75814440731060a09a0e0077d559a556453acbb"}, - {file = "ruff-0.5.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d796327eed8e168164346b769dd9a27a70e0298d667b4ecee6877ce8095ec8e"}, - {file = "ruff-0.5.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4a09ea2c3f7778cc635e7f6edf57d566a8ee8f485f3c4454db7771efb692c499"}, - {file = "ruff-0.5.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a36d8dcf55b3a3bc353270d544fb170d75d2dff41eba5df57b4e0b67a95bb64e"}, - {file = "ruff-0.5.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9369c218f789eefbd1b8d82a8cf25017b523ac47d96b2f531eba73770971c9e5"}, - {file = "ruff-0.5.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b88ca3db7eb377eb24fb7c82840546fb7acef75af4a74bd36e9ceb37a890257e"}, - {file = "ruff-0.5.7-py3-none-win32.whl", hash = "sha256:33d61fc0e902198a3e55719f4be6b375b28f860b09c281e4bdbf783c0566576a"}, - {file = "ruff-0.5.7-py3-none-win_amd64.whl", hash = "sha256:083bbcbe6fadb93cd86709037acc510f86eed5a314203079df174c40bbbca6b3"}, - {file = "ruff-0.5.7-py3-none-win_arm64.whl", hash = "sha256:2dca26154ff9571995107221d0aeaad0e75a77b5a682d6236cf89a58c70b76f4"}, - {file = "ruff-0.5.7.tar.gz", hash = "sha256:8dfc0a458797f5d9fb622dd0efc52d796f23f0a1493a9527f4e49a550ae9a7e5"}, + {file = "ruff-0.6.9-py3-none-linux_armv6l.whl", hash = "sha256:064df58d84ccc0ac0fcd63bc3090b251d90e2a372558c0f057c3f75ed73e1ccd"}, + {file = "ruff-0.6.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:140d4b5c9f5fc7a7b074908a78ab8d384dd7f6510402267bc76c37195c02a7ec"}, + {file = "ruff-0.6.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:53fd8ca5e82bdee8da7f506d7b03a261f24cd43d090ea9db9a1dc59d9313914c"}, + {file = "ruff-0.6.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:645d7d8761f915e48a00d4ecc3686969761df69fb561dd914a773c1a8266e14e"}, + {file = "ruff-0.6.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eae02b700763e3847595b9d2891488989cac00214da7f845f4bcf2989007d577"}, + {file = "ruff-0.6.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d5ccc9e58112441de8ad4b29dcb7a86dc25c5f770e3c06a9d57e0e5eba48829"}, + {file = "ruff-0.6.9-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:417b81aa1c9b60b2f8edc463c58363075412866ae4e2b9ab0f690dc1e87ac1b5"}, + {file = "ruff-0.6.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3c866b631f5fbce896a74a6e4383407ba7507b815ccc52bcedabb6810fdb3ef7"}, + {file = "ruff-0.6.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7b118afbb3202f5911486ad52da86d1d52305b59e7ef2031cea3425142b97d6f"}, + {file = "ruff-0.6.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a67267654edc23c97335586774790cde402fb6bbdb3c2314f1fc087dee320bfa"}, + {file = "ruff-0.6.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3ef0cc774b00fec123f635ce5c547dac263f6ee9fb9cc83437c5904183b55ceb"}, + {file = "ruff-0.6.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:12edd2af0c60fa61ff31cefb90aef4288ac4d372b4962c2864aeea3a1a2460c0"}, + {file = "ruff-0.6.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:55bb01caeaf3a60b2b2bba07308a02fca6ab56233302406ed5245180a05c5625"}, + {file = "ruff-0.6.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:925d26471fa24b0ce5a6cdfab1bb526fb4159952385f386bdcc643813d472039"}, + {file = "ruff-0.6.9-py3-none-win32.whl", hash = "sha256:eb61ec9bdb2506cffd492e05ac40e5bc6284873aceb605503d8494180d6fc84d"}, + {file = "ruff-0.6.9-py3-none-win_amd64.whl", hash = "sha256:785d31851c1ae91f45b3d8fe23b8ae4b5170089021fbb42402d811135f0b7117"}, + {file = "ruff-0.6.9-py3-none-win_arm64.whl", hash = "sha256:a9641e31476d601f83cd602608739a0840e348bda93fec9f1ee816f8b6798b93"}, + {file = "ruff-0.6.9.tar.gz", hash = "sha256:b076ef717a8e5bc819514ee1d602bbdca5b4420ae13a9cf61a0c0a4f53a2baa2"}, ] [[package]] name = "s3transfer" -version = "0.10.2" +version = "0.10.3" description = "An Amazon S3 Transfer Manager" optional = false python-versions = ">=3.8" files = [ - {file = "s3transfer-0.10.2-py3-none-any.whl", hash = "sha256:eca1c20de70a39daee580aef4986996620f365c4e0fda6a86100231d62f1bf69"}, - {file = "s3transfer-0.10.2.tar.gz", hash = "sha256:0711534e9356d3cc692fdde846b4a1e4b0cb6519971860796e6bc4c7aea00ef6"}, + {file = "s3transfer-0.10.3-py3-none-any.whl", hash = "sha256:263ed587a5803c6c708d3ce44dc4dfedaab4c1a32e8329bab818933d79ddcf5d"}, + {file = "s3transfer-0.10.3.tar.gz", hash = "sha256:4f50ed74ab84d474ce614475e0b8d5047ff080810aac5d01ea25231cfc944b0c"}, ] [package.dependencies] @@ -7414,121 +8510,121 @@ crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"] [[package]] name = "safetensors" -version = "0.4.4" +version = "0.4.5" description = "" optional = false python-versions = ">=3.7" files = [ - {file = "safetensors-0.4.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2adb497ada13097f30e386e88c959c0fda855a5f6f98845710f5bb2c57e14f12"}, - {file = "safetensors-0.4.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7db7fdc2d71fd1444d85ca3f3d682ba2df7d61a637dfc6d80793f439eae264ab"}, - {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d4f0eed76b430f009fbefca1a0028ddb112891b03cb556d7440d5cd68eb89a9"}, - {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:57d216fab0b5c432aabf7170883d7c11671622bde8bd1436c46d633163a703f6"}, - {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7d9b76322e49c056bcc819f8bdca37a2daa5a6d42c07f30927b501088db03309"}, - {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:32f0d1f6243e90ee43bc6ee3e8c30ac5b09ca63f5dd35dbc985a1fc5208c451a"}, - {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44d464bdc384874601a177375028012a5f177f1505279f9456fea84bbc575c7f"}, - {file = "safetensors-0.4.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:63144e36209ad8e4e65384dbf2d52dd5b1866986079c00a72335402a38aacdc5"}, - {file = "safetensors-0.4.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:051d5ecd490af7245258000304b812825974d5e56f14a3ff7e1b8b2ba6dc2ed4"}, - {file = "safetensors-0.4.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:51bc8429d9376224cd3cf7e8ce4f208b4c930cd10e515b6ac6a72cbc3370f0d9"}, - {file = "safetensors-0.4.4-cp310-none-win32.whl", hash = "sha256:fb7b54830cee8cf9923d969e2df87ce20e625b1af2fd194222ab902d3adcc29c"}, - {file = "safetensors-0.4.4-cp310-none-win_amd64.whl", hash = "sha256:4b3e8aa8226d6560de8c2b9d5ff8555ea482599c670610758afdc97f3e021e9c"}, - {file = "safetensors-0.4.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:bbaa31f2cb49013818bde319232ccd72da62ee40f7d2aa532083eda5664e85ff"}, - {file = "safetensors-0.4.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9fdcb80f4e9fbb33b58e9bf95e7dbbedff505d1bcd1c05f7c7ce883632710006"}, - {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55c14c20be247b8a1aeaf3ab4476265e3ca83096bb8e09bb1a7aa806088def4f"}, - {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:949aaa1118660f992dbf0968487b3e3cfdad67f948658ab08c6b5762e90cc8b6"}, - {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c11a4ab7debc456326a2bac67f35ee0ac792bcf812c7562a4a28559a5c795e27"}, - {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c0cea44bba5c5601b297bc8307e4075535b95163402e4906b2e9b82788a2a6df"}, - {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9d752c97f6bbe327352f76e5b86442d776abc789249fc5e72eacb49e6916482"}, - {file = "safetensors-0.4.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:03f2bb92e61b055ef6cc22883ad1ae898010a95730fa988c60a23800eb742c2c"}, - {file = "safetensors-0.4.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:87bf3f91a9328a941acc44eceffd4e1f5f89b030985b2966637e582157173b98"}, - {file = "safetensors-0.4.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:20d218ec2b6899d29d6895419a58b6e44cc5ff8f0cc29fac8d236a8978ab702e"}, - {file = "safetensors-0.4.4-cp311-none-win32.whl", hash = "sha256:8079486118919f600c603536e2490ca37b3dbd3280e3ad6eaacfe6264605ac8a"}, - {file = "safetensors-0.4.4-cp311-none-win_amd64.whl", hash = "sha256:2f8c2eb0615e2e64ee27d478c7c13f51e5329d7972d9e15528d3e4cfc4a08f0d"}, - {file = "safetensors-0.4.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:baec5675944b4a47749c93c01c73d826ef7d42d36ba8d0dba36336fa80c76426"}, - {file = "safetensors-0.4.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f15117b96866401825f3e94543145028a2947d19974429246ce59403f49e77c6"}, - {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a13a9caea485df164c51be4eb0c87f97f790b7c3213d635eba2314d959fe929"}, - {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6b54bc4ca5f9b9bba8cd4fb91c24b2446a86b5ae7f8975cf3b7a277353c3127c"}, - {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:08332c22e03b651c8eb7bf5fc2de90044f3672f43403b3d9ac7e7e0f4f76495e"}, - {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bb62841e839ee992c37bb75e75891c7f4904e772db3691c59daaca5b4ab960e1"}, - {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e5b927acc5f2f59547270b0309a46d983edc44be64e1ca27a7fcb0474d6cd67"}, - {file = "safetensors-0.4.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2a69c71b1ae98a8021a09a0b43363b0143b0ce74e7c0e83cacba691b62655fb8"}, - {file = "safetensors-0.4.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:23654ad162c02a5636f0cd520a0310902c4421aab1d91a0b667722a4937cc445"}, - {file = "safetensors-0.4.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:0677c109d949cf53756859160b955b2e75b0eefe952189c184d7be30ecf7e858"}, - {file = "safetensors-0.4.4-cp312-none-win32.whl", hash = "sha256:a51d0ddd4deb8871c6de15a772ef40b3dbd26a3c0451bb9e66bc76fc5a784e5b"}, - {file = "safetensors-0.4.4-cp312-none-win_amd64.whl", hash = "sha256:2d065059e75a798bc1933c293b68d04d79b586bb7f8c921e0ca1e82759d0dbb1"}, - {file = "safetensors-0.4.4-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:9d625692578dd40a112df30c02a1adf068027566abd8e6a74893bb13d441c150"}, - {file = "safetensors-0.4.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7cabcf39c81e5b988d0adefdaea2eb9b4fd9bd62d5ed6559988c62f36bfa9a89"}, - {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8359bef65f49d51476e9811d59c015f0ddae618ee0e44144f5595278c9f8268c"}, - {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1a32c662e7df9226fd850f054a3ead0e4213a96a70b5ce37b2d26ba27004e013"}, - {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c329a4dcc395364a1c0d2d1574d725fe81a840783dda64c31c5a60fc7d41472c"}, - {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:239ee093b1db877c9f8fe2d71331a97f3b9c7c0d3ab9f09c4851004a11f44b65"}, - {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bd574145d930cf9405a64f9923600879a5ce51d9f315443a5f706374841327b6"}, - {file = "safetensors-0.4.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f6784eed29f9e036acb0b7769d9e78a0dc2c72c2d8ba7903005350d817e287a4"}, - {file = "safetensors-0.4.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:65a4a6072436bf0a4825b1c295d248cc17e5f4651e60ee62427a5bcaa8622a7a"}, - {file = "safetensors-0.4.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:df81e3407630de060ae8313da49509c3caa33b1a9415562284eaf3d0c7705f9f"}, - {file = "safetensors-0.4.4-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:e4a0f374200e8443d9746e947ebb346c40f83a3970e75a685ade0adbba5c48d9"}, - {file = "safetensors-0.4.4-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:181fb5f3dee78dae7fd7ec57d02e58f7936498d587c6b7c1c8049ef448c8d285"}, - {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2cb4ac1d8f6b65ec84ddfacd275079e89d9df7c92f95675ba96c4f790a64df6e"}, - {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:76897944cd9239e8a70955679b531b9a0619f76e25476e57ed373322d9c2075d"}, - {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2a9e9d1a27e51a0f69e761a3d581c3af46729ec1c988fa1f839e04743026ae35"}, - {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:005ef9fc0f47cb9821c40793eb029f712e97278dae84de91cb2b4809b856685d"}, - {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26987dac3752688c696c77c3576f951dbbdb8c57f0957a41fb6f933cf84c0b62"}, - {file = "safetensors-0.4.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c05270b290acd8d249739f40d272a64dd597d5a4b90f27d830e538bc2549303c"}, - {file = "safetensors-0.4.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:068d3a33711fc4d93659c825a04480ff5a3854e1d78632cdc8f37fee917e8a60"}, - {file = "safetensors-0.4.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:063421ef08ca1021feea8b46951251b90ae91f899234dd78297cbe7c1db73b99"}, - {file = "safetensors-0.4.4-cp37-none-win32.whl", hash = "sha256:d52f5d0615ea83fd853d4e1d8acf93cc2e0223ad4568ba1e1f6ca72e94ea7b9d"}, - {file = "safetensors-0.4.4-cp37-none-win_amd64.whl", hash = "sha256:88a5ac3280232d4ed8e994cbc03b46a1807ce0aa123867b40c4a41f226c61f94"}, - {file = "safetensors-0.4.4-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:3467ab511bfe3360967d7dc53b49f272d59309e57a067dd2405b4d35e7dcf9dc"}, - {file = "safetensors-0.4.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2ab4c96d922e53670ce25fbb9b63d5ea972e244de4fa1dd97b590d9fd66aacef"}, - {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:87df18fce4440477c3ef1fd7ae17c704a69a74a77e705a12be135ee0651a0c2d"}, - {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0e5fe345b2bc7d88587149ac11def1f629d2671c4c34f5df38aed0ba59dc37f8"}, - {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9f1a3e01dce3cd54060791e7e24588417c98b941baa5974700eeb0b8eb65b0a0"}, - {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c6bf35e9a8998d8339fd9a05ac4ce465a4d2a2956cc0d837b67c4642ed9e947"}, - {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:166c0c52f6488b8538b2a9f3fbc6aad61a7261e170698779b371e81b45f0440d"}, - {file = "safetensors-0.4.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:87e9903b8668a16ef02c08ba4ebc91e57a49c481e9b5866e31d798632805014b"}, - {file = "safetensors-0.4.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a9c421153aa23c323bd8483d4155b4eee82c9a50ac11cccd83539104a8279c64"}, - {file = "safetensors-0.4.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a4b8617499b2371c7353302c5116a7e0a3a12da66389ce53140e607d3bf7b3d3"}, - {file = "safetensors-0.4.4-cp38-none-win32.whl", hash = "sha256:c6280f5aeafa1731f0a3709463ab33d8e0624321593951aefada5472f0b313fd"}, - {file = "safetensors-0.4.4-cp38-none-win_amd64.whl", hash = "sha256:6ceed6247fc2d33b2a7b7d25d8a0fe645b68798856e0bc7a9800c5fd945eb80f"}, - {file = "safetensors-0.4.4-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:5cf6c6f6193797372adf50c91d0171743d16299491c75acad8650107dffa9269"}, - {file = "safetensors-0.4.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:419010156b914a3e5da4e4adf992bee050924d0fe423c4b329e523e2c14c3547"}, - {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88f6fd5a5c1302ce79993cc5feeadcc795a70f953c762544d01fb02b2db4ea33"}, - {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d468cffb82d90789696d5b4d8b6ab8843052cba58a15296691a7a3df55143cd2"}, - {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9353c2af2dd467333d4850a16edb66855e795561cd170685178f706c80d2c71e"}, - {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:83c155b4a33368d9b9c2543e78f2452090fb030c52401ca608ef16fa58c98353"}, - {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9850754c434e636ce3dc586f534bb23bcbd78940c304775bee9005bf610e98f1"}, - {file = "safetensors-0.4.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:275f500b4d26f67b6ec05629a4600645231bd75e4ed42087a7c1801bff04f4b3"}, - {file = "safetensors-0.4.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:5c2308de665b7130cd0e40a2329278226e4cf083f7400c51ca7e19ccfb3886f3"}, - {file = "safetensors-0.4.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e06a9ebc8656e030ccfe44634f2a541b4b1801cd52e390a53ad8bacbd65f8518"}, - {file = "safetensors-0.4.4-cp39-none-win32.whl", hash = "sha256:ef73df487b7c14b477016947c92708c2d929e1dee2bacdd6fff5a82ed4539537"}, - {file = "safetensors-0.4.4-cp39-none-win_amd64.whl", hash = "sha256:83d054818a8d1198d8bd8bc3ea2aac112a2c19def2bf73758321976788706398"}, - {file = "safetensors-0.4.4-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:1d1f34c71371f0e034004a0b583284b45d233dd0b5f64a9125e16b8a01d15067"}, - {file = "safetensors-0.4.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1a8043a33d58bc9b30dfac90f75712134ca34733ec3d8267b1bd682afe7194f5"}, - {file = "safetensors-0.4.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8db8f0c59c84792c12661f8efa85de160f80efe16b87a9d5de91b93f9e0bce3c"}, - {file = "safetensors-0.4.4-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cfc1fc38e37630dd12d519bdec9dcd4b345aec9930bb9ce0ed04461f49e58b52"}, - {file = "safetensors-0.4.4-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e5c9d86d9b13b18aafa88303e2cd21e677f5da2a14c828d2c460fe513af2e9a5"}, - {file = "safetensors-0.4.4-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:43251d7f29a59120a26f5a0d9583b9e112999e500afabcfdcb91606d3c5c89e3"}, - {file = "safetensors-0.4.4-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:2c42e9b277513b81cf507e6121c7b432b3235f980cac04f39f435b7902857f91"}, - {file = "safetensors-0.4.4-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3daacc9a4e3f428a84dd56bf31f20b768eb0b204af891ed68e1f06db9edf546f"}, - {file = "safetensors-0.4.4-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:218bbb9b883596715fc9997bb42470bf9f21bb832c3b34c2bf744d6fa8f2bbba"}, - {file = "safetensors-0.4.4-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7bd5efc26b39f7fc82d4ab1d86a7f0644c8e34f3699c33f85bfa9a717a030e1b"}, - {file = "safetensors-0.4.4-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:56ad9776b65d8743f86698a1973292c966cf3abff627efc44ed60e66cc538ddd"}, - {file = "safetensors-0.4.4-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:30f23e6253c5f43a809dea02dc28a9f5fa747735dc819f10c073fe1b605e97d4"}, - {file = "safetensors-0.4.4-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:5512078d00263de6cb04e9d26c9ae17611098f52357fea856213e38dc462f81f"}, - {file = "safetensors-0.4.4-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b96c3d9266439d17f35fc2173111d93afc1162f168e95aed122c1ca517b1f8f1"}, - {file = "safetensors-0.4.4-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:08d464aa72a9a13826946b4fb9094bb4b16554bbea2e069e20bd903289b6ced9"}, - {file = "safetensors-0.4.4-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:210160816d5a36cf41f48f38473b6f70d7bcb4b0527bedf0889cc0b4c3bb07db"}, - {file = "safetensors-0.4.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eb276a53717f2bcfb6df0bcf284d8a12069002508d4c1ca715799226024ccd45"}, - {file = "safetensors-0.4.4-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a2c28c6487f17d8db0089e8b2cdc13de859366b94cc6cdc50e1b0a4147b56551"}, - {file = "safetensors-0.4.4-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:7915f0c60e4e6e65d90f136d85dd3b429ae9191c36b380e626064694563dbd9f"}, - {file = "safetensors-0.4.4-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:00eea99ae422fbfa0b46065acbc58b46bfafadfcec179d4b4a32d5c45006af6c"}, - {file = "safetensors-0.4.4-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:bb1ed4fcb0b3c2f3ea2c5767434622fe5d660e5752f21ac2e8d737b1e5e480bb"}, - {file = "safetensors-0.4.4-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:73fc9a0a4343188bdb421783e600bfaf81d0793cd4cce6bafb3c2ed567a74cd5"}, - {file = "safetensors-0.4.4-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c37e6b714200824c73ca6eaf007382de76f39466a46e97558b8dc4cf643cfbf"}, - {file = "safetensors-0.4.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f75698c5c5c542417ac4956acfc420f7d4a2396adca63a015fd66641ea751759"}, - {file = "safetensors-0.4.4-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ca1a209157f242eb183e209040097118472e169f2e069bfbd40c303e24866543"}, - {file = "safetensors-0.4.4-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:177f2b60a058f92a3cec7a1786c9106c29eca8987ecdfb79ee88126e5f47fa31"}, - {file = "safetensors-0.4.4-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:ee9622e84fe6e4cd4f020e5fda70d6206feff3157731df7151d457fdae18e541"}, - {file = "safetensors-0.4.4.tar.gz", hash = "sha256:5fe3e9b705250d0172ed4e100a811543108653fb2b66b9e702a088ad03772a07"}, + {file = "safetensors-0.4.5-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:a63eaccd22243c67e4f2b1c3e258b257effc4acd78f3b9d397edc8cf8f1298a7"}, + {file = "safetensors-0.4.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:23fc9b4ec7b602915cbb4ec1a7c1ad96d2743c322f20ab709e2c35d1b66dad27"}, + {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6885016f34bef80ea1085b7e99b3c1f92cb1be78a49839203060f67b40aee761"}, + {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:133620f443450429322f238fda74d512c4008621227fccf2f8cf4a76206fea7c"}, + {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4fb3e0609ec12d2a77e882f07cced530b8262027f64b75d399f1504ffec0ba56"}, + {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d0f1dd769f064adc33831f5e97ad07babbd728427f98e3e1db6902e369122737"}, + {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6d156bdb26732feada84f9388a9f135528c1ef5b05fae153da365ad4319c4c5"}, + {file = "safetensors-0.4.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9e347d77e2c77eb7624400ccd09bed69d35c0332f417ce8c048d404a096c593b"}, + {file = "safetensors-0.4.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9f556eea3aec1d3d955403159fe2123ddd68e880f83954ee9b4a3f2e15e716b6"}, + {file = "safetensors-0.4.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9483f42be3b6bc8ff77dd67302de8ae411c4db39f7224dec66b0eb95822e4163"}, + {file = "safetensors-0.4.5-cp310-none-win32.whl", hash = "sha256:7389129c03fadd1ccc37fd1ebbc773f2b031483b04700923c3511d2a939252cc"}, + {file = "safetensors-0.4.5-cp310-none-win_amd64.whl", hash = "sha256:e98ef5524f8b6620c8cdef97220c0b6a5c1cef69852fcd2f174bb96c2bb316b1"}, + {file = "safetensors-0.4.5-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:21f848d7aebd5954f92538552d6d75f7c1b4500f51664078b5b49720d180e47c"}, + {file = "safetensors-0.4.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bb07000b19d41e35eecef9a454f31a8b4718a185293f0d0b1c4b61d6e4487971"}, + {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09dedf7c2fda934ee68143202acff6e9e8eb0ddeeb4cfc24182bef999efa9f42"}, + {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:59b77e4b7a708988d84f26de3ebead61ef1659c73dcbc9946c18f3b1786d2688"}, + {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5d3bc83e14d67adc2e9387e511097f254bd1b43c3020440e708858c684cbac68"}, + {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39371fc551c1072976073ab258c3119395294cf49cdc1f8476794627de3130df"}, + {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a6c19feda32b931cae0acd42748a670bdf56bee6476a046af20181ad3fee4090"}, + {file = "safetensors-0.4.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a659467495de201e2f282063808a41170448c78bada1e62707b07a27b05e6943"}, + {file = "safetensors-0.4.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:bad5e4b2476949bcd638a89f71b6916fa9a5cae5c1ae7eede337aca2100435c0"}, + {file = "safetensors-0.4.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a3a315a6d0054bc6889a17f5668a73f94f7fe55121ff59e0a199e3519c08565f"}, + {file = "safetensors-0.4.5-cp311-none-win32.whl", hash = "sha256:a01e232e6d3d5cf8b1667bc3b657a77bdab73f0743c26c1d3c5dd7ce86bd3a92"}, + {file = "safetensors-0.4.5-cp311-none-win_amd64.whl", hash = "sha256:cbd39cae1ad3e3ef6f63a6f07296b080c951f24cec60188378e43d3713000c04"}, + {file = "safetensors-0.4.5-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:473300314e026bd1043cef391bb16a8689453363381561b8a3e443870937cc1e"}, + {file = "safetensors-0.4.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:801183a0f76dc647f51a2d9141ad341f9665602a7899a693207a82fb102cc53e"}, + {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1524b54246e422ad6fb6aea1ac71edeeb77666efa67230e1faf6999df9b2e27f"}, + {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b3139098e3e8b2ad7afbca96d30ad29157b50c90861084e69fcb80dec7430461"}, + {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65573dc35be9059770808e276b017256fa30058802c29e1038eb1c00028502ea"}, + {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fd33da8e9407559f8779c82a0448e2133737f922d71f884da27184549416bfed"}, + {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3685ce7ed036f916316b567152482b7e959dc754fcc4a8342333d222e05f407c"}, + {file = "safetensors-0.4.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:dde2bf390d25f67908278d6f5d59e46211ef98e44108727084d4637ee70ab4f1"}, + {file = "safetensors-0.4.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7469d70d3de970b1698d47c11ebbf296a308702cbaae7fcb993944751cf985f4"}, + {file = "safetensors-0.4.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3a6ba28118636a130ccbb968bc33d4684c48678695dba2590169d5ab03a45646"}, + {file = "safetensors-0.4.5-cp312-none-win32.whl", hash = "sha256:c859c7ed90b0047f58ee27751c8e56951452ed36a67afee1b0a87847d065eec6"}, + {file = "safetensors-0.4.5-cp312-none-win_amd64.whl", hash = "sha256:b5a8810ad6a6f933fff6c276eae92c1da217b39b4d8b1bc1c0b8af2d270dc532"}, + {file = "safetensors-0.4.5-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:25e5f8e2e92a74f05b4ca55686234c32aac19927903792b30ee6d7bd5653d54e"}, + {file = "safetensors-0.4.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:81efb124b58af39fcd684254c645e35692fea81c51627259cdf6d67ff4458916"}, + {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:585f1703a518b437f5103aa9cf70e9bd437cb78eea9c51024329e4fb8a3e3679"}, + {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4b99fbf72e3faf0b2f5f16e5e3458b93b7d0a83984fe8d5364c60aa169f2da89"}, + {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b17b299ca9966ca983ecda1c0791a3f07f9ca6ab5ded8ef3d283fff45f6bcd5f"}, + {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:76ded72f69209c9780fdb23ea89e56d35c54ae6abcdec67ccb22af8e696e449a"}, + {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2783956926303dcfeb1de91a4d1204cd4089ab441e622e7caee0642281109db3"}, + {file = "safetensors-0.4.5-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d94581aab8c6b204def4d7320f07534d6ee34cd4855688004a4354e63b639a35"}, + {file = "safetensors-0.4.5-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:67e1e7cb8678bb1b37ac48ec0df04faf689e2f4e9e81e566b5c63d9f23748523"}, + {file = "safetensors-0.4.5-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:dbd280b07e6054ea68b0cb4b16ad9703e7d63cd6890f577cb98acc5354780142"}, + {file = "safetensors-0.4.5-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:77d9b228da8374c7262046a36c1f656ba32a93df6cc51cd4453af932011e77f1"}, + {file = "safetensors-0.4.5-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:500cac01d50b301ab7bb192353317035011c5ceeef0fca652f9f43c000bb7f8d"}, + {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:75331c0c746f03158ded32465b7d0b0e24c5a22121743662a2393439c43a45cf"}, + {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:670e95fe34e0d591d0529e5e59fd9d3d72bc77b1444fcaa14dccda4f36b5a38b"}, + {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:098923e2574ff237c517d6e840acada8e5b311cb1fa226019105ed82e9c3b62f"}, + {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:13ca0902d2648775089fa6a0c8fc9e6390c5f8ee576517d33f9261656f851e3f"}, + {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f0032bedc869c56f8d26259fe39cd21c5199cd57f2228d817a0e23e8370af25"}, + {file = "safetensors-0.4.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f4b15f51b4f8f2a512341d9ce3475cacc19c5fdfc5db1f0e19449e75f95c7dc8"}, + {file = "safetensors-0.4.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:f6594d130d0ad933d885c6a7b75c5183cb0e8450f799b80a39eae2b8508955eb"}, + {file = "safetensors-0.4.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:60c828a27e852ded2c85fc0f87bf1ec20e464c5cd4d56ff0e0711855cc2e17f8"}, + {file = "safetensors-0.4.5-cp37-none-win32.whl", hash = "sha256:6d3de65718b86c3eeaa8b73a9c3d123f9307a96bbd7be9698e21e76a56443af5"}, + {file = "safetensors-0.4.5-cp37-none-win_amd64.whl", hash = "sha256:5a2d68a523a4cefd791156a4174189a4114cf0bf9c50ceb89f261600f3b2b81a"}, + {file = "safetensors-0.4.5-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:e7a97058f96340850da0601a3309f3d29d6191b0702b2da201e54c6e3e44ccf0"}, + {file = "safetensors-0.4.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:63bfd425e25f5c733f572e2246e08a1c38bd6f2e027d3f7c87e2e43f228d1345"}, + {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3664ac565d0e809b0b929dae7ccd74e4d3273cd0c6d1220c6430035befb678e"}, + {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:313514b0b9b73ff4ddfb4edd71860696dbe3c1c9dc4d5cc13dbd74da283d2cbf"}, + {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:31fa33ee326f750a2f2134a6174773c281d9a266ccd000bd4686d8021f1f3dac"}, + {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:09566792588d77b68abe53754c9f1308fadd35c9f87be939e22c623eaacbed6b"}, + {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:309aaec9b66cbf07ad3a2e5cb8a03205663324fea024ba391594423d0f00d9fe"}, + {file = "safetensors-0.4.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:53946c5813b8f9e26103c5efff4a931cc45d874f45229edd68557ffb35ffb9f8"}, + {file = "safetensors-0.4.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:868f9df9e99ad1e7f38c52194063a982bc88fedc7d05096f4f8160403aaf4bd6"}, + {file = "safetensors-0.4.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9cc9449bd0b0bc538bd5e268221f0c5590bc5c14c1934a6ae359d44410dc68c4"}, + {file = "safetensors-0.4.5-cp38-none-win32.whl", hash = "sha256:83c4f13a9e687335c3928f615cd63a37e3f8ef072a3f2a0599fa09f863fb06a2"}, + {file = "safetensors-0.4.5-cp38-none-win_amd64.whl", hash = "sha256:b98d40a2ffa560653f6274e15b27b3544e8e3713a44627ce268f419f35c49478"}, + {file = "safetensors-0.4.5-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:cf727bb1281d66699bef5683b04d98c894a2803442c490a8d45cd365abfbdeb2"}, + {file = "safetensors-0.4.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:96f1d038c827cdc552d97e71f522e1049fef0542be575421f7684756a748e457"}, + {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:139fbee92570ecea774e6344fee908907db79646d00b12c535f66bc78bd5ea2c"}, + {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c36302c1c69eebb383775a89645a32b9d266878fab619819ce660309d6176c9b"}, + {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d641f5b8149ea98deb5ffcf604d764aad1de38a8285f86771ce1abf8e74c4891"}, + {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b4db6a61d968de73722b858038c616a1bebd4a86abe2688e46ca0cc2d17558f2"}, + {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b75a616e02f21b6f1d5785b20cecbab5e2bd3f6358a90e8925b813d557666ec1"}, + {file = "safetensors-0.4.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:788ee7d04cc0e0e7f944c52ff05f52a4415b312f5efd2ee66389fb7685ee030c"}, + {file = "safetensors-0.4.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:87bc42bd04fd9ca31396d3ca0433db0be1411b6b53ac5a32b7845a85d01ffc2e"}, + {file = "safetensors-0.4.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4037676c86365a721a8c9510323a51861d703b399b78a6b4486a54a65a975fca"}, + {file = "safetensors-0.4.5-cp39-none-win32.whl", hash = "sha256:1500418454529d0ed5c1564bda376c4ddff43f30fce9517d9bee7bcce5a8ef50"}, + {file = "safetensors-0.4.5-cp39-none-win_amd64.whl", hash = "sha256:9d1a94b9d793ed8fe35ab6d5cea28d540a46559bafc6aae98f30ee0867000cab"}, + {file = "safetensors-0.4.5-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:fdadf66b5a22ceb645d5435a0be7a0292ce59648ca1d46b352f13cff3ea80410"}, + {file = "safetensors-0.4.5-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d42ffd4c2259f31832cb17ff866c111684c87bd930892a1ba53fed28370c918c"}, + {file = "safetensors-0.4.5-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd8a1f6d2063a92cd04145c7fd9e31a1c7d85fbec20113a14b487563fdbc0597"}, + {file = "safetensors-0.4.5-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:951d2fcf1817f4fb0ef0b48f6696688a4e852a95922a042b3f96aaa67eedc920"}, + {file = "safetensors-0.4.5-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6ac85d9a8c1af0e3132371d9f2d134695a06a96993c2e2f0bbe25debb9e3f67a"}, + {file = "safetensors-0.4.5-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:e3cec4a29eb7fe8da0b1c7988bc3828183080439dd559f720414450de076fcab"}, + {file = "safetensors-0.4.5-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:21742b391b859e67b26c0b2ac37f52c9c0944a879a25ad2f9f9f3cd61e7fda8f"}, + {file = "safetensors-0.4.5-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:c7db3006a4915151ce1913652e907cdede299b974641a83fbc092102ac41b644"}, + {file = "safetensors-0.4.5-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f68bf99ea970960a237f416ea394e266e0361895753df06e3e06e6ea7907d98b"}, + {file = "safetensors-0.4.5-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8158938cf3324172df024da511839d373c40fbfaa83e9abf467174b2910d7b4c"}, + {file = "safetensors-0.4.5-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:540ce6c4bf6b58cb0fd93fa5f143bc0ee341c93bb4f9287ccd92cf898cc1b0dd"}, + {file = "safetensors-0.4.5-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:bfeaa1a699c6b9ed514bd15e6a91e74738b71125a9292159e3d6b7f0a53d2cde"}, + {file = "safetensors-0.4.5-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:01c8f00da537af711979e1b42a69a8ec9e1d7112f208e0e9b8a35d2c381085ef"}, + {file = "safetensors-0.4.5-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a0dd565f83b30f2ca79b5d35748d0d99dd4b3454f80e03dfb41f0038e3bdf180"}, + {file = "safetensors-0.4.5-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:023b6e5facda76989f4cba95a861b7e656b87e225f61811065d5c501f78cdb3f"}, + {file = "safetensors-0.4.5-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9633b663393d5796f0b60249549371e392b75a0b955c07e9c6f8708a87fc841f"}, + {file = "safetensors-0.4.5-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78dd8adfb48716233c45f676d6e48534d34b4bceb50162c13d1f0bdf6f78590a"}, + {file = "safetensors-0.4.5-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8e8deb16c4321d61ae72533b8451ec4a9af8656d1c61ff81aa49f966406e4b68"}, + {file = "safetensors-0.4.5-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:52452fa5999dc50c4decaf0c53aa28371f7f1e0fe5c2dd9129059fbe1e1599c7"}, + {file = "safetensors-0.4.5-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:d5f23198821e227cfc52d50fa989813513db381255c6d100927b012f0cfec63d"}, + {file = "safetensors-0.4.5-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f4beb84b6073b1247a773141a6331117e35d07134b3bb0383003f39971d414bb"}, + {file = "safetensors-0.4.5-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:68814d599d25ed2fdd045ed54d370d1d03cf35e02dce56de44c651f828fb9b7b"}, + {file = "safetensors-0.4.5-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f0b6453c54c57c1781292c46593f8a37254b8b99004c68d6c3ce229688931a22"}, + {file = "safetensors-0.4.5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:adaa9c6dead67e2dd90d634f89131e43162012479d86e25618e821a03d1eb1dc"}, + {file = "safetensors-0.4.5-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:73e7d408e9012cd17511b382b43547850969c7979efc2bc353f317abaf23c84c"}, + {file = "safetensors-0.4.5-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:775409ce0fcc58b10773fdb4221ed1eb007de10fe7adbdf8f5e8a56096b6f0bc"}, + {file = "safetensors-0.4.5-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:834001bed193e4440c4a3950a31059523ee5090605c907c66808664c932b549c"}, + {file = "safetensors-0.4.5.tar.gz", hash = "sha256:d73de19682deabb02524b3d5d1f8b3aaba94c72f1bbfc7911b9b9d5d391c0310"}, ] [package.extras] @@ -7544,34 +8640,117 @@ tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"] testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools-rust (>=1.5.2)"] torch = ["safetensors[numpy]", "torch (>=1.10)"] +[[package]] +name = "sagemaker" +version = "2.231.0" +description = "Open source library for training and deploying models on Amazon SageMaker." +optional = false +python-versions = ">=3.8" +files = [ + {file = "sagemaker-2.231.0-py3-none-any.whl", hash = "sha256:5b6d84484a58c6ac8b22af42c6c5e0ea3c5f42d719345fe6aafba42f93635000"}, + {file = "sagemaker-2.231.0.tar.gz", hash = "sha256:d49ee9c35725832dd9810708938af723201b831e82924a3a6ac1c4260a3d8239"}, +] + +[package.dependencies] +attrs = ">=23.1.0,<24" +boto3 = ">=1.34.142,<2.0" +cloudpickle = "2.2.1" +docker = "*" +google-pasta = "*" +importlib-metadata = ">=1.4.0,<7.0" +jsonschema = "*" +numpy = ">=1.9.0,<2.0" +packaging = ">=20.0" +pandas = "*" +pathos = "*" +platformdirs = "*" +protobuf = ">=3.12,<5.0" +psutil = "*" +pyyaml = ">=6.0,<7.0" +requests = "*" +sagemaker-core = ">=1.0.0,<2.0.0" +schema = "*" +smdebug-rulesconfig = "1.0.1" +tblib = ">=1.7.0,<4" +tqdm = "*" +urllib3 = ">=1.26.8,<3.0.0" + +[package.extras] +all = ["accelerate (>=0.24.1,<=0.27.0)", "docker (>=5.0.2,<8.0.0)", "fastapi (>=0.111.0)", "nest-asyncio", "pyspark (==3.3.1)", "pyyaml (>=5.4.1,<7)", "sagemaker-feature-store-pyspark-3-3", "sagemaker-schema-inference-artifacts (>=0.0.5)", "scipy (==1.10.1)", "urllib3 (>=1.26.8,<3.0.0)", "uvicorn (>=0.30.1)"] +feature-processor = ["pyspark (==3.3.1)", "sagemaker-feature-store-pyspark-3-3"] +huggingface = ["accelerate (>=0.24.1,<=0.27.0)", "fastapi (>=0.111.0)", "nest-asyncio", "sagemaker-schema-inference-artifacts (>=0.0.5)", "uvicorn (>=0.30.1)"] +local = ["docker (>=5.0.2,<8.0.0)", "pyyaml (>=5.4.1,<7)", "urllib3 (>=1.26.8,<3.0.0)"] +scipy = ["scipy (==1.10.1)"] +test = ["accelerate (>=0.24.1,<=0.27.0)", "apache-airflow (==2.9.3)", "apache-airflow-providers-amazon (==7.2.1)", "attrs (>=23.1.0,<24)", "awslogs (==0.14.0)", "black (==24.3.0)", "build[virtualenv] (==1.2.1)", "cloudpickle (==2.2.1)", "contextlib2 (==21.6.0)", "coverage (>=5.2,<6.2)", "docker (>=5.0.2,<8.0.0)", "fabric (==2.6.0)", "fastapi (>=0.111.0)", "flake8 (==4.0.1)", "huggingface-hub (>=0.23.4)", "jinja2 (==3.1.4)", "mlflow (>=2.12.2,<2.13)", "mock (==4.0.3)", "nbformat (>=5.9,<6)", "nest-asyncio", "numpy (>=1.24.0)", "onnx (>=1.15.0)", "pandas (>=1.3.5,<1.5)", "pillow (>=10.0.1,<=11)", "pyspark (==3.3.1)", "pytest (==6.2.5)", "pytest-cov (==3.0.0)", "pytest-rerunfailures (==10.2)", "pytest-timeout (==2.1.0)", "pytest-xdist (==2.4.0)", "pyvis (==0.2.1)", "pyyaml (==6.0)", "pyyaml (>=5.4.1,<7)", "requests (==2.32.2)", "sagemaker-experiments (==0.1.35)", "sagemaker-feature-store-pyspark-3-3", "sagemaker-schema-inference-artifacts (>=0.0.5)", "schema (==0.7.5)", "scikit-learn (==1.3.0)", "scipy (==1.10.1)", "stopit (==1.1.2)", "tensorflow (>=2.1,<=2.16)", "tox (==3.24.5)", "tritonclient[http] (<2.37.0)", "urllib3 (>=1.26.8,<3.0.0)", "uvicorn (>=0.30.1)", "xgboost (>=1.6.2,<=1.7.6)"] + +[[package]] +name = "sagemaker-core" +version = "1.0.11" +description = "An python package for sagemaker core functionalities" +optional = false +python-versions = ">=3.8" +files = [ + {file = "sagemaker_core-1.0.11-py3-none-any.whl", hash = "sha256:d8ee3db83759073aa8c9f2bd4899113088a7c2acf340597e76cf9934e384d915"}, + {file = "sagemaker_core-1.0.11.tar.gz", hash = "sha256:fb48a5dcb859a54de7461c71cf58562a3be259294dcd39c317020a9b018f5016"}, +] + +[package.dependencies] +boto3 = ">=1.34.0,<2.0.0" +importlib-metadata = ">=1.4.0,<7.0" +jsonschema = "<5.0.0" +mock = ">4.0,<5.0" +platformdirs = ">=4.0.0,<5.0.0" +pydantic = ">=1.7.0,<3.0.0" +PyYAML = ">=6.0,<7.0" +rich = ">=13.0.0,<14.0.0" + +[package.extras] +codegen = ["black (>=24.3.0,<25.0.0)", "pandas (>=2.0.0,<3.0.0)", "pylint (>=3.0.0,<4.0.0)", "pytest (>=8.0.0,<9.0.0)"] + +[[package]] +name = "schema" +version = "0.7.7" +description = "Simple data validation library" +optional = false +python-versions = "*" +files = [ + {file = "schema-0.7.7-py2.py3-none-any.whl", hash = "sha256:5d976a5b50f36e74e2157b47097b60002bd4d42e65425fcc9c9befadb4255dde"}, + {file = "schema-0.7.7.tar.gz", hash = "sha256:7da553abd2958a19dc2547c388cde53398b39196175a9be59ea1caf5ab0a1807"}, +] + [[package]] name = "scikit-learn" -version = "1.5.1" +version = "1.5.2" description = "A set of python modules for machine learning and data mining" optional = false python-versions = ">=3.9" files = [ - {file = "scikit_learn-1.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:781586c414f8cc58e71da4f3d7af311e0505a683e112f2f62919e3019abd3745"}, - {file = "scikit_learn-1.5.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:f5b213bc29cc30a89a3130393b0e39c847a15d769d6e59539cd86b75d276b1a7"}, - {file = "scikit_learn-1.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1ff4ba34c2abff5ec59c803ed1d97d61b036f659a17f55be102679e88f926fac"}, - {file = "scikit_learn-1.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:161808750c267b77b4a9603cf9c93579c7a74ba8486b1336034c2f1579546d21"}, - {file = "scikit_learn-1.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:10e49170691514a94bb2e03787aa921b82dbc507a4ea1f20fd95557862c98dc1"}, - {file = "scikit_learn-1.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:154297ee43c0b83af12464adeab378dee2d0a700ccd03979e2b821e7dd7cc1c2"}, - {file = "scikit_learn-1.5.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b5e865e9bd59396220de49cb4a57b17016256637c61b4c5cc81aaf16bc123bbe"}, - {file = "scikit_learn-1.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:909144d50f367a513cee6090873ae582dba019cb3fca063b38054fa42704c3a4"}, - {file = "scikit_learn-1.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:689b6f74b2c880276e365fe84fe4f1befd6a774f016339c65655eaff12e10cbf"}, - {file = "scikit_learn-1.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:9a07f90846313a7639af6a019d849ff72baadfa4c74c778821ae0fad07b7275b"}, - {file = "scikit_learn-1.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5944ce1faada31c55fb2ba20a5346b88e36811aab504ccafb9f0339e9f780395"}, - {file = "scikit_learn-1.5.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:0828673c5b520e879f2af6a9e99eee0eefea69a2188be1ca68a6121b809055c1"}, - {file = "scikit_learn-1.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:508907e5f81390e16d754e8815f7497e52139162fd69c4fdbd2dfa5d6cc88915"}, - {file = "scikit_learn-1.5.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97625f217c5c0c5d0505fa2af28ae424bd37949bb2f16ace3ff5f2f81fb4498b"}, - {file = "scikit_learn-1.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:da3f404e9e284d2b0a157e1b56b6566a34eb2798205cba35a211df3296ab7a74"}, - {file = "scikit_learn-1.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:88e0672c7ac21eb149d409c74cc29f1d611d5158175846e7a9c2427bd12b3956"}, - {file = "scikit_learn-1.5.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:7b073a27797a283187a4ef4ee149959defc350b46cbf63a84d8514fe16b69855"}, - {file = "scikit_learn-1.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b59e3e62d2be870e5c74af4e793293753565c7383ae82943b83383fdcf5cc5c1"}, - {file = "scikit_learn-1.5.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1bd8d3a19d4bd6dc5a7d4f358c8c3a60934dc058f363c34c0ac1e9e12a31421d"}, - {file = "scikit_learn-1.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:5f57428de0c900a98389c4a433d4a3cf89de979b3aa24d1c1d251802aa15e44d"}, - {file = "scikit_learn-1.5.1.tar.gz", hash = "sha256:0ea5d40c0e3951df445721927448755d3fe1d80833b0b7308ebff5d2a45e6414"}, + {file = "scikit_learn-1.5.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:299406827fb9a4f862626d0fe6c122f5f87f8910b86fe5daa4c32dcd742139b6"}, + {file = "scikit_learn-1.5.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:2d4cad1119c77930b235579ad0dc25e65c917e756fe80cab96aa3b9428bd3fb0"}, + {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c412ccc2ad9bf3755915e3908e677b367ebc8d010acbb3f182814524f2e5540"}, + {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a686885a4b3818d9e62904d91b57fa757fc2bed3e465c8b177be652f4dd37c8"}, + {file = "scikit_learn-1.5.2-cp310-cp310-win_amd64.whl", hash = "sha256:c15b1ca23d7c5f33cc2cb0a0d6aaacf893792271cddff0edbd6a40e8319bc113"}, + {file = "scikit_learn-1.5.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:03b6158efa3faaf1feea3faa884c840ebd61b6484167c711548fce208ea09445"}, + {file = "scikit_learn-1.5.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1ff45e26928d3b4eb767a8f14a9a6efbf1cbff7c05d1fb0f95f211a89fd4f5de"}, + {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f763897fe92d0e903aa4847b0aec0e68cadfff77e8a0687cabd946c89d17e675"}, + {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8b0ccd4a902836493e026c03256e8b206656f91fbcc4fde28c57a5b752561f1"}, + {file = "scikit_learn-1.5.2-cp311-cp311-win_amd64.whl", hash = "sha256:6c16d84a0d45e4894832b3c4d0bf73050939e21b99b01b6fd59cbb0cf39163b6"}, + {file = "scikit_learn-1.5.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f932a02c3f4956dfb981391ab24bda1dbd90fe3d628e4b42caef3e041c67707a"}, + {file = "scikit_learn-1.5.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:3b923d119d65b7bd555c73be5423bf06c0105678ce7e1f558cb4b40b0a5502b1"}, + {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, + {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, + {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"}, + {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"}, + {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, + {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, + {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, + {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca64b3089a6d9b9363cd3546f8978229dcbb737aceb2c12144ee3f70f95684b7"}, + {file = "scikit_learn-1.5.2-cp39-cp39-win_amd64.whl", hash = "sha256:3bed4909ba187aca80580fe2ef370d9180dcf18e621a27c4cf2ef10d279a7efe"}, + {file = "scikit_learn-1.5.2.tar.gz", hash = "sha256:b4237ed7b3fdd0a4882792e68ef2545d5baa50aca3bb45aa7df468138ad8f94d"}, ] [package.dependencies] @@ -7583,44 +8762,52 @@ threadpoolctl = ">=3.1.0" [package.extras] benchmark = ["matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "pandas (>=1.1.5)"] build = ["cython (>=3.0.10)", "meson-python (>=0.16.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)"] -docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.23)", "pooch (>=1.6.0)", "pydata-sphinx-theme (>=0.15.3)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=7.3.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-gallery (>=0.16.0)", "sphinx-prompt (>=1.4.0)", "sphinx-remove-toctrees (>=1.0.0.post1)", "sphinxcontrib-sass (>=0.3.4)", "sphinxext-opengraph (>=0.9.1)"] +docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pydata-sphinx-theme (>=0.15.3)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=7.3.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-design (>=0.6.0)", "sphinx-gallery (>=0.16.0)", "sphinx-prompt (>=1.4.0)", "sphinx-remove-toctrees (>=1.0.0.post1)", "sphinxcontrib-sass (>=0.3.4)", "sphinxext-opengraph (>=0.9.1)"] examples = ["matplotlib (>=3.3.4)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)"] install = ["joblib (>=1.2.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)", "threadpoolctl (>=3.1.0)"] maintenance = ["conda-lock (==2.5.6)"] -tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.23)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.2.1)", "scikit-image (>=0.17.2)"] +tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.2.1)", "scikit-image (>=0.17.2)"] [[package]] name = "scipy" -version = "1.14.0" +version = "1.14.1" description = "Fundamental algorithms for scientific computing in Python" optional = false python-versions = ">=3.10" files = [ - {file = "scipy-1.14.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7e911933d54ead4d557c02402710c2396529540b81dd554fc1ba270eb7308484"}, - {file = "scipy-1.14.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:687af0a35462402dd851726295c1a5ae5f987bd6e9026f52e9505994e2f84ef6"}, - {file = "scipy-1.14.0-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:07e179dc0205a50721022344fb85074f772eadbda1e1b3eecdc483f8033709b7"}, - {file = "scipy-1.14.0-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:6a9c9a9b226d9a21e0a208bdb024c3982932e43811b62d202aaf1bb59af264b1"}, - {file = "scipy-1.14.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:076c27284c768b84a45dcf2e914d4000aac537da74236a0d45d82c6fa4b7b3c0"}, - {file = "scipy-1.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42470ea0195336df319741e230626b6225a740fd9dce9642ca13e98f667047c0"}, - {file = "scipy-1.14.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:176c6f0d0470a32f1b2efaf40c3d37a24876cebf447498a4cefb947a79c21e9d"}, - {file = "scipy-1.14.0-cp310-cp310-win_amd64.whl", hash = "sha256:ad36af9626d27a4326c8e884917b7ec321d8a1841cd6dacc67d2a9e90c2f0359"}, - {file = "scipy-1.14.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6d056a8709ccda6cf36cdd2eac597d13bc03dba38360f418560a93050c76a16e"}, - {file = "scipy-1.14.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:f0a50da861a7ec4573b7c716b2ebdcdf142b66b756a0d392c236ae568b3a93fb"}, - {file = "scipy-1.14.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:94c164a9e2498e68308e6e148646e486d979f7fcdb8b4cf34b5441894bdb9caf"}, - {file = "scipy-1.14.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:a7d46c3e0aea5c064e734c3eac5cf9eb1f8c4ceee756262f2c7327c4c2691c86"}, - {file = "scipy-1.14.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9eee2989868e274aae26125345584254d97c56194c072ed96cb433f32f692ed8"}, - {file = "scipy-1.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e3154691b9f7ed73778d746da2df67a19d046a6c8087c8b385bc4cdb2cfca74"}, - {file = "scipy-1.14.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c40003d880f39c11c1edbae8144e3813904b10514cd3d3d00c277ae996488cdb"}, - {file = "scipy-1.14.0-cp311-cp311-win_amd64.whl", hash = "sha256:5b083c8940028bb7e0b4172acafda6df762da1927b9091f9611b0bcd8676f2bc"}, - {file = "scipy-1.14.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:bff2438ea1330e06e53c424893ec0072640dac00f29c6a43a575cbae4c99b2b9"}, - {file = "scipy-1.14.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:bbc0471b5f22c11c389075d091d3885693fd3f5e9a54ce051b46308bc787e5d4"}, - {file = "scipy-1.14.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:64b2ff514a98cf2bb734a9f90d32dc89dc6ad4a4a36a312cd0d6327170339eb0"}, - {file = "scipy-1.14.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:7d3da42fbbbb860211a811782504f38ae7aaec9de8764a9bef6b262de7a2b50f"}, - {file = "scipy-1.14.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d91db2c41dd6c20646af280355d41dfa1ec7eead235642178bd57635a3f82209"}, - {file = "scipy-1.14.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a01cc03bcdc777c9da3cfdcc74b5a75caffb48a6c39c8450a9a05f82c4250a14"}, - {file = "scipy-1.14.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:65df4da3c12a2bb9ad52b86b4dcf46813e869afb006e58be0f516bc370165159"}, - {file = "scipy-1.14.0-cp312-cp312-win_amd64.whl", hash = "sha256:4c4161597c75043f7154238ef419c29a64ac4a7c889d588ea77690ac4d0d9b20"}, - {file = "scipy-1.14.0.tar.gz", hash = "sha256:b5923f48cb840380f9854339176ef21763118a7300a88203ccd0bdd26e58527b"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3"}, + {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d"}, + {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69"}, + {file = "scipy-1.14.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad"}, + {file = "scipy-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8"}, + {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37"}, + {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2"}, + {file = "scipy-1.14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2"}, + {file = "scipy-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc"}, + {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310"}, + {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066"}, + {file = "scipy-1.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1"}, + {file = "scipy-1.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e"}, + {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d"}, + {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e"}, + {file = "scipy-1.14.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06"}, + {file = "scipy-1.14.1-cp313-cp313-win_amd64.whl", hash = "sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84"}, + {file = "scipy-1.14.1.tar.gz", hash = "sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417"}, ] [package.dependencies] @@ -7628,8 +8815,8 @@ numpy = ">=1.23.5,<2.3" [package.extras] dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodestyle", "pydevtool", "rich-click", "ruff (>=0.0.292)", "types-psutil", "typing_extensions"] -doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"] -test = ["Cython", "array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] +doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<=7.3.7)", "sphinx-design (>=0.4.0)"] +test = ["Cython", "array-api-strict (>=2.0)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] [[package]] name = "sentry-sdk" @@ -7683,19 +8870,23 @@ tornado = ["tornado (>=5)"] [[package]] name = "setuptools" -version = "72.1.0" +version = "75.3.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-72.1.0-py3-none-any.whl", hash = "sha256:5a03e1860cf56bb6ef48ce186b0e557fdba433237481a9a625176c2831be15d1"}, - {file = "setuptools-72.1.0.tar.gz", hash = "sha256:8d243eff56d095e5817f796ede6ae32941278f542e0f941867cc05ae52b162ec"}, + {file = "setuptools-75.3.0-py3-none-any.whl", hash = "sha256:f2504966861356aa38616760c0f66568e535562374995367b4e69c7143cf6bcd"}, + {file = "setuptools-75.3.0.tar.gz", hash = "sha256:fba5dd4d766e97be1b1681d98712680ae8f2f26d7881245f2ce9e40714f1a686"}, ] [package.extras] -core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.text (>=3.7)", "more-itertools (>=8.8)", "ordered-set (>=3.1.1)", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] -doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test", "mypy (==1.11.*)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (<0.4)", "pytest-ruff (>=0.2.1)", "pytest-ruff (>=0.3.2)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.5.2)"] +core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.collections", "jaraco.functools", "jaraco.text (>=3.7)", "more-itertools", "more-itertools (>=8.8)", "packaging", "packaging (>=24)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test (>=5.5)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.12.*)", "pytest-mypy"] [[package]] name = "sgmllib3k" @@ -7709,47 +8900,53 @@ files = [ [[package]] name = "shapely" -version = "2.0.5" +version = "2.0.6" description = "Manipulation and analysis of geometric objects" optional = false python-versions = ">=3.7" files = [ - {file = "shapely-2.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:89d34787c44f77a7d37d55ae821f3a784fa33592b9d217a45053a93ade899375"}, - {file = "shapely-2.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:798090b426142df2c5258779c1d8d5734ec6942f778dab6c6c30cfe7f3bf64ff"}, - {file = "shapely-2.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45211276900c4790d6bfc6105cbf1030742da67594ea4161a9ce6812a6721e68"}, - {file = "shapely-2.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e119444bc27ca33e786772b81760f2028d930ac55dafe9bc50ef538b794a8e1"}, - {file = "shapely-2.0.5-cp310-cp310-win32.whl", hash = "sha256:9a4492a2b2ccbeaebf181e7310d2dfff4fdd505aef59d6cb0f217607cb042fb3"}, - {file = "shapely-2.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:1e5cb5ee72f1bc7ace737c9ecd30dc174a5295fae412972d3879bac2e82c8fae"}, - {file = "shapely-2.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5bbfb048a74cf273db9091ff3155d373020852805a37dfc846ab71dde4be93ec"}, - {file = "shapely-2.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:93be600cbe2fbaa86c8eb70656369f2f7104cd231f0d6585c7d0aa555d6878b8"}, - {file = "shapely-2.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f8e71bb9a46814019f6644c4e2560a09d44b80100e46e371578f35eaaa9da1c"}, - {file = "shapely-2.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5251c28a29012e92de01d2e84f11637eb1d48184ee8f22e2df6c8c578d26760"}, - {file = "shapely-2.0.5-cp311-cp311-win32.whl", hash = "sha256:35110e80070d664781ec7955c7de557456b25727a0257b354830abb759bf8311"}, - {file = "shapely-2.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:6c6b78c0007a34ce7144f98b7418800e0a6a5d9a762f2244b00ea560525290c9"}, - {file = "shapely-2.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:03bd7b5fa5deb44795cc0a503999d10ae9d8a22df54ae8d4a4cd2e8a93466195"}, - {file = "shapely-2.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ff9521991ed9e201c2e923da014e766c1aa04771bc93e6fe97c27dcf0d40ace"}, - {file = "shapely-2.0.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b65365cfbf657604e50d15161ffcc68de5cdb22a601bbf7823540ab4918a98d"}, - {file = "shapely-2.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21f64e647a025b61b19585d2247137b3a38a35314ea68c66aaf507a1c03ef6fe"}, - {file = "shapely-2.0.5-cp312-cp312-win32.whl", hash = "sha256:3ac7dc1350700c139c956b03d9c3df49a5b34aaf91d024d1510a09717ea39199"}, - {file = "shapely-2.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:30e8737983c9d954cd17feb49eb169f02f1da49e24e5171122cf2c2b62d65c95"}, - {file = "shapely-2.0.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:ff7731fea5face9ec08a861ed351734a79475631b7540ceb0b66fb9732a5f529"}, - {file = "shapely-2.0.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ff9e520af0c5a578e174bca3c18713cd47a6c6a15b6cf1f50ac17dc8bb8db6a2"}, - {file = "shapely-2.0.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49b299b91557b04acb75e9732645428470825061f871a2edc36b9417d66c1fc5"}, - {file = "shapely-2.0.5-cp37-cp37m-win32.whl", hash = "sha256:b5870633f8e684bf6d1ae4df527ddcb6f3895f7b12bced5c13266ac04f47d231"}, - {file = "shapely-2.0.5-cp37-cp37m-win_amd64.whl", hash = "sha256:401cb794c5067598f50518e5a997e270cd7642c4992645479b915c503866abed"}, - {file = "shapely-2.0.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e91ee179af539100eb520281ba5394919067c6b51824e6ab132ad4b3b3e76dd0"}, - {file = "shapely-2.0.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8af6f7260f809c0862741ad08b1b89cb60c130ae30efab62320bbf4ee9cc71fa"}, - {file = "shapely-2.0.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f5456dd522800306ba3faef77c5ba847ec30a0bd73ab087a25e0acdd4db2514f"}, - {file = "shapely-2.0.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b714a840402cde66fd7b663bb08cacb7211fa4412ea2a209688f671e0d0631fd"}, - {file = "shapely-2.0.5-cp38-cp38-win32.whl", hash = "sha256:7e8cf5c252fac1ea51b3162be2ec3faddedc82c256a1160fc0e8ddbec81b06d2"}, - {file = "shapely-2.0.5-cp38-cp38-win_amd64.whl", hash = "sha256:4461509afdb15051e73ab178fae79974387f39c47ab635a7330d7fee02c68a3f"}, - {file = "shapely-2.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7545a39c55cad1562be302d74c74586f79e07b592df8ada56b79a209731c0219"}, - {file = "shapely-2.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4c83a36f12ec8dee2066946d98d4d841ab6512a6ed7eb742e026a64854019b5f"}, - {file = "shapely-2.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89e640c2cd37378480caf2eeda9a51be64201f01f786d127e78eaeff091ec897"}, - {file = "shapely-2.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06efe39beafde3a18a21dde169d32f315c57da962826a6d7d22630025200c5e6"}, - {file = "shapely-2.0.5-cp39-cp39-win32.whl", hash = "sha256:8203a8b2d44dcb366becbc8c3d553670320e4acf0616c39e218c9561dd738d92"}, - {file = "shapely-2.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:7fed9dbfbcfec2682d9a047b9699db8dcc890dfca857ecba872c42185fc9e64e"}, - {file = "shapely-2.0.5.tar.gz", hash = "sha256:bff2366bc786bfa6cb353d6b47d0443c570c32776612e527ee47b6df63fcfe32"}, + {file = "shapely-2.0.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:29a34e068da2d321e926b5073539fd2a1d4429a2c656bd63f0bd4c8f5b236d0b"}, + {file = "shapely-2.0.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e1c84c3f53144febf6af909d6b581bc05e8785d57e27f35ebaa5c1ab9baba13b"}, + {file = "shapely-2.0.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ad2fae12dca8d2b727fa12b007e46fbc522148a584f5d6546c539f3464dccde"}, + {file = "shapely-2.0.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3304883bd82d44be1b27a9d17f1167fda8c7f5a02a897958d86c59ec69b705e"}, + {file = "shapely-2.0.6-cp310-cp310-win32.whl", hash = "sha256:3ec3a0eab496b5e04633a39fa3d5eb5454628228201fb24903d38174ee34565e"}, + {file = "shapely-2.0.6-cp310-cp310-win_amd64.whl", hash = "sha256:28f87cdf5308a514763a5c38de295544cb27429cfa655d50ed8431a4796090c4"}, + {file = "shapely-2.0.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5aeb0f51a9db176da9a30cb2f4329b6fbd1e26d359012bb0ac3d3c7781667a9e"}, + {file = "shapely-2.0.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9a7a78b0d51257a367ee115f4d41ca4d46edbd0dd280f697a8092dd3989867b2"}, + {file = "shapely-2.0.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f32c23d2f43d54029f986479f7c1f6e09c6b3a19353a3833c2ffb226fb63a855"}, + {file = "shapely-2.0.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3dc9fb0eb56498912025f5eb352b5126f04801ed0e8bdbd867d21bdbfd7cbd0"}, + {file = "shapely-2.0.6-cp311-cp311-win32.whl", hash = "sha256:d93b7e0e71c9f095e09454bf18dad5ea716fb6ced5df3cb044564a00723f339d"}, + {file = "shapely-2.0.6-cp311-cp311-win_amd64.whl", hash = "sha256:c02eb6bf4cfb9fe6568502e85bb2647921ee49171bcd2d4116c7b3109724ef9b"}, + {file = "shapely-2.0.6-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cec9193519940e9d1b86a3b4f5af9eb6910197d24af02f247afbfb47bcb3fab0"}, + {file = "shapely-2.0.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:83b94a44ab04a90e88be69e7ddcc6f332da7c0a0ebb1156e1c4f568bbec983c3"}, + {file = "shapely-2.0.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:537c4b2716d22c92036d00b34aac9d3775e3691f80c7aa517c2c290351f42cd8"}, + {file = "shapely-2.0.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98fea108334be345c283ce74bf064fa00cfdd718048a8af7343c59eb40f59726"}, + {file = "shapely-2.0.6-cp312-cp312-win32.whl", hash = "sha256:42fd4cd4834747e4990227e4cbafb02242c0cffe9ce7ef9971f53ac52d80d55f"}, + {file = "shapely-2.0.6-cp312-cp312-win_amd64.whl", hash = "sha256:665990c84aece05efb68a21b3523a6b2057e84a1afbef426ad287f0796ef8a48"}, + {file = "shapely-2.0.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:42805ef90783ce689a4dde2b6b2f261e2c52609226a0438d882e3ced40bb3013"}, + {file = "shapely-2.0.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6d2cb146191a47bd0cee8ff5f90b47547b82b6345c0d02dd8b25b88b68af62d7"}, + {file = "shapely-2.0.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e3fdef0a1794a8fe70dc1f514440aa34426cc0ae98d9a1027fb299d45741c381"}, + {file = "shapely-2.0.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c665a0301c645615a107ff7f52adafa2153beab51daf34587170d85e8ba6805"}, + {file = "shapely-2.0.6-cp313-cp313-win32.whl", hash = "sha256:0334bd51828f68cd54b87d80b3e7cee93f249d82ae55a0faf3ea21c9be7b323a"}, + {file = "shapely-2.0.6-cp313-cp313-win_amd64.whl", hash = "sha256:d37d070da9e0e0f0a530a621e17c0b8c3c9d04105655132a87cfff8bd77cc4c2"}, + {file = "shapely-2.0.6-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:fa7468e4f5b92049c0f36d63c3e309f85f2775752e076378e36c6387245c5462"}, + {file = "shapely-2.0.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ed5867e598a9e8ac3291da6cc9baa62ca25706eea186117034e8ec0ea4355653"}, + {file = "shapely-2.0.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81d9dfe155f371f78c8d895a7b7f323bb241fb148d848a2bf2244f79213123fe"}, + {file = "shapely-2.0.6-cp37-cp37m-win32.whl", hash = "sha256:fbb7bf02a7542dba55129062570211cfb0defa05386409b3e306c39612e7fbcc"}, + {file = "shapely-2.0.6-cp37-cp37m-win_amd64.whl", hash = "sha256:837d395fac58aa01aa544495b97940995211e3e25f9aaf87bc3ba5b3a8cd1ac7"}, + {file = "shapely-2.0.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c6d88ade96bf02f6bfd667ddd3626913098e243e419a0325ebef2bbd481d1eb6"}, + {file = "shapely-2.0.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8b3b818c4407eaa0b4cb376fd2305e20ff6df757bf1356651589eadc14aab41b"}, + {file = "shapely-2.0.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1bbc783529a21f2bd50c79cef90761f72d41c45622b3e57acf78d984c50a5d13"}, + {file = "shapely-2.0.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2423f6c0903ebe5df6d32e0066b3d94029aab18425ad4b07bf98c3972a6e25a1"}, + {file = "shapely-2.0.6-cp38-cp38-win32.whl", hash = "sha256:2de00c3bfa80d6750832bde1d9487e302a6dd21d90cb2f210515cefdb616e5f5"}, + {file = "shapely-2.0.6-cp38-cp38-win_amd64.whl", hash = "sha256:3a82d58a1134d5e975f19268710e53bddd9c473743356c90d97ce04b73e101ee"}, + {file = "shapely-2.0.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:392f66f458a0a2c706254f473290418236e52aa4c9b476a072539d63a2460595"}, + {file = "shapely-2.0.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:eba5bae271d523c938274c61658ebc34de6c4b33fdf43ef7e938b5776388c1be"}, + {file = "shapely-2.0.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7060566bc4888b0c8ed14b5d57df8a0ead5c28f9b69fb6bed4476df31c51b0af"}, + {file = "shapely-2.0.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b02154b3e9d076a29a8513dffcb80f047a5ea63c897c0cd3d3679f29363cf7e5"}, + {file = "shapely-2.0.6-cp39-cp39-win32.whl", hash = "sha256:44246d30124a4f1a638a7d5419149959532b99dfa25b54393512e6acc9c211ac"}, + {file = "shapely-2.0.6-cp39-cp39-win_amd64.whl", hash = "sha256:2b542d7f1dbb89192d3512c52b679c822ba916f93479fa5d4fc2fe4fa0b3c9e8"}, + {file = "shapely-2.0.6.tar.gz", hash = "sha256:997f6159b1484059ec239cacaa53467fd8b5564dabe186cd84ac2944663b0bf6"}, ] [package.dependencies] @@ -7772,19 +8969,20 @@ files = [ [[package]] name = "simple-websocket" -version = "1.0.0" +version = "1.1.0" description = "Simple WebSocket server and client for Python" optional = false python-versions = ">=3.6" files = [ - {file = "simple-websocket-1.0.0.tar.gz", hash = "sha256:17d2c72f4a2bd85174a97e3e4c88b01c40c3f81b7b648b0cc3ce1305968928c8"}, - {file = "simple_websocket-1.0.0-py3-none-any.whl", hash = "sha256:1d5bf585e415eaa2083e2bcf02a3ecf91f9712e7b3e6b9fa0b461ad04e0837bc"}, + {file = "simple_websocket-1.1.0-py3-none-any.whl", hash = "sha256:4af6069630a38ed6c561010f0e11a5bc0d4ca569b36306eb257cd9a192497c8c"}, + {file = "simple_websocket-1.1.0.tar.gz", hash = "sha256:7939234e7aa067c534abdab3a9ed933ec9ce4691b0713c78acb195560aa52ae4"}, ] [package.dependencies] wsproto = "*" [package.extras] +dev = ["flake8", "pytest", "pytest-cov", "tox"] docs = ["sphinx"] [[package]] @@ -7798,6 +8996,17 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +[[package]] +name = "smdebug-rulesconfig" +version = "1.0.1" +description = "SMDebug RulesConfig" +optional = false +python-versions = ">=2.7" +files = [ + {file = "smdebug_rulesconfig-1.0.1-py2.py3-none-any.whl", hash = "sha256:104da3e6931ecf879dfc687ca4bbb3bee5ea2bc27f4478e9dbb3ee3655f1ae61"}, + {file = "smdebug_rulesconfig-1.0.1.tar.gz", hash = "sha256:7a19e6eb2e6bcfefbc07e4a86ef7a88f32495001a038bf28c7d8e77ab793fcd6"}, +] + [[package]] name = "sniffio" version = "1.3.1" @@ -7822,71 +9031,79 @@ files = [ [[package]] name = "soupsieve" -version = "2.5" +version = "2.6" description = "A modern CSS selector implementation for Beautiful Soup." optional = false python-versions = ">=3.8" files = [ - {file = "soupsieve-2.5-py3-none-any.whl", hash = "sha256:eaa337ff55a1579b6549dc679565eac1e3d000563bcb1c8ab0d0fefbc0c2cdc7"}, - {file = "soupsieve-2.5.tar.gz", hash = "sha256:5663d5a7b3bfaeee0bc4372e7fc48f9cff4940b3eec54a6451cc5299f1097690"}, + {file = "soupsieve-2.6-py3-none-any.whl", hash = "sha256:e72c4ff06e4fb6e4b5a9f0f55fe6e81514581fca1515028625d0f299c602ccc9"}, + {file = "soupsieve-2.6.tar.gz", hash = "sha256:e2e68417777af359ec65daac1057404a3c8a5455bb8abc36f1a9866ab1a51abb"}, ] [[package]] name = "sqlalchemy" -version = "2.0.32" +version = "2.0.36" description = "Database Abstraction Library" optional = false python-versions = ">=3.7" files = [ - {file = "SQLAlchemy-2.0.32-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0c9045ecc2e4db59bfc97b20516dfdf8e41d910ac6fb667ebd3a79ea54084619"}, - {file = "SQLAlchemy-2.0.32-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1467940318e4a860afd546ef61fefb98a14d935cd6817ed07a228c7f7c62f389"}, - {file = "SQLAlchemy-2.0.32-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5954463675cb15db8d4b521f3566a017c8789222b8316b1e6934c811018ee08b"}, - {file = "SQLAlchemy-2.0.32-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:167e7497035c303ae50651b351c28dc22a40bb98fbdb8468cdc971821b1ae533"}, - {file = "SQLAlchemy-2.0.32-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:b27dfb676ac02529fb6e343b3a482303f16e6bc3a4d868b73935b8792edb52d0"}, - {file = "SQLAlchemy-2.0.32-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:bf2360a5e0f7bd75fa80431bf8ebcfb920c9f885e7956c7efde89031695cafb8"}, - {file = "SQLAlchemy-2.0.32-cp310-cp310-win32.whl", hash = "sha256:306fe44e754a91cd9d600a6b070c1f2fadbb4a1a257b8781ccf33c7067fd3e4d"}, - {file = "SQLAlchemy-2.0.32-cp310-cp310-win_amd64.whl", hash = "sha256:99db65e6f3ab42e06c318f15c98f59a436f1c78179e6a6f40f529c8cc7100b22"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:21b053be28a8a414f2ddd401f1be8361e41032d2ef5884b2f31d31cb723e559f"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b178e875a7a25b5938b53b006598ee7645172fccafe1c291a706e93f48499ff5"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:723a40ee2cc7ea653645bd4cf024326dea2076673fc9d3d33f20f6c81db83e1d"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:295ff8689544f7ee7e819529633d058bd458c1fd7f7e3eebd0f9268ebc56c2a0"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:49496b68cd190a147118af585173ee624114dfb2e0297558c460ad7495f9dfe2"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:acd9b73c5c15f0ec5ce18128b1fe9157ddd0044abc373e6ecd5ba376a7e5d961"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-win32.whl", hash = "sha256:9365a3da32dabd3e69e06b972b1ffb0c89668994c7e8e75ce21d3e5e69ddef28"}, - {file = "SQLAlchemy-2.0.32-cp311-cp311-win_amd64.whl", hash = "sha256:8bd63d051f4f313b102a2af1cbc8b80f061bf78f3d5bd0843ff70b5859e27924"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6bab3db192a0c35e3c9d1560eb8332463e29e5507dbd822e29a0a3c48c0a8d92"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:19d98f4f58b13900d8dec4ed09dd09ef292208ee44cc9c2fe01c1f0a2fe440e9"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3cd33c61513cb1b7371fd40cf221256456d26a56284e7d19d1f0b9f1eb7dd7e8"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d6ba0497c1d066dd004e0f02a92426ca2df20fac08728d03f67f6960271feec"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2b6be53e4fde0065524f1a0a7929b10e9280987b320716c1509478b712a7688c"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:916a798f62f410c0b80b63683c8061f5ebe237b0f4ad778739304253353bc1cb"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-win32.whl", hash = "sha256:31983018b74908ebc6c996a16ad3690301a23befb643093fcfe85efd292e384d"}, - {file = "SQLAlchemy-2.0.32-cp312-cp312-win_amd64.whl", hash = "sha256:4363ed245a6231f2e2957cccdda3c776265a75851f4753c60f3004b90e69bfeb"}, - {file = "SQLAlchemy-2.0.32-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b8afd5b26570bf41c35c0121801479958b4446751a3971fb9a480c1afd85558e"}, - {file = "SQLAlchemy-2.0.32-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c750987fc876813f27b60d619b987b057eb4896b81117f73bb8d9918c14f1cad"}, - {file = "SQLAlchemy-2.0.32-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ada0102afff4890f651ed91120c1120065663506b760da4e7823913ebd3258be"}, - {file = "SQLAlchemy-2.0.32-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:78c03d0f8a5ab4f3034c0e8482cfcc415a3ec6193491cfa1c643ed707d476f16"}, - {file = "SQLAlchemy-2.0.32-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:3bd1cae7519283ff525e64645ebd7a3e0283f3c038f461ecc1c7b040a0c932a1"}, - {file = "SQLAlchemy-2.0.32-cp37-cp37m-win32.whl", hash = "sha256:01438ebcdc566d58c93af0171c74ec28efe6a29184b773e378a385e6215389da"}, - {file = "SQLAlchemy-2.0.32-cp37-cp37m-win_amd64.whl", hash = "sha256:4979dc80fbbc9d2ef569e71e0896990bc94df2b9fdbd878290bd129b65ab579c"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6c742be912f57586ac43af38b3848f7688863a403dfb220193a882ea60e1ec3a"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:62e23d0ac103bcf1c5555b6c88c114089587bc64d048fef5bbdb58dfd26f96da"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:251f0d1108aab8ea7b9aadbd07fb47fb8e3a5838dde34aa95a3349876b5a1f1d"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ef18a84e5116340e38eca3e7f9eeaaef62738891422e7c2a0b80feab165905f"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:3eb6a97a1d39976f360b10ff208c73afb6a4de86dd2a6212ddf65c4a6a2347d5"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0c1c9b673d21477cec17ab10bc4decb1322843ba35b481585facd88203754fc5"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-win32.whl", hash = "sha256:c41a2b9ca80ee555decc605bd3c4520cc6fef9abde8fd66b1cf65126a6922d65"}, - {file = "SQLAlchemy-2.0.32-cp38-cp38-win_amd64.whl", hash = "sha256:8a37e4d265033c897892279e8adf505c8b6b4075f2b40d77afb31f7185cd6ecd"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:52fec964fba2ef46476312a03ec8c425956b05c20220a1a03703537824b5e8e1"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:328429aecaba2aee3d71e11f2477c14eec5990fb6d0e884107935f7fb6001632"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85a01b5599e790e76ac3fe3aa2f26e1feba56270023d6afd5550ed63c68552b3"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aaf04784797dcdf4c0aa952c8d234fa01974c4729db55c45732520ce12dd95b4"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:4488120becf9b71b3ac718f4138269a6be99a42fe023ec457896ba4f80749525"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:14e09e083a5796d513918a66f3d6aedbc131e39e80875afe81d98a03312889e6"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-win32.whl", hash = "sha256:0d322cc9c9b2154ba7e82f7bf25ecc7c36fbe2d82e2933b3642fc095a52cfc78"}, - {file = "SQLAlchemy-2.0.32-cp39-cp39-win_amd64.whl", hash = "sha256:7dd8583df2f98dea28b5cd53a1beac963f4f9d087888d75f22fcc93a07cf8d84"}, - {file = "SQLAlchemy-2.0.32-py3-none-any.whl", hash = "sha256:e567a8793a692451f706b363ccf3c45e056b67d90ead58c3bc9471af5d212202"}, - {file = "SQLAlchemy-2.0.32.tar.gz", hash = "sha256:c1b88cc8b02b6a5f0efb0345a03672d4c897dc7d92585176f88c67346f565ea8"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:59b8f3adb3971929a3e660337f5dacc5942c2cdb760afcabb2614ffbda9f9f72"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:37350015056a553e442ff672c2d20e6f4b6d0b2495691fa239d8aa18bb3bc908"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8318f4776c85abc3f40ab185e388bee7a6ea99e7fa3a30686580b209eaa35c08"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c245b1fbade9c35e5bd3b64270ab49ce990369018289ecfde3f9c318411aaa07"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:69f93723edbca7342624d09f6704e7126b152eaed3cdbb634cb657a54332a3c5"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f9511d8dd4a6e9271d07d150fb2f81874a3c8c95e11ff9af3a2dfc35fe42ee44"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-win32.whl", hash = "sha256:c3f3631693003d8e585d4200730616b78fafd5a01ef8b698f6967da5c605b3fa"}, + {file = "SQLAlchemy-2.0.36-cp310-cp310-win_amd64.whl", hash = "sha256:a86bfab2ef46d63300c0f06936bd6e6c0105faa11d509083ba8f2f9d237fb5b5"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:fd3a55deef00f689ce931d4d1b23fa9f04c880a48ee97af488fd215cf24e2a6c"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4f5e9cd989b45b73bd359f693b935364f7e1f79486e29015813c338450aa5a71"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0ddd9db6e59c44875211bc4c7953a9f6638b937b0a88ae6d09eb46cced54eff"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2519f3a5d0517fc159afab1015e54bb81b4406c278749779be57a569d8d1bb0d"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:59b1ee96617135f6e1d6f275bbe988f419c5178016f3d41d3c0abb0c819f75bb"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:39769a115f730d683b0eb7b694db9789267bcd027326cccc3125e862eb03bfd8"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-win32.whl", hash = "sha256:66bffbad8d6271bb1cc2f9a4ea4f86f80fe5e2e3e501a5ae2a3dc6a76e604e6f"}, + {file = "SQLAlchemy-2.0.36-cp311-cp311-win_amd64.whl", hash = "sha256:23623166bfefe1487d81b698c423f8678e80df8b54614c2bf4b4cfcd7c711959"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f7b64e6ec3f02c35647be6b4851008b26cff592a95ecb13b6788a54ef80bbdd4"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:46331b00096a6db1fdc052d55b101dbbfc99155a548e20a0e4a8e5e4d1362855"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdf3386a801ea5aba17c6410dd1dc8d39cf454ca2565541b5ac42a84e1e28f53"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac9dfa18ff2a67b09b372d5db8743c27966abf0e5344c555d86cc7199f7ad83a"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:90812a8933df713fdf748b355527e3af257a11e415b613dd794512461eb8a686"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1bc330d9d29c7f06f003ab10e1eaced295e87940405afe1b110f2eb93a233588"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-win32.whl", hash = "sha256:79d2e78abc26d871875b419e1fd3c0bca31a1cb0043277d0d850014599626c2e"}, + {file = "SQLAlchemy-2.0.36-cp312-cp312-win_amd64.whl", hash = "sha256:b544ad1935a8541d177cb402948b94e871067656b3a0b9e91dbec136b06a2ff5"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b5cc79df7f4bc3d11e4b542596c03826063092611e481fcf1c9dfee3c94355ef"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3c01117dd36800f2ecaa238c65365b7b16497adc1522bf84906e5710ee9ba0e8"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9bc633f4ee4b4c46e7adcb3a9b5ec083bf1d9a97c1d3854b92749d935de40b9b"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e46ed38affdfc95d2c958de328d037d87801cfcbea6d421000859e9789e61c2"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:b2985c0b06e989c043f1dc09d4fe89e1616aadd35392aea2844f0458a989eacf"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4a121d62ebe7d26fec9155f83f8be5189ef1405f5973ea4874a26fab9f1e262c"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-win32.whl", hash = "sha256:0572f4bd6f94752167adfd7c1bed84f4b240ee6203a95e05d1e208d488d0d436"}, + {file = "SQLAlchemy-2.0.36-cp313-cp313-win_amd64.whl", hash = "sha256:8c78ac40bde930c60e0f78b3cd184c580f89456dd87fc08f9e3ee3ce8765ce88"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:be9812b766cad94a25bc63bec11f88c4ad3629a0cec1cd5d4ba48dc23860486b"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50aae840ebbd6cdd41af1c14590e5741665e5272d2fee999306673a1bb1fdb4d"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4557e1f11c5f653ebfdd924f3f9d5ebfc718283b0b9beebaa5dd6b77ec290971"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:07b441f7d03b9a66299ce7ccf3ef2900abc81c0db434f42a5694a37bd73870f2"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:28120ef39c92c2dd60f2721af9328479516844c6b550b077ca450c7d7dc68575"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-win32.whl", hash = "sha256:b81ee3d84803fd42d0b154cb6892ae57ea6b7c55d8359a02379965706c7efe6c"}, + {file = "SQLAlchemy-2.0.36-cp37-cp37m-win_amd64.whl", hash = "sha256:f942a799516184c855e1a32fbc7b29d7e571b52612647866d4ec1c3242578fcb"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:3d6718667da04294d7df1670d70eeddd414f313738d20a6f1d1f379e3139a545"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:72c28b84b174ce8af8504ca28ae9347d317f9dba3999e5981a3cd441f3712e24"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b11d0cfdd2b095e7b0686cf5fabeb9c67fae5b06d265d8180715b8cfa86522e3"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e32092c47011d113dc01ab3e1d3ce9f006a47223b18422c5c0d150af13a00687"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:6a440293d802d3011028e14e4226da1434b373cbaf4a4bbb63f845761a708346"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c54a1e53a0c308a8e8a7dffb59097bff7facda27c70c286f005327f21b2bd6b1"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-win32.whl", hash = "sha256:1e0d612a17581b6616ff03c8e3d5eff7452f34655c901f75d62bd86449d9750e"}, + {file = "SQLAlchemy-2.0.36-cp38-cp38-win_amd64.whl", hash = "sha256:8958b10490125124463095bbdadda5aa22ec799f91958e410438ad6c97a7b793"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:dc022184d3e5cacc9579e41805a681187650e170eb2fd70e28b86192a479dcaa"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b817d41d692bf286abc181f8af476c4fbef3fd05e798777492618378448ee689"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a4e46a888b54be23d03a89be510f24a7652fe6ff660787b96cd0e57a4ebcb46d"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4ae3005ed83f5967f961fd091f2f8c5329161f69ce8480aa8168b2d7fe37f06"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:03e08af7a5f9386a43919eda9de33ffda16b44eb11f3b313e6822243770e9763"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:3dbb986bad3ed5ceaf090200eba750b5245150bd97d3e67343a3cfed06feecf7"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-win32.whl", hash = "sha256:9fe53b404f24789b5ea9003fc25b9a3988feddebd7e7b369c8fac27ad6f52f28"}, + {file = "SQLAlchemy-2.0.36-cp39-cp39-win_amd64.whl", hash = "sha256:af148a33ff0349f53512a049c6406923e4e02bf2f26c5fb285f143faf4f0e46a"}, + {file = "SQLAlchemy-2.0.36-py3-none-any.whl", hash = "sha256:fddbe92b4760c6f5d48162aef14824add991aeda8ddadb3c31d56eb15ca69f8e"}, + {file = "sqlalchemy-2.0.36.tar.gz", hash = "sha256:7f2767680b6d2398aea7082e45a774b2b0767b5c8d8ffb9c8b683088ea9b29c5"}, ] [package.dependencies] @@ -7899,7 +9116,7 @@ aioodbc = ["aioodbc", "greenlet (!=0.4.17)"] aiosqlite = ["aiosqlite", "greenlet (!=0.4.17)", "typing_extensions (!=3.10.0.1)"] asyncio = ["greenlet (!=0.4.17)"] asyncmy = ["asyncmy (>=0.2.3,!=0.2.4,!=0.2.6)", "greenlet (!=0.4.17)"] -mariadb-connector = ["mariadb (>=1.0.1,!=1.1.2,!=1.1.5)"] +mariadb-connector = ["mariadb (>=1.0.1,!=1.1.2,!=1.1.5,!=1.1.10)"] mssql = ["pyodbc"] mssql-pymssql = ["pymssql"] mssql-pyodbc = ["pyodbc"] @@ -7935,13 +9152,13 @@ doc = ["sphinx"] [[package]] name = "starlette" -version = "0.37.2" +version = "0.41.0" description = "The little ASGI library that shines." optional = false python-versions = ">=3.8" files = [ - {file = "starlette-0.37.2-py3-none-any.whl", hash = "sha256:6fe59f29268538e5d0d182f2791a479a0c64638e6935d1c6989e63fb2699c6ee"}, - {file = "starlette-0.37.2.tar.gz", hash = "sha256:9af890290133b79fc3db55474ade20f6220a364a0402e0b556e7cd5e1e093823"}, + {file = "starlette-0.41.0-py3-none-any.whl", hash = "sha256:a0193a3c413ebc9c78bff1c3546a45bb8c8bcb4a84cae8747d650a65bd37210a"}, + {file = "starlette-0.41.0.tar.gz", hash = "sha256:39cbd8768b107d68bfe1ff1672b38a2c38b49777de46d2a592841d58e3bf7c2a"}, ] [package.dependencies] @@ -7950,15 +9167,95 @@ anyio = ">=3.4.0,<5" [package.extras] full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"] +[[package]] +name = "storage3" +version = "0.8.2" +description = "Supabase Storage client for Python." +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "storage3-0.8.2-py3-none-any.whl", hash = "sha256:f2e995b18c77a2a9265d1a33047d43e4d6abb11eb3ca5067959f68281c305de3"}, + {file = "storage3-0.8.2.tar.gz", hash = "sha256:db05d3fe8fb73bd30c814c4c4749664f37a5dfc78b629e8c058ef558c2b89f5a"}, +] + +[package.dependencies] +httpx = {version = ">=0.26,<0.28", extras = ["http2"]} +python-dateutil = ">=2.8.2,<3.0.0" +typing-extensions = ">=4.2.0,<5.0.0" + +[[package]] +name = "strenum" +version = "0.4.15" +description = "An Enum that inherits from str." +optional = false +python-versions = "*" +files = [ + {file = "StrEnum-0.4.15-py3-none-any.whl", hash = "sha256:a30cda4af7cc6b5bf52c8055bc4bf4b2b6b14a93b574626da33df53cf7740659"}, + {file = "StrEnum-0.4.15.tar.gz", hash = "sha256:878fb5ab705442070e4dd1929bb5e2249511c0bcf2b0eeacf3bcd80875c82eff"}, +] + +[package.extras] +docs = ["myst-parser[linkify]", "sphinx", "sphinx-rtd-theme"] +release = ["twine"] +test = ["pylint", "pytest", "pytest-black", "pytest-cov", "pytest-pylint"] + +[[package]] +name = "strictyaml" +version = "1.7.3" +description = "Strict, typed YAML parser" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "strictyaml-1.7.3-py3-none-any.whl", hash = "sha256:fb5c8a4edb43bebb765959e420f9b3978d7f1af88c80606c03fb420888f5d1c7"}, + {file = "strictyaml-1.7.3.tar.gz", hash = "sha256:22f854a5fcab42b5ddba8030a0e4be51ca89af0267961c8d6cfa86395586c407"}, +] + +[package.dependencies] +python-dateutil = ">=2.6.0" + +[[package]] +name = "supabase" +version = "2.8.1" +description = "Supabase client for Python." +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "supabase-2.8.1-py3-none-any.whl", hash = "sha256:dfa8bef89b54129093521d5bba2136ff765baf67cd76d8ad0aa4984d61a7815c"}, + {file = "supabase-2.8.1.tar.gz", hash = "sha256:711c70e6acd9e2ff48ca0dc0b1bb70c01c25378cc5189ec9f5ed9655b30bc41d"}, +] + +[package.dependencies] +gotrue = ">=2.7.0,<3.0.0" +httpx = ">=0.24,<0.28" +postgrest = ">=0.17.0,<0.18.0" +realtime = ">=2.0.0,<3.0.0" +storage3 = ">=0.8.0,<0.9.0" +supafunc = ">=0.6.0,<0.7.0" +typing-extensions = ">=4.12.2,<5.0.0" + +[[package]] +name = "supafunc" +version = "0.6.2" +description = "Library for Supabase Functions" +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "supafunc-0.6.2-py3-none-any.whl", hash = "sha256:101b30616b0a1ce8cf938eca1df362fa4cf1deacb0271f53ebbd674190fb0da5"}, + {file = "supafunc-0.6.2.tar.gz", hash = "sha256:c7dfa20db7182f7fe4ae436e94e05c06cd7ed98d697fed75d68c7b9792822adc"}, +] + +[package.dependencies] +httpx = {version = ">=0.26,<0.28", extras = ["http2"]} + [[package]] name = "sympy" -version = "1.13.1" +version = "1.13.3" description = "Computer algebra system (CAS) in Python" optional = false python-versions = ">=3.8" files = [ - {file = "sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8"}, - {file = "sympy-1.13.1.tar.gz", hash = "sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f"}, + {file = "sympy-1.13.3-py3-none-any.whl", hash = "sha256:54612cf55a62755ee71824ce692986f23c88ffa77207b30c1368eda4a7060f73"}, + {file = "sympy-1.13.3.tar.gz", hash = "sha256:b27fd2c6530e0ab39e275fc9b683895367e51d5da91baa8d3d64db2565fec4d9"}, ] [package.dependencies] @@ -7981,6 +9278,17 @@ files = [ [package.extras] widechars = ["wcwidth"] +[[package]] +name = "tblib" +version = "3.0.0" +description = "Traceback serialization library." +optional = false +python-versions = ">=3.8" +files = [ + {file = "tblib-3.0.0-py3-none-any.whl", hash = "sha256:80a6c77e59b55e83911e1e607c649836a69c103963c5f28a46cbeef44acf8129"}, + {file = "tblib-3.0.0.tar.gz", hash = "sha256:93622790a0a29e04f0346458face1e144dc4d32f493714c6c3dff82a4adb77e6"}, +] + [[package]] name = "tcvectordb" version = "1.3.2" @@ -8013,13 +9321,13 @@ test = ["pytest", "tornado (>=4.5)", "typeguard"] [[package]] name = "tencentcloud-sdk-python-common" -version = "3.0.1206" +version = "3.0.1257" description = "Tencent Cloud Common SDK for Python" optional = false python-versions = "*" files = [ - {file = "tencentcloud-sdk-python-common-3.0.1206.tar.gz", hash = "sha256:e32745e6d46b94b2c2c33cd68c7e70bff3d63e8e5e5d314bb0b41616521c90f2"}, - {file = "tencentcloud_sdk_python_common-3.0.1206-py2.py3-none-any.whl", hash = "sha256:2100697933d62135b093bae43eee0f8862b45ca0597da72779e304c9b392ac96"}, + {file = "tencentcloud-sdk-python-common-3.0.1257.tar.gz", hash = "sha256:e10b155d598a60c43a491be10f40f7dae5774a2187d55f2da83bdb559434f3c4"}, + {file = "tencentcloud_sdk_python_common-3.0.1257-py2.py3-none-any.whl", hash = "sha256:f474a2969f3cbff91f45780f18bfbb90ab53f66c0085c4e9b4e07c2fcf0e71d9"}, ] [package.dependencies] @@ -8027,17 +9335,31 @@ requests = ">=2.16.0" [[package]] name = "tencentcloud-sdk-python-hunyuan" -version = "3.0.1206" +version = "3.0.1257" description = "Tencent Cloud Hunyuan SDK for Python" optional = false python-versions = "*" files = [ - {file = "tencentcloud-sdk-python-hunyuan-3.0.1206.tar.gz", hash = "sha256:2c37f2f50e54d23905d91d7a511a217317d944c701127daae548b7275cc32968"}, - {file = "tencentcloud_sdk_python_hunyuan-3.0.1206-py2.py3-none-any.whl", hash = "sha256:c650315bb5863f28d410fa1062122550d8015600947d04d95e2bff55d0590acc"}, + {file = "tencentcloud-sdk-python-hunyuan-3.0.1257.tar.gz", hash = "sha256:4d38505089bed70dda1f806f8c4835f8a8c520efa86dcecfef444045c21b695d"}, + {file = "tencentcloud_sdk_python_hunyuan-3.0.1257-py2.py3-none-any.whl", hash = "sha256:c9089d3e49304c9c20e7465c82372b2cd234e67f63efdffb6798a4093b3a97c6"}, ] [package.dependencies] -tencentcloud-sdk-python-common = "3.0.1206" +tencentcloud-sdk-python-common = "3.0.1257" + +[[package]] +name = "termcolor" +version = "2.5.0" +description = "ANSI color formatting for output in terminal" +optional = false +python-versions = ">=3.9" +files = [ + {file = "termcolor-2.5.0-py3-none-any.whl", hash = "sha256:37b17b5fc1e604945c2642c872a3764b5d547a48009871aea3edd3afa180afb8"}, + {file = "termcolor-2.5.0.tar.gz", hash = "sha256:998d8d27da6d48442e8e1f016119076b690d962507531df4890fcd2db2ef8a6f"}, +] + +[package.extras] +tests = ["pytest", "pytest-cov"] [[package]] name = "threadpoolctl" @@ -8069,47 +9391,42 @@ client = ["SQLAlchemy (>=1.4,<3)"] [[package]] name = "tiktoken" -version = "0.7.0" +version = "0.8.0" description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "tiktoken-0.7.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:485f3cc6aba7c6b6ce388ba634fbba656d9ee27f766216f45146beb4ac18b25f"}, - {file = "tiktoken-0.7.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e54be9a2cd2f6d6ffa3517b064983fb695c9a9d8aa7d574d1ef3c3f931a99225"}, - {file = "tiktoken-0.7.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79383a6e2c654c6040e5f8506f3750db9ddd71b550c724e673203b4f6b4b4590"}, - {file = "tiktoken-0.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d4511c52caacf3c4981d1ae2df85908bd31853f33d30b345c8b6830763f769c"}, - {file = "tiktoken-0.7.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:13c94efacdd3de9aff824a788353aa5749c0faee1fbe3816df365ea450b82311"}, - {file = "tiktoken-0.7.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8e58c7eb29d2ab35a7a8929cbeea60216a4ccdf42efa8974d8e176d50c9a3df5"}, - {file = "tiktoken-0.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:21a20c3bd1dd3e55b91c1331bf25f4af522c525e771691adbc9a69336fa7f702"}, - {file = "tiktoken-0.7.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:10c7674f81e6e350fcbed7c09a65bca9356eaab27fb2dac65a1e440f2bcfe30f"}, - {file = "tiktoken-0.7.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:084cec29713bc9d4189a937f8a35dbdfa785bd1235a34c1124fe2323821ee93f"}, - {file = "tiktoken-0.7.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:811229fde1652fedcca7c6dfe76724d0908775b353556d8a71ed74d866f73f7b"}, - {file = "tiktoken-0.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86b6e7dc2e7ad1b3757e8a24597415bafcfb454cebf9a33a01f2e6ba2e663992"}, - {file = "tiktoken-0.7.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1063c5748be36344c7e18c7913c53e2cca116764c2080177e57d62c7ad4576d1"}, - {file = "tiktoken-0.7.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:20295d21419bfcca092644f7e2f2138ff947a6eb8cfc732c09cc7d76988d4a89"}, - {file = "tiktoken-0.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:959d993749b083acc57a317cbc643fb85c014d055b2119b739487288f4e5d1cb"}, - {file = "tiktoken-0.7.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:71c55d066388c55a9c00f61d2c456a6086673ab7dec22dd739c23f77195b1908"}, - {file = "tiktoken-0.7.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:09ed925bccaa8043e34c519fbb2f99110bd07c6fd67714793c21ac298e449410"}, - {file = "tiktoken-0.7.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03c6c40ff1db0f48a7b4d2dafeae73a5607aacb472fa11f125e7baf9dce73704"}, - {file = "tiktoken-0.7.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d20b5c6af30e621b4aca094ee61777a44118f52d886dbe4f02b70dfe05c15350"}, - {file = "tiktoken-0.7.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d427614c3e074004efa2f2411e16c826f9df427d3c70a54725cae860f09e4bf4"}, - {file = "tiktoken-0.7.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:8c46d7af7b8c6987fac9b9f61041b452afe92eb087d29c9ce54951280f899a97"}, - {file = "tiktoken-0.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:0bc603c30b9e371e7c4c7935aba02af5994a909fc3c0fe66e7004070858d3f8f"}, - {file = "tiktoken-0.7.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2398fecd38c921bcd68418675a6d155fad5f5e14c2e92fcf5fe566fa5485a858"}, - {file = "tiktoken-0.7.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8f5f6afb52fb8a7ea1c811e435e4188f2bef81b5e0f7a8635cc79b0eef0193d6"}, - {file = "tiktoken-0.7.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:861f9ee616766d736be4147abac500732b505bf7013cfaf019b85892637f235e"}, - {file = "tiktoken-0.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54031f95c6939f6b78122c0aa03a93273a96365103793a22e1793ee86da31685"}, - {file = "tiktoken-0.7.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:fffdcb319b614cf14f04d02a52e26b1d1ae14a570f90e9b55461a72672f7b13d"}, - {file = "tiktoken-0.7.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c72baaeaefa03ff9ba9688624143c858d1f6b755bb85d456d59e529e17234769"}, - {file = "tiktoken-0.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:131b8aeb043a8f112aad9f46011dced25d62629091e51d9dc1adbf4a1cc6aa98"}, - {file = "tiktoken-0.7.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cabc6dc77460df44ec5b879e68692c63551ae4fae7460dd4ff17181df75f1db7"}, - {file = "tiktoken-0.7.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8d57f29171255f74c0aeacd0651e29aa47dff6f070cb9f35ebc14c82278f3b25"}, - {file = "tiktoken-0.7.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ee92776fdbb3efa02a83f968c19d4997a55c8e9ce7be821ceee04a1d1ee149c"}, - {file = "tiktoken-0.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e215292e99cb41fbc96988ef62ea63bb0ce1e15f2c147a61acc319f8b4cbe5bf"}, - {file = "tiktoken-0.7.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:8a81bac94769cab437dd3ab0b8a4bc4e0f9cf6835bcaa88de71f39af1791727a"}, - {file = "tiktoken-0.7.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:d6d73ea93e91d5ca771256dfc9d1d29f5a554b83821a1dc0891987636e0ae226"}, - {file = "tiktoken-0.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:2bcb28ddf79ffa424f171dfeef9a4daff61a94c631ca6813f43967cb263b83b9"}, - {file = "tiktoken-0.7.0.tar.gz", hash = "sha256:1077266e949c24e0291f6c350433c6f0971365ece2b173a23bc3b9f9defef6b6"}, + {file = "tiktoken-0.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b07e33283463089c81ef1467180e3e00ab00d46c2c4bbcef0acab5f771d6695e"}, + {file = "tiktoken-0.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9269348cb650726f44dd3bbb3f9110ac19a8dcc8f54949ad3ef652ca22a38e21"}, + {file = "tiktoken-0.8.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25e13f37bc4ef2d012731e93e0fef21dc3b7aea5bb9009618de9a4026844e560"}, + {file = "tiktoken-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f13d13c981511331eac0d01a59b5df7c0d4060a8be1e378672822213da51e0a2"}, + {file = "tiktoken-0.8.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:6b2ddbc79a22621ce8b1166afa9f9a888a664a579350dc7c09346a3b5de837d9"}, + {file = "tiktoken-0.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:d8c2d0e5ba6453a290b86cd65fc51fedf247e1ba170191715b049dac1f628005"}, + {file = "tiktoken-0.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d622d8011e6d6f239297efa42a2657043aaed06c4f68833550cac9e9bc723ef1"}, + {file = "tiktoken-0.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2efaf6199717b4485031b4d6edb94075e4d79177a172f38dd934d911b588d54a"}, + {file = "tiktoken-0.8.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5637e425ce1fc49cf716d88df3092048359a4b3bbb7da762840426e937ada06d"}, + {file = "tiktoken-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fb0e352d1dbe15aba082883058b3cce9e48d33101bdaac1eccf66424feb5b47"}, + {file = "tiktoken-0.8.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:56edfefe896c8f10aba372ab5706b9e3558e78db39dd497c940b47bf228bc419"}, + {file = "tiktoken-0.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:326624128590def898775b722ccc327e90b073714227175ea8febbc920ac0a99"}, + {file = "tiktoken-0.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:881839cfeae051b3628d9823b2e56b5cc93a9e2efb435f4cf15f17dc45f21586"}, + {file = "tiktoken-0.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fe9399bdc3f29d428f16a2f86c3c8ec20be3eac5f53693ce4980371c3245729b"}, + {file = "tiktoken-0.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9a58deb7075d5b69237a3ff4bb51a726670419db6ea62bdcd8bd80c78497d7ab"}, + {file = "tiktoken-0.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2908c0d043a7d03ebd80347266b0e58440bdef5564f84f4d29fb235b5df3b04"}, + {file = "tiktoken-0.8.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:294440d21a2a51e12d4238e68a5972095534fe9878be57d905c476017bff99fc"}, + {file = "tiktoken-0.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:d8f3192733ac4d77977432947d563d7e1b310b96497acd3c196c9bddb36ed9db"}, + {file = "tiktoken-0.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:02be1666096aff7da6cbd7cdaa8e7917bfed3467cd64b38b1f112e96d3b06a24"}, + {file = "tiktoken-0.8.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c94ff53c5c74b535b2cbf431d907fc13c678bbd009ee633a2aca269a04389f9a"}, + {file = "tiktoken-0.8.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b231f5e8982c245ee3065cd84a4712d64692348bc609d84467c57b4b72dcbc5"}, + {file = "tiktoken-0.8.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4177faa809bd55f699e88c96d9bb4635d22e3f59d635ba6fd9ffedf7150b9953"}, + {file = "tiktoken-0.8.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5376b6f8dc4753cd81ead935c5f518fa0fbe7e133d9e25f648d8c4dabdd4bad7"}, + {file = "tiktoken-0.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:18228d624807d66c87acd8f25fc135665617cab220671eb65b50f5d70fa51f69"}, + {file = "tiktoken-0.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7e17807445f0cf1f25771c9d86496bd8b5c376f7419912519699f3cc4dc5c12e"}, + {file = "tiktoken-0.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:886f80bd339578bbdba6ed6d0567a0d5c6cfe198d9e587ba6c447654c65b8edc"}, + {file = "tiktoken-0.8.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6adc8323016d7758d6de7313527f755b0fc6c72985b7d9291be5d96d73ecd1e1"}, + {file = "tiktoken-0.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b591fb2b30d6a72121a80be24ec7a0e9eb51c5500ddc7e4c2496516dd5e3816b"}, + {file = "tiktoken-0.8.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:845287b9798e476b4d762c3ebda5102be87ca26e5d2c9854002825d60cdb815d"}, + {file = "tiktoken-0.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:1473cfe584252dc3fa62adceb5b1c763c1874e04511b197da4e6de51d6ce5a02"}, + {file = "tiktoken-0.8.0.tar.gz", hash = "sha256:9ccbb2740f24542534369c5635cfd9b2b3c2490754a78ac8831d99f89f94eeb2"}, ] [package.dependencies] @@ -8290,24 +9607,41 @@ files = [ [[package]] name = "tomli" -version = "2.0.1" +version = "2.0.2" description = "A lil' TOML parser" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" +files = [ + {file = "tomli-2.0.2-py3-none-any.whl", hash = "sha256:2ebe24485c53d303f690b0ec092806a085f07af5a5aa1464f3931eec36caaa38"}, + {file = "tomli-2.0.2.tar.gz", hash = "sha256:d46d457a85337051c36524bc5349dd91b1877838e2979ac5ced3e710ed8a60ed"}, +] + +[[package]] +name = "tos" +version = "2.7.2" +description = "Volc TOS (Tinder Object Storage) SDK" +optional = false +python-versions = "*" files = [ - {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, - {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, + {file = "tos-2.7.2.tar.gz", hash = "sha256:3c31257716785bca7b2cac51474ff32543cda94075a7b7aff70d769c15c7b7ed"}, ] +[package.dependencies] +crcmod = ">=1.7" +Deprecated = ">=1.2.13,<2.0.0" +pytz = "*" +requests = ">=2.19.1,<3.dev0" +six = "*" + [[package]] name = "tqdm" -version = "4.66.5" +version = "4.66.6" description = "Fast, Extensible Progress Meter" optional = false python-versions = ">=3.7" files = [ - {file = "tqdm-4.66.5-py3-none-any.whl", hash = "sha256:90279a3770753eafc9194a0364852159802111925aa30eb3f9d85b0e805ac7cd"}, - {file = "tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad"}, + {file = "tqdm-4.66.6-py3-none-any.whl", hash = "sha256:223e8b5359c2efc4b30555531f09e9f2f3589bcd7fdd389271191031b49b7a63"}, + {file = "tqdm-4.66.6.tar.gz", hash = "sha256:4bdd694238bef1485ce839d67967ab50af8f9272aab687c0d7702a01da0be090"}, ] [package.dependencies] @@ -8406,13 +9740,13 @@ requests = ">=2.0.0" [[package]] name = "typer" -version = "0.12.3" +version = "0.12.5" description = "Typer, build great CLIs. Easy to code. Based on Python type hints." optional = false python-versions = ">=3.7" files = [ - {file = "typer-0.12.3-py3-none-any.whl", hash = "sha256:070d7ca53f785acbccba8e7d28b08dcd88f79f1fbda035ade0aecec71ca5c914"}, - {file = "typer-0.12.3.tar.gz", hash = "sha256:49e73131481d804288ef62598d97a1ceef3058905aa536a1134f90891ba35482"}, + {file = "typer-0.12.5-py3-none-any.whl", hash = "sha256:62fe4e471711b147e3365034133904df3e235698399bc4de2b36c8579298d52b"}, + {file = "typer-0.12.5.tar.gz", hash = "sha256:f592f089bedcc8ec1b974125d64851029c3b1af145f04aca64d69410f0c9b722"}, ] [package.dependencies] @@ -8423,13 +9757,13 @@ typing-extensions = ">=3.7.4.3" [[package]] name = "types-requests" -version = "2.32.0.20240712" +version = "2.32.0.20241016" description = "Typing stubs for requests" optional = false python-versions = ">=3.8" files = [ - {file = "types-requests-2.32.0.20240712.tar.gz", hash = "sha256:90c079ff05e549f6bf50e02e910210b98b8ff1ebdd18e19c873cd237737c1358"}, - {file = "types_requests-2.32.0.20240712-py3-none-any.whl", hash = "sha256:f754283e152c752e46e70942fa2a146b5bc70393522257bb85bd1ef7e019dcc3"}, + {file = "types-requests-2.32.0.20241016.tar.gz", hash = "sha256:0d9cad2f27515d0e3e3da7134a1b6f28fb97129d86b867f24d9c726452634d95"}, + {file = "types_requests-2.32.0.20241016-py3-none-any.whl", hash = "sha256:4195d62d6d3e043a4eaaf08ff8a62184584d2e8684e9d2aa178c7915a7da3747"}, ] [package.dependencies] @@ -8463,32 +9797,15 @@ typing-extensions = ">=3.7.4" [[package]] name = "tzdata" -version = "2024.1" +version = "2024.2" description = "Provider of IANA time zone data" optional = false python-versions = ">=2" files = [ - {file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"}, - {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, -] - -[[package]] -name = "tzlocal" -version = "5.2" -description = "tzinfo object for the local timezone" -optional = false -python-versions = ">=3.8" -files = [ - {file = "tzlocal-5.2-py3-none-any.whl", hash = "sha256:49816ef2fe65ea8ac19d19aa7a1ae0551c834303d5014c6d5a62e4cbda8047b8"}, - {file = "tzlocal-5.2.tar.gz", hash = "sha256:8d399205578f1a9342816409cc1e46a93ebd5755e39ea2d85334bea911bf0e6e"}, + {file = "tzdata-2024.2-py2.py3-none-any.whl", hash = "sha256:a48093786cdcde33cad18c2555e8532f34422074448fbc874186f0abd79565cd"}, + {file = "tzdata-2024.2.tar.gz", hash = "sha256:7d85cc416e9382e69095b7bdf4afd9e3880418a2413feec7069d533d6b4e31cc"}, ] -[package.dependencies] -tzdata = {version = "*", markers = "platform_system == \"Windows\""} - -[package.extras] -devenv = ["check-manifest", "pytest (>=4.3)", "pytest-cov", "pytest-mock (>=3.3)", "zest.releaser"] - [[package]] name = "ujson" version = "5.10.0" @@ -8578,13 +9895,13 @@ files = [ [[package]] name = "unstructured" -version = "0.10.30" +version = "0.16.3" description = "A library that prepares raw documents for downstream ML tasks." optional = false -python-versions = ">=3.7.0" +python-versions = "<3.13,>=3.9.0" files = [ - {file = "unstructured-0.10.30-py3-none-any.whl", hash = "sha256:0615f14daa37450e9c0fcf3c3fd178c3a06b6b8d006a36d1a5e54dbe487aa6b6"}, - {file = "unstructured-0.10.30.tar.gz", hash = "sha256:a86c3d15c572a28322d83cb5ecf0ac7a24f1c36864fb7c68df096de8a1acc106"}, + {file = "unstructured-0.16.3-py3-none-any.whl", hash = "sha256:e0e3b56531b44e62154d17cbfdae7fd7fa1d795b7cf510fb654c6714d4257655"}, + {file = "unstructured-0.16.3.tar.gz", hash = "sha256:f9528636773c910a53c8a34e32d4733ea54b79cbd507d0e956e299ab1da3003f"}, ] [package.dependencies] @@ -8594,71 +9911,84 @@ chardet = "*" dataclasses-json = "*" emoji = "*" filetype = "*" +html5lib = "*" langdetect = "*" lxml = "*" markdown = {version = "*", optional = true, markers = "extra == \"md\""} -msg-parser = {version = "*", optional = true, markers = "extra == \"msg\""} nltk = "*" -numpy = "*" +numpy = "<2" +psutil = "*" pypandoc = {version = "*", optional = true, markers = "extra == \"epub\""} -python-docx = {version = ">=1.1.0", optional = true, markers = "extra == \"docx\""} +python-docx = {version = ">=1.1.2", optional = true, markers = "extra == \"docx\""} python-iso639 = "*" python-magic = "*" -python-pptx = {version = "<=0.6.23", optional = true, markers = "extra == \"ppt\" or extra == \"pptx\""} +python-oxmsg = "*" +python-pptx = {version = ">=1.0.1", optional = true, markers = "extra == \"ppt\" or extra == \"pptx\""} rapidfuzz = "*" requests = "*" -tabulate = "*" +tqdm = "*" typing-extensions = "*" +unstructured-client = "*" +wrapt = "*" [package.extras] -airtable = ["pyairtable"] -all-docs = ["markdown", "msg-parser", "networkx", "onnx", "openpyxl", "pandas", "pdf2image", "pdfminer.six", "pypandoc", "python-docx (>=1.1.0)", "python-pptx (<=0.6.23)", "unstructured-inference (==0.7.11)", "unstructured.pytesseract (>=0.3.12)", "xlrd"] -azure = ["adlfs", "fsspec (==2023.9.1)"] -azure-cognitive-search = ["azure-search-documents"] -bedrock = ["boto3", "langchain"] -biomed = ["bs4"] -box = ["boxfs", "fsspec (==2023.9.1)"] -confluence = ["atlassian-python-api"] +all-docs = ["effdet", "google-cloud-vision", "markdown", "networkx", "onnx", "openpyxl", "pandas", "pdf2image", "pdfminer.six", "pi-heif", "pikepdf", "pypandoc", "pypdf", "python-docx (>=1.1.2)", "python-pptx (>=1.0.1)", "unstructured-inference (==0.8.1)", "unstructured.pytesseract (>=0.3.12)", "xlrd"] csv = ["pandas"] -delta-table = ["deltalake", "fsspec (==2023.9.1)"] -discord = ["discord-py"] -doc = ["python-docx (>=1.1.0)"] -docx = ["python-docx (>=1.1.0)"] -dropbox = ["dropboxdrivefs", "fsspec (==2023.9.1)"] -elasticsearch = ["elasticsearch", "jq"] -embed-huggingface = ["huggingface", "langchain", "sentence-transformers"] +doc = ["python-docx (>=1.1.2)"] +docx = ["python-docx (>=1.1.2)"] epub = ["pypandoc"] -gcs = ["bs4", "fsspec (==2023.9.1)", "gcsfs"] -github = ["pygithub (>1.58.0)"] -gitlab = ["python-gitlab"] -google-drive = ["google-api-python-client"] huggingface = ["langdetect", "sacremoses", "sentencepiece", "torch", "transformers"] -image = ["onnx", "pdf2image", "pdfminer.six", "unstructured-inference (==0.7.11)", "unstructured.pytesseract (>=0.3.12)"] -jira = ["atlassian-python-api"] -local-inference = ["markdown", "msg-parser", "networkx", "onnx", "openpyxl", "pandas", "pdf2image", "pdfminer.six", "pypandoc", "python-docx (>=1.1.0)", "python-pptx (<=0.6.23)", "unstructured-inference (==0.7.11)", "unstructured.pytesseract (>=0.3.12)", "xlrd"] +image = ["effdet", "google-cloud-vision", "onnx", "pdf2image", "pdfminer.six", "pi-heif", "pikepdf", "pypdf", "unstructured-inference (==0.8.1)", "unstructured.pytesseract (>=0.3.12)"] +local-inference = ["effdet", "google-cloud-vision", "markdown", "networkx", "onnx", "openpyxl", "pandas", "pdf2image", "pdfminer.six", "pi-heif", "pikepdf", "pypandoc", "pypdf", "python-docx (>=1.1.2)", "python-pptx (>=1.0.1)", "unstructured-inference (==0.8.1)", "unstructured.pytesseract (>=0.3.12)", "xlrd"] md = ["markdown"] -msg = ["msg-parser"] -notion = ["htmlBuilder", "notion-client"] -odt = ["pypandoc", "python-docx (>=1.1.0)"] -onedrive = ["Office365-REST-Python-Client (<2.4.3)", "bs4", "msal"] -openai = ["langchain", "openai", "tiktoken"] +odt = ["pypandoc", "python-docx (>=1.1.2)"] org = ["pypandoc"] -outlook = ["Office365-REST-Python-Client (<2.4.3)", "msal"] -paddleocr = ["unstructured.paddleocr (==2.6.1.3)"] -pdf = ["onnx", "pdf2image", "pdfminer.six", "unstructured-inference (==0.7.11)", "unstructured.pytesseract (>=0.3.12)"] -ppt = ["python-pptx (<=0.6.23)"] -pptx = ["python-pptx (<=0.6.23)"] -reddit = ["praw"] +paddleocr = ["paddlepaddle (==3.0.0b1)", "unstructured.paddleocr (==2.8.1.0)"] +pdf = ["effdet", "google-cloud-vision", "onnx", "pdf2image", "pdfminer.six", "pi-heif", "pikepdf", "pypdf", "unstructured-inference (==0.8.1)", "unstructured.pytesseract (>=0.3.12)"] +ppt = ["python-pptx (>=1.0.1)"] +pptx = ["python-pptx (>=1.0.1)"] rst = ["pypandoc"] rtf = ["pypandoc"] -s3 = ["fsspec (==2023.9.1)", "s3fs"] -salesforce = ["simple-salesforce"] -sharepoint = ["Office365-REST-Python-Client (<2.4.3)", "msal"] -slack = ["slack-sdk"] tsv = ["pandas"] -wikipedia = ["wikipedia"] xlsx = ["networkx", "openpyxl", "pandas", "xlrd"] +[[package]] +name = "unstructured-client" +version = "0.26.2" +description = "Python Client SDK for Unstructured API" +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "unstructured_client-0.26.2-py3-none-any.whl", hash = "sha256:0adb22b7d175814f333ee2425a279005f253220a55f459fd5830a6779b679780"}, + {file = "unstructured_client-0.26.2.tar.gz", hash = "sha256:02f7183ab16db6ec48ad1ac75c01b05967c87c561a89e96d9ffb836baed902d7"}, +] + +[package.dependencies] +cryptography = ">=3.1" +eval-type-backport = ">=0.2.0,<0.3.0" +httpx = ">=0.27.0" +jsonpath-python = ">=1.0.6,<2.0.0" +nest-asyncio = ">=1.6.0" +pydantic = ">=2.9.0,<2.10.0" +pypdf = ">=4.0" +python-dateutil = "2.8.2" +requests-toolbelt = ">=1.0.0" +typing-inspect = ">=0.9.0,<0.10.0" + +[[package]] +name = "upstash-vector" +version = "0.6.0" +description = "Serverless Vector SDK from Upstash" +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "upstash_vector-0.6.0-py3-none-any.whl", hash = "sha256:d0bdad7765b8a7f5c205b7a9c81ca4b9a4cee3ee4952afc7d5ea5fb76c3f3c3c"}, + {file = "upstash_vector-0.6.0.tar.gz", hash = "sha256:a716ed4d0251362208518db8b194158a616d37d1ccbb1155f619df690599e39b"}, +] + +[package.dependencies] +httpx = ">=0.23.0,<1" + [[package]] name = "uritemplate" version = "4.1.1" @@ -8672,13 +10002,13 @@ files = [ [[package]] name = "urllib3" -version = "2.2.2" +version = "2.2.3" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" files = [ - {file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"}, - {file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"}, + {file = "urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac"}, + {file = "urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9"}, ] [package.extras] @@ -8689,13 +10019,13 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "uvicorn" -version = "0.30.5" +version = "0.32.0" description = "The lightning-fast ASGI server." optional = false python-versions = ">=3.8" files = [ - {file = "uvicorn-0.30.5-py3-none-any.whl", hash = "sha256:b2d86de274726e9878188fa07576c9ceeff90a839e2b6e25c917fe05f5a6c835"}, - {file = "uvicorn-0.30.5.tar.gz", hash = "sha256:ac6fdbd4425c5fd17a9fe39daf4d4d075da6fdc80f653e5894cdc2fd98752bee"}, + {file = "uvicorn-0.32.0-py3-none-any.whl", hash = "sha256:60b8f3a5ac027dcd31448f411ced12b5ef452c646f76f02f8cc3f25d8d26fd82"}, + {file = "uvicorn-0.32.0.tar.gz", hash = "sha256:f78b36b143c16f54ccdb8190d0a26b5f1901fe5a3c777e1ab29f26391af8551e"}, ] [package.dependencies] @@ -8715,47 +10045,54 @@ standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", [[package]] name = "uvloop" -version = "0.19.0" +version = "0.21.0" description = "Fast implementation of asyncio event loop on top of libuv" optional = false python-versions = ">=3.8.0" files = [ - {file = "uvloop-0.19.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:de4313d7f575474c8f5a12e163f6d89c0a878bc49219641d49e6f1444369a90e"}, - {file = "uvloop-0.19.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5588bd21cf1fcf06bded085f37e43ce0e00424197e7c10e77afd4bbefffef428"}, - {file = "uvloop-0.19.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b1fd71c3843327f3bbc3237bedcdb6504fd50368ab3e04d0410e52ec293f5b8"}, - {file = "uvloop-0.19.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a05128d315e2912791de6088c34136bfcdd0c7cbc1cf85fd6fd1bb321b7c849"}, - {file = "uvloop-0.19.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:cd81bdc2b8219cb4b2556eea39d2e36bfa375a2dd021404f90a62e44efaaf957"}, - {file = "uvloop-0.19.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5f17766fb6da94135526273080f3455a112f82570b2ee5daa64d682387fe0dcd"}, - {file = "uvloop-0.19.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4ce6b0af8f2729a02a5d1575feacb2a94fc7b2e983868b009d51c9a9d2149bef"}, - {file = "uvloop-0.19.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:31e672bb38b45abc4f26e273be83b72a0d28d074d5b370fc4dcf4c4eb15417d2"}, - {file = "uvloop-0.19.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:570fc0ed613883d8d30ee40397b79207eedd2624891692471808a95069a007c1"}, - {file = "uvloop-0.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5138821e40b0c3e6c9478643b4660bd44372ae1e16a322b8fc07478f92684e24"}, - {file = "uvloop-0.19.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:91ab01c6cd00e39cde50173ba4ec68a1e578fee9279ba64f5221810a9e786533"}, - {file = "uvloop-0.19.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:47bf3e9312f63684efe283f7342afb414eea4d3011542155c7e625cd799c3b12"}, - {file = "uvloop-0.19.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:da8435a3bd498419ee8c13c34b89b5005130a476bda1d6ca8cfdde3de35cd650"}, - {file = "uvloop-0.19.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:02506dc23a5d90e04d4f65c7791e65cf44bd91b37f24cfc3ef6cf2aff05dc7ec"}, - {file = "uvloop-0.19.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2693049be9d36fef81741fddb3f441673ba12a34a704e7b4361efb75cf30befc"}, - {file = "uvloop-0.19.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7010271303961c6f0fe37731004335401eb9075a12680738731e9c92ddd96ad6"}, - {file = "uvloop-0.19.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:5daa304d2161d2918fa9a17d5635099a2f78ae5b5960e742b2fcfbb7aefaa593"}, - {file = "uvloop-0.19.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:7207272c9520203fea9b93843bb775d03e1cf88a80a936ce760f60bb5add92f3"}, - {file = "uvloop-0.19.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:78ab247f0b5671cc887c31d33f9b3abfb88d2614b84e4303f1a63b46c046c8bd"}, - {file = "uvloop-0.19.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:472d61143059c84947aa8bb74eabbace30d577a03a1805b77933d6bd13ddebbd"}, - {file = "uvloop-0.19.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45bf4c24c19fb8a50902ae37c5de50da81de4922af65baf760f7c0c42e1088be"}, - {file = "uvloop-0.19.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:271718e26b3e17906b28b67314c45d19106112067205119dddbd834c2b7ce797"}, - {file = "uvloop-0.19.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:34175c9fd2a4bc3adc1380e1261f60306344e3407c20a4d684fd5f3be010fa3d"}, - {file = "uvloop-0.19.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e27f100e1ff17f6feeb1f33968bc185bf8ce41ca557deee9d9bbbffeb72030b7"}, - {file = "uvloop-0.19.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:13dfdf492af0aa0a0edf66807d2b465607d11c4fa48f4a1fd41cbea5b18e8e8b"}, - {file = "uvloop-0.19.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6e3d4e85ac060e2342ff85e90d0c04157acb210b9ce508e784a944f852a40e67"}, - {file = "uvloop-0.19.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8ca4956c9ab567d87d59d49fa3704cf29e37109ad348f2d5223c9bf761a332e7"}, - {file = "uvloop-0.19.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f467a5fd23b4fc43ed86342641f3936a68ded707f4627622fa3f82a120e18256"}, - {file = "uvloop-0.19.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:492e2c32c2af3f971473bc22f086513cedfc66a130756145a931a90c3958cb17"}, - {file = "uvloop-0.19.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2df95fca285a9f5bfe730e51945ffe2fa71ccbfdde3b0da5772b4ee4f2e770d5"}, - {file = "uvloop-0.19.0.tar.gz", hash = "sha256:0246f4fd1bf2bf702e06b0d45ee91677ee5c31242f39aab4ea6fe0c51aedd0fd"}, -] - -[package.extras] + {file = "uvloop-0.21.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ec7e6b09a6fdded42403182ab6b832b71f4edaf7f37a9a0e371a01db5f0cb45f"}, + {file = "uvloop-0.21.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:196274f2adb9689a289ad7d65700d37df0c0930fd8e4e743fa4834e850d7719d"}, + {file = "uvloop-0.21.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f38b2e090258d051d68a5b14d1da7203a3c3677321cf32a95a6f4db4dd8b6f26"}, + {file = "uvloop-0.21.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87c43e0f13022b998eb9b973b5e97200c8b90823454d4bc06ab33829e09fb9bb"}, + {file = "uvloop-0.21.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:10d66943def5fcb6e7b37310eb6b5639fd2ccbc38df1177262b0640c3ca68c1f"}, + {file = "uvloop-0.21.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:67dd654b8ca23aed0a8e99010b4c34aca62f4b7fce88f39d452ed7622c94845c"}, + {file = "uvloop-0.21.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c0f3fa6200b3108919f8bdabb9a7f87f20e7097ea3c543754cabc7d717d95cf8"}, + {file = "uvloop-0.21.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0878c2640cf341b269b7e128b1a5fed890adc4455513ca710d77d5e93aa6d6a0"}, + {file = "uvloop-0.21.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9fb766bb57b7388745d8bcc53a359b116b8a04c83a2288069809d2b3466c37e"}, + {file = "uvloop-0.21.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a375441696e2eda1c43c44ccb66e04d61ceeffcd76e4929e527b7fa401b90fb"}, + {file = "uvloop-0.21.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:baa0e6291d91649c6ba4ed4b2f982f9fa165b5bbd50a9e203c416a2797bab3c6"}, + {file = "uvloop-0.21.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4509360fcc4c3bd2c70d87573ad472de40c13387f5fda8cb58350a1d7475e58d"}, + {file = "uvloop-0.21.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:359ec2c888397b9e592a889c4d72ba3d6befba8b2bb01743f72fffbde663b59c"}, + {file = "uvloop-0.21.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f7089d2dc73179ce5ac255bdf37c236a9f914b264825fdaacaded6990a7fb4c2"}, + {file = "uvloop-0.21.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:baa4dcdbd9ae0a372f2167a207cd98c9f9a1ea1188a8a526431eef2f8116cc8d"}, + {file = "uvloop-0.21.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86975dca1c773a2c9864f4c52c5a55631038e387b47eaf56210f873887b6c8dc"}, + {file = "uvloop-0.21.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:461d9ae6660fbbafedd07559c6a2e57cd553b34b0065b6550685f6653a98c1cb"}, + {file = "uvloop-0.21.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:183aef7c8730e54c9a3ee3227464daed66e37ba13040bb3f350bc2ddc040f22f"}, + {file = "uvloop-0.21.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:bfd55dfcc2a512316e65f16e503e9e450cab148ef11df4e4e679b5e8253a5281"}, + {file = "uvloop-0.21.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:787ae31ad8a2856fc4e7c095341cccc7209bd657d0e71ad0dc2ea83c4a6fa8af"}, + {file = "uvloop-0.21.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ee4d4ef48036ff6e5cfffb09dd192c7a5027153948d85b8da7ff705065bacc6"}, + {file = "uvloop-0.21.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3df876acd7ec037a3d005b3ab85a7e4110422e4d9c1571d4fc89b0fc41b6816"}, + {file = "uvloop-0.21.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:bd53ecc9a0f3d87ab847503c2e1552b690362e005ab54e8a48ba97da3924c0dc"}, + {file = "uvloop-0.21.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a5c39f217ab3c663dc699c04cbd50c13813e31d917642d459fdcec07555cc553"}, + {file = "uvloop-0.21.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:17df489689befc72c39a08359efac29bbee8eee5209650d4b9f34df73d22e414"}, + {file = "uvloop-0.21.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bc09f0ff191e61c2d592a752423c767b4ebb2986daa9ed62908e2b1b9a9ae206"}, + {file = "uvloop-0.21.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f0ce1b49560b1d2d8a2977e3ba4afb2414fb46b86a1b64056bc4ab929efdafbe"}, + {file = "uvloop-0.21.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e678ad6fe52af2c58d2ae3c73dc85524ba8abe637f134bf3564ed07f555c5e79"}, + {file = "uvloop-0.21.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:460def4412e473896ef179a1671b40c039c7012184b627898eea5072ef6f017a"}, + {file = "uvloop-0.21.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:10da8046cc4a8f12c91a1c39d1dd1585c41162a15caaef165c2174db9ef18bdc"}, + {file = "uvloop-0.21.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c097078b8031190c934ed0ebfee8cc5f9ba9642e6eb88322b9958b649750f72b"}, + {file = "uvloop-0.21.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:46923b0b5ee7fc0020bef24afe7836cb068f5050ca04caf6b487c513dc1a20b2"}, + {file = "uvloop-0.21.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:53e420a3afe22cdcf2a0f4846e377d16e718bc70103d7088a4f7623567ba5fb0"}, + {file = "uvloop-0.21.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88cb67cdbc0e483da00af0b2c3cdad4b7c61ceb1ee0f33fe00e09c81e3a6cb75"}, + {file = "uvloop-0.21.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:221f4f2a1f46032b403bf3be628011caf75428ee3cc204a22addf96f586b19fd"}, + {file = "uvloop-0.21.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:2d1f581393673ce119355d56da84fe1dd9d2bb8b3d13ce792524e1607139feff"}, + {file = "uvloop-0.21.0.tar.gz", hash = "sha256:3bf12b0fda68447806a7ad847bfa591613177275d35b6724b1ee573faa3704e3"}, +] + +[package.extras] +dev = ["Cython (>=3.0,<4.0)", "setuptools (>=60)"] docs = ["Sphinx (>=4.1.2,<4.2.0)", "sphinx-rtd-theme (>=0.5.2,<0.6.0)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"] -test = ["Cython (>=0.29.36,<0.30.0)", "aiohttp (==3.9.0b0)", "aiohttp (>=3.8.1)", "flake8 (>=5.0,<6.0)", "mypy (>=0.800)", "psutil", "pyOpenSSL (>=23.0.0,<23.1.0)", "pycodestyle (>=2.9.0,<2.10.0)"] +test = ["aiohttp (>=3.10.5)", "flake8 (>=5.0,<6.0)", "mypy (>=0.800)", "psutil", "pyOpenSSL (>=23.0.0,<23.1.0)", "pycodestyle (>=2.9.0,<2.10.0)"] [[package]] name = "validators" @@ -8770,19 +10107,20 @@ files = [ [[package]] name = "vanna" -version = "0.5.5" +version = "0.7.5" description = "Generate SQL queries from natural language" optional = false python-versions = ">=3.9" files = [ - {file = "vanna-0.5.5-py3-none-any.whl", hash = "sha256:e1a308b7127b9e98c2579c0e4178fc1475d891c498e4a0667cffa10df8891e73"}, - {file = "vanna-0.5.5.tar.gz", hash = "sha256:7d9bf188a635bb75e4f8db15f0e6dbe72a426784779485f087b2df0ce175e664"}, + {file = "vanna-0.7.5-py3-none-any.whl", hash = "sha256:07458c7befa49de517a8760c2d80a13147278b484c515d49a906acc88edcb835"}, + {file = "vanna-0.7.5.tar.gz", hash = "sha256:2fdffc58832898e4fc8e93c45b173424db59a22773b22ca348640161d391eacf"}, ] [package.dependencies] -clickhouse_driver = {version = "*", optional = true, markers = "extra == \"clickhouse\""} +clickhouse_connect = {version = "*", optional = true, markers = "extra == \"clickhouse\""} db-dtypes = {version = "*", optional = true, markers = "extra == \"postgres\""} duckdb = {version = "*", optional = true, markers = "extra == \"duckdb\""} +flasgger = "*" flask = "*" flask-sock = "*" kaleido = "*" @@ -8796,27 +10134,36 @@ sqlparse = "*" tabulate = "*" [package.extras] -all = ["PyMySQL", "anthropic", "chromadb", "db-dtypes", "duckdb", "fastembed", "google-cloud-aiplatform", "google-cloud-bigquery", "google-generativeai", "httpx", "marqo", "mistralai", "ollama", "openai", "opensearch-dsl", "opensearch-py", "pinecone-client", "psycopg2-binary", "qdrant-client", "snowflake-connector-python", "transformers", "zhipuai"] +all = ["PyMySQL", "anthropic", "azure-common", "azure-identity", "azure-search-documents", "boto", "boto3", "botocore", "chromadb", "db-dtypes", "duckdb", "faiss-cpu", "fastembed", "google-cloud-aiplatform", "google-cloud-bigquery", "google-generativeai", "httpx", "langchain_core", "langchain_postgres", "marqo", "mistralai (>=1.0.0)", "ollama", "openai", "opensearch-dsl", "opensearch-py", "pinecone-client", "psycopg2-binary", "pymilvus[model]", "qdrant-client", "qianfan", "snowflake-connector-python", "transformers", "weaviate-client", "xinference-client", "zhipuai"] anthropic = ["anthropic"] +azuresearch = ["azure-common", "azure-identity", "azure-search-documents", "fastembed"] +bedrock = ["boto3", "botocore"] bigquery = ["google-cloud-bigquery"] chromadb = ["chromadb"] -clickhouse = ["clickhouse_driver"] +clickhouse = ["clickhouse_connect"] duckdb = ["duckdb"] +faiss-cpu = ["faiss-cpu"] +faiss-gpu = ["faiss-gpu"] gemini = ["google-generativeai"] google = ["google-cloud-aiplatform", "google-generativeai"] hf = ["transformers"] marqo = ["marqo"] -mistralai = ["mistralai"] +milvus = ["pymilvus[model]"] +mistralai = ["mistralai (>=1.0.0)"] mysql = ["PyMySQL"] ollama = ["httpx", "ollama"] openai = ["openai"] opensearch = ["opensearch-dsl", "opensearch-py"] +pgvector = ["langchain-postgres (>=0.0.12)"] pinecone = ["fastembed", "pinecone-client"] postgres = ["db-dtypes", "psycopg2-binary"] qdrant = ["fastembed", "qdrant-client"] +qianfan = ["qianfan"] snowflake = ["snowflake-connector-python"] test = ["tox"] vllm = ["vllm"] +weaviate = ["weaviate-client"] +xinference-client = ["xinference-client"] zhipuai = ["zhipuai"] [[package]] @@ -8830,88 +10177,138 @@ files = [ {file = "vine-5.1.0.tar.gz", hash = "sha256:8b62e981d35c41049211cf62a0a1242d8c1ee9bd15bb196ce38aefd6799e61e0"}, ] +[[package]] +name = "volcengine-compat" +version = "1.0.156" +description = "Be Compatible with the Volcengine SDK for Python, The version of package dependencies has been modified. like pycryptodome, pytz." +optional = false +python-versions = "*" +files = [ + {file = "volcengine_compat-1.0.156-py3-none-any.whl", hash = "sha256:4abc149a7601ebad8fa2d28fab50c7945145cf74daecb71bca797b0bdc82c5a5"}, + {file = "volcengine_compat-1.0.156.tar.gz", hash = "sha256:e357d096828e31a202dc6047bbc5bf6fff3f54a98cd35a99ab5f965ea741a267"}, +] + +[package.dependencies] +google = ">=3.0.0" +protobuf = ">=3.18.3" +pycryptodome = ">=3.9.9" +pytz = ">=2020.5" +requests = ">=2.25.1" +retry = ">=0.9.2" +six = ">=1.0" + +[[package]] +name = "volcengine-python-sdk" +version = "1.0.103" +description = "Volcengine SDK for Python" +optional = false +python-versions = "*" +files = [ + {file = "volcengine-python-sdk-1.0.103.tar.gz", hash = "sha256:49fa8572802724972e1cb47a7e692b184b055f41b09099358c1a0fad1d146af5"}, +] + +[package.dependencies] +anyio = {version = ">=3.5.0,<5", optional = true, markers = "extra == \"ark\""} +certifi = ">=2017.4.17" +httpx = {version = ">=0.23.0,<1", optional = true, markers = "extra == \"ark\""} +pydantic = {version = ">=1.9.0,<3", optional = true, markers = "extra == \"ark\""} +python-dateutil = ">=2.1" +six = ">=1.10" +urllib3 = ">=1.23" + +[package.extras] +ark = ["anyio (>=3.5.0,<5)", "cached-property", "httpx (>=0.23.0,<1)", "pydantic (>=1.9.0,<3)"] + [[package]] name = "watchfiles" -version = "0.22.0" +version = "0.24.0" description = "Simple, modern and high performance file watching and code reload in python." optional = false python-versions = ">=3.8" files = [ - {file = "watchfiles-0.22.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:da1e0a8caebf17976e2ffd00fa15f258e14749db5e014660f53114b676e68538"}, - {file = "watchfiles-0.22.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:61af9efa0733dc4ca462347becb82e8ef4945aba5135b1638bfc20fad64d4f0e"}, - {file = "watchfiles-0.22.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d9188979a58a096b6f8090e816ccc3f255f137a009dd4bbec628e27696d67c1"}, - {file = "watchfiles-0.22.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2bdadf6b90c099ca079d468f976fd50062905d61fae183f769637cb0f68ba59a"}, - {file = "watchfiles-0.22.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:067dea90c43bf837d41e72e546196e674f68c23702d3ef80e4e816937b0a3ffd"}, - {file = "watchfiles-0.22.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bbf8a20266136507abf88b0df2328e6a9a7c7309e8daff124dda3803306a9fdb"}, - {file = "watchfiles-0.22.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1235c11510ea557fe21be5d0e354bae2c655a8ee6519c94617fe63e05bca4171"}, - {file = "watchfiles-0.22.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2444dc7cb9d8cc5ab88ebe792a8d75709d96eeef47f4c8fccb6df7c7bc5be71"}, - {file = "watchfiles-0.22.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c5af2347d17ab0bd59366db8752d9e037982e259cacb2ba06f2c41c08af02c39"}, - {file = "watchfiles-0.22.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9624a68b96c878c10437199d9a8b7d7e542feddda8d5ecff58fdc8e67b460848"}, - {file = "watchfiles-0.22.0-cp310-none-win32.whl", hash = "sha256:4b9f2a128a32a2c273d63eb1fdbf49ad64852fc38d15b34eaa3f7ca2f0d2b797"}, - {file = "watchfiles-0.22.0-cp310-none-win_amd64.whl", hash = "sha256:2627a91e8110b8de2406d8b2474427c86f5a62bf7d9ab3654f541f319ef22bcb"}, - {file = "watchfiles-0.22.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:8c39987a1397a877217be1ac0fb1d8b9f662c6077b90ff3de2c05f235e6a8f96"}, - {file = "watchfiles-0.22.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a927b3034d0672f62fb2ef7ea3c9fc76d063c4b15ea852d1db2dc75fe2c09696"}, - {file = "watchfiles-0.22.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:052d668a167e9fc345c24203b104c313c86654dd6c0feb4b8a6dfc2462239249"}, - {file = "watchfiles-0.22.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5e45fb0d70dda1623a7045bd00c9e036e6f1f6a85e4ef2c8ae602b1dfadf7550"}, - {file = "watchfiles-0.22.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c49b76a78c156979759d759339fb62eb0549515acfe4fd18bb151cc07366629c"}, - {file = "watchfiles-0.22.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4a65474fd2b4c63e2c18ac67a0c6c66b82f4e73e2e4d940f837ed3d2fd9d4da"}, - {file = "watchfiles-0.22.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1cc0cba54f47c660d9fa3218158b8963c517ed23bd9f45fe463f08262a4adae1"}, - {file = "watchfiles-0.22.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94ebe84a035993bb7668f58a0ebf998174fb723a39e4ef9fce95baabb42b787f"}, - {file = "watchfiles-0.22.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e0f0a874231e2839abbf473256efffe577d6ee2e3bfa5b540479e892e47c172d"}, - {file = "watchfiles-0.22.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:213792c2cd3150b903e6e7884d40660e0bcec4465e00563a5fc03f30ea9c166c"}, - {file = "watchfiles-0.22.0-cp311-none-win32.whl", hash = "sha256:b44b70850f0073b5fcc0b31ede8b4e736860d70e2dbf55701e05d3227a154a67"}, - {file = "watchfiles-0.22.0-cp311-none-win_amd64.whl", hash = "sha256:00f39592cdd124b4ec5ed0b1edfae091567c72c7da1487ae645426d1b0ffcad1"}, - {file = "watchfiles-0.22.0-cp311-none-win_arm64.whl", hash = "sha256:3218a6f908f6a276941422b035b511b6d0d8328edd89a53ae8c65be139073f84"}, - {file = "watchfiles-0.22.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:c7b978c384e29d6c7372209cbf421d82286a807bbcdeb315427687f8371c340a"}, - {file = "watchfiles-0.22.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bd4c06100bce70a20c4b81e599e5886cf504c9532951df65ad1133e508bf20be"}, - {file = "watchfiles-0.22.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:425440e55cd735386ec7925f64d5dde392e69979d4c8459f6bb4e920210407f2"}, - {file = "watchfiles-0.22.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:68fe0c4d22332d7ce53ad094622b27e67440dacefbaedd29e0794d26e247280c"}, - {file = "watchfiles-0.22.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a8a31bfd98f846c3c284ba694c6365620b637debdd36e46e1859c897123aa232"}, - {file = "watchfiles-0.22.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dc2e8fe41f3cac0660197d95216c42910c2b7e9c70d48e6d84e22f577d106fc1"}, - {file = "watchfiles-0.22.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55b7cc10261c2786c41d9207193a85c1db1b725cf87936df40972aab466179b6"}, - {file = "watchfiles-0.22.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28585744c931576e535860eaf3f2c0ec7deb68e3b9c5a85ca566d69d36d8dd27"}, - {file = "watchfiles-0.22.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:00095dd368f73f8f1c3a7982a9801190cc88a2f3582dd395b289294f8975172b"}, - {file = "watchfiles-0.22.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:52fc9b0dbf54d43301a19b236b4a4614e610605f95e8c3f0f65c3a456ffd7d35"}, - {file = "watchfiles-0.22.0-cp312-none-win32.whl", hash = "sha256:581f0a051ba7bafd03e17127735d92f4d286af941dacf94bcf823b101366249e"}, - {file = "watchfiles-0.22.0-cp312-none-win_amd64.whl", hash = "sha256:aec83c3ba24c723eac14225194b862af176d52292d271c98820199110e31141e"}, - {file = "watchfiles-0.22.0-cp312-none-win_arm64.whl", hash = "sha256:c668228833c5619f6618699a2c12be057711b0ea6396aeaece4ded94184304ea"}, - {file = "watchfiles-0.22.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:d47e9ef1a94cc7a536039e46738e17cce058ac1593b2eccdede8bf72e45f372a"}, - {file = "watchfiles-0.22.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:28f393c1194b6eaadcdd8f941307fc9bbd7eb567995232c830f6aef38e8a6e88"}, - {file = "watchfiles-0.22.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd64f3a4db121bc161644c9e10a9acdb836853155a108c2446db2f5ae1778c3d"}, - {file = "watchfiles-0.22.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2abeb79209630da981f8ebca30a2c84b4c3516a214451bfc5f106723c5f45843"}, - {file = "watchfiles-0.22.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4cc382083afba7918e32d5ef12321421ef43d685b9a67cc452a6e6e18920890e"}, - {file = "watchfiles-0.22.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d048ad5d25b363ba1d19f92dcf29023988524bee6f9d952130b316c5802069cb"}, - {file = "watchfiles-0.22.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:103622865599f8082f03af4214eaff90e2426edff5e8522c8f9e93dc17caee13"}, - {file = "watchfiles-0.22.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3e1f3cf81f1f823e7874ae563457828e940d75573c8fbf0ee66818c8b6a9099"}, - {file = "watchfiles-0.22.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8597b6f9dc410bdafc8bb362dac1cbc9b4684a8310e16b1ff5eee8725d13dcd6"}, - {file = "watchfiles-0.22.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:0b04a2cbc30e110303baa6d3ddce8ca3664bc3403be0f0ad513d1843a41c97d1"}, - {file = "watchfiles-0.22.0-cp38-none-win32.whl", hash = "sha256:b610fb5e27825b570554d01cec427b6620ce9bd21ff8ab775fc3a32f28bba63e"}, - {file = "watchfiles-0.22.0-cp38-none-win_amd64.whl", hash = "sha256:fe82d13461418ca5e5a808a9e40f79c1879351fcaeddbede094028e74d836e86"}, - {file = "watchfiles-0.22.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:3973145235a38f73c61474d56ad6199124e7488822f3a4fc97c72009751ae3b0"}, - {file = "watchfiles-0.22.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:280a4afbc607cdfc9571b9904b03a478fc9f08bbeec382d648181c695648202f"}, - {file = "watchfiles-0.22.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a0d883351a34c01bd53cfa75cd0292e3f7e268bacf2f9e33af4ecede7e21d1d"}, - {file = "watchfiles-0.22.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9165bcab15f2b6d90eedc5c20a7f8a03156b3773e5fb06a790b54ccecdb73385"}, - {file = "watchfiles-0.22.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dc1b9b56f051209be458b87edb6856a449ad3f803315d87b2da4c93b43a6fe72"}, - {file = "watchfiles-0.22.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8dc1fc25a1dedf2dd952909c8e5cb210791e5f2d9bc5e0e8ebc28dd42fed7562"}, - {file = "watchfiles-0.22.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dc92d2d2706d2b862ce0568b24987eba51e17e14b79a1abcd2edc39e48e743c8"}, - {file = "watchfiles-0.22.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97b94e14b88409c58cdf4a8eaf0e67dfd3ece7e9ce7140ea6ff48b0407a593ec"}, - {file = "watchfiles-0.22.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:96eec15e5ea7c0b6eb5bfffe990fc7c6bd833acf7e26704eb18387fb2f5fd087"}, - {file = "watchfiles-0.22.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:28324d6b28bcb8d7c1041648d7b63be07a16db5510bea923fc80b91a2a6cbed6"}, - {file = "watchfiles-0.22.0-cp39-none-win32.whl", hash = "sha256:8c3e3675e6e39dc59b8fe5c914a19d30029e36e9f99468dddffd432d8a7b1c93"}, - {file = "watchfiles-0.22.0-cp39-none-win_amd64.whl", hash = "sha256:25c817ff2a86bc3de3ed2df1703e3d24ce03479b27bb4527c57e722f8554d971"}, - {file = "watchfiles-0.22.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b810a2c7878cbdecca12feae2c2ae8af59bea016a78bc353c184fa1e09f76b68"}, - {file = "watchfiles-0.22.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:f7e1f9c5d1160d03b93fc4b68a0aeb82fe25563e12fbcdc8507f8434ab6f823c"}, - {file = "watchfiles-0.22.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:030bc4e68d14bcad2294ff68c1ed87215fbd9a10d9dea74e7cfe8a17869785ab"}, - {file = "watchfiles-0.22.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ace7d060432acde5532e26863e897ee684780337afb775107c0a90ae8dbccfd2"}, - {file = "watchfiles-0.22.0-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5834e1f8b71476a26df97d121c0c0ed3549d869124ed2433e02491553cb468c2"}, - {file = "watchfiles-0.22.0-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:0bc3b2f93a140df6806c8467c7f51ed5e55a931b031b5c2d7ff6132292e803d6"}, - {file = "watchfiles-0.22.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8fdebb655bb1ba0122402352b0a4254812717a017d2dc49372a1d47e24073795"}, - {file = "watchfiles-0.22.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c8e0aa0e8cc2a43561e0184c0513e291ca891db13a269d8d47cb9841ced7c71"}, - {file = "watchfiles-0.22.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:2f350cbaa4bb812314af5dab0eb8d538481e2e2279472890864547f3fe2281ed"}, - {file = "watchfiles-0.22.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:7a74436c415843af2a769b36bf043b6ccbc0f8d784814ba3d42fc961cdb0a9dc"}, - {file = "watchfiles-0.22.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00ad0bcd399503a84cc688590cdffbe7a991691314dde5b57b3ed50a41319a31"}, - {file = "watchfiles-0.22.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:72a44e9481afc7a5ee3291b09c419abab93b7e9c306c9ef9108cb76728ca58d2"}, - {file = "watchfiles-0.22.0.tar.gz", hash = "sha256:988e981aaab4f3955209e7e28c7794acdb690be1efa7f16f8ea5aba7ffdadacb"}, + {file = "watchfiles-0.24.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:083dc77dbdeef09fa44bb0f4d1df571d2e12d8a8f985dccde71ac3ac9ac067a0"}, + {file = "watchfiles-0.24.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e94e98c7cb94cfa6e071d401ea3342767f28eb5a06a58fafdc0d2a4974f4f35c"}, + {file = "watchfiles-0.24.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82ae557a8c037c42a6ef26c494d0631cacca040934b101d001100ed93d43f361"}, + {file = "watchfiles-0.24.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:acbfa31e315a8f14fe33e3542cbcafc55703b8f5dcbb7c1eecd30f141df50db3"}, + {file = "watchfiles-0.24.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b74fdffce9dfcf2dc296dec8743e5b0332d15df19ae464f0e249aa871fc1c571"}, + {file = "watchfiles-0.24.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:449f43f49c8ddca87c6b3980c9284cab6bd1f5c9d9a2b00012adaaccd5e7decd"}, + {file = "watchfiles-0.24.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4abf4ad269856618f82dee296ac66b0cd1d71450fc3c98532d93798e73399b7a"}, + {file = "watchfiles-0.24.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f895d785eb6164678ff4bb5cc60c5996b3ee6df3edb28dcdeba86a13ea0465e"}, + {file = "watchfiles-0.24.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7ae3e208b31be8ce7f4c2c0034f33406dd24fbce3467f77223d10cd86778471c"}, + {file = "watchfiles-0.24.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2efec17819b0046dde35d13fb8ac7a3ad877af41ae4640f4109d9154ed30a188"}, + {file = "watchfiles-0.24.0-cp310-none-win32.whl", hash = "sha256:6bdcfa3cd6fdbdd1a068a52820f46a815401cbc2cb187dd006cb076675e7b735"}, + {file = "watchfiles-0.24.0-cp310-none-win_amd64.whl", hash = "sha256:54ca90a9ae6597ae6dc00e7ed0a040ef723f84ec517d3e7ce13e63e4bc82fa04"}, + {file = "watchfiles-0.24.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:bdcd5538e27f188dd3c804b4a8d5f52a7fc7f87e7fd6b374b8e36a4ca03db428"}, + {file = "watchfiles-0.24.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2dadf8a8014fde6addfd3c379e6ed1a981c8f0a48292d662e27cabfe4239c83c"}, + {file = "watchfiles-0.24.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6509ed3f467b79d95fc62a98229f79b1a60d1b93f101e1c61d10c95a46a84f43"}, + {file = "watchfiles-0.24.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8360f7314a070c30e4c976b183d1d8d1585a4a50c5cb603f431cebcbb4f66327"}, + {file = "watchfiles-0.24.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:316449aefacf40147a9efaf3bd7c9bdd35aaba9ac5d708bd1eb5763c9a02bef5"}, + {file = "watchfiles-0.24.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:73bde715f940bea845a95247ea3e5eb17769ba1010efdc938ffcb967c634fa61"}, + {file = "watchfiles-0.24.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3770e260b18e7f4e576edca4c0a639f704088602e0bc921c5c2e721e3acb8d15"}, + {file = "watchfiles-0.24.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa0fd7248cf533c259e59dc593a60973a73e881162b1a2f73360547132742823"}, + {file = "watchfiles-0.24.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d7a2e3b7f5703ffbd500dabdefcbc9eafeff4b9444bbdd5d83d79eedf8428fab"}, + {file = "watchfiles-0.24.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d831ee0a50946d24a53821819b2327d5751b0c938b12c0653ea5be7dea9c82ec"}, + {file = "watchfiles-0.24.0-cp311-none-win32.whl", hash = "sha256:49d617df841a63b4445790a254013aea2120357ccacbed00253f9c2b5dc24e2d"}, + {file = "watchfiles-0.24.0-cp311-none-win_amd64.whl", hash = "sha256:d3dcb774e3568477275cc76554b5a565024b8ba3a0322f77c246bc7111c5bb9c"}, + {file = "watchfiles-0.24.0-cp311-none-win_arm64.whl", hash = "sha256:9301c689051a4857d5b10777da23fafb8e8e921bcf3abe6448a058d27fb67633"}, + {file = "watchfiles-0.24.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:7211b463695d1e995ca3feb38b69227e46dbd03947172585ecb0588f19b0d87a"}, + {file = "watchfiles-0.24.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4b8693502d1967b00f2fb82fc1e744df128ba22f530e15b763c8d82baee15370"}, + {file = "watchfiles-0.24.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cdab9555053399318b953a1fe1f586e945bc8d635ce9d05e617fd9fe3a4687d6"}, + {file = "watchfiles-0.24.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:34e19e56d68b0dad5cff62273107cf5d9fbaf9d75c46277aa5d803b3ef8a9e9b"}, + {file = "watchfiles-0.24.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:41face41f036fee09eba33a5b53a73e9a43d5cb2c53dad8e61fa6c9f91b5a51e"}, + {file = "watchfiles-0.24.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5148c2f1ea043db13ce9b0c28456e18ecc8f14f41325aa624314095b6aa2e9ea"}, + {file = "watchfiles-0.24.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7e4bd963a935aaf40b625c2499f3f4f6bbd0c3776f6d3bc7c853d04824ff1c9f"}, + {file = "watchfiles-0.24.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c79d7719d027b7a42817c5d96461a99b6a49979c143839fc37aa5748c322f234"}, + {file = "watchfiles-0.24.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:32aa53a9a63b7f01ed32e316e354e81e9da0e6267435c7243bf8ae0f10b428ef"}, + {file = "watchfiles-0.24.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ce72dba6a20e39a0c628258b5c308779b8697f7676c254a845715e2a1039b968"}, + {file = "watchfiles-0.24.0-cp312-none-win32.whl", hash = "sha256:d9018153cf57fc302a2a34cb7564870b859ed9a732d16b41a9b5cb2ebed2d444"}, + {file = "watchfiles-0.24.0-cp312-none-win_amd64.whl", hash = "sha256:551ec3ee2a3ac9cbcf48a4ec76e42c2ef938a7e905a35b42a1267fa4b1645896"}, + {file = "watchfiles-0.24.0-cp312-none-win_arm64.whl", hash = "sha256:b52a65e4ea43c6d149c5f8ddb0bef8d4a1e779b77591a458a893eb416624a418"}, + {file = "watchfiles-0.24.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:3d2e3ab79a1771c530233cadfd277fcc762656d50836c77abb2e5e72b88e3a48"}, + {file = "watchfiles-0.24.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:327763da824817b38ad125dcd97595f942d720d32d879f6c4ddf843e3da3fe90"}, + {file = "watchfiles-0.24.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd82010f8ab451dabe36054a1622870166a67cf3fce894f68895db6f74bbdc94"}, + {file = "watchfiles-0.24.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d64ba08db72e5dfd5c33be1e1e687d5e4fcce09219e8aee893a4862034081d4e"}, + {file = "watchfiles-0.24.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1cf1f6dd7825053f3d98f6d33f6464ebdd9ee95acd74ba2c34e183086900a827"}, + {file = "watchfiles-0.24.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:43e3e37c15a8b6fe00c1bce2473cfa8eb3484bbeecf3aefbf259227e487a03df"}, + {file = "watchfiles-0.24.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:88bcd4d0fe1d8ff43675360a72def210ebad3f3f72cabfeac08d825d2639b4ab"}, + {file = "watchfiles-0.24.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:999928c6434372fde16c8f27143d3e97201160b48a614071261701615a2a156f"}, + {file = "watchfiles-0.24.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:30bbd525c3262fd9f4b1865cb8d88e21161366561cd7c9e1194819e0a33ea86b"}, + {file = "watchfiles-0.24.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:edf71b01dec9f766fb285b73930f95f730bb0943500ba0566ae234b5c1618c18"}, + {file = "watchfiles-0.24.0-cp313-none-win32.whl", hash = "sha256:f4c96283fca3ee09fb044f02156d9570d156698bc3734252175a38f0e8975f07"}, + {file = "watchfiles-0.24.0-cp313-none-win_amd64.whl", hash = "sha256:a974231b4fdd1bb7f62064a0565a6b107d27d21d9acb50c484d2cdba515b9366"}, + {file = "watchfiles-0.24.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:ee82c98bed9d97cd2f53bdb035e619309a098ea53ce525833e26b93f673bc318"}, + {file = "watchfiles-0.24.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:fd92bbaa2ecdb7864b7600dcdb6f2f1db6e0346ed425fbd01085be04c63f0b05"}, + {file = "watchfiles-0.24.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f83df90191d67af5a831da3a33dd7628b02a95450e168785586ed51e6d28943c"}, + {file = "watchfiles-0.24.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fca9433a45f18b7c779d2bae7beeec4f740d28b788b117a48368d95a3233ed83"}, + {file = "watchfiles-0.24.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b995bfa6bf01a9e09b884077a6d37070464b529d8682d7691c2d3b540d357a0c"}, + {file = "watchfiles-0.24.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ed9aba6e01ff6f2e8285e5aa4154e2970068fe0fc0998c4380d0e6278222269b"}, + {file = "watchfiles-0.24.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e5171ef898299c657685306d8e1478a45e9303ddcd8ac5fed5bd52ad4ae0b69b"}, + {file = "watchfiles-0.24.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4933a508d2f78099162da473841c652ad0de892719043d3f07cc83b33dfd9d91"}, + {file = "watchfiles-0.24.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:95cf3b95ea665ab03f5a54765fa41abf0529dbaf372c3b83d91ad2cfa695779b"}, + {file = "watchfiles-0.24.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:01def80eb62bd5db99a798d5e1f5f940ca0a05986dcfae21d833af7a46f7ee22"}, + {file = "watchfiles-0.24.0-cp38-none-win32.whl", hash = "sha256:4d28cea3c976499475f5b7a2fec6b3a36208656963c1a856d328aeae056fc5c1"}, + {file = "watchfiles-0.24.0-cp38-none-win_amd64.whl", hash = "sha256:21ab23fdc1208086d99ad3f69c231ba265628014d4aed31d4e8746bd59e88cd1"}, + {file = "watchfiles-0.24.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:b665caeeda58625c3946ad7308fbd88a086ee51ccb706307e5b1fa91556ac886"}, + {file = "watchfiles-0.24.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5c51749f3e4e269231510da426ce4a44beb98db2dce9097225c338f815b05d4f"}, + {file = "watchfiles-0.24.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82b2509f08761f29a0fdad35f7e1638b8ab1adfa2666d41b794090361fb8b855"}, + {file = "watchfiles-0.24.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9a60e2bf9dc6afe7f743e7c9b149d1fdd6dbf35153c78fe3a14ae1a9aee3d98b"}, + {file = "watchfiles-0.24.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f7d9b87c4c55e3ea8881dfcbf6d61ea6775fffed1fedffaa60bd047d3c08c430"}, + {file = "watchfiles-0.24.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:78470906a6be5199524641f538bd2c56bb809cd4bf29a566a75051610bc982c3"}, + {file = "watchfiles-0.24.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:07cdef0c84c03375f4e24642ef8d8178e533596b229d32d2bbd69e5128ede02a"}, + {file = "watchfiles-0.24.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d337193bbf3e45171c8025e291530fb7548a93c45253897cd764a6a71c937ed9"}, + {file = "watchfiles-0.24.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ec39698c45b11d9694a1b635a70946a5bad066b593af863460a8e600f0dff1ca"}, + {file = "watchfiles-0.24.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2e28d91ef48eab0afb939fa446d8ebe77e2f7593f5f463fd2bb2b14132f95b6e"}, + {file = "watchfiles-0.24.0-cp39-none-win32.whl", hash = "sha256:7138eff8baa883aeaa074359daabb8b6c1e73ffe69d5accdc907d62e50b1c0da"}, + {file = "watchfiles-0.24.0-cp39-none-win_amd64.whl", hash = "sha256:b3ef2c69c655db63deb96b3c3e587084612f9b1fa983df5e0c3379d41307467f"}, + {file = "watchfiles-0.24.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:632676574429bee8c26be8af52af20e0c718cc7f5f67f3fb658c71928ccd4f7f"}, + {file = "watchfiles-0.24.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:a2a9891723a735d3e2540651184be6fd5b96880c08ffe1a98bae5017e65b544b"}, + {file = "watchfiles-0.24.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a7fa2bc0efef3e209a8199fd111b8969fe9db9c711acc46636686331eda7dd4"}, + {file = "watchfiles-0.24.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01550ccf1d0aed6ea375ef259706af76ad009ef5b0203a3a4cce0f6024f9b68a"}, + {file = "watchfiles-0.24.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:96619302d4374de5e2345b2b622dc481257a99431277662c30f606f3e22f42be"}, + {file = "watchfiles-0.24.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:85d5f0c7771dcc7a26c7a27145059b6bb0ce06e4e751ed76cdf123d7039b60b5"}, + {file = "watchfiles-0.24.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:951088d12d339690a92cef2ec5d3cfd957692834c72ffd570ea76a6790222777"}, + {file = "watchfiles-0.24.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49fb58bcaa343fedc6a9e91f90195b20ccb3135447dc9e4e2570c3a39565853e"}, + {file = "watchfiles-0.24.0.tar.gz", hash = "sha256:afb72325b74fa7a428c009c1b8be4b4d7c2afedafb2982827ef2156646df2fe1"}, ] [package.dependencies] @@ -9058,13 +10455,13 @@ files = [ [[package]] name = "werkzeug" -version = "3.0.3" +version = "3.0.6" description = "The comprehensive WSGI web application library." optional = false python-versions = ">=3.8" files = [ - {file = "werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8"}, - {file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"}, + {file = "werkzeug-3.0.6-py3-none-any.whl", hash = "sha256:1bc0c2310d2fbb07b1dd1105eba2f7af72f322e1e455f2f93c993bee8c8a5f17"}, + {file = "werkzeug-3.0.6.tar.gz", hash = "sha256:a8dd59d4de28ca70471a34cba79bed5f7ef2e036a76b3ab0835474246eb41f8d"}, ] [package.dependencies] @@ -9087,6 +10484,20 @@ files = [ beautifulsoup4 = "*" requests = ">=2.0.0,<3.0.0" +[[package]] +name = "win32-setctime" +version = "1.1.0" +description = "A small Python utility to set file creation time on Windows" +optional = false +python-versions = ">=3.5" +files = [ + {file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"}, + {file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"}, +] + +[package.extras] +dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] + [[package]] name = "wrapt" version = "1.16.0" @@ -9182,13 +10593,13 @@ h11 = ">=0.9.0,<1" [[package]] name = "xinference-client" -version = "0.13.3" +version = "0.15.2" description = "Client for Xinference" optional = false python-versions = "*" files = [ - {file = "xinference-client-0.13.3.tar.gz", hash = "sha256:822b722100affdff049c27760be7d62ac92de58c87a40d3361066df446ba648f"}, - {file = "xinference_client-0.13.3-py3-none-any.whl", hash = "sha256:f0eff3858b1ebcef2129726f82b09259c177e11db466a7ca23def3d4849c419f"}, + {file = "xinference-client-0.15.2.tar.gz", hash = "sha256:5c2259bb133148d1cc9bd2b8ec6eb8b5bbeba7f11d6252959f4e6cd79baa53ed"}, + {file = "xinference_client-0.15.2-py3-none-any.whl", hash = "sha256:b6275adab695e75e75a33e21e0ad212488fc2d5a4d0f693d544c0e78469abbe3"}, ] [package.dependencies] @@ -9228,112 +10639,114 @@ files = [ [[package]] name = "xmltodict" -version = "0.13.0" +version = "0.14.2" description = "Makes working with XML feel like you are working with JSON" optional = false -python-versions = ">=3.4" +python-versions = ">=3.6" files = [ - {file = "xmltodict-0.13.0-py2.py3-none-any.whl", hash = "sha256:aa89e8fd76320154a40d19a0df04a4695fb9dc5ba977cbb68ab3e4eb225e7852"}, - {file = "xmltodict-0.13.0.tar.gz", hash = "sha256:341595a488e3e01a85a9d8911d8912fd922ede5fecc4dce437eb4b6c8d037e56"}, + {file = "xmltodict-0.14.2-py2.py3-none-any.whl", hash = "sha256:20cc7d723ed729276e808f26fb6b3599f786cbc37e06c65e192ba77c40f20aac"}, + {file = "xmltodict-0.14.2.tar.gz", hash = "sha256:201e7c28bb210e374999d1dde6382923ab0ed1a8a5faeece48ab525b7810a553"}, ] [[package]] name = "yarl" -version = "1.9.4" +version = "1.9.11" description = "Yet another URL library" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a8c1df72eb746f4136fe9a2e72b0c9dc1da1cbd23b5372f94b5820ff8ae30e0e"}, - {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a3a6ed1d525bfb91b3fc9b690c5a21bb52de28c018530ad85093cc488bee2dd2"}, - {file = "yarl-1.9.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c38c9ddb6103ceae4e4498f9c08fac9b590c5c71b0370f98714768e22ac6fa66"}, - {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9e09c9d74f4566e905a0b8fa668c58109f7624db96a2171f21747abc7524234"}, - {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8477c1ee4bd47c57d49621a062121c3023609f7a13b8a46953eb6c9716ca392"}, - {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5ff2c858f5f6a42c2a8e751100f237c5e869cbde669a724f2062d4c4ef93551"}, - {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:357495293086c5b6d34ca9616a43d329317feab7917518bc97a08f9e55648455"}, - {file = "yarl-1.9.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54525ae423d7b7a8ee81ba189f131054defdb122cde31ff17477951464c1691c"}, - {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:801e9264d19643548651b9db361ce3287176671fb0117f96b5ac0ee1c3530d53"}, - {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e516dc8baf7b380e6c1c26792610230f37147bb754d6426462ab115a02944385"}, - {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:7d5aaac37d19b2904bb9dfe12cdb08c8443e7ba7d2852894ad448d4b8f442863"}, - {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:54beabb809ffcacbd9d28ac57b0db46e42a6e341a030293fb3185c409e626b8b"}, - {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bac8d525a8dbc2a1507ec731d2867025d11ceadcb4dd421423a5d42c56818541"}, - {file = "yarl-1.9.4-cp310-cp310-win32.whl", hash = "sha256:7855426dfbddac81896b6e533ebefc0af2f132d4a47340cee6d22cac7190022d"}, - {file = "yarl-1.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:848cd2a1df56ddbffeb375535fb62c9d1645dde33ca4d51341378b3f5954429b"}, - {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:35a2b9396879ce32754bd457d31a51ff0a9d426fd9e0e3c33394bf4b9036b099"}, - {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c7d56b293cc071e82532f70adcbd8b61909eec973ae9d2d1f9b233f3d943f2c"}, - {file = "yarl-1.9.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d8a1c6c0be645c745a081c192e747c5de06e944a0d21245f4cf7c05e457c36e0"}, - {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b3c1ffe10069f655ea2d731808e76e0f452fc6c749bea04781daf18e6039525"}, - {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:549d19c84c55d11687ddbd47eeb348a89df9cb30e1993f1b128f4685cd0ebbf8"}, - {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7409f968456111140c1c95301cadf071bd30a81cbd7ab829169fb9e3d72eae9"}, - {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e23a6d84d9d1738dbc6e38167776107e63307dfc8ad108e580548d1f2c587f42"}, - {file = "yarl-1.9.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8b889777de69897406c9fb0b76cdf2fd0f31267861ae7501d93003d55f54fbe"}, - {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:03caa9507d3d3c83bca08650678e25364e1843b484f19986a527630ca376ecce"}, - {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4e9035df8d0880b2f1c7f5031f33f69e071dfe72ee9310cfc76f7b605958ceb9"}, - {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:c0ec0ed476f77db9fb29bca17f0a8fcc7bc97ad4c6c1d8959c507decb22e8572"}, - {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:ee04010f26d5102399bd17f8df8bc38dc7ccd7701dc77f4a68c5b8d733406958"}, - {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49a180c2e0743d5d6e0b4d1a9e5f633c62eca3f8a86ba5dd3c471060e352ca98"}, - {file = "yarl-1.9.4-cp311-cp311-win32.whl", hash = "sha256:81eb57278deb6098a5b62e88ad8281b2ba09f2f1147c4767522353eaa6260b31"}, - {file = "yarl-1.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:d1d2532b340b692880261c15aee4dc94dd22ca5d61b9db9a8a361953d36410b1"}, - {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0d2454f0aef65ea81037759be5ca9947539667eecebca092733b2eb43c965a81"}, - {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:44d8ffbb9c06e5a7f529f38f53eda23e50d1ed33c6c869e01481d3fafa6b8142"}, - {file = "yarl-1.9.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aaaea1e536f98754a6e5c56091baa1b6ce2f2700cc4a00b0d49eca8dea471074"}, - {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3777ce5536d17989c91696db1d459574e9a9bd37660ea7ee4d3344579bb6f129"}, - {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fc5fc1eeb029757349ad26bbc5880557389a03fa6ada41703db5e068881e5f2"}, - {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea65804b5dc88dacd4a40279af0cdadcfe74b3e5b4c897aa0d81cf86927fee78"}, - {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa102d6d280a5455ad6a0f9e6d769989638718e938a6a0a2ff3f4a7ff8c62cc4"}, - {file = "yarl-1.9.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09efe4615ada057ba2d30df871d2f668af661e971dfeedf0c159927d48bbeff0"}, - {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:008d3e808d03ef28542372d01057fd09168419cdc8f848efe2804f894ae03e51"}, - {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6f5cb257bc2ec58f437da2b37a8cd48f666db96d47b8a3115c29f316313654ff"}, - {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:992f18e0ea248ee03b5a6e8b3b4738850ae7dbb172cc41c966462801cbf62cf7"}, - {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0e9d124c191d5b881060a9e5060627694c3bdd1fe24c5eecc8d5d7d0eb6faabc"}, - {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3986b6f41ad22988e53d5778f91855dc0399b043fc8946d4f2e68af22ee9ff10"}, - {file = "yarl-1.9.4-cp312-cp312-win32.whl", hash = "sha256:4b21516d181cd77ebd06ce160ef8cc2a5e9ad35fb1c5930882baff5ac865eee7"}, - {file = "yarl-1.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:a9bd00dc3bc395a662900f33f74feb3e757429e545d831eef5bb280252631984"}, - {file = "yarl-1.9.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:63b20738b5aac74e239622d2fe30df4fca4942a86e31bf47a81a0e94c14df94f"}, - {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7d7f7de27b8944f1fee2c26a88b4dabc2409d2fea7a9ed3df79b67277644e17"}, - {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c74018551e31269d56fab81a728f683667e7c28c04e807ba08f8c9e3bba32f14"}, - {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ca06675212f94e7a610e85ca36948bb8fc023e458dd6c63ef71abfd482481aa5"}, - {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5aef935237d60a51a62b86249839b51345f47564208c6ee615ed2a40878dccdd"}, - {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b134fd795e2322b7684155b7855cc99409d10b2e408056db2b93b51a52accc7"}, - {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d25039a474c4c72a5ad4b52495056f843a7ff07b632c1b92ea9043a3d9950f6e"}, - {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f7d6b36dd2e029b6bcb8a13cf19664c7b8e19ab3a58e0fefbb5b8461447ed5ec"}, - {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:957b4774373cf6f709359e5c8c4a0af9f6d7875db657adb0feaf8d6cb3c3964c"}, - {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d7eeb6d22331e2fd42fce928a81c697c9ee2d51400bd1a28803965883e13cead"}, - {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:6a962e04b8f91f8c4e5917e518d17958e3bdee71fd1d8b88cdce74dd0ebbf434"}, - {file = "yarl-1.9.4-cp37-cp37m-win32.whl", hash = "sha256:f3bc6af6e2b8f92eced34ef6a96ffb248e863af20ef4fde9448cc8c9b858b749"}, - {file = "yarl-1.9.4-cp37-cp37m-win_amd64.whl", hash = "sha256:ad4d7a90a92e528aadf4965d685c17dacff3df282db1121136c382dc0b6014d2"}, - {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ec61d826d80fc293ed46c9dd26995921e3a82146feacd952ef0757236fc137be"}, - {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8be9e837ea9113676e5754b43b940b50cce76d9ed7d2461df1af39a8ee674d9f"}, - {file = "yarl-1.9.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bef596fdaa8f26e3d66af846bbe77057237cb6e8efff8cd7cc8dff9a62278bbf"}, - {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d47552b6e52c3319fede1b60b3de120fe83bde9b7bddad11a69fb0af7db32f1"}, - {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84fc30f71689d7fc9168b92788abc977dc8cefa806909565fc2951d02f6b7d57"}, - {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa9741085f635934f3a2583e16fcf62ba835719a8b2b28fb2917bb0537c1dfa"}, - {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:206a55215e6d05dbc6c98ce598a59e6fbd0c493e2de4ea6cc2f4934d5a18d130"}, - {file = "yarl-1.9.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07574b007ee20e5c375a8fe4a0789fad26db905f9813be0f9fef5a68080de559"}, - {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5a2e2433eb9344a163aced6a5f6c9222c0786e5a9e9cac2c89f0b28433f56e23"}, - {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:6ad6d10ed9b67a382b45f29ea028f92d25bc0bc1daf6c5b801b90b5aa70fb9ec"}, - {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:6fe79f998a4052d79e1c30eeb7d6c1c1056ad33300f682465e1b4e9b5a188b78"}, - {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a825ec844298c791fd28ed14ed1bffc56a98d15b8c58a20e0e08c1f5f2bea1be"}, - {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8619d6915b3b0b34420cf9b2bb6d81ef59d984cb0fde7544e9ece32b4b3043c3"}, - {file = "yarl-1.9.4-cp38-cp38-win32.whl", hash = "sha256:686a0c2f85f83463272ddffd4deb5e591c98aac1897d65e92319f729c320eece"}, - {file = "yarl-1.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:a00862fb23195b6b8322f7d781b0dc1d82cb3bcac346d1e38689370cc1cc398b"}, - {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:604f31d97fa493083ea21bd9b92c419012531c4e17ea6da0f65cacdcf5d0bd27"}, - {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8a854227cf581330ffa2c4824d96e52ee621dd571078a252c25e3a3b3d94a1b1"}, - {file = "yarl-1.9.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ba6f52cbc7809cd8d74604cce9c14868306ae4aa0282016b641c661f981a6e91"}, - {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a6327976c7c2f4ee6816eff196e25385ccc02cb81427952414a64811037bbc8b"}, - {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8397a3817d7dcdd14bb266283cd1d6fc7264a48c186b986f32e86d86d35fbac5"}, - {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0381b4ce23ff92f8170080c97678040fc5b08da85e9e292292aba67fdac6c34"}, - {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23d32a2594cb5d565d358a92e151315d1b2268bc10f4610d098f96b147370136"}, - {file = "yarl-1.9.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ddb2a5c08a4eaaba605340fdee8fc08e406c56617566d9643ad8bf6852778fc7"}, - {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:26a1dc6285e03f3cc9e839a2da83bcbf31dcb0d004c72d0730e755b33466c30e"}, - {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:18580f672e44ce1238b82f7fb87d727c4a131f3a9d33a5e0e82b793362bf18b4"}, - {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:29e0f83f37610f173eb7e7b5562dd71467993495e568e708d99e9d1944f561ec"}, - {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:1f23e4fe1e8794f74b6027d7cf19dc25f8b63af1483d91d595d4a07eca1fb26c"}, - {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db8e58b9d79200c76956cefd14d5c90af54416ff5353c5bfd7cbe58818e26ef0"}, - {file = "yarl-1.9.4-cp39-cp39-win32.whl", hash = "sha256:c7224cab95645c7ab53791022ae77a4509472613e839dab722a72abe5a684575"}, - {file = "yarl-1.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:824d6c50492add5da9374875ce72db7a0733b29c2394890aef23d533106e2b15"}, - {file = "yarl-1.9.4-py3-none-any.whl", hash = "sha256:928cecb0ef9d5a7946eb6ff58417ad2fe9375762382f1bf5c55e61645f2c43ad"}, - {file = "yarl-1.9.4.tar.gz", hash = "sha256:566db86717cf8080b99b58b083b773a908ae40f06681e87e589a976faf8246bf"}, + {file = "yarl-1.9.11-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:79e08c691deae6fcac2fdde2e0515ac561dd3630d7c8adf7b1e786e22f1e193b"}, + {file = "yarl-1.9.11-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:752f4b5cf93268dc73c2ae994cc6d684b0dad5118bc87fbd965fd5d6dca20f45"}, + {file = "yarl-1.9.11-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:441049d3a449fb8756b0535be72c6a1a532938a33e1cf03523076700a5f87a01"}, + {file = "yarl-1.9.11-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3dfe17b4aed832c627319da22a33f27f282bd32633d6b145c726d519c89fbaf"}, + {file = "yarl-1.9.11-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:67abcb7df27952864440c9c85f1c549a4ad94afe44e2655f77d74b0d25895454"}, + {file = "yarl-1.9.11-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6de3fa29e76fd1518a80e6af4902c44f3b1b4d7fed28eb06913bba4727443de3"}, + {file = "yarl-1.9.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fee45b3bd4d8d5786472e056aa1359cc4dc9da68aded95a10cd7929a0ec661fe"}, + {file = "yarl-1.9.11-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c59b23886234abeba62087fd97d10fb6b905d9e36e2f3465d1886ce5c0ca30df"}, + {file = "yarl-1.9.11-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d93c612b2024ac25a3dc01341fd98fdd19c8c5e2011f3dcd084b3743cba8d756"}, + {file = "yarl-1.9.11-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:4d368e3b9ecd50fa22017a20c49e356471af6ae91c4d788c6e9297e25ddf5a62"}, + {file = "yarl-1.9.11-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:5b593acd45cdd4cf6664d342ceacedf25cd95263b83b964fddd6c78930ea5211"}, + {file = "yarl-1.9.11-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:224f8186c220ff00079e64bf193909829144d4e5174bb58665ef0da8bf6955c4"}, + {file = "yarl-1.9.11-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:91c478741d7563a12162f7a2db96c0d23d93b0521563f1f1f0ece46ea1702d33"}, + {file = "yarl-1.9.11-cp310-cp310-win32.whl", hash = "sha256:1cdb8f5bb0534986776a43df84031da7ff04ac0cf87cb22ae8a6368231949c40"}, + {file = "yarl-1.9.11-cp310-cp310-win_amd64.whl", hash = "sha256:498439af143b43a2b2314451ffd0295410aa0dcbdac5ee18fc8633da4670b605"}, + {file = "yarl-1.9.11-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:9e290de5db4fd4859b4ed57cddfe793fcb218504e65781854a8ac283ab8d5518"}, + {file = "yarl-1.9.11-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e5f50a2e26cc2b89186f04c97e0ec0ba107ae41f1262ad16832d46849864f914"}, + {file = "yarl-1.9.11-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b4a0e724a28d7447e4d549c8f40779f90e20147e94bf949d490402eee09845c6"}, + {file = "yarl-1.9.11-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85333d38a4fa5997fa2ff6fd169be66626d814b34fa35ec669e8c914ca50a097"}, + {file = "yarl-1.9.11-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6ff184002ee72e4b247240e35d5dce4c2d9a0e81fdbef715dde79ab4718aa541"}, + {file = "yarl-1.9.11-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:675004040f847c0284827f44a1fa92d8baf425632cc93e7e0aa38408774b07c1"}, + {file = "yarl-1.9.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b30703a7ade2b53f02e09a30685b70cd54f65ed314a8d9af08670c9a5391af1b"}, + {file = "yarl-1.9.11-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7230007ab67d43cf19200ec15bc6b654e6b85c402f545a6fc565d254d34ff754"}, + {file = "yarl-1.9.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8c2cf0c7ad745e1c6530fe6521dfb19ca43338239dfcc7da165d0ef2332c0882"}, + {file = "yarl-1.9.11-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4567cc08f479ad80fb07ed0c9e1bcb363a4f6e3483a490a39d57d1419bf1c4c7"}, + {file = "yarl-1.9.11-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:95adc179a02949c4560ef40f8f650a008380766eb253d74232eb9c024747c111"}, + {file = "yarl-1.9.11-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:755ae9cff06c429632d750aa8206f08df2e3d422ca67be79567aadbe74ae64cc"}, + {file = "yarl-1.9.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:94f71d54c5faf715e92c8434b4a0b968c4d1043469954d228fc031d51086f143"}, + {file = "yarl-1.9.11-cp311-cp311-win32.whl", hash = "sha256:4ae079573efeaa54e5978ce86b77f4175cd32f42afcaf9bfb8a0677e91f84e4e"}, + {file = "yarl-1.9.11-cp311-cp311-win_amd64.whl", hash = "sha256:9fae7ec5c9a4fe22abb995804e6ce87067dfaf7e940272b79328ce37c8f22097"}, + {file = "yarl-1.9.11-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:614fa50fd0db41b79f426939a413d216cdc7bab8d8c8a25844798d286a999c5a"}, + {file = "yarl-1.9.11-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ff64f575d71eacb5a4d6f0696bfe991993d979423ea2241f23ab19ff63f0f9d1"}, + {file = "yarl-1.9.11-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5c23f6dc3d7126b4c64b80aa186ac2bb65ab104a8372c4454e462fb074197bc6"}, + {file = "yarl-1.9.11-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b8f847cc092c2b85d22e527f91ea83a6cf51533e727e2461557a47a859f96734"}, + {file = "yarl-1.9.11-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:63a5dc2866791236779d99d7a422611d22bb3a3d50935bafa4e017ea13e51469"}, + {file = "yarl-1.9.11-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c335342d482e66254ae94b1231b1532790afb754f89e2e0c646f7f19d09740aa"}, + {file = "yarl-1.9.11-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4a8c3dedd081cca134a21179aebe58b6e426e8d1e0202da9d1cafa56e01af3c"}, + {file = "yarl-1.9.11-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:504d19320c92532cabc3495fb7ed6bb599f3c2bfb45fed432049bf4693dbd6d0"}, + {file = "yarl-1.9.11-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0b2a8e5eb18181060197e3d5db7e78f818432725c0759bc1e5a9d603d9246389"}, + {file = "yarl-1.9.11-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:f568d70b7187f4002b6b500c0996c37674a25ce44b20716faebe5fdb8bd356e7"}, + {file = "yarl-1.9.11-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:735b285ea46ca7e86ad261a462a071d0968aade44e1a3ea2b7d4f3d63b5aab12"}, + {file = "yarl-1.9.11-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:2d1c81c3b92bef0c1c180048e43a5a85754a61b4f69d6f84df8e4bd615bef25d"}, + {file = "yarl-1.9.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:8d6e1c1562b53bd26efd38e886fc13863b8d904d559426777990171020c478a9"}, + {file = "yarl-1.9.11-cp312-cp312-win32.whl", hash = "sha256:aeba4aaa59cb709edb824fa88a27cbbff4e0095aaf77212b652989276c493c00"}, + {file = "yarl-1.9.11-cp312-cp312-win_amd64.whl", hash = "sha256:569309a3efb8369ff5d32edb2a0520ebaf810c3059f11d34477418c90aa878fd"}, + {file = "yarl-1.9.11-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:4915818ac850c3b0413e953af34398775b7a337babe1e4d15f68c8f5c4872553"}, + {file = "yarl-1.9.11-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ef9610b2f5a73707d4d8bac040f0115ca848e510e3b1f45ca53e97f609b54130"}, + {file = "yarl-1.9.11-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:47c0a3dc8076a8dd159de10628dea04215bc7ddaa46c5775bf96066a0a18f82b"}, + {file = "yarl-1.9.11-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:545f2fbfa0c723b446e9298b5beba0999ff82ce2c126110759e8dac29b5deaf4"}, + {file = "yarl-1.9.11-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9137975a4ccc163ad5d7a75aad966e6e4e95dedee08d7995eab896a639a0bce2"}, + {file = "yarl-1.9.11-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0b0c70c451d2a86f8408abced5b7498423e2487543acf6fcf618b03f6e669b0a"}, + {file = "yarl-1.9.11-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce2bd986b1e44528677c237b74d59f215c8bfcdf2d69442aa10f62fd6ab2951c"}, + {file = "yarl-1.9.11-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8d7b717f77846a9631046899c6cc730ea469c0e2fb252ccff1cc119950dbc296"}, + {file = "yarl-1.9.11-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:3a26a24bbd19241283d601173cea1e5b93dec361a223394e18a1e8e5b0ef20bd"}, + {file = "yarl-1.9.11-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:c189bf01af155ac9882e128d9f3b3ad68a1f2c2f51404afad7201305df4e12b1"}, + {file = "yarl-1.9.11-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:0cbcc2c54084b2bda4109415631db017cf2960f74f9e8fd1698e1400e4f8aae2"}, + {file = "yarl-1.9.11-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:30f201bc65941a4aa59c1236783efe89049ec5549dafc8cd2b63cc179d3767b0"}, + {file = "yarl-1.9.11-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:922ba3b74f0958a0b5b9c14ff1ef12714a381760c08018f2b9827632783a590c"}, + {file = "yarl-1.9.11-cp313-cp313-win32.whl", hash = "sha256:17107b4b8c43e66befdcbe543fff2f9c93f7a3a9f8e3a9c9ac42bffeba0e8828"}, + {file = "yarl-1.9.11-cp313-cp313-win_amd64.whl", hash = "sha256:0324506afab4f2e176a93cb08b8abcb8b009e1f324e6cbced999a8f5dd9ddb76"}, + {file = "yarl-1.9.11-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:4e4f820fde9437bb47297194f43d29086433e6467fa28fe9876366ad357bd7bb"}, + {file = "yarl-1.9.11-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:dfa9b9d5c9c0dbe69670f5695264452f5e40947590ec3a38cfddc9640ae8ff89"}, + {file = "yarl-1.9.11-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e700eb26635ce665c018c8cfea058baff9b843ed0cc77aa61849d807bb82a64c"}, + {file = "yarl-1.9.11-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c305c1bdf10869b5e51facf50bd5b15892884aeae81962ae4ba061fc11217103"}, + {file = "yarl-1.9.11-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c5b7b307140231ea4f7aad5b69355aba2a67f2d7bc34271cffa3c9c324d35b27"}, + {file = "yarl-1.9.11-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a744bdeda6c86cf3025c94eb0e01ccabe949cf385cd75b6576a3ac9669404b68"}, + {file = "yarl-1.9.11-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e8ed183c7a8f75e40068333fc185566472a8f6c77a750cf7541e11810576ea5"}, + {file = "yarl-1.9.11-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c1db9a4384694b5d20bdd9cb53f033b0831ac816416ab176c8d0997835015d22"}, + {file = "yarl-1.9.11-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:70194da6e99713250aa3f335a7fa246b36adf53672a2bcd0ddaa375d04e53dc0"}, + {file = "yarl-1.9.11-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:ddad5cfcda729e22422bb1c85520bdf2770ce6d975600573ac9017fe882f4b7e"}, + {file = "yarl-1.9.11-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:ca35996e0a4bed28fa0640d9512d37952f6b50dea583bcc167d4f0b1e112ac7f"}, + {file = "yarl-1.9.11-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:61ec0e80970b21a8f3c4b97fa6c6d181c6c6a135dbc7b4a601a78add3feeb209"}, + {file = "yarl-1.9.11-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:9636e4519f6c7558fdccf8f91e6e3b98df2340dc505c4cc3286986d33f2096c2"}, + {file = "yarl-1.9.11-cp38-cp38-win32.whl", hash = "sha256:58081cea14b8feda57c7ce447520e9d0a96c4d010cce54373d789c13242d7083"}, + {file = "yarl-1.9.11-cp38-cp38-win_amd64.whl", hash = "sha256:7d2dee7d6485807c0f64dd5eab9262b7c0b34f760e502243dd83ec09d647d5e1"}, + {file = "yarl-1.9.11-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:d65ad67f981e93ea11f87815f67d086c4f33da4800cf2106d650dd8a0b79dda4"}, + {file = "yarl-1.9.11-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:752c0d33b4aacdb147871d0754b88f53922c6dc2aff033096516b3d5f0c02a0f"}, + {file = "yarl-1.9.11-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:54cc24be98d7f4ff355ca2e725a577e19909788c0db6beead67a0dda70bd3f82"}, + {file = "yarl-1.9.11-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c82126817492bb2ebc946e74af1ffa10aacaca81bee360858477f96124be39a"}, + {file = "yarl-1.9.11-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8503989860d7ac10c85cb5b607fec003a45049cf7a5b4b72451e87893c6bb990"}, + {file = "yarl-1.9.11-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:475e09a67f8b09720192a170ad9021b7abf7827ffd4f3a83826317a705be06b7"}, + {file = "yarl-1.9.11-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afcac5bda602b74ff701e1f683feccd8cce0d5a21dbc68db81bf9bd8fd93ba56"}, + {file = "yarl-1.9.11-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aaeffcb84faceb2923a94a8a9aaa972745d3c728ab54dd011530cc30a3d5d0c1"}, + {file = "yarl-1.9.11-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:51a6f770ac86477cd5c553f88a77a06fe1f6f3b643b053fcc7902ab55d6cbe14"}, + {file = "yarl-1.9.11-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:3fcd056cb7dff3aea5b1ee1b425b0fbaa2fbf6a1c6003e88caf524f01de5f395"}, + {file = "yarl-1.9.11-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:21e56c30e39a1833e4e3fd0112dde98c2abcbc4c39b077e6105c76bb63d2aa04"}, + {file = "yarl-1.9.11-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:0a205ec6349879f5e75dddfb63e069a24f726df5330b92ce76c4752a436aac01"}, + {file = "yarl-1.9.11-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:a5706821e1cf3c70dfea223e4e0958ea354f4e2af9420a1bd45c6b547297fb97"}, + {file = "yarl-1.9.11-cp39-cp39-win32.whl", hash = "sha256:cc295969f8c2172b5d013c0871dccfec7a0e1186cf961e7ea575d47b4d5cbd32"}, + {file = "yarl-1.9.11-cp39-cp39-win_amd64.whl", hash = "sha256:55a67dd29367ce7c08a0541bb602ec0a2c10d46c86b94830a1a665f7fd093dfa"}, + {file = "yarl-1.9.11-py3-none-any.whl", hash = "sha256:c6f6c87665a9e18a635f0545ea541d9640617832af2317d4f5ad389686b4ed3d"}, + {file = "yarl-1.9.11.tar.gz", hash = "sha256:c7548a90cb72b67652e2cd6ae80e2683ee08fde663104528ac7df12d8ef271d2"}, ] [package.dependencies] @@ -9342,13 +10755,13 @@ multidict = ">=4.0" [[package]] name = "yfinance" -version = "0.2.41" +version = "0.2.48" description = "Download market data from Yahoo! Finance API" optional = false python-versions = "*" files = [ - {file = "yfinance-0.2.41-py2.py3-none-any.whl", hash = "sha256:2ed7b453cb8568773eb2dbb4d87cc37ff02e5d133f7723ec3e219ab0b86b56d8"}, - {file = "yfinance-0.2.41.tar.gz", hash = "sha256:f94409a1ed4d596b9da8d2dbb498faaabfcf593d5870e1412e17669a212bb345"}, + {file = "yfinance-0.2.48-py2.py3-none-any.whl", hash = "sha256:eda797145faa4536595eb629f869d3616e58ed7e71de36856b19f1abaef71a5b"}, + {file = "yfinance-0.2.48.tar.gz", hash = "sha256:1434cd8bf22f345fa27ef1ed82bfdd291c1bb5b6fe3067118a94e256aa90c4eb"}, ] [package.dependencies] @@ -9370,35 +10783,40 @@ repair = ["scipy (>=1.6.3)"] [[package]] name = "zhipuai" -version = "1.0.7" +version = "2.1.5.20230904" description = "A SDK library for accessing big model apis from ZhipuAI" optional = false -python-versions = ">=3.6" +python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" files = [ - {file = "zhipuai-1.0.7-py3-none-any.whl", hash = "sha256:360c01b8c2698f366061452e86d5a36a5ff68a576ea33940da98e4806f232530"}, - {file = "zhipuai-1.0.7.tar.gz", hash = "sha256:b80f699543d83cce8648acf1ce32bc2725d1c1c443baffa5882abc2cc704d581"}, + {file = "zhipuai-2.1.5.20230904-py3-none-any.whl", hash = "sha256:8485ca452c2f07fea476fb0666abc8fbbdf1b2e4feeee46a3bb3c1a2b51efccd"}, + {file = "zhipuai-2.1.5.20230904.tar.gz", hash = "sha256:2c19dd796b12e2f19b93d8f9be6fd01e85d3320737a187ebf3c75a9806a7c2b5"}, ] [package.dependencies] -cachetools = "*" -dataclasses = "*" -PyJWT = "*" -requests = "*" +cachetools = ">=4.2.2" +httpx = ">=0.23.0" +pydantic = ">=1.9.0,<3.0" +pydantic-core = ">=2.14.6" +pyjwt = ">=2.8.0,<2.9.0" [[package]] name = "zipp" -version = "3.19.2" +version = "3.20.2" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.8" files = [ - {file = "zipp-3.19.2-py3-none-any.whl", hash = "sha256:f091755f667055f2d02b32c53771a7a6c8b47e1fdbc4b72a8b9072b3eef8015c"}, - {file = "zipp-3.19.2.tar.gz", hash = "sha256:bf1dcf6450f873a13e952a29504887c89e6de7506209e5b1bcc3460135d4de19"}, + {file = "zipp-3.20.2-py3-none-any.whl", hash = "sha256:a817ac80d6cf4b23bf7f2828b7cabf326f15a001bea8b1f9b49631780ba28350"}, + {file = "zipp-3.20.2.tar.gz", hash = "sha256:bc9eb26f4506fda01b81bcde0ca78103b6e62f991b381fec825435c836edbc29"}, ] [package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"] +type = ["pytest-mypy"] [[package]] name = "zope-event" @@ -9420,54 +10838,57 @@ test = ["zope.testrunner"] [[package]] name = "zope-interface" -version = "7.0.1" +version = "7.1.1" description = "Interfaces for Python" optional = false python-versions = ">=3.8" files = [ - {file = "zope.interface-7.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ec4e87e6fdc511a535254daa122c20e11959ce043b4e3425494b237692a34f1c"}, - {file = "zope.interface-7.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:51d5713e8e38f2d3ec26e0dfdca398ed0c20abda2eb49ffc15a15a23eb8e5f6d"}, - {file = "zope.interface-7.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea8d51e5eb29e57d34744369cd08267637aa5a0fefc9b5d33775ab7ff2ebf2e3"}, - {file = "zope.interface-7.0.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:55bbcc74dc0c7ab489c315c28b61d7a1d03cf938cc99cc58092eb065f120c3a5"}, - {file = "zope.interface-7.0.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10ebac566dd0cec66f942dc759d46a994a2b3ba7179420f0e2130f88f8a5f400"}, - {file = "zope.interface-7.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:7039e624bcb820f77cc2ff3d1adcce531932990eee16121077eb51d9c76b6c14"}, - {file = "zope.interface-7.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:03bd5c0db82237bbc47833a8b25f1cc090646e212f86b601903d79d7e6b37031"}, - {file = "zope.interface-7.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3f52050c6a10d4a039ec6f2c58e5b3ade5cc570d16cf9d102711e6b8413c90e6"}, - {file = "zope.interface-7.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af0b33f04677b57843d529b9257a475d2865403300b48c67654c40abac2f9f24"}, - {file = "zope.interface-7.0.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:696c2a381fc7876b3056711717dba5eddd07c2c9e5ccd50da54029a1293b6e43"}, - {file = "zope.interface-7.0.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f89a420cf5a6f2aa7849dd59e1ff0e477f562d97cf8d6a1ee03461e1eec39887"}, - {file = "zope.interface-7.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:b59deb0ddc7b431e41d720c00f99d68b52cb9bd1d5605a085dc18f502fe9c47f"}, - {file = "zope.interface-7.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:52f5253cca1b35eaeefa51abd366b87f48f8714097c99b131ba61f3fdbbb58e7"}, - {file = "zope.interface-7.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:88d108d004e0df25224de77ce349a7e73494ea2cb194031f7c9687e68a88ec9b"}, - {file = "zope.interface-7.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c203d82069ba31e1f3bc7ba530b2461ec86366cd4bfc9b95ec6ce58b1b559c34"}, - {file = "zope.interface-7.0.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3f3495462bc0438b76536a0e10d765b168ae636092082531b88340dc40dcd118"}, - {file = "zope.interface-7.0.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:192b7a792e3145ed880ff6b1a206fdb783697cfdb4915083bfca7065ec845e60"}, - {file = "zope.interface-7.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:400d06c9ec8dbcc96f56e79376297e7be07a315605c9a2208720da263d44d76f"}, - {file = "zope.interface-7.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c1dff87b30fd150c61367d0e2cdc49bb55f8b9fd2a303560bbc24b951573ae1"}, - {file = "zope.interface-7.0.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f749ca804648d00eda62fe1098f229b082dfca930d8bad8386e572a6eafa7525"}, - {file = "zope.interface-7.0.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4ec212037becf6d2f705b7ed4538d56980b1e7bba237df0d8995cbbed29961dc"}, - {file = "zope.interface-7.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d33cb526efdc235a2531433fc1287fcb80d807d5b401f9b801b78bf22df560dd"}, - {file = "zope.interface-7.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b419f2144e1762ab845f20316f1df36b15431f2622ebae8a6d5f7e8e712b413c"}, - {file = "zope.interface-7.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03f1452d5d1f279184d5bdb663a3dc39902d9320eceb63276240791e849054b6"}, - {file = "zope.interface-7.0.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ba4b3638d014918b918aa90a9c8370bd74a03abf8fcf9deb353b3a461a59a84"}, - {file = "zope.interface-7.0.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc0615351221926a36a0fbcb2520fb52e0b23e8c22a43754d9cb8f21358c33c0"}, - {file = "zope.interface-7.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:ce6cbb852fb8f2f9bb7b9cdca44e2e37bce783b5f4c167ff82cb5f5128163c8f"}, - {file = "zope.interface-7.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5566fd9271c89ad03d81b0831c37d46ae5e2ed211122c998637130159a120cf1"}, - {file = "zope.interface-7.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:da0cef4d7e3f19c3bd1d71658d6900321af0492fee36ec01b550a10924cffb9c"}, - {file = "zope.interface-7.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f32ca483e6ade23c7caaee9d5ee5d550cf4146e9b68d2fb6c68bac183aa41c37"}, - {file = "zope.interface-7.0.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:da21e7eec49252df34d426c2ee9cf0361c923026d37c24728b0fa4cc0599fd03"}, - {file = "zope.interface-7.0.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a8195b99e650e6f329ce4e5eb22d448bdfef0406404080812bc96e2a05674cb"}, - {file = "zope.interface-7.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:19c829d52e921b9fe0b2c0c6a8f9a2508c49678ee1be598f87d143335b6a35dc"}, - {file = "zope.interface-7.0.1.tar.gz", hash = "sha256:f0f5fda7cbf890371a59ab1d06512da4f2c89a6ea194e595808123c863c38eff"}, + {file = "zope.interface-7.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6650bd56ef350d37c8baccfd3ee8a0483ed6f8666e641e4b9ae1a1827b79f9e5"}, + {file = "zope.interface-7.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:84e87eba6b77a3af187bae82d8de1a7c208c2a04ec9f6bd444fd091b811ad92e"}, + {file = "zope.interface-7.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c4e1b4c06d9abd1037c088dae1566c85f344a3e6ae4350744c3f7f7259d9c67"}, + {file = "zope.interface-7.1.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7cd5e3d910ac87652a09f6e5db8e41bc3b49cf08ddd2d73d30afc644801492cd"}, + {file = "zope.interface-7.1.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca95594d936ee349620900be5b46c0122a1ff6ce42d7d5cb2cf09dc84071ef16"}, + {file = "zope.interface-7.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:ad339509dcfbbc99bf8e147db6686249c4032f26586699ec4c82f6e5909c9fe2"}, + {file = "zope.interface-7.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3e59f175e868f856a77c0a77ba001385c377df2104fdbda6b9f99456a01e102a"}, + {file = "zope.interface-7.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0de23bcb93401994ea00bc5c677ef06d420340ac0a4e9c10d80e047b9ce5af3f"}, + {file = "zope.interface-7.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5cdb7e7e5524b76d3ec037c1d81a9e2c7457b240fd4cb0a2476b65c3a5a6c81f"}, + {file = "zope.interface-7.1.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3603ef82a9920bd0bfb505423cb7e937498ad971ad5a6141841e8f76d2fd5446"}, + {file = "zope.interface-7.1.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f1d52d052355e0c5c89e0630dd2ff7c0b823fd5f56286a663e92444761b35e25"}, + {file = "zope.interface-7.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:179ad46ece518c9084cb272e4a69d266b659f7f8f48e51706746c2d8a426433e"}, + {file = "zope.interface-7.1.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e6503534b52bb1720ace9366ee30838a58a3413d3e197512f3338c8f34b5d89d"}, + {file = "zope.interface-7.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f85b290e5b8b11814efb0d004d8ce6c9a483c35c462e8d9bf84abb93e79fa770"}, + {file = "zope.interface-7.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d029fac6a80edae80f79c37e5e3abfa92968fe921886139b3ee470a1b177321a"}, + {file = "zope.interface-7.1.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5836b8fb044c6e75ba34dfaabc602493019eadfa0faf6ff25f4c4c356a71a853"}, + {file = "zope.interface-7.1.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7395f13533318f150ee72adb55b29284b16e73b6d5f02ab21f173b3e83f242b8"}, + {file = "zope.interface-7.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:1d0e23c6b746eb8ce04573cc47bcac60961ac138885d207bd6f57e27a1431ae8"}, + {file = "zope.interface-7.1.1-cp313-cp313-macosx_10_9_x86_64.whl", hash = "sha256:9fad9bd5502221ab179f13ea251cb30eef7cf65023156967f86673aff54b53a0"}, + {file = "zope.interface-7.1.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:55c373becbd36a44d0c9be1d5271422fdaa8562d158fb44b4192297b3c67096c"}, + {file = "zope.interface-7.1.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ed1df8cc01dd1e3970666a7370b8bfc7457371c58ba88c57bd5bca17ab198053"}, + {file = "zope.interface-7.1.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:99c14f0727c978639139e6cad7a60e82b7720922678d75aacb90cf4ef74a068c"}, + {file = "zope.interface-7.1.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9b1eed7670d564f1025d7cda89f99f216c30210e42e95de466135be0b4a499d9"}, + {file = "zope.interface-7.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:3defc925c4b22ac1272d544a49c6ba04c3eefcce3200319ee1be03d9270306dd"}, + {file = "zope.interface-7.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8d0fe45be57b5219aa4b96e846631c04615d5ef068146de5a02ccd15c185321f"}, + {file = "zope.interface-7.1.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bcbeb44fc16e0078b3b68a95e43f821ae34dcbf976dde6985141838a5f23dd3d"}, + {file = "zope.interface-7.1.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c8e7b05dc6315a193cceaec071cc3cf1c180cea28808ccded0b1283f1c38ba73"}, + {file = "zope.interface-7.1.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d553e02b68c0ea5a226855f02edbc9eefd99f6a8886fa9f9bdf999d77f46585"}, + {file = "zope.interface-7.1.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81744a7e61b598ebcf4722ac56a7a4f50502432b5b4dc7eb29075a89cf82d029"}, + {file = "zope.interface-7.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:7720322763aceb5e0a7cadcc38c67b839efe599f0887cbf6c003c55b1458c501"}, + {file = "zope.interface-7.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1a2ed0852c25950cf430067f058f8d98df6288502ac313861d9803fe7691a9b3"}, + {file = "zope.interface-7.1.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9595e478047ce752b35cfa221d7601a5283ccdaab40422e0dc1d4a334c70f580"}, + {file = "zope.interface-7.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2317e1d4dba68203a5227ea3057f9078ec9376275f9700086b8f0ffc0b358e1b"}, + {file = "zope.interface-7.1.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6821ef9870f32154da873fcde439274f99814ea452dd16b99fa0b66345c4b6b"}, + {file = "zope.interface-7.1.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:190eeec67e023d5aac54d183fa145db0b898664234234ac54643a441da434616"}, + {file = "zope.interface-7.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:d17e7fc814eaab93409b80819fd6d30342844345c27f3bc3c4b43c2425a8d267"}, + {file = "zope.interface-7.1.1.tar.gz", hash = "sha256:4284d664ef0ff7b709836d4de7b13d80873dc5faeffc073abdb280058bfac5e3"}, ] [package.dependencies] setuptools = "*" [package.extras] -docs = ["Sphinx", "repoze.sphinx.autointerface", "sphinx-rtd-theme"] -test = ["coverage (>=5.0.3)", "zope.event", "zope.testing"] -testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"] +docs = ["Sphinx", "furo", "repoze.sphinx.autointerface"] +test = ["coverage[toml]", "zope.event", "zope.testing"] +testing = ["coverage[toml]", "zope.event", "zope.testing"] [[package]] name = "zstandard" @@ -9584,4 +11005,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "05dfa6b9bce9ed8ac21caf58eff1596f146080ab2ab6987924b189be673c22cf" +content-hash = "f20bd678044926913dbbc24bd0cf22503a75817aa55f59457ff7822032139b77" diff --git a/api/pyproject.toml b/api/pyproject.toml index 60c1c86d078390..40485a8efacf84 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -6,7 +6,8 @@ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.ruff] -exclude = [ +exclude=[ + "migrations/*", ] line-length = 120 @@ -15,117 +16,111 @@ preview = true select = [ "B", # flake8-bugbear rules "C4", # flake8-comprehensions + "E", # pycodestyle E rules "F", # pyflakes rules + "FURB", # refurb rules "I", # isort rules - "UP", # pyupgrade rules - "B035", # static-key-dict-comprehension - "E101", # mixed-spaces-and-tabs - "E111", # indentation-with-invalid-multiple - "E112", # no-indented-block - "E113", # unexpected-indentation - "E115", # no-indented-block-comment - "E116", # unexpected-indentation-comment - "E117", # over-indented + "N", # pep8-naming + "PT", # flake8-pytest-style rules + "PLC0208", # iteration-over-set + "PLC2801", # unnecessary-dunder-call + "PLC0414", # useless-import-alias + "PLR0402", # manual-from-import + "PLR1711", # useless-return + "PLR1714", # repeated-equality-comparison + "RUF013", # implicit-optional "RUF019", # unnecessary-key-check "RUF100", # unused-noqa "RUF101", # redirected-noqa "S506", # unsafe-yaml-load - "SIM116", # if-else-block-instead-of-dict-lookup - "SIM401", # if-else-block-instead-of-dict-get - "SIM910", # dict-get-with-none-default + "SIM", # flake8-simplify rules + "TRY400", # error-instead-of-exception + "UP", # pyupgrade rules "W191", # tab-indentation "W605", # invalid-escape-sequence - "F601", # multi-value-repeated-key-literal - "F602", # multi-value-repeated-key-variable ] ignore = [ - "F403", # undefined-local-with-import-star - "F405", # undefined-local-with-import-star-usage + "E402", # module-import-not-at-top-of-file + "E711", # none-comparison + "E712", # true-false-comparison + "E721", # type-comparison + "E722", # bare-except + "E731", # lambda-assignment "F821", # undefined-name "F841", # unused-variable + "FURB113", # repeated-append + "FURB152", # math-constant "UP007", # non-pep604-annotation "UP032", # f-string "B005", # strip-with-multi-characters "B006", # mutable-argument-default "B007", # unused-loop-control-variable "B026", # star-arg-unpacking-after-keyword-arg -# "B901", # return-in-generator "B904", # raise-without-from-inside-except "B905", # zip-without-explicit-strict + "N806", # non-lowercase-variable-in-function + "N815", # mixed-case-variable-in-class-scope + "PT011", # pytest-raises-too-broad + "SIM102", # collapsible-if + "SIM103", # needless-bool + "SIM105", # suppressible-exception + "SIM107", # return-in-try-except-finally + "SIM108", # if-else-block-instead-of-if-exp + "SIM113", # eumerate-for-loop + "SIM117", # multiple-with-statements + "SIM210", # if-expr-with-true-false + "SIM300", # yoda-conditions, ] [tool.ruff.lint.per-file-ignores] "app.py" = [ - "F401", # unused-import - "F811", # redefined-while-unused ] "__init__.py" = [ "F401", # unused-import "F811", # redefined-while-unused ] +"configs/*" = [ + "N802", # invalid-function-name +] +"libs/gmpy2_pkcs10aep_cipher.py" = [ + "N803", # invalid-argument-name +] "tests/*" = [ - "F401", # unused-import "F811", # redefined-while-unused + "F401", # unused-import +] + +[tool.ruff.lint.pyflakes] +extend-generics=[ + "_pytest.monkeypatch", + "tests.integration_tests", ] [tool.ruff.format] exclude = [ - "core/**/*.py", - "controllers/**/*.py", - "models/**/*.py", - "migrations/**/*", - "services/**/*.py", - "tasks/**/*.py", - "tests/**/*.py", - "configs/**/*.py", ] -[tool.pytest_env] -OPENAI_API_KEY = "sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii" -UPSTAGE_API_KEY = "up-aaaaaaaaaaaaaaaaaaaa" -AZURE_OPENAI_API_BASE = "https://difyai-openai.openai.azure.com" -AZURE_OPENAI_API_KEY = "xxxxb1707exxxxxxxxxxaaxxxxxf94" -ANTHROPIC_API_KEY = "sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz" -CHATGLM_API_BASE = "http://a.abc.com:11451" -XINFERENCE_SERVER_URL = "http://a.abc.com:11451" -XINFERENCE_GENERATION_MODEL_UID = "generate" -XINFERENCE_CHAT_MODEL_UID = "chat" -XINFERENCE_EMBEDDINGS_MODEL_UID = "embedding" -XINFERENCE_RERANK_MODEL_UID = "rerank" -GOOGLE_API_KEY = "abcdefghijklmnopqrstuvwxyz" -HUGGINGFACE_API_KEY = "hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu" -HUGGINGFACE_TEXT_GEN_ENDPOINT_URL = "a" -HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = "b" -HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL = "c" -MOCK_SWITCH = "true" -CODE_MAX_STRING_LENGTH = "80000" -CODE_EXECUTION_ENDPOINT = "http://127.0.0.1:8194" -CODE_EXECUTION_API_KEY = "dify-sandbox" -FIRECRAWL_API_KEY = "fc-" -TEI_EMBEDDING_SERVER_URL = "http://a.abc.com:11451" -TEI_RERANK_SERVER_URL = "http://a.abc.com:11451" - [tool.poetry] name = "dify-api" package-mode = false ############################################################ -# Main dependencies +# [ Main ] Dependency group ############################################################ [tool.poetry.dependencies] anthropic = "~0.23.1" authlib = "1.3.1" +azure-ai-inference = "~1.0.0b3" +azure-ai-ml = "~1.20.0" azure-identity = "1.16.1" -azure-storage-blob = "12.13.0" beautifulsoup4 = "4.12.2" -boto3 = "1.34.148" +boto3 = "1.35.17" bs4 = "~0.0.1" cachetools = "~5.3.0" -celery = "~5.3.6" +celery = "~5.4.0" chardet = "~5.1.0" cohere = "~5.2.4" -cos-python-sdk-v5 = "1.9.30" dashscope = { version = "~1.17.0", extras = ["tokenizer"] } flask = "~3.0.1" flask-compress = "~1.14" @@ -133,35 +128,36 @@ flask-cors = "~4.0.0" flask-login = "~0.6.3" flask-migrate = "~4.0.5" flask-restful = "~0.3.10" -Flask-SQLAlchemy = "~3.1.1" +flask-sqlalchemy = "~3.1.1" gevent = "~23.9.1" gmpy2 = "~2.2.1" -google-ai-generativelanguage = "0.6.1" +google-ai-generativelanguage = "0.6.9" google-api-core = "2.18.0" google-api-python-client = "2.90.0" google-auth = "2.29.0" google-auth-httplib2 = "0.2.0" google-cloud-aiplatform = "1.49.0" -google-cloud-storage = "2.16.0" -google-generativeai = "0.5.0" +google-generativeai = "0.8.1" googleapis-common-protos = "1.63.0" gunicorn = "~22.0.0" httpx = { version = "~0.27.0", extras = ["socks"] } huggingface-hub = "~0.16.4" jieba = "0.42.1" -langfuse = "^2.36.1" -langsmith = "^0.1.77" +langfuse = "~2.51.3" +langsmith = "~0.1.77" mailchimp-transactional = "~1.0.50" markdown = "~3.5.1" -novita-client = "^0.5.6" +nomic = "~3.1.2" +novita-client = "~0.5.7" numpy = "~1.26.4" -openai = "~1.29.0" -oss2 = "2.18.5" +oci = "~2.135.1" +openai = "~1.52.0" +openpyxl = "~3.1.5" pandas = { version = "~2.2.2", extras = ["performance", "excel"] } psycopg2-binary = "~2.9.6" pycryptodome = "3.19.1" -pydantic = "~2.8.2" -pydantic-settings = "~2.3.4" +pydantic = "~2.9.2" +pydantic-settings = "~2.6.0" pydantic_extra_types = "~2.9.0" pyjwt = "~2.8.0" pypdfium2 = "~4.17.0" @@ -173,84 +169,113 @@ readabilipy = "0.2.0" redis = { version = "~5.0.3", extras = ["hiredis"] } replicate = "~0.22.0" resend = "~0.7.0" -safetensors = "~0.4.3" -scikit-learn = "^1.5.1" +sagemaker = "~2.231.0" +scikit-learn = "~1.5.1" sentry-sdk = { version = "~1.44.1", extras = ["flask"] } sqlalchemy = "~2.0.29" +starlette = "0.41.0" tencentcloud-sdk-python-hunyuan = "~3.0.1158" -tiktoken = "~0.7.0" +tiktoken = "~0.8.0" tokenizers = "~0.15.0" transformers = "~4.35.0" -unstructured = { version = "~0.10.27", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] } +unstructured = { version = "~0.16.1", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] } +validators = "0.21.0" +volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"} websocket-client = "~1.7.0" werkzeug = "~3.0.1" -xinference-client = "0.13.3" +xinference-client = "0.15.2" yarl = "~1.9.4" -zhipuai = "1.0.7" -rank-bm25 = "~0.2.2" -openpyxl = "^3.1.5" -kaleido = "0.2.1" -elasticsearch = "8.14.0" +zhipuai = "~2.1.5" +# Before adding new dependency, consider place it in alphabet order (a-z) and suitable group. ############################################################ -# Tool dependencies required by tool implementations +# [ Indirect ] dependency group +# Related transparent dependencies with pinned version +# required by main implementations ############################################################ +[tool.poetry.group.indirect.dependencies] +kaleido = "0.2.1" +rank-bm25 = "~0.2.2" +safetensors = "~0.4.3" -[tool.poetry.group.tool.dependencies] +############################################################ +# [ Tools ] dependency group +############################################################ +[tool.poetry.group.tools.dependencies] arxiv = "2.1.0" +cloudscraper = "1.2.71" +duckduckgo-search = "~6.3.0" +jsonpath-ng = "1.6.1" matplotlib = "~3.8.2" +mplfonts = "~0.0.8" newspaper3k = "0.2.8" -duckduckgo-search = "^6.2.6" -jsonpath-ng = "1.6.1" +nltk = "3.9.1" numexpr = "~2.9.0" -opensearch-py = "2.4.0" +pydub = "~0.25.1" qrcode = "~7.4.2" twilio = "~9.0.4" -vanna = { version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"] } +vanna = { version = "0.7.5", extras = ["postgres", "mysql", "clickhouse", "duckdb", "oracle"] } wikipedia = "1.4.0" yfinance = "~0.2.40" -cloudscraper = "1.2.71" ############################################################ -# VDB dependencies required by vector store clients +# [ Storage ] dependency group +# Required for storage clients ############################################################ +[tool.poetry.group.storage.dependencies] +azure-storage-blob = "12.13.0" +bce-python-sdk = "~0.9.23" +cos-python-sdk-v5 = "1.9.30" +esdk-obs-python = "3.24.6.1" +google-cloud-storage = "2.16.0" +oss2 = "2.18.5" +supabase = "~2.8.1" +tos = "~2.7.1" +############################################################ +# [ VDB ] dependency group +# Required by vector store clients +############################################################ [tool.poetry.group.vdb.dependencies] +alibabacloud_gpdb20160503 = "~3.8.0" +alibabacloud_tea_openapi = "~0.3.9" chromadb = "0.5.1" +clickhouse-connect = "~0.7.16" +couchbase = "~4.3.0" +elasticsearch = "8.14.0" +opensearch-py = "2.4.0" oracledb = "~2.2.1" pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] } pgvector = "0.2.5" pymilvus = "~2.4.4" -pymysql = "1.1.1" +pymochow = "1.3.1" +pyobvector = "~0.1.6" +qdrant-client = "1.7.3" tcvectordb = "1.3.2" tidb-vector = "0.0.9" -qdrant-client = "1.7.3" +upstash-vector = "0.6.0" +volcengine-compat = "~1.0.156" weaviate-client = "~3.21.0" -alibabacloud_gpdb20160503 = "~3.8.0" -alibabacloud_tea_openapi = "~0.3.9" -clickhouse-connect = "~0.7.16" ############################################################ -# Dev dependencies for running tests +# [ Dev ] dependency group +# Required for development and running tests ############################################################ - [tool.poetry.group.dev] optional = true - [tool.poetry.group.dev.dependencies] coverage = "~7.2.4" -pytest = "~8.1.1" +pytest = "~8.3.2" pytest-benchmark = "~4.0.0" pytest-env = "~1.1.3" pytest-mock = "~3.14.0" ############################################################ -# Lint dependencies for code style linting +# [ Lint ] dependency group +# Required for code style linting ############################################################ - [tool.poetry.group.lint] optional = true - [tool.poetry.group.lint.dependencies] -ruff = "~0.5.7" dotenv-linter = "~0.5.0" +ruff = "~0.6.9" diff --git a/api/pytest.ini b/api/pytest.ini new file mode 100644 index 00000000000000..a23a4b3f3d89c5 --- /dev/null +++ b/api/pytest.ini @@ -0,0 +1,30 @@ +[pytest] +env = + ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz + AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com + AZURE_OPENAI_API_KEY = xxxxb1707exxxxxxxxxxaaxxxxxf94 + CHATGLM_API_BASE = http://a.abc.com:11451 + CODE_EXECUTION_API_KEY = dify-sandbox + CODE_EXECUTION_ENDPOINT = http://127.0.0.1:8194 + CODE_MAX_STRING_LENGTH = 80000 + FIRECRAWL_API_KEY = fc- + FIREWORKS_API_KEY = fw_aaaaaaaaaaaaaaaaaaaa + GOOGLE_API_KEY = abcdefghijklmnopqrstuvwxyz + HUGGINGFACE_API_KEY = hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu + HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL = c + HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = b + HUGGINGFACE_TEXT_GEN_ENDPOINT_URL = a + MIXEDBREAD_API_KEY = mk-aaaaaaaaaaaaaaaaaaaa + MOCK_SWITCH = true + NOMIC_API_KEY = nk-aaaaaaaaaaaaaaaaaaaa + OPENAI_API_KEY = sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii + TEI_EMBEDDING_SERVER_URL = http://a.abc.com:11451 + TEI_RERANK_SERVER_URL = http://a.abc.com:11451 + UPSTAGE_API_KEY = up-aaaaaaaaaaaaaaaaaaaa + VOYAGE_API_KEY = va-aaaaaaaaaaaaaaaaaaaa + XINFERENCE_CHAT_MODEL_UID = chat + XINFERENCE_EMBEDDINGS_MODEL_UID = embedding + XINFERENCE_GENERATION_MODEL_UID = generate + XINFERENCE_RERANK_MODEL_UID = rerank + XINFERENCE_SERVER_URL = http://a.abc.com:11451 + GITEE_AI_API_KEY = aaaaaaaaaaaaaaaaaaaa diff --git a/api/schedule/clean_embedding_cache_task.py b/api/schedule/clean_embedding_cache_task.py index 67d070682867bb..9efe120b7a57fe 100644 --- a/api/schedule/clean_embedding_cache_task.py +++ b/api/schedule/clean_embedding_cache_task.py @@ -14,7 +14,7 @@ @app.celery.task(queue="dataset") def clean_embedding_cache_task(): click.echo(click.style("Start clean embedding cache.", fg="green")) - clean_days = int(dify_config.CLEAN_DAY_SETTING) + clean_days = int(dify_config.PLAN_SANDBOX_CLEAN_DAY_SETTING) start_at = time.perf_counter() thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days) while True: diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index 3d799bfd4ef732..100fd8dfab67ea 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -9,15 +9,19 @@ from configs import dify_config from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db +from extensions.ext_redis import redis_client from models.dataset import Dataset, DatasetQuery, Document +from services.feature_service import FeatureService @app.celery.task(queue="dataset") def clean_unused_datasets_task(): click.echo(click.style("Start clean unused datasets indexes.", fg="green")) - clean_days = dify_config.CLEAN_DAY_SETTING + plan_sandbox_clean_day_setting = dify_config.PLAN_SANDBOX_CLEAN_DAY_SETTING + plan_pro_clean_day_setting = dify_config.PLAN_PRO_CLEAN_DAY_SETTING start_at = time.perf_counter() - thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days) + plan_sandbox_clean_day = datetime.datetime.now() - datetime.timedelta(days=plan_sandbox_clean_day_setting) + plan_pro_clean_day = datetime.datetime.now() - datetime.timedelta(days=plan_pro_clean_day_setting) page = 1 while True: try: @@ -28,7 +32,7 @@ def clean_unused_datasets_task(): Document.indexing_status == "completed", Document.enabled == True, Document.archived == False, - Document.updated_at > thirty_days_ago, + Document.updated_at > plan_sandbox_clean_day, ) .group_by(Document.dataset_id) .subquery() @@ -41,7 +45,7 @@ def clean_unused_datasets_task(): Document.indexing_status == "completed", Document.enabled == True, Document.archived == False, - Document.updated_at < thirty_days_ago, + Document.updated_at < plan_sandbox_clean_day, ) .group_by(Document.dataset_id) .subquery() @@ -53,7 +57,7 @@ def clean_unused_datasets_task(): .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .filter( - Dataset.created_at < thirty_days_ago, + Dataset.created_at < plan_sandbox_clean_day, func.coalesce(document_subquery_new.c.document_count, 0) == 0, func.coalesce(document_subquery_old.c.document_count, 0) > 0, ) @@ -69,7 +73,7 @@ def clean_unused_datasets_task(): for dataset in datasets: dataset_query = ( db.session.query(DatasetQuery) - .filter(DatasetQuery.created_at > thirty_days_ago, DatasetQuery.dataset_id == dataset.id) + .filter(DatasetQuery.created_at > plan_sandbox_clean_day, DatasetQuery.dataset_id == dataset.id) .all() ) if not dataset_query or len(dataset_query) == 0: @@ -88,5 +92,84 @@ def clean_unused_datasets_task(): click.echo( click.style("clean dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red") ) + page = 1 + while True: + try: + # Subquery for counting new documents + document_subquery_new = ( + db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) + .filter( + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + Document.updated_at > plan_pro_clean_day, + ) + .group_by(Document.dataset_id) + .subquery() + ) + + # Subquery for counting old documents + document_subquery_old = ( + db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) + .filter( + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + Document.updated_at < plan_pro_clean_day, + ) + .group_by(Document.dataset_id) + .subquery() + ) + + # Main query with join and filter + datasets = ( + db.session.query(Dataset) + .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) + .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) + .filter( + Dataset.created_at < plan_pro_clean_day, + func.coalesce(document_subquery_new.c.document_count, 0) == 0, + func.coalesce(document_subquery_old.c.document_count, 0) > 0, + ) + .order_by(Dataset.created_at.desc()) + .paginate(page=page, per_page=50) + ) + + except NotFound: + break + if datasets.items is None or len(datasets.items) == 0: + break + page += 1 + for dataset in datasets: + dataset_query = ( + db.session.query(DatasetQuery) + .filter(DatasetQuery.created_at > plan_pro_clean_day, DatasetQuery.dataset_id == dataset.id) + .all() + ) + if not dataset_query or len(dataset_query) == 0: + try: + features_cache_key = f"features:{dataset.tenant_id}" + plan = redis_client.get(features_cache_key) + if plan is None: + features = FeatureService.get_features(dataset.tenant_id) + redis_client.setex(features_cache_key, 600, features.billing.subscription.plan) + plan = features.billing.subscription.plan + if plan == "sandbox": + # remove index + index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() + index_processor.clean(dataset, None) + + # update document + update_params = {Document.enabled: False} + + Document.query.filter_by(dataset_id=dataset.id).update(update_params) + db.session.commit() + click.echo( + click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green") + ) + except Exception as e: + click.echo( + click.style("clean dataset index error: {} {}".format(e.__class__.__name__, str(e)), fg="red") + ) end_at = time.perf_counter() click.echo(click.style("Cleaned unused dataset from db success latency: {}".format(end_at - start_at), fg="green")) diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py new file mode 100644 index 00000000000000..a20b500308a4d6 --- /dev/null +++ b/api/schedule/create_tidb_serverless_task.py @@ -0,0 +1,58 @@ +import time + +import click + +import app +from configs import dify_config +from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService +from extensions.ext_database import db +from models.dataset import TidbAuthBinding + + +@app.celery.task(queue="dataset") +def create_tidb_serverless_task(): + click.echo(click.style("Start create tidb serverless task.", fg="green")) + if not dify_config.CREATE_TIDB_SERVICE_JOB_ENABLED: + return + tidb_serverless_number = dify_config.TIDB_SERVERLESS_NUMBER + start_at = time.perf_counter() + while True: + try: + # check the number of idle tidb serverless + idle_tidb_serverless_number = TidbAuthBinding.query.filter(TidbAuthBinding.active == False).count() + if idle_tidb_serverless_number >= tidb_serverless_number: + break + # create tidb serverless + iterations_per_thread = 20 + create_clusters(iterations_per_thread) + + except Exception as e: + click.echo(click.style(f"Error: {e}", fg="red")) + break + + end_at = time.perf_counter() + click.echo(click.style("Create tidb serverless task success latency: {}".format(end_at - start_at), fg="green")) + + +def create_clusters(batch_size): + try: + new_clusters = TidbService.batch_create_tidb_serverless_cluster( + batch_size, + dify_config.TIDB_PROJECT_ID, + dify_config.TIDB_API_URL, + dify_config.TIDB_IAM_API_URL, + dify_config.TIDB_PUBLIC_KEY, + dify_config.TIDB_PRIVATE_KEY, + dify_config.TIDB_REGION, + ) + for new_cluster in new_clusters: + tidb_auth_binding = TidbAuthBinding( + cluster_id=new_cluster["cluster_id"], + cluster_name=new_cluster["cluster_name"], + account=new_cluster["account"], + password=new_cluster["password"], + ) + db.session.add(tidb_auth_binding) + db.session.commit() + except Exception as e: + click.echo(click.style(f"Error: {e}", fg="red")) diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py new file mode 100644 index 00000000000000..07eca3173b3ce5 --- /dev/null +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -0,0 +1,51 @@ +import time + +import click + +import app +from configs import dify_config +from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService +from models.dataset import TidbAuthBinding + + +@app.celery.task(queue="dataset") +def update_tidb_serverless_status_task(): + click.echo(click.style("Update tidb serverless status task.", fg="green")) + start_at = time.perf_counter() + while True: + try: + # check the number of idle tidb serverless + tidb_serverless_list = TidbAuthBinding.query.filter( + TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING" + ).all() + if len(tidb_serverless_list) == 0: + break + # update tidb serverless status + iterations_per_thread = 20 + update_clusters(tidb_serverless_list) + + except Exception as e: + click.echo(click.style(f"Error: {e}", fg="red")) + break + + end_at = time.perf_counter() + click.echo( + click.style("Update tidb serverless status task success latency: {}".format(end_at - start_at), fg="green") + ) + + +def update_clusters(tidb_serverless_list: list[TidbAuthBinding]): + try: + # batch 20 + for i in range(0, len(tidb_serverless_list), 20): + items = tidb_serverless_list[i : i + 20] + TidbService.batch_update_tidb_serverless_cluster_status( + items, + dify_config.TIDB_PROJECT_ID, + dify_config.TIDB_API_URL, + dify_config.TIDB_IAM_API_URL, + dify_config.TIDB_PUBLIC_KEY, + dify_config.TIDB_PRIVATE_KEY, + ) + except Exception as e: + click.echo(click.style(f"Error: {e}", fg="red")) diff --git a/api/services/__init__.py b/api/services/__init__.py index 6891436314b299..5163862cc12781 100644 --- a/api/services/__init__.py +++ b/api/services/__init__.py @@ -1,3 +1,3 @@ from . import errors -__all__ = ['errors'] +__all__ = ["errors"] diff --git a/api/services/account_service.py b/api/services/account_service.py index d73cec2697f975..963a05594845e4 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -1,28 +1,43 @@ import base64 +import json import logging +import random import secrets import uuid from datetime import datetime, timedelta, timezone from hashlib import sha256 from typing import Any, Optional +from pydantic import BaseModel from sqlalchemy import func from werkzeug.exceptions import Unauthorized from configs import dify_config from constants.languages import language_timezone_mapping, languages from events.tenant_event import tenant_was_created +from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.helper import RateLimiter, TokenManager from libs.passport import PassportService from libs.password import compare_password, hash_password, valid_password from libs.rsa import generate_key_pair -from models.account import * +from models.account import ( + Account, + AccountIntegrate, + AccountStatus, + Tenant, + TenantAccountJoin, + TenantAccountJoinRole, + TenantAccountRole, + TenantStatus, +) from models.model import DifySetup from services.errors.account import ( AccountAlreadyInTenantError, AccountLoginError, + AccountNotFoundError, AccountNotLinkTenantError, + AccountPasswordError, AccountRegisterError, CannotOperateSelfError, CurrentPasswordIncorrectError, @@ -30,21 +45,52 @@ LinkAccountIntegrateError, MemberNotInTenantError, NoPermissionError, - RateLimitExceededError, RoleAlreadyAssignedError, - TenantNotFound, + TenantNotFoundError, ) +from services.errors.workspace import WorkSpaceNotAllowedCreateError +from services.feature_service import FeatureService +from tasks.mail_email_code_login import send_email_code_login_mail_task from tasks.mail_invite_member_task import send_invite_member_mail_task from tasks.mail_reset_password_task import send_reset_password_mail_task -class AccountService: +class TokenPair(BaseModel): + access_token: str + refresh_token: str + + +REFRESH_TOKEN_PREFIX = "refresh_token:" +ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:" +REFRESH_TOKEN_EXPIRY = timedelta(days=30) - reset_password_rate_limiter = RateLimiter( - prefix="reset_password_rate_limit", - max_attempts=5, - time_window=60 * 60 + +class AccountService: + reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1) + email_code_login_rate_limiter = RateLimiter( + prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1 ) + LOGIN_MAX_ERROR_LIMITS = 5 + + @staticmethod + def _get_refresh_token_key(refresh_token: str) -> str: + return f"{REFRESH_TOKEN_PREFIX}{refresh_token}" + + @staticmethod + def _get_account_refresh_token_key(account_id: str) -> str: + return f"{ACCOUNT_REFRESH_TOKEN_PREFIX}{account_id}" + + @staticmethod + def _store_refresh_token(refresh_token: str, account_id: str) -> None: + redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id) + redis_client.setex( + AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token + ) + + @staticmethod + def _delete_refresh_token(refresh_token: str, account_id: str) -> None: + redis_client.delete(AccountService._get_refresh_token_key(refresh_token)) + redis_client.delete(AccountService._get_account_refresh_token_key(account_id)) @staticmethod def load_user(user_id: str) -> None | Account: @@ -52,15 +98,16 @@ def load_user(user_id: str) -> None | Account: if not account: return None - if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]: - raise Unauthorized("Account is banned or closed.") + if account.status == AccountStatus.BANNED.value: + raise Unauthorized("Account is banned.") - current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() + current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() if current_tenant: account.current_tenant_id = current_tenant.tenant_id else: - available_ta = TenantAccountJoin.query.filter_by(account_id=account.id) \ - .order_by(TenantAccountJoin.id.asc()).first() + available_ta = ( + TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() + ) if not available_ta: return None @@ -74,37 +121,49 @@ def load_user(user_id: str) -> None | Account: return account - @staticmethod - def get_account_jwt_token(account, *, exp: timedelta = timedelta(days=30)): + def get_account_jwt_token(account: Account) -> str: + exp_dt = datetime.now(timezone.utc) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES) + exp = int(exp_dt.timestamp()) payload = { "user_id": account.id, - "exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp, + "exp": exp, "iss": dify_config.EDITION, - "sub": 'Console API Passport', + "sub": "Console API Passport", } token = PassportService().issue(payload) return token @staticmethod - def authenticate(email: str, password: str) -> Account: + def authenticate(email: str, password: str, invite_token: Optional[str] = None) -> Account: """authenticate account with email and password""" account = Account.query.filter_by(email=email).first() if not account: - raise AccountLoginError('Invalid email or password.') + raise AccountNotFoundError() + + if account.status == AccountStatus.BANNED.value: + raise AccountLoginError("Account is banned.") + + if password and invite_token and account.password is None: + # if invite_token is valid, set password and password_salt + salt = secrets.token_bytes(16) + base64_salt = base64.b64encode(salt).decode() + password_hashed = hash_password(password, salt) + base64_password_hashed = base64.b64encode(password_hashed).decode() + account.password = base64_password_hashed + account.password_salt = base64_salt - if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: - raise AccountLoginError('Account is banned or closed.') + if account.password is None or not compare_password(password, account.password, account.password_salt): + raise AccountPasswordError("Invalid email or password.") if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) - db.session.commit() - if account.password is None or not compare_password(password, account.password, account.password_salt): - raise AccountLoginError('Invalid email or password.') + db.session.commit() + return account @staticmethod @@ -129,12 +188,19 @@ def update_account_password(account, password, new_password): return account @staticmethod - def create_account(email: str, - name: str, - interface_language: str, - password: Optional[str] = None, - interface_theme: str = 'light') -> Account: + def create_account( + email: str, + name: str, + interface_language: str, + password: Optional[str] = None, + interface_theme: str = "light", + is_setup: Optional[bool] = False, + ) -> Account: """create account""" + if not FeatureService.get_system_features().is_allow_register and not is_setup: + from controllers.console.error import NotAllowedRegister + + raise NotAllowedRegister() account = Account() account.email = email account.name = name @@ -155,19 +221,33 @@ def create_account(email: str, account.interface_theme = interface_theme # Set timezone based on language - account.timezone = language_timezone_mapping.get(interface_language, 'UTC') + account.timezone = language_timezone_mapping.get(interface_language, "UTC") db.session.add(account) db.session.commit() return account + @staticmethod + def create_account_and_tenant( + email: str, name: str, interface_language: str, password: Optional[str] = None + ) -> Account: + """create account""" + account = AccountService.create_account( + email=email, name=name, interface_language=interface_language, password=password + ) + + TenantService.create_owner_tenant_if_not_exist(account=account) + + return account + @staticmethod def link_account_integrate(provider: str, open_id: str, account: Account) -> None: """Link account integrate""" try: # Query whether there is an existing binding record for the same provider - account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(account_id=account.id, - provider=provider).first() + account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by( + account_id=account.id, provider=provider + ).first() if account_integrate: # If it exists, update the record @@ -176,15 +256,16 @@ def link_account_integrate(provider: str, open_id: str, account: Account) -> Non account_integrate.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) else: # If it does not exist, create a new record - account_integrate = AccountIntegrate(account_id=account.id, provider=provider, open_id=open_id, - encrypted_token="") + account_integrate = AccountIntegrate( + account_id=account.id, provider=provider, open_id=open_id, encrypted_token="" + ) db.session.add(account_integrate) db.session.commit() - logging.info(f'Account {account.id} linked {provider} account {open_id}.') + logging.info(f"Account {account.id} linked {provider} account {open_id}.") except Exception as e: - logging.exception(f'Failed to link {provider} account {open_id} to Account {account.id}') - raise LinkAccountIntegrateError('Failed to link account.') from e + logging.exception(f"Failed to link {provider} account {open_id} to Account {account.id}") + raise LinkAccountIntegrateError("Failed to link account.") from e @staticmethod def close_account(account: Account) -> None: @@ -205,7 +286,7 @@ def update_account(account, **kwargs): return account @staticmethod - def update_last_login(account: Account, *, ip_address: str) -> None: + def update_login_info(account: Account, *, ip_address: str) -> None: """Update last login time and ip""" account.last_login_at = datetime.now(timezone.utc).replace(tzinfo=None) account.last_login_ip = ip_address @@ -213,45 +294,190 @@ def update_last_login(account: Account, *, ip_address: str) -> None: db.session.commit() @staticmethod - def login(account: Account, *, ip_address: Optional[str] = None): + def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair: if ip_address: - AccountService.update_last_login(account, ip_address=ip_address) - exp = timedelta(days=30) - token = AccountService.get_account_jwt_token(account, exp=exp) - redis_client.set(_get_login_cache_key(account_id=account.id, token=token), '1', ex=int(exp.total_seconds())) - return token + AccountService.update_login_info(account=account, ip_address=ip_address) + + if account.status == AccountStatus.PENDING.value: + account.status = AccountStatus.ACTIVE.value + db.session.commit() + + access_token = AccountService.get_account_jwt_token(account=account) + refresh_token = _generate_refresh_token() + + AccountService._store_refresh_token(refresh_token, account.id) + + return TokenPair(access_token=access_token, refresh_token=refresh_token) @staticmethod - def logout(*, account: Account, token: str): - redis_client.delete(_get_login_cache_key(account_id=account.id, token=token)) + def logout(*, account: Account) -> None: + refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id)) + if refresh_token: + AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id) @staticmethod - def load_logged_in_account(*, account_id: str, token: str): - if not redis_client.get(_get_login_cache_key(account_id=account_id, token=token)): - return None + def refresh_token(refresh_token: str) -> TokenPair: + # Verify the refresh token + account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token)) + if not account_id: + raise ValueError("Invalid refresh token") + + account = AccountService.load_user(account_id.decode("utf-8")) + if not account: + raise ValueError("Invalid account") + + # Generate new access token and refresh token + new_access_token = AccountService.get_account_jwt_token(account) + new_refresh_token = _generate_refresh_token() + + AccountService._delete_refresh_token(refresh_token, account.id) + AccountService._store_refresh_token(new_refresh_token, account.id) + + return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token) + + @staticmethod + def load_logged_in_account(*, account_id: str): return AccountService.load_user(account_id) @classmethod - def send_reset_password_email(cls, account): - if cls.reset_password_rate_limiter.is_rate_limited(account.email): - raise RateLimitExceededError(f"Rate limit exceeded for email: {account.email}. Please try again later.") - - token = TokenManager.generate_token(account, 'reset_password') + def send_reset_password_email( + cls, + account: Optional[Account] = None, + email: Optional[str] = None, + language: Optional[str] = "en-US", + ): + account_email = account.email if account else email + + if cls.reset_password_rate_limiter.is_rate_limited(account_email): + from controllers.console.auth.error import PasswordResetRateLimitExceededError + + raise PasswordResetRateLimitExceededError() + + code = "".join([str(random.randint(0, 9)) for _ in range(6)]) + token = TokenManager.generate_token( + account=account, email=email, token_type="reset_password", additional_data={"code": code} + ) send_reset_password_mail_task.delay( - language=account.interface_language, - to=account.email, - token=token + language=language, + to=account_email, + code=code, ) - cls.reset_password_rate_limiter.increment_rate_limit(account.email) + cls.reset_password_rate_limiter.increment_rate_limit(account_email) return token @classmethod def revoke_reset_password_token(cls, token: str): - TokenManager.revoke_token(token, 'reset_password') + TokenManager.revoke_token(token, "reset_password") @classmethod def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]: - return TokenManager.get_token_data(token, 'reset_password') + return TokenManager.get_token_data(token, "reset_password") + + @classmethod + def send_email_code_login_email( + cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US" + ): + if cls.email_code_login_rate_limiter.is_rate_limited(email): + from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError + + raise EmailCodeLoginRateLimitExceededError() + + code = "".join([str(random.randint(0, 9)) for _ in range(6)]) + token = TokenManager.generate_token( + account=account, email=email, token_type="email_code_login", additional_data={"code": code} + ) + send_email_code_login_mail_task.delay( + language=language, + to=account.email if account else email, + code=code, + ) + cls.email_code_login_rate_limiter.increment_rate_limit(email) + return token + + @classmethod + def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]: + return TokenManager.get_token_data(token, "email_code_login") + + @classmethod + def revoke_email_code_login_token(cls, token: str): + TokenManager.revoke_token(token, "email_code_login") + + @classmethod + def get_user_through_email(cls, email: str): + account = db.session.query(Account).filter(Account.email == email).first() + if not account: + return None + + if account.status == AccountStatus.BANNED.value: + raise Unauthorized("Account is banned.") + + return account + + @staticmethod + def add_login_error_rate_limit(email: str) -> None: + key = f"login_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + count = 0 + count = int(count) + 1 + redis_client.setex(key, 60 * 60 * 24, count) + + @staticmethod + def is_login_error_rate_limit(email: str) -> bool: + key = f"login_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + return False + + count = int(count) + if count > AccountService.LOGIN_MAX_ERROR_LIMITS: + return True + return False + + @staticmethod + def reset_login_error_rate_limit(email: str): + key = f"login_error_rate_limit:{email}" + redis_client.delete(key) + + @staticmethod + def is_email_send_ip_limit(ip_address: str): + minute_key = f"email_send_ip_limit_minute:{ip_address}" + freeze_key = f"email_send_ip_limit_freeze:{ip_address}" + hour_limit_key = f"email_send_ip_limit_hour:{ip_address}" + + # check ip is frozen + if redis_client.get(freeze_key): + return True + + # check current minute count + current_minute_count = redis_client.get(minute_key) + if current_minute_count is None: + current_minute_count = 0 + current_minute_count = int(current_minute_count) + + # check current hour count + if current_minute_count > dify_config.EMAIL_SEND_IP_LIMIT_PER_MINUTE: + hour_limit_count = redis_client.get(hour_limit_key) + if hour_limit_count is None: + hour_limit_count = 0 + hour_limit_count = int(hour_limit_count) + + if hour_limit_count >= 1: + redis_client.setex(freeze_key, 60 * 60, 1) + return True + else: + redis_client.setex(hour_limit_key, 60 * 10, hour_limit_count + 1) # first time limit 10 minutes + + # add hour limit count + redis_client.incr(hour_limit_key) + redis_client.expire(hour_limit_key, 60 * 60) + + return True + + redis_client.setex(minute_key, 60, current_minute_count + 1) + redis_client.expire(minute_key, 60) + + return False def _get_login_cache_key(*, account_id: str, token: str): @@ -259,10 +485,17 @@ def _get_login_cache_key(*, account_id: str, token: str): class TenantService: - @staticmethod - def create_tenant(name: str) -> Tenant: + def create_tenant(name: str, is_setup: Optional[bool] = False, is_from_dashboard: Optional[bool] = False) -> Tenant: """Create tenant""" + if ( + not FeatureService.get_system_features().is_allow_create_workspace + and not is_setup + and not is_from_dashboard + ): + from controllers.console.error import NotAllowedCreateWorkspace + + raise NotAllowedCreateWorkspace() tenant = Tenant(name=name) db.session.add(tenant) @@ -273,76 +506,97 @@ def create_tenant(name: str) -> Tenant: return tenant @staticmethod - def create_owner_tenant_if_not_exist(account: Account): - """Create owner tenant if not exist""" - available_ta = TenantAccountJoin.query.filter_by(account_id=account.id) \ - .order_by(TenantAccountJoin.id.asc()).first() + def create_owner_tenant_if_not_exist( + account: Account, name: Optional[str] = None, is_setup: Optional[bool] = False + ): + """Check if user have a workspace or not""" + available_ta = ( + TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() + ) if available_ta: return - tenant = TenantService.create_tenant(f"{account.name}'s Workspace") - TenantService.create_tenant_member(tenant, account, role='owner') + """Create owner tenant if not exist""" + if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup: + raise WorkSpaceNotAllowedCreateError() + + if name: + tenant = TenantService.create_tenant(name=name, is_setup=is_setup) + else: + tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup) + TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant db.session.commit() tenant_was_created.send(tenant) @staticmethod - def create_tenant_member(tenant: Tenant, account: Account, role: str = 'normal') -> TenantAccountJoin: + def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin: """Create tenant member""" if role == TenantAccountJoinRole.OWNER.value: if TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER]): - logging.error(f'Tenant {tenant.id} has already an owner.') - raise Exception('Tenant already has an owner.') + logging.error(f"Tenant {tenant.id} has already an owner.") + raise Exception("Tenant already has an owner.") + + ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + if ta: + ta.role = role + else: + ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role) + db.session.add(ta) - ta = TenantAccountJoin( - tenant_id=tenant.id, - account_id=account.id, - role=role - ) - db.session.add(ta) db.session.commit() return ta @staticmethod def get_join_tenants(account: Account) -> list[Tenant]: """Get account join tenants""" - return db.session.query(Tenant).join( - TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id - ).filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL).all() + return ( + db.session.query(Tenant) + .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) + .filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL) + .all() + ) @staticmethod def get_current_tenant_by_account(account: Account): """Get tenant by account and add the role""" tenant = account.current_tenant if not tenant: - raise TenantNotFound("Tenant not found.") + raise TenantNotFoundError("Tenant not found.") ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() if ta: tenant.role = ta.role else: - raise TenantNotFound("Tenant not found for the account.") + raise TenantNotFoundError("Tenant not found for the account.") return tenant @staticmethod - def switch_tenant(account: Account, tenant_id: int = None) -> None: + def switch_tenant(account: Account, tenant_id: Optional[int] = None) -> None: """Switch the current workspace for the account""" # Ensure tenant_id is provided if tenant_id is None: raise ValueError("Tenant ID must be provided.") - tenant_account_join = db.session.query(TenantAccountJoin).join(Tenant, TenantAccountJoin.tenant_id == Tenant.id).filter( - TenantAccountJoin.account_id == account.id, - TenantAccountJoin.tenant_id == tenant_id, - Tenant.status == TenantStatus.NORMAL, - ).first() + tenant_account_join = ( + db.session.query(TenantAccountJoin) + .join(Tenant, TenantAccountJoin.tenant_id == Tenant.id) + .filter( + TenantAccountJoin.account_id == account.id, + TenantAccountJoin.tenant_id == tenant_id, + Tenant.status == TenantStatus.NORMAL, + ) + .first() + ) if not tenant_account_join: raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") else: - TenantAccountJoin.query.filter(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id).update({'current': False}) + TenantAccountJoin.query.filter( + TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id + ).update({"current": False}) tenant_account_join.current = True # Set the current tenant for the account account.current_tenant_id = tenant_account_join.tenant_id @@ -354,9 +608,7 @@ def get_tenant_members(tenant: Tenant) -> list[Account]: query = ( db.session.query(Account, TenantAccountJoin.role) .select_from(Account) - .join( - TenantAccountJoin, Account.id == TenantAccountJoin.account_id - ) + .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .filter(TenantAccountJoin.tenant_id == tenant.id) ) @@ -375,11 +627,9 @@ def get_dataset_operator_members(tenant: Tenant) -> list[Account]: query = ( db.session.query(Account, TenantAccountJoin.role) .select_from(Account) - .join( - TenantAccountJoin, Account.id == TenantAccountJoin.account_id - ) + .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .filter(TenantAccountJoin.tenant_id == tenant.id) - .filter(TenantAccountJoin.role == 'dataset_operator') + .filter(TenantAccountJoin.role == "dataset_operator") ) # Initialize an empty list to store the updated accounts @@ -395,20 +645,25 @@ def get_dataset_operator_members(tenant: Tenant) -> list[Account]: def has_roles(tenant: Tenant, roles: list[TenantAccountJoinRole]) -> bool: """Check if user has any of the given roles for a tenant""" if not all(isinstance(role, TenantAccountJoinRole) for role in roles): - raise ValueError('all roles must be TenantAccountJoinRole') + raise ValueError("all roles must be TenantAccountJoinRole") - return db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.tenant_id == tenant.id, - TenantAccountJoin.role.in_([role.value for role in roles]) - ).first() is not None + return ( + db.session.query(TenantAccountJoin) + .filter( + TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles]) + ) + .first() + is not None + ) @staticmethod def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoinRole]: """Get the role of the current account for a given tenant""" - join = db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.tenant_id == tenant.id, - TenantAccountJoin.account_id == account.id - ).first() + join = ( + db.session.query(TenantAccountJoin) + .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) + .first() + ) return join.role if join else None @staticmethod @@ -420,29 +675,26 @@ def get_tenant_count() -> int: def check_member_permission(tenant: Tenant, operator: Account, member: Account, action: str) -> None: """Check member permission""" perms = { - 'add': [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], - 'remove': [TenantAccountRole.OWNER], - 'update': [TenantAccountRole.OWNER] + "add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], + "remove": [TenantAccountRole.OWNER], + "update": [TenantAccountRole.OWNER], } - if action not in ['add', 'remove', 'update']: + if action not in {"add", "remove", "update"}: raise InvalidActionError("Invalid action.") if member: if operator.id == member.id: raise CannotOperateSelfError("Cannot operate self.") - ta_operator = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, - account_id=operator.id - ).first() + ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first() if not ta_operator or ta_operator.role not in perms[action]: - raise NoPermissionError(f'No permission to {action} member.') + raise NoPermissionError(f"No permission to {action} member.") @staticmethod def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None: """Remove member from tenant""" - if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, 'remove'): + if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, "remove"): raise CannotOperateSelfError("Cannot operate self.") ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() @@ -455,23 +707,17 @@ def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Accoun @staticmethod def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None: """Update member role""" - TenantService.check_member_permission(tenant, operator, member, 'update') + TenantService.check_member_permission(tenant, operator, member, "update") - target_member_join = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, - account_id=member.id - ).first() + target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first() if target_member_join.role == new_role: raise RoleAlreadyAssignedError("The provided role is already assigned to the member.") - if new_role == 'owner': + if new_role == "owner": # Find the current owner and change their role to 'admin' - current_owner_join = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, - role='owner' - ).first() - current_owner_join.role = 'admin' + current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() + current_owner_join.role = "admin" # Update the role of the target member target_member_join.role = new_role @@ -480,8 +726,8 @@ def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: @staticmethod def dissolve_tenant(tenant: Tenant, operator: Account) -> None: """Dissolve tenant""" - if not TenantService.check_member_permission(tenant, operator, operator, 'remove'): - raise NoPermissionError('No permission to dissolve tenant.') + if not TenantService.check_member_permission(tenant, operator, operator, "remove"): + raise NoPermissionError("No permission to dissolve tenant.") db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete() db.session.delete(tenant) db.session.commit() @@ -494,10 +740,9 @@ def get_custom_config(tenant_id: str) -> None: class RegisterService: - @classmethod def _get_invitation_token_key(cls, token: str) -> str: - return f'member_invite:token:{token}' + return f"member_invite:token:{token}" @classmethod def setup(cls, email: str, name: str, password: str, ip_address: str) -> None: @@ -516,16 +761,15 @@ def setup(cls, email: str, name: str, password: str, ip_address: str) -> None: name=name, interface_language=languages[0], password=password, + is_setup=True, ) account.last_login_ip = ip_address account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) - TenantService.create_owner_tenant_if_not_exist(account) + TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True) - dify_setup = DifySetup( - version=dify_config.CURRENT_VERSION - ) + dify_setup = DifySetup(version=dify_config.CURRENT_VERSION) db.session.add(dify_setup) db.session.commit() except Exception as e: @@ -535,65 +779,73 @@ def setup(cls, email: str, name: str, password: str, ip_address: str) -> None: db.session.query(Tenant).delete() db.session.commit() - logging.exception(f'Setup failed: {e}') - raise ValueError(f'Setup failed: {e}') + logging.exception(f"Setup failed: {e}") + raise ValueError(f"Setup failed: {e}") @classmethod - def register(cls, email, name, - password: Optional[str] = None, - open_id: Optional[str] = None, - provider: Optional[str] = None, - language: Optional[str] = None, - status: Optional[AccountStatus] = None) -> Account: + def register( + cls, + email, + name, + password: Optional[str] = None, + open_id: Optional[str] = None, + provider: Optional[str] = None, + language: Optional[str] = None, + status: Optional[AccountStatus] = None, + is_setup: Optional[bool] = False, + ) -> Account: db.session.begin_nested() """Register account""" try: account = AccountService.create_account( email=email, name=name, - interface_language=language if language else languages[0], - password=password + interface_language=language or languages[0], + password=password, + is_setup=is_setup, ) account.status = AccountStatus.ACTIVE.value if not status else status.value account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) if open_id is not None or provider is not None: AccountService.link_account_integrate(provider, open_id, account) - if dify_config.EDITION != 'SELF_HOSTED': - tenant = TenantService.create_tenant(f"{account.name}'s Workspace") - TenantService.create_tenant_member(tenant, account, role='owner') + if FeatureService.get_system_features().is_allow_create_workspace: + tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant - tenant_was_created.send(tenant) db.session.commit() + except WorkSpaceNotAllowedCreateError: + db.session.rollback() except Exception as e: db.session.rollback() - logging.error(f'Register failed: {e}') - raise AccountRegisterError(f'Registration failed: {e}') from e + logging.exception(f"Register failed: {e}") + raise AccountRegisterError(f"Registration failed: {e}") from e return account @classmethod - def invite_new_member(cls, tenant: Tenant, email: str, language: str, role: str = 'normal', inviter: Account = None) -> str: + def invite_new_member( + cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account = None + ) -> str: """Invite new member""" account = Account.query.filter_by(email=email).first() if not account: - TenantService.check_member_permission(tenant, inviter, None, 'add') - name = email.split('@')[0] + TenantService.check_member_permission(tenant, inviter, None, "add") + name = email.split("@")[0] - account = cls.register(email=email, name=name, language=language, status=AccountStatus.PENDING) + account = cls.register( + email=email, name=name, language=language, status=AccountStatus.PENDING, is_setup=True + ) # Create new tenant member for invited tenant TenantService.create_tenant_member(tenant, account, role) TenantService.switch_tenant(account, tenant.id) else: - TenantService.check_member_permission(tenant, inviter, account, 'add') - ta = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, - account_id=account.id - ).first() + TenantService.check_member_permission(tenant, inviter, account, "add") + ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() if not ta: TenantService.create_tenant_member(tenant, account, role) @@ -609,7 +861,7 @@ def invite_new_member(cls, tenant: Tenant, email: str, language: str, role: str language=account.interface_language, to=email, token=token, - inviter_name=inviter.name if inviter else 'Dify', + inviter_name=inviter.name if inviter else "Dify", workspace_name=tenant.name, ) @@ -619,23 +871,24 @@ def invite_new_member(cls, tenant: Tenant, email: str, language: str, role: str def generate_invite_token(cls, tenant: Tenant, account: Account) -> str: token = str(uuid.uuid4()) invitation_data = { - 'account_id': account.id, - 'email': account.email, - 'workspace_id': tenant.id, + "account_id": account.id, + "email": account.email, + "workspace_id": tenant.id, } - expiryHours = dify_config.INVITE_EXPIRY_HOURS - redis_client.setex( - cls._get_invitation_token_key(token), - expiryHours * 60 * 60, - json.dumps(invitation_data) - ) + expiry_hours = dify_config.INVITE_EXPIRY_HOURS + redis_client.setex(cls._get_invitation_token_key(token), expiry_hours * 60 * 60, json.dumps(invitation_data)) return token + @classmethod + def is_valid_invite_token(cls, token: str) -> bool: + data = redis_client.get(cls._get_invitation_token_key(token)) + return data is not None + @classmethod def revoke_token(cls, workspace_id: str, email: str, token: str): if workspace_id and email: email_hash = sha256(email.encode()).hexdigest() - cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token) + cache_key = "member_invite_token:{}, {}:{}".format(workspace_id, email_hash, token) redis_client.delete(cache_key) else: redis_client.delete(cls._get_invitation_token_key(token)) @@ -646,17 +899,21 @@ def get_invitation_if_token_valid(cls, workspace_id: str, email: str, token: str if not invitation_data: return None - tenant = db.session.query(Tenant).filter( - Tenant.id == invitation_data['workspace_id'], - Tenant.status == 'normal' - ).first() + tenant = ( + db.session.query(Tenant) + .filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal") + .first() + ) if not tenant: return None - tenant_account = db.session.query(Account, TenantAccountJoin.role).join( - TenantAccountJoin, Account.id == TenantAccountJoin.account_id - ).filter(Account.email == invitation_data['email'], TenantAccountJoin.tenant_id == tenant.id).first() + tenant_account = ( + db.session.query(Account, TenantAccountJoin.role) + .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) + .filter(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id) + .first() + ) if not tenant_account: return None @@ -665,29 +922,31 @@ def get_invitation_if_token_valid(cls, workspace_id: str, email: str, token: str if not account: return None - if invitation_data['account_id'] != str(account.id): + if invitation_data["account_id"] != str(account.id): return None return { - 'account': account, - 'data': invitation_data, - 'tenant': tenant, + "account": account, + "data": invitation_data, + "tenant": tenant, } @classmethod - def _get_invitation_by_token(cls, token: str, workspace_id: str, email: str) -> Optional[dict[str, str]]: + def _get_invitation_by_token( + cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None + ) -> Optional[dict[str, str]]: if workspace_id is not None and email is not None: email_hash = sha256(email.encode()).hexdigest() - cache_key = f'member_invite_token:{workspace_id}, {email_hash}:{token}' + cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}" account_id = redis_client.get(cache_key) if not account_id: return None return { - 'account_id': account_id.decode('utf-8'), - 'email': email, - 'workspace_id': workspace_id, + "account_id": account_id.decode("utf-8"), + "email": email, + "workspace_id": workspace_id, } else: data = redis_client.get(cls._get_invitation_token_key(token)) @@ -696,3 +955,8 @@ def _get_invitation_by_token(cls, token: str, workspace_id: str, email: str) -> invitation = json.loads(data) return invitation + + +def _generate_refresh_token(length: int = 64): + token = secrets.token_hex(length) + return token diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index 213df262223d8a..d2cd7bea67c5b6 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -1,4 +1,3 @@ - import copy from core.prompt.prompt_templates.advanced_prompt_templates import ( @@ -17,59 +16,78 @@ class AdvancedPromptTemplateService: - @classmethod def get_prompt(cls, args: dict) -> dict: - app_mode = args['app_mode'] - model_mode = args['model_mode'] - model_name = args['model_name'] - has_context = args['has_context'] + app_mode = args["app_mode"] + model_mode = args["model_mode"] + model_name = args["model_name"] + has_context = args["has_context"] - if 'baichuan' in model_name.lower(): + if "baichuan" in model_name.lower(): return cls.get_baichuan_prompt(app_mode, model_mode, has_context) else: return cls.get_common_prompt(app_mode, model_mode, has_context) @classmethod - def get_common_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict: + def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict: context_prompt = copy.deepcopy(CONTEXT) if app_mode == AppMode.CHAT.value: if model_mode == "completion": - return cls.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt) + return cls.get_completion_prompt( + copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt + ) elif model_mode == "chat": return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) elif app_mode == AppMode.COMPLETION.value: if model_mode == "completion": - return cls.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt) + return cls.get_completion_prompt( + copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt + ) elif model_mode == "chat": - return cls.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) - + return cls.get_chat_prompt( + copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt + ) + @classmethod def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict: - if has_context == 'true': - prompt_template['completion_prompt_config']['prompt']['text'] = context + prompt_template['completion_prompt_config']['prompt']['text'] - + if has_context == "true": + prompt_template["completion_prompt_config"]["prompt"]["text"] = ( + context + prompt_template["completion_prompt_config"]["prompt"]["text"] + ) + return prompt_template @classmethod def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict: - if has_context == 'true': - prompt_template['chat_prompt_config']['prompt'][0]['text'] = context + prompt_template['chat_prompt_config']['prompt'][0]['text'] - + if has_context == "true": + prompt_template["chat_prompt_config"]["prompt"][0]["text"] = ( + context + prompt_template["chat_prompt_config"]["prompt"][0]["text"] + ) + return prompt_template @classmethod - def get_baichuan_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict: + def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict: baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT) if app_mode == AppMode.CHAT.value: if model_mode == "completion": - return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt) + return cls.get_completion_prompt( + copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt + ) elif model_mode == "chat": - return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt) + return cls.get_chat_prompt( + copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt + ) elif app_mode == AppMode.COMPLETION.value: if model_mode == "completion": - return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt) + return cls.get_completion_prompt( + copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), + has_context, + baichuan_context_prompt, + ) elif model_mode == "chat": - return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt) \ No newline at end of file + return cls.get_chat_prompt( + copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt + ) diff --git a/api/services/agent_service.py b/api/services/agent_service.py index ba5fd93326a96f..c8819535f11a39 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -10,59 +10,65 @@ class AgentService: @classmethod - def get_agent_logs(cls, app_model: App, - conversation_id: str, - message_id: str) -> dict: + def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) -> dict: """ Service to get agent logs """ - conversation: Conversation = db.session.query(Conversation).filter( - Conversation.id == conversation_id, - Conversation.app_id == app_model.id, - ).first() + conversation: Conversation = ( + db.session.query(Conversation) + .filter( + Conversation.id == conversation_id, + Conversation.app_id == app_model.id, + ) + .first() + ) if not conversation: raise ValueError(f"Conversation not found: {conversation_id}") - message: Message = db.session.query(Message).filter( - Message.id == message_id, - Message.conversation_id == conversation_id, - ).first() + message: Message = ( + db.session.query(Message) + .filter( + Message.id == message_id, + Message.conversation_id == conversation_id, + ) + .first() + ) if not message: raise ValueError(f"Message not found: {message_id}") - + agent_thoughts: list[MessageAgentThought] = message.agent_thoughts if conversation.from_end_user_id: # only select name field - executor = db.session.query(EndUser, EndUser.name).filter( - EndUser.id == conversation.from_end_user_id - ).first() + executor = ( + db.session.query(EndUser, EndUser.name).filter(EndUser.id == conversation.from_end_user_id).first() + ) else: - executor = db.session.query(Account, Account.name).filter( - Account.id == conversation.from_account_id - ).first() - + executor = ( + db.session.query(Account, Account.name).filter(Account.id == conversation.from_account_id).first() + ) + if executor: executor = executor.name else: - executor = 'Unknown' + executor = "Unknown" timezone = pytz.timezone(current_user.timezone) result = { - 'meta': { - 'status': 'success', - 'executor': executor, - 'start_time': message.created_at.astimezone(timezone).isoformat(), - 'elapsed_time': message.provider_response_latency, - 'total_tokens': message.answer_tokens + message.message_tokens, - 'agent_mode': app_model.app_model_config.agent_mode_dict.get('strategy', 'react'), - 'iterations': len(agent_thoughts), + "meta": { + "status": "success", + "executor": executor, + "start_time": message.created_at.astimezone(timezone).isoformat(), + "elapsed_time": message.provider_response_latency, + "total_tokens": message.answer_tokens + message.message_tokens, + "agent_mode": app_model.app_model_config.agent_mode_dict.get("strategy", "react"), + "iterations": len(agent_thoughts), }, - 'iterations': [], - 'files': message.files, + "iterations": [], + "files": message.message_files, } agent_config = AgentConfigManager.convert(app_model.app_model_config.to_dict()) @@ -86,12 +92,12 @@ def find_agent_tool(tool_name: str): tool_input = tool_inputs.get(tool_name, {}) tool_output = tool_outputs.get(tool_name, {}) tool_meta_data = tool_meta.get(tool_name, {}) - tool_config = tool_meta_data.get('tool_config', {}) - if tool_config.get('tool_provider_type', '') != 'dataset-retrieval': + tool_config = tool_meta_data.get("tool_config", {}) + if tool_config.get("tool_provider_type", "") != "dataset-retrieval": tool_icon = ToolManager.get_tool_icon( tenant_id=app_model.tenant_id, - provider_type=tool_config.get('tool_provider_type', ''), - provider_id=tool_config.get('tool_provider', ''), + provider_type=tool_config.get("tool_provider_type", ""), + provider_id=tool_config.get("tool_provider", ""), ) if not tool_icon: tool_entity = find_agent_tool(tool_name) @@ -102,30 +108,34 @@ def find_agent_tool(tool_name: str): provider_id=tool_entity.provider_id, ) else: - tool_icon = '' - - tool_calls.append({ - 'status': 'success' if not tool_meta_data.get('error') else 'error', - 'error': tool_meta_data.get('error'), - 'time_cost': tool_meta_data.get('time_cost', 0), - 'tool_name': tool_name, - 'tool_label': tool_label, - 'tool_input': tool_input, - 'tool_output': tool_output, - 'tool_parameters': tool_meta_data.get('tool_parameters', {}), - 'tool_icon': tool_icon, - }) - - result['iterations'].append({ - 'tokens': agent_thought.tokens, - 'tool_calls': tool_calls, - 'tool_raw': { - 'inputs': agent_thought.tool_input, - 'outputs': agent_thought.observation, - }, - 'thought': agent_thought.thought, - 'created_at': agent_thought.created_at.isoformat(), - 'files': agent_thought.files, - }) - - return result \ No newline at end of file + tool_icon = "" + + tool_calls.append( + { + "status": "success" if not tool_meta_data.get("error") else "error", + "error": tool_meta_data.get("error"), + "time_cost": tool_meta_data.get("time_cost", 0), + "tool_name": tool_name, + "tool_label": tool_label, + "tool_input": tool_input, + "tool_output": tool_output, + "tool_parameters": tool_meta_data.get("tool_parameters", {}), + "tool_icon": tool_icon, + } + ) + + result["iterations"].append( + { + "tokens": agent_thought.tokens, + "tool_calls": tool_calls, + "tool_raw": { + "inputs": agent_thought.tool_input, + "outputs": agent_thought.observation, + }, + "thought": agent_thought.thought, + "created_at": agent_thought.created_at.isoformat(), + "files": agent_thought.files, + } + ) + + return result diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index addcde44ed9a4a..915d37ec032549 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -23,21 +23,18 @@ class AppAnnotationService: @classmethod def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation: # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") - if args.get('message_id'): - message_id = str(args['message_id']) + if args.get("message_id"): + message_id = str(args["message_id"]) # get message info - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app.id - ).first() + message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app.id).first() if not message: raise NotFound("Message Not Exists.") @@ -45,159 +42,166 @@ def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> Messa annotation = message.annotation # save the message annotation if annotation: - annotation.content = args['answer'] - annotation.question = args['question'] + annotation.content = args["answer"] + annotation.question = args["question"] else: annotation = MessageAnnotation( app_id=app.id, conversation_id=message.conversation_id, message_id=message.id, - content=args['answer'], - question=args['question'], - account_id=current_user.id + content=args["answer"], + question=args["question"], + account_id=current_user.id, ) else: annotation = MessageAnnotation( - app_id=app.id, - content=args['answer'], - question=args['question'], - account_id=current_user.id + app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id ) db.session.add(annotation) db.session.commit() # if annotation reply is enabled , add annotation to index - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if annotation_setting: - add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id, - app_id, annotation_setting.collection_binding_id) + add_annotation_to_index_task.delay( + annotation.id, + args["question"], + current_user.current_tenant_id, + app_id, + annotation_setting.collection_binding_id, + ) return annotation @classmethod def enable_app_annotation(cls, args: dict, app_id: str) -> dict: - enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id)) + enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id)) cache_result = redis_client.get(enable_app_annotation_key) if cache_result is not None: - return { - 'job_id': cache_result, - 'job_status': 'processing' - } + return {"job_id": cache_result, "job_status": "processing"} # async job job_id = str(uuid.uuid4()) - enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id)) + enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id)) # send batch add segments task - redis_client.setnx(enable_app_annotation_job_key, 'waiting') - enable_annotation_reply_task.delay(str(job_id), app_id, current_user.id, current_user.current_tenant_id, - args['score_threshold'], - args['embedding_provider_name'], args['embedding_model_name']) - return { - 'job_id': job_id, - 'job_status': 'waiting' - } + redis_client.setnx(enable_app_annotation_job_key, "waiting") + enable_annotation_reply_task.delay( + str(job_id), + app_id, + current_user.id, + current_user.current_tenant_id, + args["score_threshold"], + args["embedding_provider_name"], + args["embedding_model_name"], + ) + return {"job_id": job_id, "job_status": "waiting"} @classmethod def disable_app_annotation(cls, app_id: str) -> dict: - disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id)) + disable_app_annotation_key = "disable_app_annotation_{}".format(str(app_id)) cache_result = redis_client.get(disable_app_annotation_key) if cache_result is not None: - return { - 'job_id': cache_result, - 'job_status': 'processing' - } + return {"job_id": cache_result, "job_status": "processing"} # async job job_id = str(uuid.uuid4()) - disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id)) + disable_app_annotation_job_key = "disable_app_annotation_job_{}".format(str(job_id)) # send batch add segments task - redis_client.setnx(disable_app_annotation_job_key, 'waiting') + redis_client.setnx(disable_app_annotation_job_key, "waiting") disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id) - return { - 'job_id': job_id, - 'job_status': 'waiting' - } + return {"job_id": job_id, "job_status": "waiting"} @classmethod def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") if keyword: - annotations = (db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) - .filter( - or_( - MessageAnnotation.question.ilike('%{}%'.format(keyword)), - MessageAnnotation.content.ilike('%{}%'.format(keyword)) + annotations = ( + db.session.query(MessageAnnotation) + .filter(MessageAnnotation.app_id == app_id) + .filter( + or_( + MessageAnnotation.question.ilike("%{}%".format(keyword)), + MessageAnnotation.content.ilike("%{}%".format(keyword)), + ) ) + .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) + .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) ) - .order_by(MessageAnnotation.created_at.desc()) - .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) else: - annotations = (db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) - .order_by(MessageAnnotation.created_at.desc()) - .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) + annotations = ( + db.session.query(MessageAnnotation) + .filter(MessageAnnotation.app_id == app_id) + .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) + .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + ) return annotations.items, annotations.total @classmethod def export_annotation_list_by_app_id(cls, app_id: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") - annotations = (db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) - .order_by(MessageAnnotation.created_at.desc()).all()) + annotations = ( + db.session.query(MessageAnnotation) + .filter(MessageAnnotation.app_id == app_id) + .order_by(MessageAnnotation.created_at.desc()) + .all() + ) return annotations @classmethod def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation: # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") annotation = MessageAnnotation( - app_id=app.id, - content=args['answer'], - question=args['question'], - account_id=current_user.id + app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id ) db.session.add(annotation) db.session.commit() # if annotation reply is enabled , add annotation to index - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if annotation_setting: - add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id, - app_id, annotation_setting.collection_binding_id) + add_annotation_to_index_task.delay( + annotation.id, + args["question"], + current_user.current_tenant_id, + app_id, + annotation_setting.collection_binding_id, + ) return annotation @classmethod def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") @@ -207,30 +211,34 @@ def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: if not annotation: raise NotFound("Annotation not found") - annotation.content = args['answer'] - annotation.question = args['question'] + annotation.content = args["answer"] + annotation.question = args["question"] db.session.commit() # if annotation reply is enabled , add annotation to index - app_annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id - ).first() + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if app_annotation_setting: - update_annotation_to_index_task.delay(annotation.id, annotation.question, - current_user.current_tenant_id, - app_id, app_annotation_setting.collection_binding_id) + update_annotation_to_index_task.delay( + annotation.id, + annotation.question, + current_user.current_tenant_id, + app_id, + app_annotation_setting.collection_binding_id, + ) return annotation @classmethod def delete_app_annotation(cls, app_id: str, annotation_id: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") @@ -242,33 +250,34 @@ def delete_app_annotation(cls, app_id: str, annotation_id: str): db.session.delete(annotation) - annotation_hit_histories = (db.session.query(AppAnnotationHitHistory) - .filter(AppAnnotationHitHistory.annotation_id == annotation_id) - .all() - ) + annotation_hit_histories = ( + db.session.query(AppAnnotationHitHistory) + .filter(AppAnnotationHitHistory.annotation_id == annotation_id) + .all() + ) if annotation_hit_histories: for annotation_hit_history in annotation_hit_histories: db.session.delete(annotation_hit_history) db.session.commit() # if annotation reply is enabled , delete annotation index - app_annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id - ).first() + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if app_annotation_setting: - delete_annotation_index_task.delay(annotation.id, app_id, - current_user.current_tenant_id, - app_annotation_setting.collection_binding_id) + delete_annotation_index_task.delay( + annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id + ) @classmethod def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict: # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") @@ -278,10 +287,7 @@ def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict: df = pd.read_csv(file) result = [] for index, row in df.iterrows(): - content = { - 'question': row[0], - 'answer': row[1] - } + content = {"question": row[0], "answer": row[1]} result.append(content) if len(result) == 0: raise ValueError("The CSV file is empty.") @@ -293,28 +299,24 @@ def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict: raise ValueError("The number of annotations exceeds the limit of your subscription.") # async job job_id = str(uuid.uuid4()) - indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id)) + indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) # send batch add segments task - redis_client.setnx(indexing_cache_key, 'waiting') - batch_import_annotations_task.delay(str(job_id), result, app_id, - current_user.current_tenant_id, current_user.id) + redis_client.setnx(indexing_cache_key, "waiting") + batch_import_annotations_task.delay( + str(job_id), result, app_id, current_user.current_tenant_id, current_user.id + ) except Exception as e: - return { - 'error_msg': str(e) - } - return { - 'job_id': job_id, - 'job_status': 'waiting' - } + return {"error_msg": str(e)} + return {"job_id": job_id, "job_status": "waiting"} @classmethod def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") @@ -324,12 +326,15 @@ def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, lim if not annotation: raise NotFound("Annotation not found") - annotation_hit_histories = (db.session.query(AppAnnotationHitHistory) - .filter(AppAnnotationHitHistory.app_id == app_id, - AppAnnotationHitHistory.annotation_id == annotation_id, - ) - .order_by(AppAnnotationHitHistory.created_at.desc()) - .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) + annotation_hit_histories = ( + db.session.query(AppAnnotationHitHistory) + .filter( + AppAnnotationHitHistory.app_id == app_id, + AppAnnotationHitHistory.annotation_id == annotation_id, + ) + .order_by(AppAnnotationHitHistory.created_at.desc()) + .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + ) return annotation_hit_histories.items, annotation_hit_histories.total @classmethod @@ -341,15 +346,21 @@ def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None: return annotation @classmethod - def add_annotation_history(cls, annotation_id: str, app_id: str, annotation_question: str, - annotation_content: str, query: str, user_id: str, - message_id: str, from_source: str, score: float): + def add_annotation_history( + cls, + annotation_id: str, + app_id: str, + annotation_question: str, + annotation_content: str, + query: str, + user_id: str, + message_id: str, + from_source: str, + score: float, + ): # add hit count to annotation - db.session.query(MessageAnnotation).filter( - MessageAnnotation.id == annotation_id - ).update( - {MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, - synchronize_session=False + db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).update( + {MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, synchronize_session=False ) annotation_hit_history = AppAnnotationHitHistory( @@ -361,7 +372,7 @@ def add_annotation_history(cls, annotation_id: str, app_id: str, annotation_ques score=score, message_id=message_id, annotation_question=annotation_question, - annotation_content=annotation_content + annotation_content=annotation_content, ) db.session.add(annotation_hit_history) db.session.commit() @@ -369,17 +380,18 @@ def add_annotation_history(cls, annotation_id: str, app_id: str, annotation_ques @classmethod def get_app_annotation_setting_by_app_id(cls, app_id: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if annotation_setting: collection_binding_detail = annotation_setting.collection_binding_detail return { @@ -388,32 +400,34 @@ def get_app_annotation_setting_by_app_id(cls, app_id: str): "score_threshold": annotation_setting.score_threshold, "embedding_model": { "embedding_provider_name": collection_binding_detail.provider_name, - "embedding_model_name": collection_binding_detail.model_name - } + "embedding_model_name": collection_binding_detail.model_name, + }, } - return { - "enabled": False - } + return {"enabled": False} @classmethod def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id, - AppAnnotationSetting.id == annotation_setting_id, - ).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting) + .filter( + AppAnnotationSetting.app_id == app_id, + AppAnnotationSetting.id == annotation_setting_id, + ) + .first() + ) if not annotation_setting: raise NotFound("App annotation not found") - annotation_setting.score_threshold = args['score_threshold'] + annotation_setting.score_threshold = args["score_threshold"] annotation_setting.updated_user_id = current_user.id annotation_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(annotation_setting) @@ -427,6 +441,6 @@ def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, "score_threshold": annotation_setting.score_threshold, "embedding_model": { "embedding_provider_name": collection_binding_detail.provider_name, - "embedding_model_name": collection_binding_detail.model_name - } + "embedding_model_name": collection_binding_detail.model_name, + }, } diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py index 8441bbedb344bc..601d67d2fba4e3 100644 --- a/api/services/api_based_extension_service.py +++ b/api/services/api_based_extension_service.py @@ -5,13 +5,14 @@ class APIBasedExtensionService: - @staticmethod def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]: - extension_list = db.session.query(APIBasedExtension) \ - .filter_by(tenant_id=tenant_id) \ - .order_by(APIBasedExtension.created_at.desc()) \ - .all() + extension_list = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=tenant_id) + .order_by(APIBasedExtension.created_at.desc()) + .all() + ) for extension in extension_list: extension.api_key = decrypt_token(extension.tenant_id, extension.api_key) @@ -35,10 +36,12 @@ def delete(extension_data: APIBasedExtension) -> None: @staticmethod def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: - extension = db.session.query(APIBasedExtension) \ - .filter_by(tenant_id=tenant_id) \ - .filter_by(id=api_based_extension_id) \ + extension = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=tenant_id) + .filter_by(id=api_based_extension_id) .first() + ) if not extension: raise ValueError("API based extension is not found") @@ -55,20 +58,24 @@ def _validation(cls, extension_data: APIBasedExtension) -> None: if not extension_data.id: # case one: check new data, name must be unique - is_name_existed = db.session.query(APIBasedExtension) \ - .filter_by(tenant_id=extension_data.tenant_id) \ - .filter_by(name=extension_data.name) \ + is_name_existed = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=extension_data.tenant_id) + .filter_by(name=extension_data.name) .first() + ) if is_name_existed: raise ValueError("name must be unique, it is already existed") else: # case two: check existing data, name must be unique - is_name_existed = db.session.query(APIBasedExtension) \ - .filter_by(tenant_id=extension_data.tenant_id) \ - .filter_by(name=extension_data.name) \ - .filter(APIBasedExtension.id != extension_data.id) \ + is_name_existed = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=extension_data.tenant_id) + .filter_by(name=extension_data.name) + .filter(APIBasedExtension.id != extension_data.id) .first() + ) if is_name_existed: raise ValueError("name must be unique, it is already existed") @@ -92,7 +99,7 @@ def _ping_connection(extension_data: APIBasedExtension) -> None: try: client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key) resp = client.request(point=APIBasedExtensionPoint.PING, params={}) - if resp.get('result') != 'pong': + if resp.get("result") != "pong": raise ValueError(resp) except Exception as e: raise ValueError("connection error: {}".format(e)) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py deleted file mode 100644 index bfb160b3e476d9..00000000000000 --- a/api/services/app_dsl_service.py +++ /dev/null @@ -1,417 +0,0 @@ -import logging - -import httpx -import yaml # type: ignore - -from core.app.segments import factory -from events.app_event import app_model_config_was_updated, app_was_created -from extensions.ext_database import db -from models.account import Account -from models.model import App, AppMode, AppModelConfig -from models.workflow import Workflow -from services.workflow_service import WorkflowService - -logger = logging.getLogger(__name__) - -current_dsl_version = "0.1.1" -dsl_to_dify_version_mapping: dict[str, str] = { - "0.1.1": "0.6.0", # dsl version -> from dify version -} - - -class AppDslService: - @classmethod - def import_and_create_new_app_from_url(cls, tenant_id: str, url: str, args: dict, account: Account) -> App: - """ - Import app dsl from url and create new app - :param tenant_id: tenant id - :param url: import url - :param args: request args - :param account: Account instance - """ - try: - max_size = 10 * 1024 * 1024 # 10MB - timeout = httpx.Timeout(10.0) - with httpx.stream("GET", url.strip(), follow_redirects=True, timeout=timeout) as response: - response.raise_for_status() - total_size = 0 - content = b"" - for chunk in response.iter_bytes(): - total_size += len(chunk) - if total_size > max_size: - raise ValueError("File size exceeds the limit of 10MB") - content += chunk - except httpx.HTTPStatusError as http_err: - raise ValueError(f"HTTP error occurred: {http_err}") - except httpx.RequestError as req_err: - raise ValueError(f"Request error occurred: {req_err}") - except Exception as e: - raise ValueError(f"Failed to fetch DSL from URL: {e}") - - if not content: - raise ValueError("Empty content from url") - - try: - data = content.decode("utf-8") - except UnicodeDecodeError as e: - raise ValueError(f"Error decoding content: {e}") - - return cls.import_and_create_new_app(tenant_id, data, args, account) - - @classmethod - def import_and_create_new_app(cls, tenant_id: str, data: str, args: dict, account: Account) -> App: - """ - Import app dsl and create new app - :param tenant_id: tenant id - :param data: import data - :param args: request args - :param account: Account instance - """ - try: - import_data = yaml.safe_load(data) - except yaml.YAMLError: - raise ValueError("Invalid YAML format in data argument.") - - # check or repair dsl version - import_data = cls._check_or_fix_dsl(import_data) - - app_data = import_data.get('app') - if not app_data: - raise ValueError("Missing app in data argument") - - # get app basic info - name = args.get("name") if args.get("name") else app_data.get('name') - description = args.get("description") if args.get("description") else app_data.get('description', '') - icon = args.get("icon") if args.get("icon") else app_data.get('icon') - icon_background = args.get("icon_background") if args.get("icon_background") \ - else app_data.get('icon_background') - - # import dsl and create app - app_mode = AppMode.value_of(app_data.get('mode')) - if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: - app = cls._import_and_create_new_workflow_based_app( - tenant_id=tenant_id, - app_mode=app_mode, - workflow_data=import_data.get('workflow'), - account=account, - name=name, - description=description, - icon=icon, - icon_background=icon_background - ) - elif app_mode in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]: - app = cls._import_and_create_new_model_config_based_app( - tenant_id=tenant_id, - app_mode=app_mode, - model_config_data=import_data.get('model_config'), - account=account, - name=name, - description=description, - icon=icon, - icon_background=icon_background - ) - else: - raise ValueError("Invalid app mode") - - return app - - @classmethod - def import_and_overwrite_workflow(cls, app_model: App, data: str, account: Account) -> Workflow: - """ - Import app dsl and overwrite workflow - :param app_model: App instance - :param data: import data - :param account: Account instance - """ - try: - import_data = yaml.safe_load(data) - except yaml.YAMLError: - raise ValueError("Invalid YAML format in data argument.") - - # check or repair dsl version - import_data = cls._check_or_fix_dsl(import_data) - - app_data = import_data.get('app') - if not app_data: - raise ValueError("Missing app in data argument") - - # import dsl and overwrite app - app_mode = AppMode.value_of(app_data.get('mode')) - if app_mode not in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: - raise ValueError("Only support import workflow in advanced-chat or workflow app.") - - if app_data.get('mode') != app_model.mode: - raise ValueError( - f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}") - - return cls._import_and_overwrite_workflow_based_app( - app_model=app_model, - workflow_data=import_data.get('workflow'), - account=account, - ) - - @classmethod - def export_dsl(cls, app_model: App, include_secret:bool = False) -> str: - """ - Export app - :param app_model: App instance - :return: - """ - app_mode = AppMode.value_of(app_model.mode) - - export_data = { - "version": current_dsl_version, - "kind": "app", - "app": { - "name": app_model.name, - "mode": app_model.mode, - "icon": app_model.icon, - "icon_background": app_model.icon_background, - "description": app_model.description - } - } - - if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: - cls._append_workflow_export_data(export_data=export_data, app_model=app_model, include_secret=include_secret) - else: - cls._append_model_config_export_data(export_data, app_model) - - return yaml.dump(export_data, allow_unicode=True) - - @classmethod - def _check_or_fix_dsl(cls, import_data: dict) -> dict: - """ - Check or fix dsl - - :param import_data: import data - """ - if not import_data.get('version'): - import_data['version'] = "0.1.0" - - if not import_data.get('kind') or import_data.get('kind') != "app": - import_data['kind'] = "app" - - if import_data.get('version') != current_dsl_version: - # Currently only one DSL version, so no difference checks or compatibility fixes will be performed. - logger.warning(f"DSL version {import_data.get('version')} is not compatible " - f"with current version {current_dsl_version}, related to " - f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}.") - - return import_data - - @classmethod - def _import_and_create_new_workflow_based_app(cls, - tenant_id: str, - app_mode: AppMode, - workflow_data: dict, - account: Account, - name: str, - description: str, - icon: str, - icon_background: str) -> App: - """ - Import app dsl and create new workflow based app - - :param tenant_id: tenant id - :param app_mode: app mode - :param workflow_data: workflow data - :param account: Account instance - :param name: app name - :param description: app description - :param icon: app icon - :param icon_background: app icon background - """ - if not workflow_data: - raise ValueError("Missing workflow in data argument " - "when app mode is advanced-chat or workflow") - - app = cls._create_app( - tenant_id=tenant_id, - app_mode=app_mode, - account=account, - name=name, - description=description, - icon=icon, - icon_background=icon_background - ) - - # init draft workflow - environment_variables_list = workflow_data.get('environment_variables') or [] - environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] - conversation_variables_list = workflow_data.get('conversation_variables') or [] - conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] - workflow_service = WorkflowService() - draft_workflow = workflow_service.sync_draft_workflow( - app_model=app, - graph=workflow_data.get('graph', {}), - features=workflow_data.get('../core/app/features', {}), - unique_hash=None, - account=account, - environment_variables=environment_variables, - conversation_variables=conversation_variables, - ) - workflow_service.publish_workflow( - app_model=app, - account=account, - draft_workflow=draft_workflow - ) - - return app - - @classmethod - def _import_and_overwrite_workflow_based_app(cls, - app_model: App, - workflow_data: dict, - account: Account) -> Workflow: - """ - Import app dsl and overwrite workflow based app - - :param app_model: App instance - :param workflow_data: workflow data - :param account: Account instance - """ - if not workflow_data: - raise ValueError("Missing workflow in data argument " - "when app mode is advanced-chat or workflow") - - # fetch draft workflow by app_model - workflow_service = WorkflowService() - current_draft_workflow = workflow_service.get_draft_workflow(app_model=app_model) - if current_draft_workflow: - unique_hash = current_draft_workflow.unique_hash - else: - unique_hash = None - - # sync draft workflow - environment_variables_list = workflow_data.get('environment_variables') or [] - environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] - conversation_variables_list = workflow_data.get('conversation_variables') or [] - conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] - draft_workflow = workflow_service.sync_draft_workflow( - app_model=app_model, - graph=workflow_data.get('graph', {}), - features=workflow_data.get('features', {}), - unique_hash=unique_hash, - account=account, - environment_variables=environment_variables, - conversation_variables=conversation_variables, - ) - - return draft_workflow - - @classmethod - def _import_and_create_new_model_config_based_app(cls, - tenant_id: str, - app_mode: AppMode, - model_config_data: dict, - account: Account, - name: str, - description: str, - icon: str, - icon_background: str) -> App: - """ - Import app dsl and create new model config based app - - :param tenant_id: tenant id - :param app_mode: app mode - :param model_config_data: model config data - :param account: Account instance - :param name: app name - :param description: app description - :param icon: app icon - :param icon_background: app icon background - """ - if not model_config_data: - raise ValueError("Missing model_config in data argument " - "when app mode is chat, agent-chat or completion") - - app = cls._create_app( - tenant_id=tenant_id, - app_mode=app_mode, - account=account, - name=name, - description=description, - icon=icon, - icon_background=icon_background - ) - - app_model_config = AppModelConfig() - app_model_config = app_model_config.from_model_config_dict(model_config_data) - app_model_config.app_id = app.id - - db.session.add(app_model_config) - db.session.commit() - - app.app_model_config_id = app_model_config.id - - app_model_config_was_updated.send( - app, - app_model_config=app_model_config - ) - - return app - - @classmethod - def _create_app(cls, - tenant_id: str, - app_mode: AppMode, - account: Account, - name: str, - description: str, - icon: str, - icon_background: str) -> App: - """ - Create new app - - :param tenant_id: tenant id - :param app_mode: app mode - :param account: Account instance - :param name: app name - :param description: app description - :param icon: app icon - :param icon_background: app icon background - """ - app = App( - tenant_id=tenant_id, - mode=app_mode.value, - name=name, - description=description, - icon=icon, - icon_background=icon_background, - enable_site=True, - enable_api=True - ) - - db.session.add(app) - db.session.commit() - - app_was_created.send(app, account=account) - - return app - - @classmethod - def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None: - """ - Append workflow export data - :param export_data: export data - :param app_model: App instance - """ - workflow_service = WorkflowService() - workflow = workflow_service.get_draft_workflow(app_model) - if not workflow: - raise ValueError("Missing draft workflow configuration, please check.") - - export_data['workflow'] = workflow.to_dict(include_secret=include_secret) - - @classmethod - def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None: - """ - Append model config export data - :param export_data: export data - :param app_model: App instance - """ - app_model_config = app_model.app_model_config - if not app_model_config: - raise ValueError("Missing app configuration, please check.") - - export_data['model_config'] = app_model_config.to_dict() diff --git a/api/services/app_dsl_service/__init__.py b/api/services/app_dsl_service/__init__.py new file mode 100644 index 00000000000000..9fc988ffb36266 --- /dev/null +++ b/api/services/app_dsl_service/__init__.py @@ -0,0 +1,3 @@ +from .service import AppDslService + +__all__ = ["AppDslService"] diff --git a/api/services/app_dsl_service/exc.py b/api/services/app_dsl_service/exc.py new file mode 100644 index 00000000000000..6da4b1938f3cf2 --- /dev/null +++ b/api/services/app_dsl_service/exc.py @@ -0,0 +1,34 @@ +class DSLVersionNotSupportedError(ValueError): + """Raised when the imported DSL version is not supported by the current Dify version.""" + + +class InvalidYAMLFormatError(ValueError): + """Raised when the provided YAML format is invalid.""" + + +class MissingAppDataError(ValueError): + """Raised when the app data is missing in the provided DSL.""" + + +class InvalidAppModeError(ValueError): + """Raised when the app mode is invalid.""" + + +class MissingWorkflowDataError(ValueError): + """Raised when the workflow data is missing in the provided DSL.""" + + +class MissingModelConfigError(ValueError): + """Raised when the model config data is missing in the provided DSL.""" + + +class FileSizeLimitExceededError(ValueError): + """Raised when the file size exceeds the allowed limit.""" + + +class EmptyContentError(ValueError): + """Raised when the content fetched from the URL is empty.""" + + +class ContentDecodingError(ValueError): + """Raised when there is an error decoding the content.""" diff --git a/api/services/app_dsl_service/service.py b/api/services/app_dsl_service/service.py new file mode 100644 index 00000000000000..e6b0d9a2725b0f --- /dev/null +++ b/api/services/app_dsl_service/service.py @@ -0,0 +1,484 @@ +import logging +from collections.abc import Mapping +from typing import Any + +import yaml +from packaging import version + +from core.helper import ssrf_proxy +from events.app_event import app_model_config_was_updated, app_was_created +from extensions.ext_database import db +from factories import variable_factory +from models.account import Account +from models.model import App, AppMode, AppModelConfig +from models.workflow import Workflow +from services.workflow_service import WorkflowService + +from .exc import ( + ContentDecodingError, + EmptyContentError, + FileSizeLimitExceededError, + InvalidAppModeError, + InvalidYAMLFormatError, + MissingAppDataError, + MissingModelConfigError, + MissingWorkflowDataError, +) + +logger = logging.getLogger(__name__) + +current_dsl_version = "0.1.3" + + +class AppDslService: + @classmethod + def import_and_create_new_app_from_url(cls, tenant_id: str, url: str, args: dict, account: Account) -> App: + """ + Import app dsl from url and create new app + :param tenant_id: tenant id + :param url: import url + :param args: request args + :param account: Account instance + """ + max_size = 10 * 1024 * 1024 # 10MB + response = ssrf_proxy.get(url.strip(), follow_redirects=True, timeout=(10, 10)) + response.raise_for_status() + content = response.content + + if len(content) > max_size: + raise FileSizeLimitExceededError("File size exceeds the limit of 10MB") + + if not content: + raise EmptyContentError("Empty content from url") + + try: + data = content.decode("utf-8") + except UnicodeDecodeError as e: + raise ContentDecodingError(f"Error decoding content: {e}") + + return cls.import_and_create_new_app(tenant_id, data, args, account) + + @classmethod + def import_and_create_new_app(cls, tenant_id: str, data: str, args: dict, account: Account) -> App: + """ + Import app dsl and create new app + :param tenant_id: tenant id + :param data: import data + :param args: request args + :param account: Account instance + """ + try: + import_data = yaml.safe_load(data) + except yaml.YAMLError: + raise InvalidYAMLFormatError("Invalid YAML format in data argument.") + + # check or repair dsl version + import_data = _check_or_fix_dsl(import_data) + + app_data = import_data.get("app") + if not app_data: + raise MissingAppDataError("Missing app in data argument") + + # get app basic info + name = args.get("name") or app_data.get("name") + description = args.get("description") or app_data.get("description", "") + icon_type = args.get("icon_type") or app_data.get("icon_type") + icon = args.get("icon") or app_data.get("icon") + icon_background = args.get("icon_background") or app_data.get("icon_background") + use_icon_as_answer_icon = app_data.get("use_icon_as_answer_icon", False) + + # import dsl and create app + app_mode = AppMode.value_of(app_data.get("mode")) + + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + workflow_data = import_data.get("workflow") + if not workflow_data or not isinstance(workflow_data, dict): + raise MissingWorkflowDataError( + "Missing workflow in data argument when app mode is advanced-chat or workflow" + ) + + app = cls._import_and_create_new_workflow_based_app( + tenant_id=tenant_id, + app_mode=app_mode, + workflow_data=workflow_data, + account=account, + name=name, + description=description, + icon_type=icon_type, + icon=icon, + icon_background=icon_background, + use_icon_as_answer_icon=use_icon_as_answer_icon, + ) + elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}: + model_config = import_data.get("model_config") + if not model_config or not isinstance(model_config, dict): + raise MissingModelConfigError( + "Missing model_config in data argument when app mode is chat, agent-chat or completion" + ) + + app = cls._import_and_create_new_model_config_based_app( + tenant_id=tenant_id, + app_mode=app_mode, + model_config_data=model_config, + account=account, + name=name, + description=description, + icon_type=icon_type, + icon=icon, + icon_background=icon_background, + use_icon_as_answer_icon=use_icon_as_answer_icon, + ) + else: + raise InvalidAppModeError("Invalid app mode") + + return app + + @classmethod + def import_and_overwrite_workflow(cls, app_model: App, data: str, account: Account) -> Workflow: + """ + Import app dsl and overwrite workflow + :param app_model: App instance + :param data: import data + :param account: Account instance + """ + try: + import_data = yaml.safe_load(data) + except yaml.YAMLError: + raise InvalidYAMLFormatError("Invalid YAML format in data argument.") + + # check or repair dsl version + import_data = _check_or_fix_dsl(import_data) + + app_data = import_data.get("app") + if not app_data: + raise MissingAppDataError("Missing app in data argument") + + # import dsl and overwrite app + app_mode = AppMode.value_of(app_data.get("mode")) + if app_mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + raise InvalidAppModeError("Only support import workflow in advanced-chat or workflow app.") + + if app_data.get("mode") != app_model.mode: + raise ValueError(f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}") + + workflow_data = import_data.get("workflow") + if not workflow_data or not isinstance(workflow_data, dict): + raise MissingWorkflowDataError( + "Missing workflow in data argument when app mode is advanced-chat or workflow" + ) + + return cls._import_and_overwrite_workflow_based_app( + app_model=app_model, + workflow_data=workflow_data, + account=account, + ) + + @classmethod + def export_dsl(cls, app_model: App, include_secret: bool = False) -> str: + """ + Export app + :param app_model: App instance + :return: + """ + app_mode = AppMode.value_of(app_model.mode) + + export_data = { + "version": current_dsl_version, + "kind": "app", + "app": { + "name": app_model.name, + "mode": app_model.mode, + "icon": "🤖" if app_model.icon_type == "image" else app_model.icon, + "icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background, + "description": app_model.description, + "use_icon_as_answer_icon": app_model.use_icon_as_answer_icon, + }, + } + + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + cls._append_workflow_export_data( + export_data=export_data, app_model=app_model, include_secret=include_secret + ) + else: + cls._append_model_config_export_data(export_data, app_model) + + return yaml.dump(export_data, allow_unicode=True) + + @classmethod + def _import_and_create_new_workflow_based_app( + cls, + tenant_id: str, + app_mode: AppMode, + workflow_data: Mapping[str, Any], + account: Account, + name: str, + description: str, + icon_type: str, + icon: str, + icon_background: str, + use_icon_as_answer_icon: bool, + ) -> App: + """ + Import app dsl and create new workflow based app + + :param tenant_id: tenant id + :param app_mode: app mode + :param workflow_data: workflow data + :param account: Account instance + :param name: app name + :param description: app description + :param icon_type: app icon type, "emoji" or "image" + :param icon: app icon + :param icon_background: app icon background + :param use_icon_as_answer_icon: use app icon as answer icon + """ + if not workflow_data: + raise MissingWorkflowDataError( + "Missing workflow in data argument when app mode is advanced-chat or workflow" + ) + + app = cls._create_app( + tenant_id=tenant_id, + app_mode=app_mode, + account=account, + name=name, + description=description, + icon_type=icon_type, + icon=icon, + icon_background=icon_background, + use_icon_as_answer_icon=use_icon_as_answer_icon, + ) + + # init draft workflow + environment_variables_list = workflow_data.get("environment_variables") or [] + environment_variables = [ + variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list + ] + conversation_variables_list = workflow_data.get("conversation_variables") or [] + conversation_variables = [ + variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list + ] + workflow_service = WorkflowService() + draft_workflow = workflow_service.sync_draft_workflow( + app_model=app, + graph=workflow_data.get("graph", {}), + features=workflow_data.get("features", {}), + unique_hash=None, + account=account, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + ) + workflow_service.publish_workflow(app_model=app, account=account, draft_workflow=draft_workflow) + + return app + + @classmethod + def _import_and_overwrite_workflow_based_app( + cls, app_model: App, workflow_data: Mapping[str, Any], account: Account + ) -> Workflow: + """ + Import app dsl and overwrite workflow based app + + :param app_model: App instance + :param workflow_data: workflow data + :param account: Account instance + """ + if not workflow_data: + raise MissingWorkflowDataError( + "Missing workflow in data argument when app mode is advanced-chat or workflow" + ) + + # fetch draft workflow by app_model + workflow_service = WorkflowService() + current_draft_workflow = workflow_service.get_draft_workflow(app_model=app_model) + if current_draft_workflow: + unique_hash = current_draft_workflow.unique_hash + else: + unique_hash = None + + # sync draft workflow + environment_variables_list = workflow_data.get("environment_variables") or [] + environment_variables = [ + variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list + ] + conversation_variables_list = workflow_data.get("conversation_variables") or [] + conversation_variables = [ + variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list + ] + draft_workflow = workflow_service.sync_draft_workflow( + app_model=app_model, + graph=workflow_data.get("graph", {}), + features=workflow_data.get("features", {}), + unique_hash=unique_hash, + account=account, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + ) + + return draft_workflow + + @classmethod + def _import_and_create_new_model_config_based_app( + cls, + tenant_id: str, + app_mode: AppMode, + model_config_data: Mapping[str, Any], + account: Account, + name: str, + description: str, + icon_type: str, + icon: str, + icon_background: str, + use_icon_as_answer_icon: bool, + ) -> App: + """ + Import app dsl and create new model config based app + + :param tenant_id: tenant id + :param app_mode: app mode + :param model_config_data: model config data + :param account: Account instance + :param name: app name + :param description: app description + :param icon: app icon + :param icon_background: app icon background + """ + if not model_config_data: + raise MissingModelConfigError( + "Missing model_config in data argument when app mode is chat, agent-chat or completion" + ) + + app = cls._create_app( + tenant_id=tenant_id, + app_mode=app_mode, + account=account, + name=name, + description=description, + icon_type=icon_type, + icon=icon, + icon_background=icon_background, + use_icon_as_answer_icon=use_icon_as_answer_icon, + ) + + app_model_config = AppModelConfig() + app_model_config = app_model_config.from_model_config_dict(model_config_data) + app_model_config.app_id = app.id + app_model_config.created_by = account.id + app_model_config.updated_by = account.id + + db.session.add(app_model_config) + db.session.commit() + + app.app_model_config_id = app_model_config.id + + app_model_config_was_updated.send(app, app_model_config=app_model_config) + + return app + + @classmethod + def _create_app( + cls, + tenant_id: str, + app_mode: AppMode, + account: Account, + name: str, + description: str, + icon_type: str, + icon: str, + icon_background: str, + use_icon_as_answer_icon: bool, + ) -> App: + """ + Create new app + + :param tenant_id: tenant id + :param app_mode: app mode + :param account: Account instance + :param name: app name + :param description: app description + :param icon_type: app icon type, "emoji" or "image" + :param icon: app icon + :param icon_background: app icon background + :param use_icon_as_answer_icon: use app icon as answer icon + """ + app = App( + tenant_id=tenant_id, + mode=app_mode.value, + name=name, + description=description, + icon_type=icon_type, + icon=icon, + icon_background=icon_background, + enable_site=True, + enable_api=True, + use_icon_as_answer_icon=use_icon_as_answer_icon, + created_by=account.id, + updated_by=account.id, + ) + + db.session.add(app) + db.session.commit() + + app_was_created.send(app, account=account) + + return app + + @classmethod + def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None: + """ + Append workflow export data + :param export_data: export data + :param app_model: App instance + """ + workflow_service = WorkflowService() + workflow = workflow_service.get_draft_workflow(app_model) + if not workflow: + raise ValueError("Missing draft workflow configuration, please check.") + + export_data["workflow"] = workflow.to_dict(include_secret=include_secret) + + @classmethod + def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None: + """ + Append model config export data + :param export_data: export data + :param app_model: App instance + """ + app_model_config = app_model.app_model_config + if not app_model_config: + raise ValueError("Missing app configuration, please check.") + + export_data["model_config"] = app_model_config.to_dict() + + +def _check_or_fix_dsl(import_data: dict[str, Any]) -> Mapping[str, Any]: + """ + Check or fix dsl + + :param import_data: import data + :raises DSLVersionNotSupportedError: if the imported DSL version is newer than the current version + """ + if not import_data.get("version"): + import_data["version"] = "0.1.0" + + if not import_data.get("kind") or import_data.get("kind") != "app": + import_data["kind"] = "app" + + imported_version = import_data.get("version") + if imported_version != current_dsl_version: + if imported_version and version.parse(imported_version) > version.parse(current_dsl_version): + errmsg = ( + f"The imported DSL version {imported_version} is newer than " + f"the current supported version {current_dsl_version}. " + f"Please upgrade your Dify instance to import this configuration." + ) + logger.warning(errmsg) + # raise DSLVersionNotSupportedError(errmsg) + else: + logger.warning( + f"DSL version {imported_version} is older than " + f"the current version {current_dsl_version}. " + f"This may cause compatibility issues." + ) + + return import_data diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index cff4ba8af9dd46..83a9a16904186a 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -1,6 +1,8 @@ -from collections.abc import Generator +from collections.abc import Generator, Mapping from typing import Any, Union +from openai._exceptions import RateLimitError + from configs import dify_config from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator @@ -10,18 +12,21 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting import RateLimit from models.model import Account, App, AppMode, EndUser +from models.workflow import Workflow +from services.errors.llm import InvokeRateLimitError from services.workflow_service import WorkflowService class AppGenerateService: - @classmethod - def generate(cls, app_model: App, - user: Union[Account, EndUser], - args: Any, - invoke_from: InvokeFrom, - streaming: bool = True, - ): + def generate( + cls, + app_model: App, + user: Union[Account, EndUser], + args: Mapping[str, Any], + invoke_from: InvokeFrom, + streaming: bool = True, + ): """ App Content Generate :param app_model: app model @@ -37,51 +42,56 @@ def generate(cls, app_model: App, try: request_id = rate_limit.enter(request_id) if app_model.mode == AppMode.COMPLETION.value: - return rate_limit.generate(CompletionAppGenerator().generate( - app_model=app_model, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + CompletionAppGenerator().generate( + app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming + ), + request_id, + ) elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: - return rate_limit.generate(AgentChatAppGenerator().generate( - app_model=app_model, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + AgentChatAppGenerator().generate( + app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming + ), + request_id, + ) elif app_model.mode == AppMode.CHAT.value: - return rate_limit.generate(ChatAppGenerator().generate( - app_model=app_model, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + ChatAppGenerator().generate( + app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming + ), + request_id, + ) elif app_model.mode == AppMode.ADVANCED_CHAT.value: workflow = cls._get_workflow(app_model, invoke_from) - return rate_limit.generate(AdvancedChatAppGenerator().generate( - app_model=app_model, - workflow=workflow, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + AdvancedChatAppGenerator().generate( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + stream=streaming, + ), + request_id, + ) elif app_model.mode == AppMode.WORKFLOW.value: workflow = cls._get_workflow(app_model, invoke_from) - return rate_limit.generate(WorkflowAppGenerator().generate( - app_model=app_model, - workflow=workflow, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + WorkflowAppGenerator().generate( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + stream=streaming, + ), + request_id, + ) else: - raise ValueError(f'Invalid app mode {app_model.mode}') + raise ValueError(f"Invalid app mode {app_model.mode}") + except RateLimitError as e: + raise InvokeRateLimitError(str(e)) finally: if not streaming: rate_limit.exit(request_id) @@ -94,38 +104,29 @@ def _get_max_active_requests(app_model: App) -> int: return max_active_requests @classmethod - def generate_single_iteration(cls, app_model: App, - user: Union[Account, EndUser], - node_id: str, - args: Any, - streaming: bool = True): + def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True): if app_model.mode == AppMode.ADVANCED_CHAT.value: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator().single_iteration_generate( - app_model=app_model, - workflow=workflow, - node_id=node_id, - user=user, - args=args, - stream=streaming + app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming ) elif app_model.mode == AppMode.WORKFLOW.value: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return WorkflowAppGenerator().single_iteration_generate( - app_model=app_model, - workflow=workflow, - node_id=node_id, - user=user, - args=args, - stream=streaming + app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming ) else: - raise ValueError(f'Invalid app mode {app_model.mode}') + raise ValueError(f"Invalid app mode {app_model.mode}") @classmethod - def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], - message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \ - -> Union[dict, Generator]: + def generate_more_like_this( + cls, + app_model: App, + user: Union[Account, EndUser], + message_id: str, + invoke_from: InvokeFrom, + streaming: bool = True, + ) -> Union[dict, Generator]: """ Generate more like this :param app_model: app model @@ -136,15 +137,11 @@ def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], :return: """ return CompletionAppGenerator().generate_more_like_this( - app_model=app_model, - message_id=message_id, - user=user, - invoke_from=invoke_from, - stream=streaming + app_model=app_model, message_id=message_id, user=user, invoke_from=invoke_from, stream=streaming ) @classmethod - def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Any: + def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Workflow: """ Get workflow :param app_model: app model @@ -157,12 +154,12 @@ def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Any: workflow = workflow_service.get_draft_workflow(app_model=app_model) if not workflow: - raise ValueError('Workflow not initialized') + raise ValueError("Workflow not initialized") else: # fetch published workflow by app_model workflow = workflow_service.get_published_workflow(app_model=app_model) if not workflow: - raise ValueError('Workflow not published') + raise ValueError("Workflow not published") return workflow diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index c84f6fbf454daf..a1ad2710534a84 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -5,7 +5,6 @@ class AppModelConfigService: - @classmethod def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict: if app_mode == AppMode.CHAT: diff --git a/api/services/app_service.py b/api/services/app_service.py index e433bb59bbe994..ac45d623e84bc9 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -33,27 +33,22 @@ def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination | None: :param args: request args :return: """ - filters = [ - App.tenant_id == tenant_id, - App.is_universal == False - ] + filters = [App.tenant_id == tenant_id, App.is_universal == False] - if args['mode'] == 'workflow': + if args["mode"] == "workflow": filters.append(App.mode.in_([AppMode.WORKFLOW.value, AppMode.COMPLETION.value])) - elif args['mode'] == 'chat': + elif args["mode"] == "chat": filters.append(App.mode.in_([AppMode.CHAT.value, AppMode.ADVANCED_CHAT.value])) - elif args['mode'] == 'agent-chat': + elif args["mode"] == "agent-chat": filters.append(App.mode == AppMode.AGENT_CHAT.value) - elif args['mode'] == 'channel': + elif args["mode"] == "channel": filters.append(App.mode == AppMode.CHANNEL.value) - if args.get('name'): - name = args['name'][:30] - filters.append(App.name.ilike(f'%{name}%')) - if args.get('tag_ids'): - target_ids = TagService.get_target_ids_by_tag_ids('app', - tenant_id, - args['tag_ids']) + if args.get("name"): + name = args["name"][:30] + filters.append(App.name.ilike(f"%{name}%")) + if args.get("tag_ids"): + target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"]) if target_ids: filters.append(App.id.in_(target_ids)) else: @@ -61,9 +56,9 @@ def get_paginate_apps(self, tenant_id: str, args: dict) -> Pagination | None: app_models = db.paginate( db.select(App).where(*filters).order_by(App.created_at.desc()), - page=args['page'], - per_page=args['limit'], - error_out=False + page=args["page"], + per_page=args["limit"], + error_out=False, ) return app_models @@ -75,21 +70,20 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App: :param args: request args :param account: Account instance """ - app_mode = AppMode.value_of(args['mode']) + app_mode = AppMode.value_of(args["mode"]) app_template = default_app_templates[app_mode] # get model config - default_model_config = app_template.get('model_config') + default_model_config = app_template.get("model_config") default_model_config = default_model_config.copy() if default_model_config else None - if default_model_config and 'model' in default_model_config: + if default_model_config and "model" in default_model_config: # get model provider model_manager = ModelManager() # get default model instance try: model_instance = model_manager.get_default_model_instance( - tenant_id=account.current_tenant_id, - model_type=ModelType.LLM + tenant_id=account.current_tenant_id, model_type=ModelType.LLM ) except (ProviderTokenNotInitError, LLMBadRequestError): model_instance = None @@ -98,32 +92,43 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App: model_instance = None if model_instance: - if model_instance.model == default_model_config['model']['name'] and model_instance.provider == default_model_config['model']['provider']: - default_model_dict = default_model_config['model'] + if ( + model_instance.model == default_model_config["model"]["name"] + and model_instance.provider == default_model_config["model"]["provider"] + ): + default_model_dict = default_model_config["model"] else: llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) default_model_dict = { - 'provider': model_instance.provider, - 'name': model_instance.model, - 'mode': model_schema.model_properties.get(ModelPropertyKey.MODE), - 'completion_params': {} + "provider": model_instance.provider, + "name": model_instance.model, + "mode": model_schema.model_properties.get(ModelPropertyKey.MODE), + "completion_params": {}, } else: - default_model_dict = default_model_config['model'] - - default_model_config['model'] = json.dumps(default_model_dict) - - app = App(**app_template['app']) - app.name = args['name'] - app.description = args.get('description', '') - app.mode = args['mode'] - app.icon = args['icon'] - app.icon_background = args['icon_background'] + provider, model = model_manager.get_default_provider_model_name( + tenant_id=account.current_tenant_id, model_type=ModelType.LLM + ) + default_model_config["model"]["provider"] = provider + default_model_config["model"]["name"] = model + default_model_dict = default_model_config["model"] + + default_model_config["model"] = json.dumps(default_model_dict) + + app = App(**app_template["app"]) + app.name = args["name"] + app.description = args.get("description", "") + app.mode = args["mode"] + app.icon_type = args.get("icon_type", "emoji") + app.icon = args["icon"] + app.icon_background = args["icon_background"] app.tenant_id = tenant_id - app.api_rph = args.get('api_rph', 0) - app.api_rpm = args.get('api_rpm', 0) + app.api_rph = args.get("api_rph", 0) + app.api_rpm = args.get("api_rpm", 0) + app.created_by = account.id + app.updated_by = account.id db.session.add(app) db.session.flush() @@ -131,6 +136,8 @@ def create_app(self, tenant_id: str, args: dict, account: Account) -> App: if default_model_config: app_model_config = AppModelConfig(**default_model_config) app_model_config.app_id = app.id + app_model_config.created_by = account.id + app_model_config.updated_by = account.id db.session.add(app_model_config) db.session.flush() @@ -151,7 +158,7 @@ def get_app(self, app: App) -> App: model_config: AppModelConfig = app.app_model_config agent_mode = model_config.agent_mode_dict # decrypt agent tool parameters if it's secret-input - for tool in agent_mode.get('tools') or []: + for tool in agent_mode.get("tools") or []: if not isinstance(tool, dict) or len(tool.keys()) <= 3: continue agent_tool_entity = AgentToolEntity(**tool) @@ -167,7 +174,7 @@ def get_app(self, app: App) -> App: tool_runtime=tool_runtime, provider_name=agent_tool_entity.provider_id, provider_type=agent_tool_entity.provider_type, - identity_id=f'AGENT.{app.id}' + identity_id=f"AGENT.{app.id}", ) # get decrypted parameters @@ -178,7 +185,7 @@ def get_app(self, app: App) -> App: masked_parameter = {} # override tool parameters - tool['tool_parameters'] = masked_parameter + tool["tool_parameters"] = masked_parameter except Exception as e: pass @@ -189,13 +196,14 @@ class ModifiedApp(App): """ Modified App class """ + def __init__(self, app): self.__dict__.update(app.__dict__) @property def app_model_config(self): return model_config - + app = ModifiedApp(app) return app @@ -207,11 +215,14 @@ def update_app(self, app: App, args: dict) -> App: :param args: request args :return: App instance """ - app.name = args.get('name') - app.description = args.get('description', '') - app.max_active_requests = args.get('max_active_requests') - app.icon = args.get('icon') - app.icon_background = args.get('icon_background') + app.name = args.get("name") + app.description = args.get("description", "") + app.max_active_requests = args.get("max_active_requests") + app.icon_type = args.get("icon_type", "emoji") + app.icon = args.get("icon") + app.icon_background = args.get("icon_background") + app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) + app.updated_by = current_user.id app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() @@ -228,6 +239,7 @@ def update_app_name(self, app: App, name: str) -> App: :return: App instance """ app.name = name + app.updated_by = current_user.id app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() @@ -243,6 +255,7 @@ def update_app_icon(self, app: App, icon: str, icon_background: str) -> App: """ app.icon = icon app.icon_background = icon_background + app.updated_by = current_user.id app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() @@ -259,6 +272,7 @@ def update_app_site_status(self, app: App, enable_site: bool) -> App: return app app.enable_site = enable_site + app.updated_by = current_user.id app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() @@ -275,6 +289,7 @@ def update_app_api_status(self, app: App, enable_api: bool) -> App: return app app.enable_api = enable_api + app.updated_by = current_user.id app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() @@ -289,10 +304,7 @@ def delete_app(self, app: App) -> None: db.session.commit() # Trigger asynchronous deletion of app and related data - remove_app_and_related_data_task.delay( - tenant_id=app.tenant_id, - app_id=app.id - ) + remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id) def get_app_meta(self, app_model: App) -> dict: """ @@ -302,27 +314,27 @@ def get_app_meta(self, app_model: App) -> dict: """ app_mode = AppMode.value_of(app_model.mode) - meta = { - 'tool_icons': {} - } + meta = {"tool_icons": {}} - if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: return meta graph = workflow.graph_dict - nodes = graph.get('nodes', []) + nodes = graph.get("nodes", []) tools = [] for node in nodes: - if node.get('data', {}).get('type') == 'tool': - node_data = node.get('data', {}) - tools.append({ - 'provider_type': node_data.get('provider_type'), - 'provider_id': node_data.get('provider_id'), - 'tool_name': node_data.get('tool_name'), - 'tool_parameters': {} - }) + if node.get("data", {}).get("type") == "tool": + node_data = node.get("data", {}) + tools.append( + { + "provider_type": node_data.get("provider_type"), + "provider_id": node_data.get("provider_id"), + "tool_name": node_data.get("tool_name"), + "tool_parameters": {}, + } + ) else: app_model_config: AppModelConfig = app_model.app_model_config @@ -332,30 +344,26 @@ def get_app_meta(self, app_model: App) -> dict: agent_config = app_model_config.agent_mode_dict or {} # get all tools - tools = agent_config.get('tools', []) + tools = agent_config.get("tools", []) - url_prefix = (dify_config.CONSOLE_API_URL - + "/console/api/workspaces/current/tool-provider/builtin/") + url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/" for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: # current tool standard - provider_type = tool.get('provider_type') - provider_id = tool.get('provider_id') - tool_name = tool.get('tool_name') - if provider_type == 'builtin': - meta['tool_icons'][tool_name] = url_prefix + provider_id + '/icon' - elif provider_type == 'api': + provider_type = tool.get("provider_type") + provider_id = tool.get("provider_id") + tool_name = tool.get("tool_name") + if provider_type == "builtin": + meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon" + elif provider_type == "api": try: - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.id == provider_id - ).first() - meta['tool_icons'][tool_name] = json.loads(provider.icon) + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first() + ) + meta["tool_icons"][tool_name] = json.loads(provider.icon) except: - meta['tool_icons'][tool_name] = { - "background": "#252525", - "content": "\ud83d\ude01" - } + meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"} return meta diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 58c950816fdb86..7a0cd5725b2a96 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -17,7 +17,7 @@ FILE_SIZE = 30 FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024 -ALLOWED_EXTENSIONS = ['mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'wav', 'webm', 'amr'] +ALLOWED_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm", "amr"] logger = logging.getLogger(__name__) @@ -25,25 +25,25 @@ class AudioService: @classmethod def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None): - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise ValueError("Speech to text is not enabled") features_dict = workflow.features_dict - if 'speech_to_text' not in features_dict or not features_dict['speech_to_text'].get('enabled'): + if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"): raise ValueError("Speech to text is not enabled") else: app_model_config: AppModelConfig = app_model.app_model_config - if not app_model_config.speech_to_text_dict['enabled']: + if not app_model_config.speech_to_text_dict["enabled"]: raise ValueError("Speech to text is not enabled") if file is None: raise NoAudioUploadedServiceError() extension = file.mimetype - if extension not in [f'audio/{ext}' for ext in ALLOWED_EXTENSIONS]: + if extension not in [f"audio/{ext}" for ext in ALLOWED_EXTENSIONS]: raise UnsupportedAudioTypeServiceError() file_content = file.read() @@ -55,20 +55,25 @@ def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[st model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( - tenant_id=app_model.tenant_id, - model_type=ModelType.SPEECH2TEXT + tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT ) if model_instance is None: raise ProviderNotSupportSpeechToTextServiceError() buffer = io.BytesIO(file_content) - buffer.name = 'temp.mp3' + buffer.name = "temp.mp3" return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)} @classmethod - def transcript_tts(cls, app_model: App, text: Optional[str] = None, - voice: Optional[str] = None, end_user: Optional[str] = None, message_id: Optional[str] = None): + def transcript_tts( + cls, + app_model: App, + text: Optional[str] = None, + voice: Optional[str] = None, + end_user: Optional[str] = None, + message_id: Optional[str] = None, + ): from collections.abc import Generator from flask import Response, stream_with_context @@ -78,71 +83,62 @@ def transcript_tts(cls, app_model: App, text: Optional[str] = None, def invoke_tts(text_content: str, app_model, voice: Optional[str] = None): with app.app_context(): - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise ValueError("TTS is not enabled") features_dict = workflow.features_dict - if 'text_to_speech' not in features_dict or not features_dict['text_to_speech'].get('enabled'): + if "text_to_speech" not in features_dict or not features_dict["text_to_speech"].get("enabled"): raise ValueError("TTS is not enabled") - voice = features_dict['text_to_speech'].get('voice') if voice is None else voice + voice = features_dict["text_to_speech"].get("voice") if voice is None else voice else: text_to_speech_dict = app_model.app_model_config.text_to_speech_dict - if not text_to_speech_dict.get('enabled'): + if not text_to_speech_dict.get("enabled"): raise ValueError("TTS is not enabled") - voice = text_to_speech_dict.get('voice') if voice is None else voice + voice = text_to_speech_dict.get("voice") if voice is None else voice model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( - tenant_id=app_model.tenant_id, - model_type=ModelType.TTS + tenant_id=app_model.tenant_id, model_type=ModelType.TTS ) try: if not voice: voices = model_instance.get_tts_voices() if voices: - voice = voices[0].get('value') + voice = voices[0].get("value") else: raise ValueError("Sorry, no voice available.") return model_instance.invoke_tts( - content_text=text_content.strip(), - user=end_user, - tenant_id=app_model.tenant_id, - voice=voice + content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice ) except Exception as e: raise e if message_id: - message = db.session.query(Message).filter( - Message.id == message_id - ).first() - if message.answer == '' and message.status == 'normal': + message = db.session.query(Message).filter(Message.id == message_id).first() + if message.answer == "" and message.status == "normal": return None else: response = invoke_tts(message.answer, app_model=app_model, voice=voice) if isinstance(response, Generator): - return Response(stream_with_context(response), content_type='audio/mpeg') + return Response(stream_with_context(response), content_type="audio/mpeg") return response else: response = invoke_tts(text, app_model, voice) if isinstance(response, Generator): - return Response(stream_with_context(response), content_type='audio/mpeg') + return Response(stream_with_context(response), content_type="audio/mpeg") return response @classmethod def transcript_tts_voices(cls, tenant_id: str, language: str): model_manager = ModelManager() - model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.TTS - ) + model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.TTS) if model_instance is None: raise ProviderNotSupportTextToSpeechServiceError() diff --git a/api/services/auth/api_key_auth_factory.py b/api/services/auth/api_key_auth_factory.py index ccd0023c44d84d..f91c448fb94a23 100644 --- a/api/services/auth/api_key_auth_factory.py +++ b/api/services/auth/api_key_auth_factory.py @@ -1,14 +1,25 @@ - -from services.auth.firecrawl import FirecrawlAuth +from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.auth_type import AuthType class ApiKeyAuthFactory: - def __init__(self, provider: str, credentials: dict): - if provider == 'firecrawl': - self.auth = FirecrawlAuth(credentials) - else: - raise ValueError('Invalid provider') + auth_factory = self.get_apikey_auth_factory(provider) + self.auth = auth_factory(credentials) def validate_credentials(self): return self.auth.validate_credentials() + + @staticmethod + def get_apikey_auth_factory(provider: str) -> type[ApiKeyAuthBase]: + match provider: + case AuthType.FIRECRAWL: + from services.auth.firecrawl.firecrawl import FirecrawlAuth + + return FirecrawlAuth + case AuthType.JINA: + from services.auth.jina.jina import JinaAuth + + return JinaAuth + case _: + raise ValueError("Invalid provider") diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py index 43d0fbf98f2df7..e5f4a3ef6e12d3 100644 --- a/api/services/auth/api_key_auth_service.py +++ b/api/services/auth/api_key_auth_service.py @@ -7,39 +7,43 @@ class ApiKeyAuthService: - @staticmethod def get_provider_auth_list(tenant_id: str) -> list: - data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter( - DataSourceApiKeyAuthBinding.tenant_id == tenant_id, - DataSourceApiKeyAuthBinding.disabled.is_(False) - ).all() + data_source_api_key_bindings = ( + db.session.query(DataSourceApiKeyAuthBinding) + .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)) + .all() + ) return data_source_api_key_bindings @staticmethod def create_provider_auth(tenant_id: str, args: dict): - auth_result = ApiKeyAuthFactory(args['provider'], args['credentials']).validate_credentials() + auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials() if auth_result: # Encrypt the api key - api_key = encrypter.encrypt_token(tenant_id, args['credentials']['config']['api_key']) - args['credentials']['config']['api_key'] = api_key + api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"]) + args["credentials"]["config"]["api_key"] = api_key data_source_api_key_binding = DataSourceApiKeyAuthBinding() data_source_api_key_binding.tenant_id = tenant_id - data_source_api_key_binding.category = args['category'] - data_source_api_key_binding.provider = args['provider'] - data_source_api_key_binding.credentials = json.dumps(args['credentials'], ensure_ascii=False) + data_source_api_key_binding.category = args["category"] + data_source_api_key_binding.provider = args["provider"] + data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False) db.session.add(data_source_api_key_binding) db.session.commit() @staticmethod def get_auth_credentials(tenant_id: str, category: str, provider: str): - data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter( - DataSourceApiKeyAuthBinding.tenant_id == tenant_id, - DataSourceApiKeyAuthBinding.category == category, - DataSourceApiKeyAuthBinding.provider == provider, - DataSourceApiKeyAuthBinding.disabled.is_(False) - ).first() + data_source_api_key_bindings = ( + db.session.query(DataSourceApiKeyAuthBinding) + .filter( + DataSourceApiKeyAuthBinding.tenant_id == tenant_id, + DataSourceApiKeyAuthBinding.category == category, + DataSourceApiKeyAuthBinding.provider == provider, + DataSourceApiKeyAuthBinding.disabled.is_(False), + ) + .first() + ) if not data_source_api_key_bindings: return None credentials = json.loads(data_source_api_key_bindings.credentials) @@ -47,24 +51,24 @@ def get_auth_credentials(tenant_id: str, category: str, provider: str): @staticmethod def delete_provider_auth(tenant_id: str, binding_id: str): - data_source_api_key_binding = db.session.query(DataSourceApiKeyAuthBinding).filter( - DataSourceApiKeyAuthBinding.tenant_id == tenant_id, - DataSourceApiKeyAuthBinding.id == binding_id - ).first() + data_source_api_key_binding = ( + db.session.query(DataSourceApiKeyAuthBinding) + .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id) + .first() + ) if data_source_api_key_binding: db.session.delete(data_source_api_key_binding) db.session.commit() @classmethod def validate_api_key_auth_args(cls, args): - if 'category' not in args or not args['category']: - raise ValueError('category is required') - if 'provider' not in args or not args['provider']: - raise ValueError('provider is required') - if 'credentials' not in args or not args['credentials']: - raise ValueError('credentials is required') - if not isinstance(args['credentials'], dict): - raise ValueError('credentials must be a dictionary') - if 'auth_type' not in args['credentials'] or not args['credentials']['auth_type']: - raise ValueError('auth_type is required') - + if "category" not in args or not args["category"]: + raise ValueError("category is required") + if "provider" not in args or not args["provider"]: + raise ValueError("provider is required") + if "credentials" not in args or not args["credentials"]: + raise ValueError("credentials is required") + if not isinstance(args["credentials"], dict): + raise ValueError("credentials must be a dictionary") + if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]: + raise ValueError("auth_type is required") diff --git a/api/services/auth/auth_type.py b/api/services/auth/auth_type.py new file mode 100644 index 00000000000000..2d6e901447c369 --- /dev/null +++ b/api/services/auth/auth_type.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class AuthType(str, Enum): + FIRECRAWL = "firecrawl" + JINA = "jinareader" diff --git a/api/services/auth/firecrawl.py b/api/services/auth/firecrawl.py deleted file mode 100644 index 69e3fb43c79dab..00000000000000 --- a/api/services/auth/firecrawl.py +++ /dev/null @@ -1,56 +0,0 @@ -import json - -import requests - -from services.auth.api_key_auth_base import ApiKeyAuthBase - - -class FirecrawlAuth(ApiKeyAuthBase): - def __init__(self, credentials: dict): - super().__init__(credentials) - auth_type = credentials.get('auth_type') - if auth_type != 'bearer': - raise ValueError('Invalid auth type, Firecrawl auth type must be Bearer') - self.api_key = credentials.get('config').get('api_key', None) - self.base_url = credentials.get('config').get('base_url', 'https://api.firecrawl.dev') - - if not self.api_key: - raise ValueError('No API key provided') - - def validate_credentials(self): - headers = self._prepare_headers() - options = { - 'url': 'https://example.com', - 'crawlerOptions': { - 'excludes': [], - 'includes': [], - 'limit': 1 - }, - 'pageOptions': { - 'onlyMainContent': True - } - } - response = self._post_request(f'{self.base_url}/v0/crawl', options, headers) - if response.status_code == 200: - return True - else: - self._handle_error(response) - - def _prepare_headers(self): - return { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } - - def _post_request(self, url, data, headers): - return requests.post(url, headers=headers, json=data) - - def _handle_error(self, response): - if response.status_code in [402, 409, 500]: - error_message = response.json().get('error', 'Unknown error occurred') - raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}') - else: - if response.text: - error_message = json.loads(response.text).get('error', 'Unknown error occurred') - raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}') - raise Exception(f'Unexpected error occurred while trying to authorize. Status code: {response.status_code}') diff --git a/api/services/auth/firecrawl/__init__.py b/api/services/auth/firecrawl/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/services/auth/firecrawl/firecrawl.py b/api/services/auth/firecrawl/firecrawl.py new file mode 100644 index 00000000000000..afc491398f25f3 --- /dev/null +++ b/api/services/auth/firecrawl/firecrawl.py @@ -0,0 +1,47 @@ +import json + +import requests + +from services.auth.api_key_auth_base import ApiKeyAuthBase + + +class FirecrawlAuth(ApiKeyAuthBase): + def __init__(self, credentials: dict): + super().__init__(credentials) + auth_type = credentials.get("auth_type") + if auth_type != "bearer": + raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer") + self.api_key = credentials.get("config").get("api_key", None) + self.base_url = credentials.get("config").get("base_url", "https://api.firecrawl.dev") + + if not self.api_key: + raise ValueError("No API key provided") + + def validate_credentials(self): + headers = self._prepare_headers() + options = { + "url": "https://example.com", + "crawlerOptions": {"excludes": [], "includes": [], "limit": 1}, + "pageOptions": {"onlyMainContent": True}, + } + response = self._post_request(f"{self.base_url}/v0/crawl", options, headers) + if response.status_code == 200: + return True + else: + self._handle_error(response) + + def _prepare_headers(self): + return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} + + def _post_request(self, url, data, headers): + return requests.post(url, headers=headers, json=data) + + def _handle_error(self, response): + if response.status_code in {402, 409, 500}: + error_message = response.json().get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") + else: + if response.text: + error_message = json.loads(response.text).get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") + raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}") diff --git a/api/services/auth/jina.py b/api/services/auth/jina.py new file mode 100644 index 00000000000000..de898a1f94b763 --- /dev/null +++ b/api/services/auth/jina.py @@ -0,0 +1,44 @@ +import json + +import requests + +from services.auth.api_key_auth_base import ApiKeyAuthBase + + +class JinaAuth(ApiKeyAuthBase): + def __init__(self, credentials: dict): + super().__init__(credentials) + auth_type = credentials.get("auth_type") + if auth_type != "bearer": + raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer") + self.api_key = credentials.get("config").get("api_key", None) + + if not self.api_key: + raise ValueError("No API key provided") + + def validate_credentials(self): + headers = self._prepare_headers() + options = { + "url": "https://example.com", + } + response = self._post_request("https://r.jina.ai", options, headers) + if response.status_code == 200: + return True + else: + self._handle_error(response) + + def _prepare_headers(self): + return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} + + def _post_request(self, url, data, headers): + return requests.post(url, headers=headers, json=data) + + def _handle_error(self, response): + if response.status_code in {402, 409, 500}: + error_message = response.json().get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") + else: + if response.text: + error_message = json.loads(response.text).get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") + raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}") diff --git a/api/services/auth/jina/__init__.py b/api/services/auth/jina/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/services/auth/jina/jina.py b/api/services/auth/jina/jina.py new file mode 100644 index 00000000000000..de898a1f94b763 --- /dev/null +++ b/api/services/auth/jina/jina.py @@ -0,0 +1,44 @@ +import json + +import requests + +from services.auth.api_key_auth_base import ApiKeyAuthBase + + +class JinaAuth(ApiKeyAuthBase): + def __init__(self, credentials: dict): + super().__init__(credentials) + auth_type = credentials.get("auth_type") + if auth_type != "bearer": + raise ValueError("Invalid auth type, Jina Reader auth type must be Bearer") + self.api_key = credentials.get("config").get("api_key", None) + + if not self.api_key: + raise ValueError("No API key provided") + + def validate_credentials(self): + headers = self._prepare_headers() + options = { + "url": "https://example.com", + } + response = self._post_request("https://r.jina.ai", options, headers) + if response.status_code == 200: + return True + else: + self._handle_error(response) + + def _prepare_headers(self): + return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} + + def _post_request(self, url, data, headers): + return requests.post(url, headers=headers, json=data) + + def _handle_error(self, response): + if response.status_code in {402, 409, 500}: + error_message = response.json().get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") + else: + if response.text: + error_message = json.loads(response.text).get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") + raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}") diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 539f2712bb9f38..911d2346415ce5 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -7,58 +7,40 @@ class BillingService: - base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL') - secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY') + base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") + secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY") @classmethod def get_info(cls, tenant_id: str): - params = {'tenant_id': tenant_id} + params = {"tenant_id": tenant_id} - billing_info = cls._send_request('GET', '/subscription/info', params=params) + billing_info = cls._send_request("GET", "/subscription/info", params=params) return billing_info @classmethod - def get_subscription(cls, plan: str, - interval: str, - prefilled_email: str = '', - tenant_id: str = ''): - params = { - 'plan': plan, - 'interval': interval, - 'prefilled_email': prefilled_email, - 'tenant_id': tenant_id - } - return cls._send_request('GET', '/subscription/payment-link', params=params) + def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""): + params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id} + return cls._send_request("GET", "/subscription/payment-link", params=params) @classmethod - def get_model_provider_payment_link(cls, - provider_name: str, - tenant_id: str, - account_id: str, - prefilled_email: str): + def get_model_provider_payment_link(cls, provider_name: str, tenant_id: str, account_id: str, prefilled_email: str): params = { - 'provider_name': provider_name, - 'tenant_id': tenant_id, - 'account_id': account_id, - 'prefilled_email': prefilled_email + "provider_name": provider_name, + "tenant_id": tenant_id, + "account_id": account_id, + "prefilled_email": prefilled_email, } - return cls._send_request('GET', '/model-provider/payment-link', params=params) + return cls._send_request("GET", "/model-provider/payment-link", params=params) @classmethod - def get_invoices(cls, prefilled_email: str = '', tenant_id: str = ''): - params = { - 'prefilled_email': prefilled_email, - 'tenant_id': tenant_id - } - return cls._send_request('GET', '/invoices', params=params) + def get_invoices(cls, prefilled_email: str = "", tenant_id: str = ""): + params = {"prefilled_email": prefilled_email, "tenant_id": tenant_id} + return cls._send_request("GET", "/invoices", params=params) @classmethod def _send_request(cls, method, endpoint, json=None, params=None): - headers = { - "Content-Type": "application/json", - "Billing-Api-Secret-Key": cls.secret_key - } + headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" response = requests.request(method, url, json=json, params=params, headers=headers) @@ -69,10 +51,11 @@ def _send_request(cls, method, endpoint, json=None, params=None): def is_tenant_owner_or_admin(current_user): tenant_id = current_user.current_tenant_id - join = db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.tenant_id == tenant_id, - TenantAccountJoin.account_id == current_user.id - ).first() + join = ( + db.session.query(TenantAccountJoin) + .filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) + .first() + ) if not TenantAccountRole.is_privileged_role(join.role): - raise ValueError('Only team owner or team admin can perform this action') + raise ValueError("Only team owner or team admin can perform this action") diff --git a/api/services/code_based_extension_service.py b/api/services/code_based_extension_service.py index 7b0d50a835dbf8..f7597b7f1fcd45 100644 --- a/api/services/code_based_extension_service.py +++ b/api/services/code_based_extension_service.py @@ -2,12 +2,15 @@ class CodeBasedExtensionService: - @staticmethod def get_code_based_extension(module: str) -> list[dict]: module_extensions = code_based_extension.module_extensions(module) - return [{ - 'name': module_extension.name, - 'label': module_extension.label, - 'form_schema': module_extension.form_schema - } for module_extension in module_extensions if not module_extension.builtin] + return [ + { + "name": module_extension.name, + "label": module_extension.label, + "form_schema": module_extension.form_schema, + } + for module_extension in module_extensions + if not module_extension.builtin + ] diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 82ee10ee78f095..f9e41988c0d189 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,6 +1,7 @@ +from datetime import datetime, timezone from typing import Optional, Union -from sqlalchemy import or_ +from sqlalchemy import asc, desc, or_ from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.llm_generator import LLMGenerator @@ -14,21 +15,27 @@ class ConversationService: @classmethod - def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - last_id: Optional[str], limit: int, - invoke_from: InvokeFrom, - include_ids: Optional[list] = None, - exclude_ids: Optional[list] = None) -> InfiniteScrollPagination: + def pagination_by_last_id( + cls, + app_model: App, + user: Optional[Union[Account, EndUser]], + last_id: Optional[str], + limit: int, + invoke_from: InvokeFrom, + include_ids: Optional[list] = None, + exclude_ids: Optional[list] = None, + sort_by: str = "-updated_at", + ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) base_query = db.session.query(Conversation).filter( Conversation.is_deleted == False, Conversation.app_id == app_model.id, - Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'), + Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), Conversation.from_account_id == (user.id if isinstance(user, Account) else None), - or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value) + or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value), ) if include_ids is not None: @@ -37,47 +44,67 @@ def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, End if exclude_ids is not None: base_query = base_query.filter(~Conversation.id.in_(exclude_ids)) - if last_id: - last_conversation = base_query.filter( - Conversation.id == last_id, - ).first() + # define sort fields and directions + sort_field, sort_direction = cls._get_sort_params(sort_by) + if last_id: + last_conversation = base_query.filter(Conversation.id == last_id).first() if not last_conversation: raise LastConversationNotExistsError() - conversations = base_query.filter( - Conversation.created_at < last_conversation.created_at, - Conversation.id != last_conversation.id - ).order_by(Conversation.created_at.desc()).limit(limit).all() - else: - conversations = base_query.order_by(Conversation.created_at.desc()).limit(limit).all() + # build filters based on sorting + filter_condition = cls._build_filter_condition(sort_field, sort_direction, last_conversation) + base_query = base_query.filter(filter_condition) + + base_query = base_query.order_by(sort_direction(getattr(Conversation, sort_field))) + + conversations = base_query.limit(limit).all() has_more = False if len(conversations) == limit: - current_page_first_conversation = conversations[-1] - rest_count = base_query.filter( - Conversation.created_at < current_page_first_conversation.created_at, - Conversation.id != current_page_first_conversation.id - ).count() + current_page_last_conversation = conversations[-1] + rest_filter_condition = cls._build_filter_condition( + sort_field, sort_direction, current_page_last_conversation, is_next_page=True + ) + rest_count = base_query.filter(rest_filter_condition).count() if rest_count > 0: has_more = True - return InfiniteScrollPagination( - data=conversations, - limit=limit, - has_more=has_more - ) + return InfiniteScrollPagination(data=conversations, limit=limit, has_more=has_more) @classmethod - def rename(cls, app_model: App, conversation_id: str, - user: Optional[Union[Account, EndUser]], name: str, auto_generate: bool): + def _get_sort_params(cls, sort_by: str) -> tuple[str, callable]: + if sort_by.startswith("-"): + return sort_by[1:], desc + return sort_by, asc + + @classmethod + def _build_filter_condition( + cls, sort_field: str, sort_direction: callable, reference_conversation: Conversation, is_next_page: bool = False + ): + field_value = getattr(reference_conversation, sort_field) + if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page): + return getattr(Conversation, sort_field) < field_value + else: + return getattr(Conversation, sort_field) > field_value + + @classmethod + def rename( + cls, + app_model: App, + conversation_id: str, + user: Optional[Union[Account, EndUser]], + name: str, + auto_generate: bool, + ): conversation = cls.get_conversation(app_model, conversation_id, user) if auto_generate: return cls.auto_generate_name(app_model, conversation) else: conversation.name = name + conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() return conversation @@ -85,11 +112,12 @@ def rename(cls, app_model: App, conversation_id: str, @classmethod def auto_generate_name(cls, app_model: App, conversation: Conversation): # get conversation first message - message = db.session.query(Message) \ - .filter( - Message.app_id == app_model.id, - Message.conversation_id == conversation.id - ).order_by(Message.created_at.asc()).first() + message = ( + db.session.query(Message) + .filter(Message.app_id == app_model.id, Message.conversation_id == conversation.id) + .order_by(Message.created_at.asc()) + .first() + ) if not message: raise MessageNotExistsError() @@ -109,15 +137,18 @@ def auto_generate_name(cls, app_model: App, conversation: Conversation): @classmethod def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): - conversation = db.session.query(Conversation) \ + conversation = ( + db.session.query(Conversation) .filter( - Conversation.id == conversation_id, - Conversation.app_id == app_model.id, - Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'), - Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), - Conversation.from_account_id == (user.id if isinstance(user, Account) else None), - Conversation.is_deleted == False - ).first() + Conversation.id == conversation_id, + Conversation.app_id == app_model.id, + Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), + Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Conversation.from_account_id == (user.id if isinstance(user, Account) else None), + Conversation.is_deleted == False, + ) + .first() + ) if not conversation: raise ConversationNotExistsError() @@ -129,4 +160,5 @@ def delete(cls, app_model: App, conversation_id: str, user: Optional[Union[Accou conversation = cls.get_conversation(app_model, conversation_id, user) conversation.is_deleted = True + conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 12ae0e39a8c6e8..8562dad1d3e9ee 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -4,18 +4,17 @@ import random import time import uuid -from typing import Optional +from typing import Any, Optional from flask_login import current_user from sqlalchemy import func +from werkzeug.exceptions import NotFound from configs import dify_config from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.rag.datasource.keyword.keyword_factory import Keyword -from core.rag.models.document import Document as RAGDocument -from core.rag.retrieval.retrival_methods import RetrievalMethod +from core.rag.retrieval.retrieval_methods import RetrievalMethod from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted from extensions.ext_database import db @@ -27,17 +26,21 @@ Dataset, DatasetCollectionBinding, DatasetPermission, + DatasetPermissionEnum, DatasetProcessRule, DatasetQuery, Document, DocumentSegment, + ExternalKnowledgeBindings, ) from models.model import UploadFile from models.source import DataSourceOauthBinding +from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateEntity from services.errors.account import NoPermissionError from services.errors.dataset import DatasetNameDuplicateError from services.errors.document import DocumentIndexingError from services.errors.file import FileNotExistsError +from services.external_knowledge_service import ExternalDatasetService from services.feature_service import FeatureModel, FeatureService from services.tag_service import TagService from services.vector_service import VectorService @@ -54,19 +57,13 @@ class DatasetService: - @staticmethod - def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None, search=None, tag_ids=None): - query = Dataset.query.filter(Dataset.provider == provider, Dataset.tenant_id == tenant_id).order_by( - Dataset.created_at.desc() - ) + def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None): + query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc()) if user: # get permitted dataset ids - dataset_permission = DatasetPermission.query.filter_by( - account_id=user.id, - tenant_id=tenant_id - ).all() + dataset_permission = DatasetPermission.query.filter_by(account_id=user.id, tenant_id=tenant_id).all() permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None if user.current_role == TenantAccountRole.DATASET_OPERATOR: @@ -80,111 +77,128 @@ def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None, s if permitted_dataset_ids: query = query.filter( db.or_( - Dataset.permission == 'all_team_members', - db.and_(Dataset.permission == 'only_me', Dataset.created_by == user.id), - db.and_(Dataset.permission == 'partial_members', Dataset.id.in_(permitted_dataset_ids)) + Dataset.permission == DatasetPermissionEnum.ALL_TEAM, + db.and_(Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id), + db.and_( + Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM, + Dataset.id.in_(permitted_dataset_ids), + ), ) ) else: query = query.filter( db.or_( - Dataset.permission == 'all_team_members', - db.and_(Dataset.permission == 'only_me', Dataset.created_by == user.id) + Dataset.permission == DatasetPermissionEnum.ALL_TEAM, + db.and_(Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id), ) ) else: # if no user, only show datasets that are shared with all team members - query = query.filter(Dataset.permission == 'all_team_members') + query = query.filter(Dataset.permission == DatasetPermissionEnum.ALL_TEAM) if search: - query = query.filter(Dataset.name.ilike(f'%{search}%')) + query = query.filter(Dataset.name.ilike(f"%{search}%")) if tag_ids: - target_ids = TagService.get_target_ids_by_tag_ids('knowledge', tenant_id, tag_ids) + target_ids = TagService.get_target_ids_by_tag_ids("knowledge", tenant_id, tag_ids) if target_ids: query = query.filter(Dataset.id.in_(target_ids)) else: return [], 0 - datasets = query.paginate( - page=page, - per_page=per_page, - max_per_page=100, - error_out=False - ) + datasets = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) return datasets.items, datasets.total @staticmethod def get_process_rules(dataset_id): # get the latest process rule - dataset_process_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.dataset_id == dataset_id). \ - order_by(DatasetProcessRule.created_at.desc()). \ - limit(1). \ - one_or_none() + dataset_process_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.dataset_id == dataset_id) + .order_by(DatasetProcessRule.created_at.desc()) + .limit(1) + .one_or_none() + ) if dataset_process_rule: mode = dataset_process_rule.mode rules = dataset_process_rule.rules_dict else: - mode = DocumentService.DEFAULT_RULES['mode'] - rules = DocumentService.DEFAULT_RULES['rules'] - return { - 'mode': mode, - 'rules': rules - } + mode = DocumentService.DEFAULT_RULES["mode"] + rules = DocumentService.DEFAULT_RULES["rules"] + return {"mode": mode, "rules": rules} @staticmethod def get_datasets_by_ids(ids, tenant_id): - datasets = Dataset.query.filter( - Dataset.id.in_(ids), - Dataset.tenant_id == tenant_id - ).paginate( + datasets = Dataset.query.filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id).paginate( page=1, per_page=len(ids), max_per_page=len(ids), error_out=False ) return datasets.items, datasets.total @staticmethod - def create_empty_dataset(tenant_id: str, name: str, indexing_technique: Optional[str], account: Account): + def create_empty_dataset( + tenant_id: str, + name: str, + description: Optional[str], + indexing_technique: Optional[str], + account: Account, + permission: Optional[str] = None, + provider: str = "vendor", + external_knowledge_api_id: Optional[str] = None, + external_knowledge_id: Optional[str] = None, + ): # check if dataset name already exists if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first(): - raise DatasetNameDuplicateError( - f'Dataset with name {name} already exists.' - ) + raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") embedding_model = None - if indexing_technique == 'high_quality': + if indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.TEXT_EMBEDDING + tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING ) dataset = Dataset(name=name, indexing_technique=indexing_technique) # dataset = Dataset(name=name, provider=provider, config=config) + dataset.description = description dataset.created_by = account.id dataset.updated_by = account.id dataset.tenant_id = tenant_id dataset.embedding_model_provider = embedding_model.provider if embedding_model else None dataset.embedding_model = embedding_model.model if embedding_model else None + dataset.permission = permission or DatasetPermissionEnum.ONLY_ME + dataset.provider = provider db.session.add(dataset) + db.session.flush() + + if provider == "external" and external_knowledge_api_id: + external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id) + if not external_knowledge_api: + raise ValueError("External API template not found.") + external_knowledge_binding = ExternalKnowledgeBindings( + tenant_id=tenant_id, + dataset_id=dataset.id, + external_knowledge_api_id=external_knowledge_api_id, + external_knowledge_id=external_knowledge_id, + created_by=account.id, + ) + db.session.add(external_knowledge_binding) + db.session.commit() return dataset @staticmethod - def get_dataset(dataset_id): - return Dataset.query.filter_by( - id=dataset_id - ).first() + def get_dataset(dataset_id) -> Dataset: + return Dataset.query.filter_by(id=dataset_id).first() @staticmethod def check_dataset_model_setting(dataset): - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ValueError( @@ -192,115 +206,132 @@ def check_dataset_model_setting(dataset): "in the Settings -> Model Provider." ) except ProviderTokenNotInitError as ex: - raise ValueError( - f"The dataset in unavailable, due to: " - f"{ex.description}" - ) + raise ValueError(f"The dataset in unavailable, due to: {ex.description}") @staticmethod - def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model:str): + def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str): try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=tenant_id, provider=embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=embedding_model + model=embedding_model, ) except LLMBadRequestError: raise ValueError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider." + "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." ) except ProviderTokenNotInitError as ex: - raise ValueError( - f"The dataset in unavailable, due to: " - f"{ex.description}" - ) - + raise ValueError(f"The dataset in unavailable, due to: {ex.description}") @staticmethod def update_dataset(dataset_id, data, user): - data.pop('partial_member_list', None) - filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'} dataset = DatasetService.get_dataset(dataset_id) + DatasetService.check_dataset_permission(dataset, user) - action = None - if dataset.indexing_technique != data['indexing_technique']: - # if update indexing_technique - if data['indexing_technique'] == 'economy': - action = 'remove' - filtered_data['embedding_model'] = None - filtered_data['embedding_model_provider'] = None - filtered_data['collection_binding_id'] = None - elif data['indexing_technique'] == 'high_quality': - action = 'add' - # get embedding model setting - try: - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, - provider=data['embedding_model_provider'], - model_type=ModelType.TEXT_EMBEDDING, - model=data['embedding_model'] - ) - filtered_data['embedding_model'] = embedding_model.model - filtered_data['embedding_model_provider'] = embedding_model.provider - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, - embedding_model.model - ) - filtered_data['collection_binding_id'] = dataset_collection_binding.id - except LLMBadRequestError: - raise ValueError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider." - ) - except ProviderTokenNotInitError as ex: - raise ValueError(ex.description) + if dataset.provider == "external": + dataset.retrieval_model = data.get("external_retrieval_model", None) + dataset.name = data.get("name", dataset.name) + dataset.description = data.get("description", "") + external_knowledge_id = data.get("external_knowledge_id", None) + dataset.permission = data.get("permission") + db.session.add(dataset) + if not external_knowledge_id: + raise ValueError("External knowledge id is required.") + external_knowledge_api_id = data.get("external_knowledge_api_id", None) + if not external_knowledge_api_id: + raise ValueError("External knowledge api id is required.") + external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(dataset_id=dataset_id).first() + if ( + external_knowledge_binding.external_knowledge_id != external_knowledge_id + or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id + ): + external_knowledge_binding.external_knowledge_id = external_knowledge_id + external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id + db.session.add(external_knowledge_binding) + db.session.commit() else: - if data['embedding_model_provider'] != dataset.embedding_model_provider or \ - data['embedding_model'] != dataset.embedding_model: - action = 'update' - try: - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, - provider=data['embedding_model_provider'], - model_type=ModelType.TEXT_EMBEDDING, - model=data['embedding_model'] - ) - filtered_data['embedding_model'] = embedding_model.model - filtered_data['embedding_model_provider'] = embedding_model.provider - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, - embedding_model.model - ) - filtered_data['collection_binding_id'] = dataset_collection_binding.id - except LLMBadRequestError: - raise ValueError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider." - ) - except ProviderTokenNotInitError as ex: - raise ValueError(ex.description) + data.pop("partial_member_list", None) + data.pop("external_knowledge_api_id", None) + data.pop("external_knowledge_id", None) + data.pop("external_retrieval_model", None) + filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"} + action = None + if dataset.indexing_technique != data["indexing_technique"]: + # if update indexing_technique + if data["indexing_technique"] == "economy": + action = "remove" + filtered_data["embedding_model"] = None + filtered_data["embedding_model_provider"] = None + filtered_data["collection_binding_id"] = None + elif data["indexing_technique"] == "high_quality": + action = "add" + # get embedding model setting + try: + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=data["embedding_model_provider"], + model_type=ModelType.TEXT_EMBEDDING, + model=data["embedding_model"], + ) + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + filtered_data["collection_binding_id"] = dataset_collection_binding.id + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) + else: + if ( + data["embedding_model_provider"] != dataset.embedding_model_provider + or data["embedding_model"] != dataset.embedding_model + ): + action = "update" + try: + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=data["embedding_model_provider"], + model_type=ModelType.TEXT_EMBEDDING, + model=data["embedding_model"], + ) + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + filtered_data["collection_binding_id"] = dataset_collection_binding.id + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) - filtered_data['updated_by'] = user.id - filtered_data['updated_at'] = datetime.datetime.now() + filtered_data["updated_by"] = user.id + filtered_data["updated_at"] = datetime.datetime.now() - # update Retrieval model - filtered_data['retrieval_model'] = data['retrieval_model'] + # update Retrieval model + filtered_data["retrieval_model"] = data["retrieval_model"] - dataset.query.filter_by(id=dataset_id).update(filtered_data) + dataset.query.filter_by(id=dataset_id).update(filtered_data) - db.session.commit() - if action: - deal_dataset_vector_index_task.delay(dataset_id, action) + db.session.commit() + if action: + deal_dataset_vector_index_task.delay(dataset_id, action) return dataset @staticmethod def delete_dataset(dataset_id, user): - dataset = DatasetService.get_dataset(dataset_id) if dataset is None: @@ -324,72 +355,57 @@ def dataset_use_check(dataset_id) -> bool: @staticmethod def check_dataset_permission(dataset, user): if dataset.tenant_id != user.current_tenant_id: - logging.debug( - f'User {user.id} does not have permission to access dataset {dataset.id}' - ) - raise NoPermissionError( - 'You do not have permission to access this dataset.' - ) - if dataset.permission == 'only_me' and dataset.created_by != user.id: - logging.debug( - f'User {user.id} does not have permission to access dataset {dataset.id}' - ) - raise NoPermissionError( - 'You do not have permission to access this dataset.' - ) - if dataset.permission == 'partial_members': - user_permission = DatasetPermission.query.filter_by( - dataset_id=dataset.id, account_id=user.id - ).first() + logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + raise NoPermissionError("You do not have permission to access this dataset.") + if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id: + logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + raise NoPermissionError("You do not have permission to access this dataset.") + if dataset.permission == "partial_members": + user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first() if not user_permission and dataset.tenant_id != user.current_tenant_id and dataset.created_by != user.id: - logging.debug( - f'User {user.id} does not have permission to access dataset {dataset.id}' - ) - raise NoPermissionError( - 'You do not have permission to access this dataset.' - ) + logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod def check_dataset_operator_permission(user: Account = None, dataset: Dataset = None): - if dataset.permission == 'only_me': + if dataset.permission == DatasetPermissionEnum.ONLY_ME: if dataset.created_by != user.id: - raise NoPermissionError('You do not have permission to access this dataset.') + raise NoPermissionError("You do not have permission to access this dataset.") - elif dataset.permission == 'partial_members': + elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: if not any( dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all() ): - raise NoPermissionError('You do not have permission to access this dataset.') + raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod def get_dataset_queries(dataset_id: str, page: int, per_page: int): - dataset_queries = DatasetQuery.query.filter_by(dataset_id=dataset_id) \ - .order_by(db.desc(DatasetQuery.created_at)) \ - .paginate( - page=page, per_page=per_page, max_per_page=100, error_out=False + dataset_queries = ( + DatasetQuery.query.filter_by(dataset_id=dataset_id) + .order_by(db.desc(DatasetQuery.created_at)) + .paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) ) return dataset_queries.items, dataset_queries.total @staticmethod def get_related_apps(dataset_id: str): - return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \ - .order_by(db.desc(AppDatasetJoin.created_at)).all() + return ( + AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) + .order_by(db.desc(AppDatasetJoin.created_at)) + .all() + ) class DocumentService: DEFAULT_RULES = { - 'mode': 'custom', - 'rules': { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': False} + "mode": "custom", + "rules": { + "pre_processing_rules": [ + {"id": "remove_extra_spaces", "enabled": True}, + {"id": "remove_urls_emails", "enabled": False}, ], - 'segmentation': { - 'delimiter': '\n', - 'max_tokens': 500, - 'chunk_overlap': 50 - } - } + "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, + }, } DOCUMENT_METADATA_SCHEMA = { @@ -482,58 +498,55 @@ class DocumentService: "commit_date": str, "commit_author": str, }, - "others": dict + "others": dict, } @staticmethod def get_document(dataset_id: str, document_id: str) -> Optional[Document]: - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) return document @staticmethod def get_document_by_id(document_id: str) -> Optional[Document]: - document = db.session.query(Document).filter( - Document.id == document_id - ).first() + document = db.session.query(Document).filter(Document.id == document_id).first() return document @staticmethod def get_document_by_dataset_id(dataset_id: str) -> list[Document]: - documents = db.session.query(Document).filter( - Document.dataset_id == dataset_id, - Document.enabled == True - ).all() + documents = db.session.query(Document).filter(Document.dataset_id == dataset_id, Document.enabled == True).all() return documents @staticmethod def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]: - documents = db.session.query(Document).filter( - Document.dataset_id == dataset_id, - Document.indexing_status.in_(['error', 'paused']) - ).all() + documents = ( + db.session.query(Document) + .filter(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) + .all() + ) return documents @staticmethod def get_batch_documents(dataset_id: str, batch: str) -> list[Document]: - documents = db.session.query(Document).filter( - Document.batch == batch, - Document.dataset_id == dataset_id, - Document.tenant_id == current_user.current_tenant_id - ).all() + documents = ( + db.session.query(Document) + .filter( + Document.batch == batch, + Document.dataset_id == dataset_id, + Document.tenant_id == current_user.current_tenant_id, + ) + .all() + ) return documents @staticmethod def get_document_file_detail(file_id: str): - file_detail = db.session.query(UploadFile). \ - filter(UploadFile.id == file_id). \ - one_or_none() + file_detail = db.session.query(UploadFile).filter(UploadFile.id == file_id).one_or_none() return file_detail @staticmethod @@ -547,13 +560,14 @@ def check_archived(document): def delete_document(document): # trigger document_was_deleted signal file_id = None - if document.data_source_type == 'upload_file': + if document.data_source_type == "upload_file": if document.data_source_info: data_source_info = document.data_source_info_dict - if data_source_info and 'upload_file_id' in data_source_info: - file_id = data_source_info['upload_file_id'] - document_was_deleted.send(document.id, dataset_id=document.dataset_id, - doc_form=document.doc_form, file_id=file_id) + if data_source_info and "upload_file_id" in data_source_info: + file_id = data_source_info["upload_file_id"] + document_was_deleted.send( + document.id, dataset_id=document.dataset_id, doc_form=document.doc_form, file_id=file_id + ) db.session.delete(document) db.session.commit() @@ -562,15 +576,15 @@ def delete_document(document): def rename_document(dataset_id: str, document_id: str, name: str) -> Document: dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise ValueError('Dataset not found.') + raise ValueError("Dataset not found.") document = DocumentService.get_document(dataset_id, document_id) if not document: - raise ValueError('Document not found.') + raise ValueError("Document not found.") if document.tenant_id != current_user.current_tenant_id: - raise ValueError('No permission.') + raise ValueError("No permission.") document.name = name @@ -581,7 +595,7 @@ def rename_document(dataset_id: str, document_id: str, name: str) -> Document: @staticmethod def pause_document(document): - if document.indexing_status not in ["waiting", "parsing", "cleaning", "splitting", "indexing"]: + if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}: raise DocumentIndexingError() # update document to be paused document.is_paused = True @@ -591,7 +605,7 @@ def pause_document(document): db.session.add(document) db.session.commit() # set document paused flag - indexing_cache_key = 'document_{}_is_paused'.format(document.id) + indexing_cache_key = "document_{}_is_paused".format(document.id) redis_client.setnx(indexing_cache_key, "True") @staticmethod @@ -606,7 +620,7 @@ def recover_document(document): db.session.add(document) db.session.commit() # delete paused flag - indexing_cache_key = 'document_{}_is_paused'.format(document.id) + indexing_cache_key = "document_{}_is_paused".format(document.id) redis_client.delete(indexing_cache_key) # trigger async task recover_document_indexing_task.delay(document.dataset_id, document.id) @@ -615,12 +629,12 @@ def recover_document(document): def retry_document(dataset_id: str, documents: list[Document]): for document in documents: # add retry flag - retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id) + retry_indexing_cache_key = "document_{}_is_retried".format(document.id) cache_result = redis_client.get(retry_indexing_cache_key) if cache_result is not None: raise ValueError("Document is being retried, please try again later") # retry document indexing - document.indexing_status = 'waiting' + document.indexing_status = "waiting" db.session.add(document) db.session.commit() @@ -632,14 +646,14 @@ def retry_document(dataset_id: str, documents: list[Document]): @staticmethod def sync_website_document(dataset_id: str, document: Document): # add sync flag - sync_indexing_cache_key = 'document_{}_is_sync'.format(document.id) + sync_indexing_cache_key = "document_{}_is_sync".format(document.id) cache_result = redis_client.get(sync_indexing_cache_key) if cache_result is not None: raise ValueError("Document is being synced, please try again later") # sync document indexing - document.indexing_status = 'waiting' + document.indexing_status = "waiting" data_source_info = document.data_source_info_dict - data_source_info['mode'] = 'scrape' + data_source_info["mode"] = "scrape" document.data_source_info = json.dumps(data_source_info, ensure_ascii=False) db.session.add(document) db.session.commit() @@ -658,27 +672,28 @@ def get_documents_position(dataset_id): @staticmethod def save_document_with_dataset_id( - dataset: Dataset, document_data: dict, - account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None, - created_from: str = 'web' + dataset: Dataset, + document_data: dict, + account: Account | Any, + dataset_process_rule: Optional[DatasetProcessRule] = None, + created_from: str = "web", ): - # check document limit features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: - if 'original_document_id' not in document_data or not document_data['original_document_id']: + if "original_document_id" not in document_data or not document_data["original_document_id"]: count = 0 if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] count = len(upload_file_list) elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] + notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] for notion_info in notion_info_list: - count = count + len(notion_info['pages']) + count = count + len(notion_info["pages"]) elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]['info_list']['website_info_list'] - count = len(website_info['urls']) + website_info = document_data["data_source"]["info_list"]["website_info_list"] + count = len(website_info["urls"]) batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") @@ -690,46 +705,42 @@ def save_document_with_dataset_id( dataset.data_source_type = document_data["data_source"]["type"] if not dataset.indexing_technique: - if 'indexing_technique' not in document_data \ - or document_data['indexing_technique'] not in Dataset.INDEXING_TECHNIQUE_LIST: + if ( + "indexing_technique" not in document_data + or document_data["indexing_technique"] not in Dataset.INDEXING_TECHNIQUE_LIST + ): raise ValueError("Indexing technique is required") dataset.indexing_technique = document_data["indexing_technique"] - if document_data["indexing_technique"] == 'high_quality': + if document_data["indexing_technique"] == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.TEXT_EMBEDDING + tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING ) dataset.embedding_model = embedding_model.model dataset.embedding_model_provider = embedding_model.provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, - embedding_model.model + embedding_model.provider, embedding_model.model ) dataset.collection_binding_id = dataset_collection_binding.id if not dataset.retrieval_model: default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } - dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get( - 'retrieval_model' - ) else default_retrieval_model + dataset.retrieval_model = document_data.get("retrieval_model") or default_retrieval_model documents = [] - batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) if document_data.get("original_document_id"): document = DocumentService.update_document_with_dataset_id(dataset, document_data, account) documents.append(document) + batch = document.batch else: + batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) # save process rule if not dataset_process_rule: process_rule = document_data["process_rule"] @@ -738,159 +749,177 @@ def save_document_with_dataset_id( dataset_id=dataset.id, mode=process_rule["mode"], rules=json.dumps(process_rule["rules"]), - created_by=account.id + created_by=account.id, ) elif process_rule["mode"] == "automatic": dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule["mode"], rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), - created_by=account.id + created_by=account.id, ) db.session.add(dataset_process_rule) db.session.commit() - position = DocumentService.get_documents_position(dataset.id) - document_ids = [] - duplicate_document_ids = [] - if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] - for file_id in upload_file_list: - file = db.session.query(UploadFile).filter( - UploadFile.tenant_id == dataset.tenant_id, - UploadFile.id == file_id - ).first() + lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) + with redis_client.lock(lock_name, timeout=600): + position = DocumentService.get_documents_position(dataset.id) + document_ids = [] + duplicate_document_ids = [] + if document_data["data_source"]["type"] == "upload_file": + upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] + for file_id in upload_file_list: + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .first() + ) - # raise error if file not found - if not file: - raise FileNotExistsError() + # raise error if file not found + if not file: + raise FileNotExistsError() - file_name = file.name - data_source_info = { - "upload_file_id": file_id, - } - # check duplicate - if document_data.get('duplicate', False): - document = Document.query.filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type='upload_file', - enabled=True, - name=file_name - ).first() - if document: - document.dataset_process_rule_id = dataset_process_rule.id - document.updated_at = datetime.datetime.utcnow() - document.created_from = created_from - document.doc_form = document_data['doc_form'] - document.doc_language = document_data['doc_language'] - document.data_source_info = json.dumps(data_source_info) - document.batch = batch - document.indexing_status = 'waiting' - db.session.add(document) - documents.append(document) - duplicate_document_ids.append(document.id) - continue - document = DocumentService.build_document( - dataset, dataset_process_rule.id, - document_data["data_source"]["type"], - document_data["doc_form"], - document_data["doc_language"], - data_source_info, created_from, position, - account, file_name, batch - ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 - elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] - exist_page_ids = [] - exist_document = {} - documents = Document.query.filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type='notion_import', - enabled=True - ).all() - if documents: - for document in documents: - data_source_info = json.loads(document.data_source_info) - exist_page_ids.append(data_source_info['notion_page_id']) - exist_document[data_source_info['notion_page_id']] = document.id - for notion_info in notion_info_list: - workspace_id = notion_info['workspace_id'] - data_source_binding = DataSourceOauthBinding.query.filter( - db.and_( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', - DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' + file_name = file.name + data_source_info = { + "upload_file_id": file_id, + } + # check duplicate + if document_data.get("duplicate", False): + document = Document.query.filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type="upload_file", + enabled=True, + name=file_name, + ).first() + if document: + document.dataset_process_rule_id = dataset_process_rule.id + document.updated_at = datetime.datetime.utcnow() + document.created_from = created_from + document.doc_form = document_data["doc_form"] + document.doc_language = document_data["doc_language"] + document.data_source_info = json.dumps(data_source_info) + document.batch = batch + document.indexing_status = "waiting" + db.session.add(document) + documents.append(document) + duplicate_document_ids.append(document.id) + continue + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, + document_data["data_source"]["type"], + document_data["doc_form"], + document_data["doc_language"], + data_source_info, + created_from, + position, + account, + file_name, + batch, ) - ).first() - if not data_source_binding: - raise ValueError('Data source binding not found.') - for page in notion_info['pages']: - if page['page_id'] not in exist_page_ids: - data_source_info = { - "notion_workspace_id": workspace_id, - "notion_page_id": page['page_id'], - "notion_page_icon": page['page_icon'], - "type": page['type'] - } - document = DocumentService.build_document( - dataset, dataset_process_rule.id, - document_data["data_source"]["type"], - document_data["doc_form"], - document_data["doc_language"], - data_source_info, created_from, position, - account, page['page_name'], batch + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + elif document_data["data_source"]["type"] == "notion_import": + notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] + exist_page_ids = [] + exist_document = {} + documents = Document.query.filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type="notion_import", + enabled=True, + ).all() + if documents: + for document in documents: + data_source_info = json.loads(document.data_source_info) + exist_page_ids.append(data_source_info["notion_page_id"]) + exist_document[data_source_info["notion_page_id"]] = document.id + for notion_info in notion_info_list: + workspace_id = notion_info["workspace_id"] + data_source_binding = DataSourceOauthBinding.query.filter( + db.and_( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.disabled == False, + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 + ).first() + if not data_source_binding: + raise ValueError("Data source binding not found.") + for page in notion_info["pages"]: + if page["page_id"] not in exist_page_ids: + data_source_info = { + "notion_workspace_id": workspace_id, + "notion_page_id": page["page_id"], + "notion_page_icon": page["page_icon"], + "type": page["type"], + } + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, + document_data["data_source"]["type"], + document_data["doc_form"], + document_data["doc_language"], + data_source_info, + created_from, + position, + account, + page["page_name"], + batch, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + else: + exist_document.pop(page["page_id"]) + # delete not selected documents + if len(exist_document) > 0: + clean_notion_document_task.delay(list(exist_document.values()), dataset.id) + elif document_data["data_source"]["type"] == "website_crawl": + website_info = document_data["data_source"]["info_list"]["website_info_list"] + urls = website_info["urls"] + for url in urls: + data_source_info = { + "url": url, + "provider": website_info["provider"], + "job_id": website_info["job_id"], + "only_main_content": website_info.get("only_main_content", False), + "mode": "crawl", + } + if len(url) > 255: + document_name = url[:200] + "..." else: - exist_document.pop(page['page_id']) - # delete not selected documents - if len(exist_document) > 0: - clean_notion_document_task.delay(list(exist_document.values()), dataset.id) - elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]['info_list']['website_info_list'] - urls = website_info['urls'] - for url in urls: - data_source_info = { - 'url': url, - 'provider': website_info['provider'], - 'job_id': website_info['job_id'], - 'only_main_content': website_info.get('only_main_content', False), - 'mode': 'crawl', - } - if len(url) > 255: - document_name = url[:200] + '...' - else: - document_name = url - document = DocumentService.build_document( - dataset, dataset_process_rule.id, - document_data["data_source"]["type"], - document_data["doc_form"], - document_data["doc_language"], - data_source_info, created_from, position, - account, document_name, batch - ) - db.session.add(document) - db.session.flush() - document_ids.append(document.id) - documents.append(document) - position += 1 - db.session.commit() + document_name = url + document = DocumentService.build_document( + dataset, + dataset_process_rule.id, + document_data["data_source"]["type"], + document_data["doc_form"], + document_data["doc_language"], + data_source_info, + created_from, + position, + account, + document_name, + batch, + ) + db.session.add(document) + db.session.flush() + document_ids.append(document.id) + documents.append(document) + position += 1 + db.session.commit() - # trigger async task - if document_ids: - document_indexing_task.delay(dataset.id, document_ids) - if duplicate_document_ids: - duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) + # trigger async task + if document_ids: + document_indexing_task.delay(dataset.id, document_ids) + if duplicate_document_ids: + duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) return documents, batch @@ -899,15 +928,22 @@ def check_documents_upload_quota(count: int, features: FeatureModel): can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size if count > can_upload_size: raise ValueError( - f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.' + f"You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded." ) @staticmethod def build_document( - dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str, - document_language: str, data_source_info: dict, created_from: str, position: int, + dataset: Dataset, + process_rule_id: str, + data_source_type: str, + document_form: str, + document_language: str, + data_source_info: dict, + created_from: str, + position: int, account: Account, - name: str, batch: str + name: str, + batch: str, ): document = Document( tenant_id=dataset.tenant_id, @@ -921,7 +957,7 @@ def build_document( created_from=created_from, created_by=account.id, doc_form=document_form, - doc_language=document_language + doc_language=document_language, ) return document @@ -931,54 +967,56 @@ def get_tenant_documents_count(): Document.completed_at.isnot(None), Document.enabled == True, Document.archived == False, - Document.tenant_id == current_user.current_tenant_id + Document.tenant_id == current_user.current_tenant_id, ).count() return documents_count @staticmethod def update_document_with_dataset_id( - dataset: Dataset, document_data: dict, - account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None, - created_from: str = 'web' + dataset: Dataset, + document_data: dict, + account: Account, + dataset_process_rule: Optional[DatasetProcessRule] = None, + created_from: str = "web", ): DatasetService.check_dataset_model_setting(dataset) document = DocumentService.get_document(dataset.id, document_data["original_document_id"]) - if document.display_status != 'available': + if document is None: + raise NotFound("Document not found") + if document.display_status != "available": raise ValueError("Document is not available") - # update document name - if document_data.get('name'): - document.name = document_data['name'] # save process rule - if document_data.get('process_rule'): + if document_data.get("process_rule"): process_rule = document_data["process_rule"] if process_rule["mode"] == "custom": dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule["mode"], rules=json.dumps(process_rule["rules"]), - created_by=account.id + created_by=account.id, ) elif process_rule["mode"] == "automatic": dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule["mode"], rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), - created_by=account.id + created_by=account.id, ) db.session.add(dataset_process_rule) db.session.commit() document.dataset_process_rule_id = dataset_process_rule.id # update document data source - if document_data.get('data_source'): - file_name = '' + if document_data.get("data_source"): + file_name = "" data_source_info = {} if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] for file_id in upload_file_list: - file = db.session.query(UploadFile).filter( - UploadFile.tenant_id == dataset.tenant_id, - UploadFile.id == file_id - ).first() + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .first() + ) # raise error if file not found if not file: @@ -989,42 +1027,46 @@ def update_document_with_dataset_id( "upload_file_id": file_id, } elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] + notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] for notion_info in notion_info_list: - workspace_id = notion_info['workspace_id'] + workspace_id = notion_info["workspace_id"] data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', ) ).first() if not data_source_binding: - raise ValueError('Data source binding not found.') - for page in notion_info['pages']: + raise ValueError("Data source binding not found.") + for page in notion_info["pages"]: data_source_info = { "notion_workspace_id": workspace_id, - "notion_page_id": page['page_id'], - "notion_page_icon": page['page_icon'], - "type": page['type'] + "notion_page_id": page["page_id"], + "notion_page_icon": page["page_icon"], + "type": page["type"], } elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]['info_list']['website_info_list'] - urls = website_info['urls'] + website_info = document_data["data_source"]["info_list"]["website_info_list"] + urls = website_info["urls"] for url in urls: data_source_info = { - 'url': url, - 'provider': website_info['provider'], - 'job_id': website_info['job_id'], - 'only_main_content': website_info.get('only_main_content', False), - 'mode': 'crawl', + "url": url, + "provider": website_info["provider"], + "job_id": website_info["job_id"], + "only_main_content": website_info.get("only_main_content", False), + "mode": "crawl", } document.data_source_type = document_data["data_source"]["type"] document.data_source_info = json.dumps(data_source_info) document.name = file_name + + # update document name + if document_data.get("name"): + document.name = document_data["name"] # update document to be waiting - document.indexing_status = 'waiting' + document.indexing_status = "waiting" document.completed_at = None document.processing_started_at = None document.parsing_completed_at = None @@ -1032,13 +1074,11 @@ def update_document_with_dataset_id( document.splitting_completed_at = None document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) document.created_from = created_from - document.doc_form = document_data['doc_form'] + document.doc_form = document_data["doc_form"] db.session.add(document) db.session.commit() # update document segment - update_params = { - DocumentSegment.status: 're_segment' - } + update_params = {DocumentSegment.status: "re_segment"} DocumentSegment.query.filter_by(document_id=document.id).update(update_params) db.session.commit() # trigger async task @@ -1052,60 +1092,50 @@ def save_document_without_dataset_id(tenant_id: str, document_data: dict, accoun if features.billing.enabled: count = 0 if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] count = len(upload_file_list) elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] + notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] for notion_info in notion_info_list: - count = count + len(notion_info['pages']) + count = count + len(notion_info["pages"]) elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]['info_list']['website_info_list'] - count = len(website_info['urls']) + website_info = document_data["data_source"]["info_list"]["website_info_list"] + count = len(website_info["urls"]) batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") DocumentService.check_documents_upload_quota(count, features) - embedding_model = None dataset_collection_binding_id = None retrieval_model = None - if document_data['indexing_technique'] == 'high_quality': - model_manager = ModelManager() - embedding_model = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.TEXT_EMBEDDING - ) + if document_data["indexing_technique"] == "high_quality": dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, - embedding_model.model + document_data["embedding_model_provider"], document_data["embedding_model"] ) dataset_collection_binding_id = dataset_collection_binding.id - if document_data.get('retrieval_model'): - retrieval_model = document_data['retrieval_model'] + if document_data.get("retrieval_model"): + retrieval_model = document_data["retrieval_model"] else: default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } retrieval_model = default_retrieval_model # save dataset dataset = Dataset( tenant_id=tenant_id, - name='', + name="", data_source_type=document_data["data_source"]["type"], - indexing_technique=document_data["indexing_technique"], + indexing_technique=document_data.get("indexing_technique", "high_quality"), created_by=account.id, - embedding_model=embedding_model.model if embedding_model else None, - embedding_model_provider=embedding_model.provider if embedding_model else None, + embedding_model=document_data.get("embedding_model"), + embedding_model_provider=document_data.get("embedding_model_provider"), collection_binding_id=dataset_collection_binding_id, - retrieval_model=retrieval_model + retrieval_model=retrieval_model, ) db.session.add(dataset) @@ -1115,236 +1145,259 @@ def save_document_without_dataset_id(tenant_id: str, document_data: dict, accoun cut_length = 18 cut_name = documents[0].name[:cut_length] - dataset.name = cut_name + '...' - dataset.description = 'useful for when you want to answer queries about the ' + documents[0].name + dataset.name = cut_name + "..." + dataset.description = "useful for when you want to answer queries about the " + documents[0].name db.session.commit() return dataset, documents, batch @classmethod def document_create_args_validate(cls, args: dict): - if 'original_document_id' not in args or not args['original_document_id']: + if "original_document_id" not in args or not args["original_document_id"]: DocumentService.data_source_args_validate(args) DocumentService.process_rule_args_validate(args) else: - if ('data_source' not in args and not args['data_source']) \ - and ('process_rule' not in args and not args['process_rule']): + if ("data_source" not in args or not args["data_source"]) and ( + "process_rule" not in args or not args["process_rule"] + ): raise ValueError("Data source or Process rule is required") else: - if args.get('data_source'): + if args.get("data_source"): DocumentService.data_source_args_validate(args) - if args.get('process_rule'): + if args.get("process_rule"): DocumentService.process_rule_args_validate(args) @classmethod def data_source_args_validate(cls, args: dict): - if 'data_source' not in args or not args['data_source']: + if "data_source" not in args or not args["data_source"]: raise ValueError("Data source is required") - if not isinstance(args['data_source'], dict): + if not isinstance(args["data_source"], dict): raise ValueError("Data source is invalid") - if 'type' not in args['data_source'] or not args['data_source']['type']: + if "type" not in args["data_source"] or not args["data_source"]["type"]: raise ValueError("Data source type is required") - if args['data_source']['type'] not in Document.DATA_SOURCES: + if args["data_source"]["type"] not in Document.DATA_SOURCES: raise ValueError("Data source type is invalid") - if 'info_list' not in args['data_source'] or not args['data_source']['info_list']: + if "info_list" not in args["data_source"] or not args["data_source"]["info_list"]: raise ValueError("Data source info is required") - if args['data_source']['type'] == 'upload_file': - if 'file_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][ - 'file_info_list']: + if args["data_source"]["type"] == "upload_file": + if ( + "file_info_list" not in args["data_source"]["info_list"] + or not args["data_source"]["info_list"]["file_info_list"] + ): raise ValueError("File source info is required") - if args['data_source']['type'] == 'notion_import': - if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][ - 'notion_info_list']: + if args["data_source"]["type"] == "notion_import": + if ( + "notion_info_list" not in args["data_source"]["info_list"] + or not args["data_source"]["info_list"]["notion_info_list"] + ): raise ValueError("Notion source info is required") - if args['data_source']['type'] == 'website_crawl': - if 'website_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][ - 'website_info_list']: + if args["data_source"]["type"] == "website_crawl": + if ( + "website_info_list" not in args["data_source"]["info_list"] + or not args["data_source"]["info_list"]["website_info_list"] + ): raise ValueError("Website source info is required") @classmethod def process_rule_args_validate(cls, args: dict): - if 'process_rule' not in args or not args['process_rule']: + if "process_rule" not in args or not args["process_rule"]: raise ValueError("Process rule is required") - if not isinstance(args['process_rule'], dict): + if not isinstance(args["process_rule"], dict): raise ValueError("Process rule is invalid") - if 'mode' not in args['process_rule'] or not args['process_rule']['mode']: + if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]: raise ValueError("Process rule mode is required") - if args['process_rule']['mode'] not in DatasetProcessRule.MODES: + if args["process_rule"]["mode"] not in DatasetProcessRule.MODES: raise ValueError("Process rule mode is invalid") - if args['process_rule']['mode'] == 'automatic': - args['process_rule']['rules'] = {} + if args["process_rule"]["mode"] == "automatic": + args["process_rule"]["rules"] = {} else: - if 'rules' not in args['process_rule'] or not args['process_rule']['rules']: + if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]: raise ValueError("Process rule rules is required") - if not isinstance(args['process_rule']['rules'], dict): + if not isinstance(args["process_rule"]["rules"], dict): raise ValueError("Process rule rules is invalid") - if 'pre_processing_rules' not in args['process_rule']['rules'] \ - or args['process_rule']['rules']['pre_processing_rules'] is None: + if ( + "pre_processing_rules" not in args["process_rule"]["rules"] + or args["process_rule"]["rules"]["pre_processing_rules"] is None + ): raise ValueError("Process rule pre_processing_rules is required") - if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list): + if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list): raise ValueError("Process rule pre_processing_rules is invalid") unique_pre_processing_rule_dicts = {} - for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']: - if 'id' not in pre_processing_rule or not pre_processing_rule['id']: + for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]: + if "id" not in pre_processing_rule or not pre_processing_rule["id"]: raise ValueError("Process rule pre_processing_rules id is required") - if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES: + if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES: raise ValueError("Process rule pre_processing_rules id is invalid") - if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None: + if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None: raise ValueError("Process rule pre_processing_rules enabled is required") - if not isinstance(pre_processing_rule['enabled'], bool): + if not isinstance(pre_processing_rule["enabled"], bool): raise ValueError("Process rule pre_processing_rules enabled is invalid") - unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule + unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule - args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values()) + args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values()) - if 'segmentation' not in args['process_rule']['rules'] \ - or args['process_rule']['rules']['segmentation'] is None: + if ( + "segmentation" not in args["process_rule"]["rules"] + or args["process_rule"]["rules"]["segmentation"] is None + ): raise ValueError("Process rule segmentation is required") - if not isinstance(args['process_rule']['rules']['segmentation'], dict): + if not isinstance(args["process_rule"]["rules"]["segmentation"], dict): raise ValueError("Process rule segmentation is invalid") - if 'separator' not in args['process_rule']['rules']['segmentation'] \ - or not args['process_rule']['rules']['segmentation']['separator']: + if ( + "separator" not in args["process_rule"]["rules"]["segmentation"] + or not args["process_rule"]["rules"]["segmentation"]["separator"] + ): raise ValueError("Process rule segmentation separator is required") - if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str): + if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str): raise ValueError("Process rule segmentation separator is invalid") - if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \ - or not args['process_rule']['rules']['segmentation']['max_tokens']: + if ( + "max_tokens" not in args["process_rule"]["rules"]["segmentation"] + or not args["process_rule"]["rules"]["segmentation"]["max_tokens"] + ): raise ValueError("Process rule segmentation max_tokens is required") - if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int): + if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): raise ValueError("Process rule segmentation max_tokens is invalid") @classmethod def estimate_args_validate(cls, args: dict): - if 'info_list' not in args or not args['info_list']: + if "info_list" not in args or not args["info_list"]: raise ValueError("Data source info is required") - if not isinstance(args['info_list'], dict): + if not isinstance(args["info_list"], dict): raise ValueError("Data info is invalid") - if 'process_rule' not in args or not args['process_rule']: + if "process_rule" not in args or not args["process_rule"]: raise ValueError("Process rule is required") - if not isinstance(args['process_rule'], dict): + if not isinstance(args["process_rule"], dict): raise ValueError("Process rule is invalid") - if 'mode' not in args['process_rule'] or not args['process_rule']['mode']: + if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]: raise ValueError("Process rule mode is required") - if args['process_rule']['mode'] not in DatasetProcessRule.MODES: + if args["process_rule"]["mode"] not in DatasetProcessRule.MODES: raise ValueError("Process rule mode is invalid") - if args['process_rule']['mode'] == 'automatic': - args['process_rule']['rules'] = {} + if args["process_rule"]["mode"] == "automatic": + args["process_rule"]["rules"] = {} else: - if 'rules' not in args['process_rule'] or not args['process_rule']['rules']: + if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]: raise ValueError("Process rule rules is required") - if not isinstance(args['process_rule']['rules'], dict): + if not isinstance(args["process_rule"]["rules"], dict): raise ValueError("Process rule rules is invalid") - if 'pre_processing_rules' not in args['process_rule']['rules'] \ - or args['process_rule']['rules']['pre_processing_rules'] is None: + if ( + "pre_processing_rules" not in args["process_rule"]["rules"] + or args["process_rule"]["rules"]["pre_processing_rules"] is None + ): raise ValueError("Process rule pre_processing_rules is required") - if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list): + if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list): raise ValueError("Process rule pre_processing_rules is invalid") unique_pre_processing_rule_dicts = {} - for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']: - if 'id' not in pre_processing_rule or not pre_processing_rule['id']: + for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]: + if "id" not in pre_processing_rule or not pre_processing_rule["id"]: raise ValueError("Process rule pre_processing_rules id is required") - if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES: + if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES: raise ValueError("Process rule pre_processing_rules id is invalid") - if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None: + if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None: raise ValueError("Process rule pre_processing_rules enabled is required") - if not isinstance(pre_processing_rule['enabled'], bool): + if not isinstance(pre_processing_rule["enabled"], bool): raise ValueError("Process rule pre_processing_rules enabled is invalid") - unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule + unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule - args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values()) + args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values()) - if 'segmentation' not in args['process_rule']['rules'] \ - or args['process_rule']['rules']['segmentation'] is None: + if ( + "segmentation" not in args["process_rule"]["rules"] + or args["process_rule"]["rules"]["segmentation"] is None + ): raise ValueError("Process rule segmentation is required") - if not isinstance(args['process_rule']['rules']['segmentation'], dict): + if not isinstance(args["process_rule"]["rules"]["segmentation"], dict): raise ValueError("Process rule segmentation is invalid") - if 'separator' not in args['process_rule']['rules']['segmentation'] \ - or not args['process_rule']['rules']['segmentation']['separator']: + if ( + "separator" not in args["process_rule"]["rules"]["segmentation"] + or not args["process_rule"]["rules"]["segmentation"]["separator"] + ): raise ValueError("Process rule segmentation separator is required") - if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str): + if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str): raise ValueError("Process rule segmentation separator is invalid") - if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \ - or not args['process_rule']['rules']['segmentation']['max_tokens']: + if ( + "max_tokens" not in args["process_rule"]["rules"]["segmentation"] + or not args["process_rule"]["rules"]["segmentation"]["max_tokens"] + ): raise ValueError("Process rule segmentation max_tokens is required") - if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int): + if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): raise ValueError("Process rule segmentation max_tokens is invalid") class SegmentService: @classmethod def segment_create_args_validate(cls, args: dict, document: Document): - if document.doc_form == 'qa_model': - if 'answer' not in args or not args['answer']: + if document.doc_form == "qa_model": + if "answer" not in args or not args["answer"]: raise ValueError("Answer is required") - if not args['answer'].strip(): + if not args["answer"].strip(): raise ValueError("Answer is empty") - if 'content' not in args or not args['content'] or not args['content'].strip(): + if "content" not in args or not args["content"] or not args["content"].strip(): raise ValueError("Content is empty") @classmethod def create_segment(cls, args: dict, document: Document, dataset: Dataset): - content = args['content'] + content = args["content"] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[content] - ) - lock_name = 'add_segment_lock_document_id_{}'.format(document.id) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) + lock_name = "add_segment_lock_document_id_{}".format(document.id) with redis_client.lock(lock_name, timeout=600): - max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document_id == document.id - ).scalar() + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .filter(DocumentSegment.document_id == document.id) + .scalar() + ) segment_document = DocumentSegment( tenant_id=current_user.current_tenant_id, dataset_id=document.dataset_id, @@ -1355,25 +1408,29 @@ def create_segment(cls, args: dict, document: Document, dataset: Dataset): content=content, word_count=len(content), tokens=tokens, - status='completed', + status="completed", indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), - created_by=current_user.id + created_by=current_user.id, ) - if document.doc_form == 'qa_model': - segment_document.answer = args['answer'] + if document.doc_form == "qa_model": + segment_document.word_count += len(args["answer"]) + segment_document.answer = args["answer"] db.session.add(segment_document) + # update document word count + document.word_count += segment_document.word_count + db.session.add(document) db.session.commit() # save vector index try: - VectorService.create_segments_vector([args['keywords']], [segment_document], dataset) + VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset) except Exception as e: logging.exception("create segment index failed") segment_document.enabled = False segment_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment_document.status = 'error' + segment_document.status = "error" segment_document.error = str(e) db.session.commit() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first() @@ -1381,33 +1438,37 @@ def create_segment(cls, args: dict, document: Document, dataset: Dataset): @classmethod def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset): - lock_name = 'multi_add_segment_lock_document_id_{}'.format(document.id) + lock_name = "multi_add_segment_lock_document_id_{}".format(document.id) + increment_word_count = 0 with redis_client.lock(lock_name, timeout=600): embedding_model = None - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) - max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document_id == document.id - ).scalar() + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .filter(DocumentSegment.document_id == document.id) + .scalar() + ) pre_segment_data_list = [] segment_data_list = [] keywords_list = [] for segment_item in segments: - content = segment_item['content'] + content = segment_item["content"] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == 'high_quality' and embedding_model: + if dataset.indexing_technique == "high_quality" and embedding_model: # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[content] - ) + if document.doc_form == "qa_model": + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment_item["answer"]]) + else: + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) segment_document = DocumentSegment( tenant_id=current_user.current_tenant_id, dataset_id=document.dataset_id, @@ -1418,22 +1479,26 @@ def multi_create_segment(cls, segments: list, document: Document, dataset: Datas content=content, word_count=len(content), tokens=tokens, - status='completed', + status="completed", indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), - created_by=current_user.id + created_by=current_user.id, ) - if document.doc_form == 'qa_model': - segment_document.answer = segment_item['answer'] + if document.doc_form == "qa_model": + segment_document.answer = segment_item["answer"] + segment_document.word_count += len(segment_item["answer"]) + increment_word_count += segment_document.word_count db.session.add(segment_document) segment_data_list.append(segment_document) pre_segment_data_list.append(segment_document) - if 'keywords' in segment_item: - keywords_list.append(segment_item['keywords']) + if "keywords" in segment_item: + keywords_list.append(segment_item["keywords"]) else: keywords_list.append(None) - + # update document word count + document.word_count += increment_word_count + db.session.add(document) try: # save vector index VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset) @@ -1442,19 +1507,20 @@ def multi_create_segment(cls, segments: list, document: Document, dataset: Datas for segment_document in segment_data_list: segment_document.enabled = False segment_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment_document.status = 'error' + segment_document.status = "error" segment_document.error = str(e) db.session.commit() return segment_data_list @classmethod def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset): - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + segment_update_entity = SegmentUpdateEntity(**args) + indexing_cache_key = "segment_{}_indexing".format(segment.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise ValueError("Segment is indexing, please try again later") - if 'enabled' in args and args['enabled'] is not None: - action = args['enabled'] + if segment_update_entity.enabled is not None: + action = segment_update_entity.enabled if segment.enabled != action: if not action: segment.enabled = action @@ -1467,58 +1533,56 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document disable_segment_from_index_task.delay(segment.id) return segment if not segment.enabled: - if 'enabled' in args and args['enabled'] is not None: - if not args['enabled']: + if segment_update_entity.enabled is not None: + if not segment_update_entity.enabled: raise ValueError("Can't update disabled segment") else: raise ValueError("Can't update disabled segment") try: - content = args['content'] + word_count_change = segment.word_count + content = segment_update_entity.content if segment.content == content: - if document.doc_form == 'qa_model': - segment.answer = args['answer'] - if args.get('keywords'): - segment.keywords = args['keywords'] + segment.word_count = len(content) + if document.doc_form == "qa_model": + segment.answer = segment_update_entity.answer + segment.word_count += len(segment_update_entity.answer) + word_count_change = segment.word_count - word_count_change + if segment_update_entity.keywords: + segment.keywords = segment_update_entity.keywords segment.enabled = True segment.disabled_at = None segment.disabled_by = None db.session.add(segment) db.session.commit() + # update document word count + if word_count_change != 0: + document.word_count = max(0, document.word_count + word_count_change) + db.session.add(document) # update segment index task - if 'keywords' in args: - keyword = Keyword(dataset) - keyword.delete_by_ids([segment.index_node_id]) - document = RAGDocument( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - } - ) - keyword.add_texts([document], keywords_list=[args['keywords']]) + if segment_update_entity.enabled: + VectorService.create_segments_vector([segment_update_entity.keywords], [segment], dataset) else: segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[content] - ) + if document.doc_form == "qa_model": + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer]) + else: + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) segment.content = content segment.index_node_hash = segment_hash segment.word_count = len(content) segment.tokens = tokens - segment.status = 'completed' + segment.status = "completed" segment.indexing_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) segment.completed_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) segment.updated_by = current_user.id @@ -1526,18 +1590,24 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document segment.enabled = True segment.disabled_at = None segment.disabled_by = None - if document.doc_form == 'qa_model': - segment.answer = args['answer'] + if document.doc_form == "qa_model": + segment.answer = segment_update_entity.answer + segment.word_count += len(segment_update_entity.answer) + word_count_change = segment.word_count - word_count_change + # update document word count + if word_count_change != 0: + document.word_count = max(0, document.word_count + word_count_change) + db.session.add(document) db.session.add(segment) db.session.commit() # update segment vector index - VectorService.update_segment_vector(args['keywords'], segment, dataset) + VectorService.update_segment_vector(segment_update_entity.keywords, segment, dataset) except Exception as e: logging.exception("update segment index failed") segment.enabled = False segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment.status = 'error' + segment.status = "error" segment.error = str(e) db.session.commit() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first() @@ -1545,7 +1615,7 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document @classmethod def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset): - indexing_cache_key = 'segment_{}_delete_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_delete_indexing".format(segment.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise ValueError("Segment is deleting.") @@ -1556,30 +1626,34 @@ def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: D redis_client.setex(indexing_cache_key, 600, 1) delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id) db.session.delete(segment) + # update document word count + document.word_count -= segment.word_count + db.session.add(document) db.session.commit() class DatasetCollectionBindingService: @classmethod def get_dataset_collection_binding( - cls, provider_name: str, model_name: str, - collection_type: str = 'dataset' + cls, provider_name: str, model_name: str, collection_type: str = "dataset" ) -> DatasetCollectionBinding: - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter( - DatasetCollectionBinding.provider_name == provider_name, - DatasetCollectionBinding.model_name == model_name, - DatasetCollectionBinding.type == collection_type - ). \ - order_by(DatasetCollectionBinding.created_at). \ - first() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter( + DatasetCollectionBinding.provider_name == provider_name, + DatasetCollectionBinding.model_name == model_name, + DatasetCollectionBinding.type == collection_type, + ) + .order_by(DatasetCollectionBinding.created_at) + .first() + ) if not dataset_collection_binding: dataset_collection_binding = DatasetCollectionBinding( provider_name=provider_name, model_name=model_name, collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), - type=collection_type + type=collection_type, ) db.session.add(dataset_collection_binding) db.session.commit() @@ -1587,16 +1661,16 @@ def get_dataset_collection_binding( @classmethod def get_dataset_collection_binding_by_id_and_type( - cls, collection_binding_id: str, - collection_type: str = 'dataset' + cls, collection_binding_id: str, collection_type: str = "dataset" ) -> DatasetCollectionBinding: - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter( - DatasetCollectionBinding.id == collection_binding_id, - DatasetCollectionBinding.type == collection_type - ). \ - order_by(DatasetCollectionBinding.created_at). \ - first() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter( + DatasetCollectionBinding.id == collection_binding_id, DatasetCollectionBinding.type == collection_type + ) + .order_by(DatasetCollectionBinding.created_at) + .first() + ) return dataset_collection_binding @@ -1604,11 +1678,13 @@ def get_dataset_collection_binding_by_id_and_type( class DatasetPermissionService: @classmethod def get_dataset_partial_member_list(cls, dataset_id): - user_list_query = db.session.query( - DatasetPermission.account_id, - ).filter( - DatasetPermission.dataset_id == dataset_id - ).all() + user_list_query = ( + db.session.query( + DatasetPermission.account_id, + ) + .filter(DatasetPermission.dataset_id == dataset_id) + .all() + ) user_list = [] for user in user_list_query: @@ -1625,7 +1701,7 @@ def update_partial_member_list(cls, tenant_id, dataset_id, user_list): permission = DatasetPermission( tenant_id=tenant_id, dataset_id=dataset_id, - account_id=user['user_id'], + account_id=user["user_id"], ) permissions.append(permission) @@ -1638,19 +1714,19 @@ def update_partial_member_list(cls, tenant_id, dataset_id, user_list): @classmethod def check_permission(cls, user, dataset, requested_permission, requested_partial_member_list): if not user.is_dataset_editor: - raise NoPermissionError('User does not have permission to edit this dataset.') + raise NoPermissionError("User does not have permission to edit this dataset.") if user.is_dataset_operator and dataset.permission != requested_permission: - raise NoPermissionError('Dataset operators cannot change the dataset permissions.') + raise NoPermissionError("Dataset operators cannot change the dataset permissions.") - if user.is_dataset_operator and requested_permission == 'partial_members': + if user.is_dataset_operator and requested_permission == "partial_members": if not requested_partial_member_list: - raise ValueError('Partial member list is required when setting to partial members.') + raise ValueError("Partial member list is required when setting to partial members.") local_member_list = cls.get_dataset_partial_member_list(dataset.id) - request_member_list = [user['user_id'] for user in requested_partial_member_list] + request_member_list = [user["user_id"] for user in requested_partial_member_list] if set(local_member_list) != set(request_member_list): - raise ValueError('Dataset operators cannot change the dataset permissions.') + raise ValueError("Dataset operators cannot change the dataset permissions.") @classmethod def clear_partial_member_list(cls, dataset_id): diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index c483d28152c570..92098f06cca538 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -4,17 +4,17 @@ class EnterpriseRequest: - base_url = os.environ.get('ENTERPRISE_API_URL', 'ENTERPRISE_API_URL') - secret_key = os.environ.get('ENTERPRISE_API_SECRET_KEY', 'ENTERPRISE_API_SECRET_KEY') + base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL") + secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") + + proxies = { + "http": None, + "https": None, + } @classmethod def send_request(cls, method, endpoint, json=None, params=None): - headers = { - "Content-Type": "application/json", - "Enterprise-Api-Secret-Key": cls.secret_key - } - + headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" - response = requests.request(method, url, json=json, params=params, headers=headers) - + response = requests.request(method, url, json=json, params=params, headers=headers, proxies=cls.proxies) return response.json() diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 115d0d55232c29..abc01ddf8f58b0 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -2,7 +2,10 @@ class EnterpriseService: - @classmethod def get_info(cls): - return EnterpriseRequest.send_request('GET', '/info') + return EnterpriseRequest.send_request("GET", "/info") + + @classmethod + def get_app_web_sso_enabled(cls, app_code): + return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}") diff --git a/api/services/entities/external_knowledge_entities/external_knowledge_entities.py b/api/services/entities/external_knowledge_entities/external_knowledge_entities.py new file mode 100644 index 00000000000000..4545f385eb9891 --- /dev/null +++ b/api/services/entities/external_knowledge_entities/external_knowledge_entities.py @@ -0,0 +1,26 @@ +from typing import Literal, Optional, Union + +from pydantic import BaseModel + + +class AuthorizationConfig(BaseModel): + type: Literal[None, "basic", "bearer", "custom"] + api_key: Union[None, str] = None + header: Union[None, str] = None + + +class Authorization(BaseModel): + type: Literal["no-auth", "api-key"] + config: Optional[AuthorizationConfig] = None + + +class ProcessStatusSetting(BaseModel): + request_method: str + url: str + + +class ExternalKnowledgeApiSetting(BaseModel): + url: str + request_method: str + headers: Optional[dict] = None + params: Optional[dict] = None diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py new file mode 100644 index 00000000000000..449b79f339b9a9 --- /dev/null +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -0,0 +1,10 @@ +from typing import Optional + +from pydantic import BaseModel + + +class SegmentUpdateEntity(BaseModel): + content: str + answer: Optional[str] = None + keywords: Optional[list[str]] = None + enabled: Optional[bool] = None diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index e5e4d7e23586d5..c519f0b0e51b68 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -22,14 +22,16 @@ class CustomConfigurationStatus(Enum): """ Enum class for custom configuration status. """ - ACTIVE = 'active' - NO_CONFIGURE = 'no-configure' + + ACTIVE = "active" + NO_CONFIGURE = "no-configure" class CustomConfigurationResponse(BaseModel): """ Model class for provider custom configuration response. """ + status: CustomConfigurationStatus @@ -37,6 +39,7 @@ class SystemConfigurationResponse(BaseModel): """ Model class for provider system configuration response. """ + enabled: bool current_quota_type: Optional[ProviderQuotaType] = None quota_configurations: list[QuotaConfiguration] = [] @@ -46,6 +49,7 @@ class ProviderResponse(BaseModel): """ Model class for provider response. """ + provider: str label: I18nObject description: Optional[I18nObject] = None @@ -67,18 +71,15 @@ class ProviderResponse(BaseModel): def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = (dify_config.CONSOLE_API_URL - + f"/console/api/workspaces/current/model-providers/{self.provider}") + url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" if self.icon_small is not None: self.icon_small = I18nObject( - en_US=f"{url_prefix}/icon_small/en_US", - zh_Hans=f"{url_prefix}/icon_small/zh_Hans" + en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" ) if self.icon_large is not None: self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", - zh_Hans=f"{url_prefix}/icon_large/zh_Hans" + en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) @@ -86,6 +87,7 @@ class ProviderWithModelsResponse(BaseModel): """ Model class for provider with models response. """ + provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -96,18 +98,15 @@ class ProviderWithModelsResponse(BaseModel): def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = (dify_config.CONSOLE_API_URL - + f"/console/api/workspaces/current/model-providers/{self.provider}") + url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" if self.icon_small is not None: self.icon_small = I18nObject( - en_US=f"{url_prefix}/icon_small/en_US", - zh_Hans=f"{url_prefix}/icon_small/zh_Hans" + en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" ) if self.icon_large is not None: self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", - zh_Hans=f"{url_prefix}/icon_large/zh_Hans" + en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) @@ -119,18 +118,15 @@ class SimpleProviderEntityResponse(SimpleProviderEntity): def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = (dify_config.CONSOLE_API_URL - + f"/console/api/workspaces/current/model-providers/{self.provider}") + url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" if self.icon_small is not None: self.icon_small = I18nObject( - en_US=f"{url_prefix}/icon_small/en_US", - zh_Hans=f"{url_prefix}/icon_small/zh_Hans" + en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" ) if self.icon_large is not None: self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", - zh_Hans=f"{url_prefix}/icon_large/zh_Hans" + en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) @@ -138,6 +134,7 @@ class DefaultModelResponse(BaseModel): """ Default model entity. """ + model: str model_type: ModelType provider: SimpleProviderEntityResponse @@ -150,6 +147,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity): """ Model with provider entity. """ + provider: SimpleProviderEntityResponse def __init__(self, model: ModelWithProviderEntity) -> None: diff --git a/api/services/errors/account.py b/api/services/errors/account.py index ddc2dbdea86265..5aca12ffeb9891 100644 --- a/api/services/errors/account.py +++ b/api/services/errors/account.py @@ -1,7 +1,7 @@ from services.errors.base import BaseServiceError -class AccountNotFound(BaseServiceError): +class AccountNotFoundError(BaseServiceError): pass @@ -13,6 +13,10 @@ class AccountLoginError(BaseServiceError): pass +class AccountPasswordError(BaseServiceError): + pass + + class AccountNotLinkTenantError(BaseServiceError): pass @@ -25,7 +29,7 @@ class LinkAccountIntegrateError(BaseServiceError): pass -class TenantNotFound(BaseServiceError): +class TenantNotFoundError(BaseServiceError): pass @@ -55,4 +59,3 @@ class RoleAlreadyAssignedError(BaseServiceError): class RateLimitExceededError(BaseServiceError): pass - diff --git a/api/services/errors/base.py b/api/services/errors/base.py index f5d41e17f1142d..4d39f956b8c932 100644 --- a/api/services/errors/base.py +++ b/api/services/errors/base.py @@ -1,3 +1,6 @@ +from typing import Optional + + class BaseServiceError(Exception): - def __init__(self, description: str = None): - self.description = description \ No newline at end of file + def __init__(self, description: Optional[str] = None): + self.description = description diff --git a/api/services/errors/llm.py b/api/services/errors/llm.py new file mode 100644 index 00000000000000..e4fac6f7450040 --- /dev/null +++ b/api/services/errors/llm.py @@ -0,0 +1,19 @@ +from typing import Optional + + +class InvokeError(Exception): + """Base class for all LLM exceptions.""" + + description: Optional[str] = None + + def __init__(self, description: Optional[str] = None) -> None: + self.description = description + + def __str__(self): + return self.description or self.__class__.__name__ + + +class InvokeRateLimitError(InvokeError): + """Raised when the Invoke returns rate limit error.""" + + description = "Rate Limit Error" diff --git a/api/services/errors/workspace.py b/api/services/errors/workspace.py new file mode 100644 index 00000000000000..714064ffdf8c3d --- /dev/null +++ b/api/services/errors/workspace.py @@ -0,0 +1,9 @@ +from services.errors.base import BaseServiceError + + +class WorkSpaceNotAllowedCreateError(BaseServiceError): + pass + + +class WorkSpaceNotFoundError(BaseServiceError): + pass diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py new file mode 100644 index 00000000000000..98e5d9face03d1 --- /dev/null +++ b/api/services/external_knowledge_service.py @@ -0,0 +1,276 @@ +import json +from copy import deepcopy +from datetime import datetime, timezone +from typing import Any, Optional, Union + +import httpx +import validators + +from constants import HIDDEN_VALUE +from core.helper import ssrf_proxy +from extensions.ext_database import db +from models.dataset import ( + Dataset, + ExternalKnowledgeApis, + ExternalKnowledgeBindings, +) +from services.entities.external_knowledge_entities.external_knowledge_entities import ( + Authorization, + ExternalKnowledgeApiSetting, +) +from services.errors.dataset import DatasetNameDuplicateError + + +class ExternalDatasetService: + @staticmethod + def get_external_knowledge_apis(page, per_page, tenant_id, search=None) -> tuple[list[ExternalKnowledgeApis], int]: + query = ExternalKnowledgeApis.query.filter(ExternalKnowledgeApis.tenant_id == tenant_id).order_by( + ExternalKnowledgeApis.created_at.desc() + ) + if search: + query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%")) + + external_knowledge_apis = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) + + return external_knowledge_apis.items, external_knowledge_apis.total + + @classmethod + def validate_api_list(cls, api_settings: dict): + if not api_settings: + raise ValueError("api list is empty") + if "endpoint" not in api_settings and not api_settings["endpoint"]: + raise ValueError("endpoint is required") + if "api_key" not in api_settings and not api_settings["api_key"]: + raise ValueError("api_key is required") + + @staticmethod + def create_external_knowledge_api(tenant_id: str, user_id: str, args: dict) -> ExternalKnowledgeApis: + ExternalDatasetService.check_endpoint_and_api_key(args.get("settings")) + external_knowledge_api = ExternalKnowledgeApis( + tenant_id=tenant_id, + created_by=user_id, + updated_by=user_id, + name=args.get("name"), + description=args.get("description", ""), + settings=json.dumps(args.get("settings"), ensure_ascii=False), + ) + + db.session.add(external_knowledge_api) + db.session.commit() + return external_knowledge_api + + @staticmethod + def check_endpoint_and_api_key(settings: dict): + if "endpoint" not in settings or not settings["endpoint"]: + raise ValueError("endpoint is required") + if "api_key" not in settings or not settings["api_key"]: + raise ValueError("api_key is required") + + endpoint = f"{settings['endpoint']}/retrieval" + api_key = settings["api_key"] + if not validators.url(endpoint, simple_host=True): + raise ValueError(f"invalid endpoint: {endpoint}") + try: + response = httpx.post(endpoint, headers={"Authorization": f"Bearer {api_key}"}) + except Exception as e: + raise ValueError(f"failed to connect to the endpoint: {endpoint}") + if response.status_code == 502: + raise ValueError(f"Bad Gateway: failed to connect to the endpoint: {endpoint}") + if response.status_code == 404: + raise ValueError(f"Not Found: failed to connect to the endpoint: {endpoint}") + if response.status_code == 403: + raise ValueError(f"Forbidden: Authorization failed with api_key: {api_key}") + + @staticmethod + def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis: + return ExternalKnowledgeApis.query.filter_by(id=external_knowledge_api_id).first() + + @staticmethod + def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis: + external_knowledge_api = ExternalKnowledgeApis.query.filter_by( + id=external_knowledge_api_id, tenant_id=tenant_id + ).first() + if external_knowledge_api is None: + raise ValueError("api template not found") + if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE: + args.get("settings")["api_key"] = external_knowledge_api.settings_dict.get("api_key") + + external_knowledge_api.name = args.get("name") + external_knowledge_api.description = args.get("description", "") + external_knowledge_api.settings = json.dumps(args.get("settings"), ensure_ascii=False) + external_knowledge_api.updated_by = user_id + external_knowledge_api.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + db.session.commit() + + return external_knowledge_api + + @staticmethod + def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str): + external_knowledge_api = ExternalKnowledgeApis.query.filter_by( + id=external_knowledge_api_id, tenant_id=tenant_id + ).first() + if external_knowledge_api is None: + raise ValueError("api template not found") + + db.session.delete(external_knowledge_api) + db.session.commit() + + @staticmethod + def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]: + count = ExternalKnowledgeBindings.query.filter_by(external_knowledge_api_id=external_knowledge_api_id).count() + if count > 0: + return True, count + return False, 0 + + @staticmethod + def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings: + external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by( + dataset_id=dataset_id, tenant_id=tenant_id + ).first() + if not external_knowledge_binding: + raise ValueError("external knowledge binding not found") + return external_knowledge_binding + + @staticmethod + def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict): + external_knowledge_api = ExternalKnowledgeApis.query.filter_by( + id=external_knowledge_api_id, tenant_id=tenant_id + ).first() + if external_knowledge_api is None: + raise ValueError("api template not found") + settings = json.loads(external_knowledge_api.settings) + for setting in settings: + custom_parameters = setting.get("document_process_setting") + if custom_parameters: + for parameter in custom_parameters: + if parameter.get("required", False) and not process_parameter.get(parameter.get("name")): + raise ValueError(f'{parameter.get("name")} is required') + + @staticmethod + def process_external_api( + settings: ExternalKnowledgeApiSetting, files: Union[None, dict[str, Any]] + ) -> httpx.Response: + """ + do http request depending on api bundle + """ + + kwargs = { + "url": settings.url, + "headers": settings.headers, + "follow_redirects": True, + } + + response = getattr(ssrf_proxy, settings.request_method)(data=json.dumps(settings.params), files=files, **kwargs) + + return response + + @staticmethod + def assembling_headers(authorization: Authorization, headers: Optional[dict] = None) -> dict[str, Any]: + authorization = deepcopy(authorization) + if headers: + headers = deepcopy(headers) + else: + headers = {} + if authorization.type == "api-key": + if authorization.config is None: + raise ValueError("authorization config is required") + + if authorization.config.api_key is None: + raise ValueError("api_key is required") + + if not authorization.config.header: + authorization.config.header = "Authorization" + + if authorization.config.type == "bearer": + headers[authorization.config.header] = f"Bearer {authorization.config.api_key}" + elif authorization.config.type == "basic": + headers[authorization.config.header] = f"Basic {authorization.config.api_key}" + elif authorization.config.type == "custom": + headers[authorization.config.header] = authorization.config.api_key + + return headers + + @staticmethod + def get_external_knowledge_api_settings(settings: dict) -> ExternalKnowledgeApiSetting: + return ExternalKnowledgeApiSetting.parse_obj(settings) + + @staticmethod + def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset: + # check if dataset name already exists + if Dataset.query.filter_by(name=args.get("name"), tenant_id=tenant_id).first(): + raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.") + external_knowledge_api = ExternalKnowledgeApis.query.filter_by( + id=args.get("external_knowledge_api_id"), tenant_id=tenant_id + ).first() + + if external_knowledge_api is None: + raise ValueError("api template not found") + + dataset = Dataset( + tenant_id=tenant_id, + name=args.get("name"), + description=args.get("description", ""), + provider="external", + retrieval_model=args.get("external_retrieval_model"), + created_by=user_id, + ) + + db.session.add(dataset) + db.session.flush() + + external_knowledge_binding = ExternalKnowledgeBindings( + tenant_id=tenant_id, + dataset_id=dataset.id, + external_knowledge_api_id=args.get("external_knowledge_api_id"), + external_knowledge_id=args.get("external_knowledge_id"), + created_by=user_id, + ) + db.session.add(external_knowledge_binding) + + db.session.commit() + + return dataset + + @staticmethod + def fetch_external_knowledge_retrieval( + tenant_id: str, dataset_id: str, query: str, external_retrieval_parameters: dict + ) -> list: + external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by( + dataset_id=dataset_id, tenant_id=tenant_id + ).first() + if not external_knowledge_binding: + raise ValueError("external knowledge binding not found") + + external_knowledge_api = ExternalKnowledgeApis.query.filter_by( + id=external_knowledge_binding.external_knowledge_api_id + ).first() + if not external_knowledge_api: + raise ValueError("external api template not found") + + settings = json.loads(external_knowledge_api.settings) + headers = {"Content-Type": "application/json"} + if settings.get("api_key"): + headers["Authorization"] = f"Bearer {settings.get('api_key')}" + score_threshold_enabled = external_retrieval_parameters.get("score_threshold_enabled") or False + score_threshold = external_retrieval_parameters.get("score_threshold", 0.0) if score_threshold_enabled else 0.0 + request_params = { + "retrieval_setting": { + "top_k": external_retrieval_parameters.get("top_k"), + "score_threshold": score_threshold, + }, + "query": query, + "knowledge_id": external_knowledge_binding.external_knowledge_id, + } + + external_knowledge_api_setting = { + "url": f"{settings.get('endpoint')}/retrieval", + "request_method": "post", + "headers": headers, + "params": request_params, + } + response = ExternalDatasetService.process_external_api( + ExternalKnowledgeApiSetting(**external_knowledge_api_setting), None + ) + if response.status_code == 200: + return response.json().get("records", []) + return [] diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 83e675a9d2e43a..c321393bc53f66 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -6,8 +6,8 @@ class SubscriptionModel(BaseModel): - plan: str = 'sandbox' - interval: str = '' + plan: str = "sandbox" + interval: str = "" class BillingModel(BaseModel): @@ -27,7 +27,7 @@ class FeatureModel(BaseModel): vector_space: LimitationModel = LimitationModel(size=0, limit=5) annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10) documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50) - docs_processing: str = 'standard' + docs_processing: str = "standard" can_replace_logo: bool = False model_load_balancing_enabled: bool = False dataset_operator_enabled: bool = False @@ -38,13 +38,18 @@ class FeatureModel(BaseModel): class SystemFeatureModel(BaseModel): sso_enforced_for_signin: bool = False - sso_enforced_for_signin_protocol: str = '' + sso_enforced_for_signin_protocol: str = "" sso_enforced_for_web: bool = False - sso_enforced_for_web_protocol: str = '' + sso_enforced_for_web_protocol: str = "" + enable_web_sso_switch_component: bool = False + enable_email_code_login: bool = False + enable_email_password_login: bool = True + enable_social_oauth_login: bool = False + is_allow_register: bool = False + is_allow_create_workspace: bool = False class FeatureService: - @classmethod def get_features(cls, tenant_id: str) -> FeatureModel: features = FeatureModel() @@ -60,11 +65,23 @@ def get_features(cls, tenant_id: str) -> FeatureModel: def get_system_features(cls) -> SystemFeatureModel: system_features = SystemFeatureModel() + cls._fulfill_system_params_from_env(system_features) + if dify_config.ENTERPRISE_ENABLED: + system_features.enable_web_sso_switch_component = True + cls._fulfill_params_from_enterprise(system_features) return system_features + @classmethod + def _fulfill_system_params_from_env(cls, system_features: SystemFeatureModel): + system_features.enable_email_code_login = dify_config.ENABLE_EMAIL_CODE_LOGIN + system_features.enable_email_password_login = dify_config.ENABLE_EMAIL_PASSWORD_LOGIN + system_features.enable_social_oauth_login = dify_config.ENABLE_SOCIAL_OAUTH_LOGIN + system_features.is_allow_register = dify_config.ALLOW_REGISTER + system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE + @classmethod def _fulfill_params_from_env(cls, features: FeatureModel): features.can_replace_logo = dify_config.CAN_REPLACE_LOGO @@ -75,44 +92,56 @@ def _fulfill_params_from_env(cls, features: FeatureModel): def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): billing_info = BillingService.get_info(tenant_id) - features.billing.enabled = billing_info['enabled'] - features.billing.subscription.plan = billing_info['subscription']['plan'] - features.billing.subscription.interval = billing_info['subscription']['interval'] + features.billing.enabled = billing_info["enabled"] + features.billing.subscription.plan = billing_info["subscription"]["plan"] + features.billing.subscription.interval = billing_info["subscription"]["interval"] - if 'members' in billing_info: - features.members.size = billing_info['members']['size'] - features.members.limit = billing_info['members']['limit'] + if "members" in billing_info: + features.members.size = billing_info["members"]["size"] + features.members.limit = billing_info["members"]["limit"] - if 'apps' in billing_info: - features.apps.size = billing_info['apps']['size'] - features.apps.limit = billing_info['apps']['limit'] + if "apps" in billing_info: + features.apps.size = billing_info["apps"]["size"] + features.apps.limit = billing_info["apps"]["limit"] - if 'vector_space' in billing_info: - features.vector_space.size = billing_info['vector_space']['size'] - features.vector_space.limit = billing_info['vector_space']['limit'] + if "vector_space" in billing_info: + features.vector_space.size = billing_info["vector_space"]["size"] + features.vector_space.limit = billing_info["vector_space"]["limit"] - if 'documents_upload_quota' in billing_info: - features.documents_upload_quota.size = billing_info['documents_upload_quota']['size'] - features.documents_upload_quota.limit = billing_info['documents_upload_quota']['limit'] + if "documents_upload_quota" in billing_info: + features.documents_upload_quota.size = billing_info["documents_upload_quota"]["size"] + features.documents_upload_quota.limit = billing_info["documents_upload_quota"]["limit"] - if 'annotation_quota_limit' in billing_info: - features.annotation_quota_limit.size = billing_info['annotation_quota_limit']['size'] - features.annotation_quota_limit.limit = billing_info['annotation_quota_limit']['limit'] + if "annotation_quota_limit" in billing_info: + features.annotation_quota_limit.size = billing_info["annotation_quota_limit"]["size"] + features.annotation_quota_limit.limit = billing_info["annotation_quota_limit"]["limit"] - if 'docs_processing' in billing_info: - features.docs_processing = billing_info['docs_processing'] + if "docs_processing" in billing_info: + features.docs_processing = billing_info["docs_processing"] - if 'can_replace_logo' in billing_info: - features.can_replace_logo = billing_info['can_replace_logo'] + if "can_replace_logo" in billing_info: + features.can_replace_logo = billing_info["can_replace_logo"] - if 'model_load_balancing_enabled' in billing_info: - features.model_load_balancing_enabled = billing_info['model_load_balancing_enabled'] + if "model_load_balancing_enabled" in billing_info: + features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"] @classmethod def _fulfill_params_from_enterprise(cls, features): enterprise_info = EnterpriseService.get_info() - features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin'] - features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol'] - features.sso_enforced_for_web = enterprise_info['sso_enforced_for_web'] - features.sso_enforced_for_web_protocol = enterprise_info['sso_enforced_for_web_protocol'] + if "sso_enforced_for_signin" in enterprise_info: + features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"] + if "sso_enforced_for_signin_protocol" in enterprise_info: + features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"] + if "sso_enforced_for_web" in enterprise_info: + features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"] + if "sso_enforced_for_web_protocol" in enterprise_info: + features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"] + if "enable_email_code_login" in enterprise_info: + features.enable_email_code_login = enterprise_info["enable_email_code_login"] + if "enable_email_password_login" in enterprise_info: + features.enable_email_password_login = enterprise_info["enable_email_password_login"] + if "is_allow_register" in enterprise_info: + features.is_allow_register = enterprise_info["is_allow_register"] + if "is_allow_create_workspace" in enterprise_info: + features.is_allow_create_workspace = enterprise_info["is_allow_create_workspace"] diff --git a/api/services/file_service.py b/api/services/file_service.py index 9139962240991c..976111502c4ebf 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -1,64 +1,58 @@ import datetime import hashlib import uuid -from collections.abc import Generator -from typing import Union +from typing import Any, Literal, Union from flask_login import current_user -from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound from configs import dify_config -from core.file.upload_file_parser import UploadFileParser +from constants import ( + AUDIO_EXTENSIONS, + DOCUMENT_EXTENSIONS, + IMAGE_EXTENSIONS, + VIDEO_EXTENSIONS, +) +from core.file import helpers as file_helpers from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_database import db from extensions.ext_storage import storage from models.account import Account +from models.enums import CreatedByRole from models.model import EndUser, UploadFile -from services.errors.file import FileTooLargeError, UnsupportedFileTypeError -IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg'] -IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) - -ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'xls', 'docx', 'csv'] -UNSTRUCTURED_ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'xls', - 'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml', 'epub'] +from .errors.file import FileTooLargeError, UnsupportedFileTypeError PREVIEW_WORDS_LIMIT = 3000 class FileService: - @staticmethod - def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile: - filename = file.filename - extension = file.filename.split('.')[-1] + def upload_file( + *, + filename: str, + content: bytes, + mimetype: str, + user: Union[Account, EndUser, Any], + source: Literal["datasets"] | None = None, + source_url: str = "", + ) -> UploadFile: + # get file extension + extension = filename.split(".")[-1].lower() if len(filename) > 200: - filename = filename.split('.')[0][:200] + '.' + extension - etl_type = dify_config.ETL_TYPE - allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if etl_type == 'Unstructured' \ - else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS - if extension.lower() not in allowed_extensions: - raise UnsupportedFileTypeError() - elif only_image and extension.lower() not in IMAGE_EXTENSIONS: - raise UnsupportedFileTypeError() + filename = filename.split(".")[0][:200] + "." + extension - # read file content - file_content = file.read() + if source == "datasets" and extension not in DOCUMENT_EXTENSIONS: + raise UnsupportedFileTypeError() # get file size - file_size = len(file_content) - - if extension.lower() in IMAGE_EXTENSIONS: - file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 - else: - file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 + file_size = len(content) - if file_size > file_size_limit: - message = f'File size exceeded. {file_size} > {file_size_limit}' - raise FileTooLargeError(message) + # check if the file size is exceeded + if not FileService.is_file_size_within_limit(extension=extension, file_size=file_size): + raise FileTooLargeError - # user uuid as file name + # generate file key file_uuid = str(uuid.uuid4()) if isinstance(user, Account): @@ -67,10 +61,10 @@ def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bo # end_user current_tenant_id = user.tenant_id - file_key = 'upload_files/' + current_tenant_id + '/' + file_uuid + '.' + extension + file_key = "upload_files/" + current_tenant_id + "/" + file_uuid + "." + extension # save file to storage - storage.save(file_key, file_content) + storage.save(file_key, content) # save file to db upload_file = UploadFile( @@ -80,12 +74,13 @@ def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bo name=filename, size=file_size, extension=extension, - mime_type=file.mimetype, - created_by_role=('account' if isinstance(user, Account) else 'end_user'), + mime_type=mimetype, + created_by_role=(CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER), created_by=user.id, created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), used=False, - hash=hashlib.sha3_256(file_content).hexdigest() + hash=hashlib.sha3_256(content).hexdigest(), + source_url=source_url, ) db.session.add(upload_file) @@ -93,16 +88,29 @@ def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bo return upload_file + @staticmethod + def is_file_size_within_limit(*, extension: str, file_size: int) -> bool: + if extension in IMAGE_EXTENSIONS: + file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 + elif extension in VIDEO_EXTENSIONS: + file_size_limit = dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT * 1024 * 1024 + elif extension in AUDIO_EXTENSIONS: + file_size_limit = dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT * 1024 * 1024 + else: + file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 + + return file_size <= file_size_limit + @staticmethod def upload_text(text: str, text_name: str) -> UploadFile: if len(text_name) > 200: text_name = text_name[:200] # user uuid as file name file_uuid = str(uuid.uuid4()) - file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.txt' + file_key = "upload_files/" + current_user.current_tenant_id + "/" + file_uuid + ".txt" # save file to storage - storage.save(file_key, text.encode('utf-8')) + storage.save(file_key, text.encode("utf-8")) # save file to db upload_file = UploadFile( @@ -111,13 +119,14 @@ def upload_text(text: str, text_name: str) -> UploadFile: key=file_key, name=text_name, size=len(text), - extension='txt', - mime_type='text/plain', + extension="txt", + mime_type="text/plain", created_by=current_user.id, + created_by_role=CreatedByRole.ACCOUNT, created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), used=True, used_by=current_user.id, - used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), ) db.session.add(upload_file) @@ -126,35 +135,31 @@ def upload_text(text: str, text_name: str) -> UploadFile: return upload_file @staticmethod - def get_file_preview(file_id: str) -> str: - upload_file = db.session.query(UploadFile) \ - .filter(UploadFile.id == file_id) \ - .first() + def get_file_preview(file_id: str): + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found") # extract text from file extension = upload_file.extension - etl_type = dify_config.ETL_TYPE - allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS - if extension.lower() not in allowed_extensions: + if extension.lower() not in DOCUMENT_EXTENSIONS: raise UnsupportedFileTypeError() text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True) - text = text[0:PREVIEW_WORDS_LIMIT] if text else '' + text = text[0:PREVIEW_WORDS_LIMIT] if text else "" return text @staticmethod - def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str) -> tuple[Generator, str]: - result = UploadFileParser.verify_image_file_signature(file_id, timestamp, nonce, sign) + def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str): + result = file_helpers.verify_image_signature( + upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign + ) if not result: raise NotFound("File not found or signature is invalid") - upload_file = db.session.query(UploadFile) \ - .filter(UploadFile.id == file_id) \ - .first() + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found or signature is invalid") @@ -169,10 +174,23 @@ def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str) -> tu return generator, upload_file.mime_type @staticmethod - def get_public_image_preview(file_id: str) -> tuple[Generator, str]: - upload_file = db.session.query(UploadFile) \ - .filter(UploadFile.id == file_id) \ - .first() + def get_file_generator_by_file_id(file_id: str, timestamp: str, nonce: str, sign: str): + result = file_helpers.verify_file_signature(upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign) + if not result: + raise NotFound("File not found or signature is invalid") + + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + + if not upload_file: + raise NotFound("File not found or signature is invalid") + + generator = storage.load(upload_file.key, stream=True) + + return generator, upload_file + + @staticmethod + def get_public_image_preview(file_id: str): + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found or signature is invalid") diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index de5f6994b0ebca..7957b4dc82dfd4 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -3,63 +3,66 @@ from core.rag.datasource.retrieval_service import RetrievalService from core.rag.models.document import Document -from core.rag.retrieval.retrival_methods import RetrievalMethod +from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from models.account import Account from models.dataset import Dataset, DatasetQuery, DocumentSegment default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } class HitTestingService: @classmethod - def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict: + def retrieve( + cls, + dataset: Dataset, + query: str, + account: Account, + retrieval_model: dict, + external_retrieval_model: dict, + limit: int = 10, + ) -> dict: if dataset.available_document_count == 0 or dataset.available_segment_count == 0: return { "query": { "content": query, - "tsne_position": {'x': 0, 'y': 0}, + "tsne_position": {"x": 0, "y": 0}, }, - "records": [] + "records": [], } start = time.perf_counter() # get retrieval model , if the model is not setting , using default if not retrieval_model: - retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model - - all_documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'), - dataset_id=dataset.id, - query=cls.escape_query_for_search(query), - top_k=retrieval_model.get('top_k', 2), - score_threshold=retrieval_model.get('score_threshold', .0) - if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model.get('reranking_model', None) - if retrieval_model['reranking_enable'] else None, - reranking_mode=retrieval_model.get('reranking_mode') - if retrieval_model.get('reranking_mode') else 'reranking_model', - weights=retrieval_model.get('weights', None), - ) + retrieval_model = dataset.retrieval_model or default_retrieval_model + + all_documents = RetrievalService.retrieve( + retrieval_method=retrieval_model.get("search_method", "semantic_search"), + dataset_id=dataset.id, + query=cls.escape_query_for_search(query), + top_k=retrieval_model.get("top_k", 2), + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else 0.0, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", + weights=retrieval_model.get("weights", None), + ) end = time.perf_counter() logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") dataset_query = DatasetQuery( - dataset_id=dataset.id, - content=query, - source='hit_testing', - created_by_role='account', - created_by=account.id + dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id ) db.session.add(dataset_query) @@ -67,46 +70,100 @@ def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_mode return cls.compact_retrieve_response(dataset, query, all_documents) + @classmethod + def external_retrieve( + cls, + dataset: Dataset, + query: str, + account: Account, + external_retrieval_model: dict, + ) -> dict: + if dataset.provider != "external": + return { + "query": {"content": query}, + "records": [], + } + + start = time.perf_counter() + + all_documents = RetrievalService.external_retrieve( + dataset_id=dataset.id, + query=cls.escape_query_for_search(query), + external_retrieval_model=external_retrieval_model, + ) + + end = time.perf_counter() + logging.debug(f"External knowledge hit testing retrieve in {end - start:0.4f} seconds") + + dataset_query = DatasetQuery( + dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id + ) + + db.session.add(dataset_query) + db.session.commit() + + return cls.compact_external_retrieve_response(dataset, query, all_documents) + @classmethod def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]): - i = 0 records = [] - for document in documents: - index_node_id = document.metadata['doc_id'] - segment = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.enabled == True, - DocumentSegment.status == 'completed', - DocumentSegment.index_node_id == index_node_id - ).first() + for document in documents: + index_node_id = document.metadata["doc_id"] + + segment = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.index_node_id == index_node_id, + ) + .first() + ) if not segment: - i += 1 continue record = { "segment": segment, - "score": document.metadata.get('score', None), + "score": document.metadata.get("score", None), } records.append(record) - i += 1 - return { "query": { "content": query, }, - "records": records + "records": records, } + @classmethod + def compact_external_retrieve_response(cls, dataset: Dataset, query: str, documents: list): + records = [] + if dataset.provider == "external": + for document in documents: + record = { + "content": document.get("content", None), + "title": document.get("title", None), + "score": document.get("score", None), + "metadata": document.get("metadata", None), + } + records.append(record) + return { + "query": { + "content": query, + }, + "records": records, + } + @classmethod def hit_testing_args_check(cls, args): - query = args['query'] + query = args["query"] if not query or len(query) > 250: - raise ValueError('Query is required and cannot exceed 250 characters') + raise ValueError("Query is required and cannot exceed 250 characters") @staticmethod def escape_query_for_search(query: str) -> str: diff --git a/api/services/knowledge_service.py b/api/services/knowledge_service.py new file mode 100644 index 00000000000000..02fe1d19bc42be --- /dev/null +++ b/api/services/knowledge_service.py @@ -0,0 +1,45 @@ +import boto3 + +from configs import dify_config + + +class ExternalDatasetTestService: + # this service is only for internal testing + @staticmethod + def knowledge_retrieval(retrieval_setting: dict, query: str, knowledge_id: str): + # get bedrock client + client = boto3.client( + "bedrock-agent-runtime", + aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY, + aws_access_key_id=dify_config.AWS_ACCESS_KEY_ID, + # example: us-east-1 + region_name="us-east-1", + ) + # fetch external knowledge retrieval + response = client.retrieve( + knowledgeBaseId=knowledge_id, + retrievalConfiguration={ + "vectorSearchConfiguration": { + "numberOfResults": retrieval_setting.get("top_k"), + "overrideSearchType": "HYBRID", + } + }, + retrievalQuery={"text": query}, + ) + # parse response + results = [] + if response.get("ResponseMetadata") and response.get("ResponseMetadata").get("HTTPStatusCode") == 200: + if response.get("retrievalResults"): + retrieval_results = response.get("retrievalResults") + for retrieval_result in retrieval_results: + # filter out results with score less than threshold + if retrieval_result.get("score") < retrieval_setting.get("score_threshold", 0.0): + continue + result = { + "metadata": retrieval_result.get("metadata"), + "score": retrieval_result.get("score"), + "title": retrieval_result.get("metadata").get("x-amz-bedrock-kb-source-uri"), + "content": retrieval_result.get("content").get("text"), + } + results.append(result) + return {"records": results} diff --git a/api/services/message_service.py b/api/services/message_service.py index 491a914c776387..f432a77c80e511 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -27,8 +27,15 @@ class MessageService: @classmethod - def pagination_by_first_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - conversation_id: str, first_id: Optional[str], limit: int) -> InfiniteScrollPagination: + def pagination_by_first_id( + cls, + app_model: App, + user: Optional[Union[Account, EndUser]], + conversation_id: str, + first_id: Optional[str], + limit: int, + order: str = "asc", + ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) @@ -36,52 +43,70 @@ def pagination_by_first_id(cls, app_model: App, user: Optional[Union[Account, En return InfiniteScrollPagination(data=[], limit=limit, has_more=False) conversation = ConversationService.get_conversation( - app_model=app_model, - user=user, - conversation_id=conversation_id + app_model=app_model, user=user, conversation_id=conversation_id ) if first_id: - first_message = db.session.query(Message) \ - .filter(Message.conversation_id == conversation.id, Message.id == first_id).first() + first_message = ( + db.session.query(Message) + .filter(Message.conversation_id == conversation.id, Message.id == first_id) + .first() + ) if not first_message: raise FirstMessageNotExistsError() - history_messages = db.session.query(Message).filter( - Message.conversation_id == conversation.id, - Message.created_at < first_message.created_at, - Message.id != first_message.id - ) \ - .order_by(Message.created_at.desc()).limit(limit).all() + history_messages = ( + db.session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.created_at < first_message.created_at, + Message.id != first_message.id, + ) + .order_by(Message.created_at.desc()) + .limit(limit) + .all() + ) else: - history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \ - .order_by(Message.created_at.desc()).limit(limit).all() + history_messages = ( + db.session.query(Message) + .filter(Message.conversation_id == conversation.id) + .order_by(Message.created_at.desc()) + .limit(limit) + .all() + ) has_more = False if len(history_messages) == limit: current_page_first_message = history_messages[-1] - rest_count = db.session.query(Message).filter( - Message.conversation_id == conversation.id, - Message.created_at < current_page_first_message.created_at, - Message.id != current_page_first_message.id - ).count() + rest_count = ( + db.session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.created_at < current_page_first_message.created_at, + Message.id != current_page_first_message.id, + ) + .count() + ) if rest_count > 0: has_more = True - history_messages = list(reversed(history_messages)) + if order == "asc": + history_messages = list(reversed(history_messages)) - return InfiniteScrollPagination( - data=history_messages, - limit=limit, - has_more=has_more - ) + return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more) @classmethod - def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - last_id: Optional[str], limit: int, conversation_id: Optional[str] = None, - include_ids: Optional[list] = None) -> InfiniteScrollPagination: + def pagination_by_last_id( + cls, + app_model: App, + user: Optional[Union[Account, EndUser]], + last_id: Optional[str], + limit: int, + conversation_id: Optional[str] = None, + include_ids: Optional[list] = None, + ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) @@ -89,9 +114,7 @@ def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, End if conversation_id is not None: conversation = ConversationService.get_conversation( - app_model=app_model, - user=user, - conversation_id=conversation_id + app_model=app_model, user=user, conversation_id=conversation_id ) base_query = base_query.filter(Message.conversation_id == conversation.id) @@ -105,10 +128,12 @@ def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, End if not last_message: raise LastMessageNotExistsError() - history_messages = base_query.filter( - Message.created_at < last_message.created_at, - Message.id != last_message.id - ).order_by(Message.created_at.desc()).limit(limit).all() + history_messages = ( + base_query.filter(Message.created_at < last_message.created_at, Message.id != last_message.id) + .order_by(Message.created_at.desc()) + .limit(limit) + .all() + ) else: history_messages = base_query.order_by(Message.created_at.desc()).limit(limit).all() @@ -116,30 +141,22 @@ def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, End if len(history_messages) == limit: current_page_first_message = history_messages[-1] rest_count = base_query.filter( - Message.created_at < current_page_first_message.created_at, - Message.id != current_page_first_message.id + Message.created_at < current_page_first_message.created_at, Message.id != current_page_first_message.id ).count() if rest_count > 0: has_more = True - return InfiniteScrollPagination( - data=history_messages, - limit=limit, - has_more=has_more - ) + return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more) @classmethod - def create_feedback(cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], - rating: Optional[str]) -> MessageFeedback: + def create_feedback( + cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], rating: Optional[str] + ) -> MessageFeedback: if not user: - raise ValueError('user cannot be None') + raise ValueError("user cannot be None") - message = cls.get_message( - app_model=app_model, - user=user, - message_id=message_id - ) + message = cls.get_message(app_model=app_model, user=user, message_id=message_id) feedback = message.user_feedback if isinstance(user, EndUser) else message.admin_feedback @@ -148,14 +165,14 @@ def create_feedback(cls, app_model: App, message_id: str, user: Optional[Union[A elif rating and feedback: feedback.rating = rating elif not rating and not feedback: - raise ValueError('rating cannot be None when feedback not exists') + raise ValueError("rating cannot be None when feedback not exists") else: feedback = MessageFeedback( app_id=app_model.id, conversation_id=message.conversation_id, message_id=message.id, rating=rating, - from_source=('user' if isinstance(user, EndUser) else 'admin'), + from_source=("user" if isinstance(user, EndUser) else "admin"), from_end_user_id=(user.id if isinstance(user, EndUser) else None), from_account_id=(user.id if isinstance(user, Account) else None), ) @@ -167,13 +184,17 @@ def create_feedback(cls, app_model: App, message_id: str, user: Optional[Union[A @classmethod def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id, - Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), - Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), - Message.from_account_id == (user.id if isinstance(user, Account) else None), - ).first() + message = ( + db.session.query(Message) + .filter( + Message.id == message_id, + Message.app_id == app_model.id, + Message.from_source == ("api" if isinstance(user, EndUser) else "console"), + Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Message.from_account_id == (user.id if isinstance(user, Account) else None), + ) + .first() + ) if not message: raise MessageNotExistsError() @@ -181,27 +202,22 @@ def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], me return message @classmethod - def get_suggested_questions_after_answer(cls, app_model: App, user: Optional[Union[Account, EndUser]], - message_id: str, invoke_from: InvokeFrom) -> list[Message]: + def get_suggested_questions_after_answer( + cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str, invoke_from: InvokeFrom + ) -> list[Message]: if not user: - raise ValueError('user cannot be None') + raise ValueError("user cannot be None") - message = cls.get_message( - app_model=app_model, - user=user, - message_id=message_id - ) + message = cls.get_message(app_model=app_model, user=user, message_id=message_id) conversation = ConversationService.get_conversation( - app_model=app_model, - conversation_id=message.conversation_id, - user=user + app_model=app_model, conversation_id=message.conversation_id, user=user ) if not conversation: raise ConversationNotExistsError() - if conversation.status != 'normal': + if conversation.status != "normal": raise ConversationCompletedError() model_manager = ModelManager() @@ -216,24 +232,23 @@ def get_suggested_questions_after_answer(cls, app_model: App, user: Optional[Uni if workflow is None: return [] - app_config = AdvancedChatAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) if not app_config.additional_features.suggested_questions_after_answer: raise SuggestedQuestionsAfterAnswerDisabledError() model_instance = model_manager.get_default_model_instance( - tenant_id=app_model.tenant_id, - model_type=ModelType.LLM + tenant_id=app_model.tenant_id, model_type=ModelType.LLM ) else: if not conversation.override_model_configs: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation.app_model_config_id, - AppModelConfig.app_id == app_model.id - ).first() + app_model_config = ( + db.session.query(AppModelConfig) + .filter( + AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id + ) + .first() + ) else: conversation_override_model_configs = json.loads(conversation.override_model_configs) app_model_config = AppModelConfig( @@ -249,16 +264,13 @@ def get_suggested_questions_after_answer(cls, app_model: App, user: Optional[Uni model_instance = model_manager.get_model_instance( tenant_id=app_model.tenant_id, - provider=app_model_config.model_dict['provider'], + provider=app_model_config.model_dict["provider"], model_type=ModelType.LLM, - model=app_model_config.model_dict['name'] + model=app_model_config.model_dict["name"], ) # get memory of conversation (read-only) - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) histories = memory.get_history_prompt_text( max_token_limit=3000, @@ -267,18 +279,14 @@ def get_suggested_questions_after_answer(cls, app_model: App, user: Optional[Uni with measure_time() as timer: questions = LLMGenerator.generate_suggested_questions_after_answer( - tenant_id=app_model.tenant_id, - histories=histories + tenant_id=app_model.tenant_id, histories=histories ) # get tracing instance trace_manager = TraceQueueManager(app_id=app_model.id) trace_manager.add_trace_task( TraceTask( - TraceTaskName.SUGGESTED_QUESTION_TRACE, - message_id=message_id, - suggested_question=questions, - timer=timer + TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id=message_id, suggested_question=questions, timer=timer ) ) diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 80eb72140d19b5..e7b9422cfe1e08 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -23,7 +23,6 @@ class ModelLoadBalancingService: - def __init__(self) -> None: self.provider_manager = ProviderManager() @@ -46,10 +45,7 @@ def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, raise ValueError(f"Provider {provider} does not exist.") # Enable model load balancing - provider_configuration.enable_model_load_balancing( - model=model, - model_type=ModelType.value_of(model_type) - ) + provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: """ @@ -70,13 +66,11 @@ def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str raise ValueError(f"Provider {provider} does not exist.") # disable model load balancing - provider_configuration.disable_model_load_balancing( - model=model, - model_type=ModelType.value_of(model_type) - ) + provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) - def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, model_type: str) \ - -> tuple[bool, list[dict]]: + def get_load_balancing_configs( + self, tenant_id: str, provider: str, model: str, model_type: str + ) -> tuple[bool, list[dict]]: """ Get load balancing configurations. :param tenant_id: workspace id @@ -107,20 +101,24 @@ def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, is_load_balancing_enabled = True # Get load balancing configurations - load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ + load_balancing_configs = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == tenant_id, - LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model - ).order_by(LoadBalancingModelConfig.created_at).all() + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + ) + .order_by(LoadBalancingModelConfig.created_at) + .all() + ) if provider_configuration.custom_configuration.provider: # check if the inherit configuration exists, # inherit is represented for the provider or model custom credentials inherit_config_exists = False for load_balancing_config in load_balancing_configs: - if load_balancing_config.name == '__inherit__': + if load_balancing_config.name == "__inherit__": inherit_config_exists = True break @@ -133,7 +131,7 @@ def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, else: # move the inherit configuration to the first for i, load_balancing_config in enumerate(load_balancing_configs[:]): - if load_balancing_config.name == '__inherit__': + if load_balancing_config.name == "__inherit__": inherit_config = load_balancing_configs.pop(i) load_balancing_configs.insert(0, inherit_config) @@ -151,7 +149,7 @@ def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, provider=provider, model=model, model_type=model_type, - config_id=load_balancing_config.id + config_id=load_balancing_config.id, ) try: @@ -172,32 +170,32 @@ def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, if variable in credentials: try: credentials[variable] = encrypter.decrypt_token_with_decoding( - credentials.get(variable), - decoding_rsa_key, - decoding_cipher_rsa + credentials.get(variable), decoding_rsa_key, decoding_cipher_rsa ) except ValueError: pass # Obfuscate credentials credentials = provider_configuration.obfuscated_credentials( - credentials=credentials, - credential_form_schemas=credential_schemas.credential_form_schemas + credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas ) - datas.append({ - 'id': load_balancing_config.id, - 'name': load_balancing_config.name, - 'credentials': credentials, - 'enabled': load_balancing_config.enabled, - 'in_cooldown': in_cooldown, - 'ttl': ttl - }) + datas.append( + { + "id": load_balancing_config.id, + "name": load_balancing_config.name, + "credentials": credentials, + "enabled": load_balancing_config.enabled, + "in_cooldown": in_cooldown, + "ttl": ttl, + } + ) return is_load_balancing_enabled, datas - def get_load_balancing_config(self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str) \ - -> Optional[dict]: + def get_load_balancing_config( + self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str + ) -> Optional[dict]: """ Get load balancing configuration. :param tenant_id: workspace id @@ -219,14 +217,17 @@ def get_load_balancing_config(self, tenant_id: str, provider: str, model: str, m model_type = ModelType.value_of(model_type) # Get load balancing configurations - load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \ + load_balancing_model_config = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == tenant_id, - LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model, - LoadBalancingModelConfig.id == config_id - ).first() + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + LoadBalancingModelConfig.id == config_id, + ) + .first() + ) if not load_balancing_model_config: return None @@ -244,19 +245,19 @@ def get_load_balancing_config(self, tenant_id: str, provider: str, model: str, m # Obfuscate credentials credentials = provider_configuration.obfuscated_credentials( - credentials=credentials, - credential_form_schemas=credential_schemas.credential_form_schemas + credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas ) return { - 'id': load_balancing_model_config.id, - 'name': load_balancing_model_config.name, - 'credentials': credentials, - 'enabled': load_balancing_model_config.enabled + "id": load_balancing_model_config.id, + "name": load_balancing_model_config.name, + "credentials": credentials, + "enabled": load_balancing_model_config.enabled, } - def _init_inherit_config(self, tenant_id: str, provider: str, model: str, model_type: ModelType) \ - -> LoadBalancingModelConfig: + def _init_inherit_config( + self, tenant_id: str, provider: str, model: str, model_type: ModelType + ) -> LoadBalancingModelConfig: """ Initialize the inherit configuration. :param tenant_id: workspace id @@ -271,18 +272,16 @@ def _init_inherit_config(self, tenant_id: str, provider: str, model: str, model_ provider_name=provider, model_type=model_type.to_origin_model_type(), model_name=model, - name='__inherit__' + name="__inherit__", ) db.session.add(inherit_config) db.session.commit() return inherit_config - def update_load_balancing_configs(self, tenant_id: str, - provider: str, - model: str, - model_type: str, - configs: list[dict]) -> None: + def update_load_balancing_configs( + self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict] + ) -> None: """ Update load balancing configurations. :param tenant_id: workspace id @@ -304,15 +303,18 @@ def update_load_balancing_configs(self, tenant_id: str, model_type = ModelType.value_of(model_type) if not isinstance(configs, list): - raise ValueError('Invalid load balancing configs') + raise ValueError("Invalid load balancing configs") - current_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ + current_load_balancing_configs = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == tenant_id, - LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model - ).all() + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + ) + .all() + ) # id as key, config as value current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs} @@ -320,25 +322,25 @@ def update_load_balancing_configs(self, tenant_id: str, for config in configs: if not isinstance(config, dict): - raise ValueError('Invalid load balancing config') + raise ValueError("Invalid load balancing config") - config_id = config.get('id') - name = config.get('name') - credentials = config.get('credentials') - enabled = config.get('enabled') + config_id = config.get("id") + name = config.get("name") + credentials = config.get("credentials") + enabled = config.get("enabled") if not name: - raise ValueError('Invalid load balancing config name') + raise ValueError("Invalid load balancing config name") if enabled is None: - raise ValueError('Invalid load balancing config enabled') + raise ValueError("Invalid load balancing config enabled") # is config exists if config_id: config_id = str(config_id) if config_id not in current_load_balancing_configs_dict: - raise ValueError('Invalid load balancing config id: {}'.format(config_id)) + raise ValueError("Invalid load balancing config id: {}".format(config_id)) updated_config_ids.add(config_id) @@ -347,11 +349,11 @@ def update_load_balancing_configs(self, tenant_id: str, # check duplicate name for current_load_balancing_config in current_load_balancing_configs: if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name: - raise ValueError('Load balancing config name {} already exists'.format(name)) + raise ValueError("Load balancing config name {} already exists".format(name)) if credentials: if not isinstance(credentials, dict): - raise ValueError('Invalid load balancing config credentials') + raise ValueError("Invalid load balancing config credentials") # validate custom provider config credentials = self._custom_credentials_validate( @@ -361,7 +363,7 @@ def update_load_balancing_configs(self, tenant_id: str, model=model, credentials=credentials, load_balancing_model_config=load_balancing_config, - validate=False + validate=False, ) # update load balancing config @@ -375,19 +377,19 @@ def update_load_balancing_configs(self, tenant_id: str, self._clear_credentials_cache(tenant_id, config_id) else: # create load balancing config - if name == '__inherit__': - raise ValueError('Invalid load balancing config name') + if name == "__inherit__": + raise ValueError("Invalid load balancing config name") # check duplicate name for current_load_balancing_config in current_load_balancing_configs: if current_load_balancing_config.name == name: - raise ValueError('Load balancing config name {} already exists'.format(name)) + raise ValueError("Load balancing config name {} already exists".format(name)) if not credentials: - raise ValueError('Invalid load balancing config credentials') + raise ValueError("Invalid load balancing config credentials") if not isinstance(credentials, dict): - raise ValueError('Invalid load balancing config credentials') + raise ValueError("Invalid load balancing config credentials") # validate custom provider config credentials = self._custom_credentials_validate( @@ -396,7 +398,7 @@ def update_load_balancing_configs(self, tenant_id: str, model_type=model_type, model=model, credentials=credentials, - validate=False + validate=False, ) # create load balancing config @@ -406,7 +408,7 @@ def update_load_balancing_configs(self, tenant_id: str, model_type=model_type.to_origin_model_type(), model_name=model, name=name, - encrypted_config=json.dumps(credentials) + encrypted_config=json.dumps(credentials), ) db.session.add(load_balancing_model_config) @@ -420,12 +422,15 @@ def update_load_balancing_configs(self, tenant_id: str, self._clear_credentials_cache(tenant_id, config_id) - def validate_load_balancing_credentials(self, tenant_id: str, - provider: str, - model: str, - model_type: str, - credentials: dict, - config_id: Optional[str] = None) -> None: + def validate_load_balancing_credentials( + self, + tenant_id: str, + provider: str, + model: str, + model_type: str, + credentials: dict, + config_id: Optional[str] = None, + ) -> None: """ Validate load balancing credentials. :param tenant_id: workspace id @@ -450,14 +455,17 @@ def validate_load_balancing_credentials(self, tenant_id: str, load_balancing_model_config = None if config_id: # Get load balancing config - load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \ + load_balancing_model_config = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == tenant_id, - LoadBalancingModelConfig.provider_name == provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model, - LoadBalancingModelConfig.id == config_id - ).first() + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + LoadBalancingModelConfig.id == config_id, + ) + .first() + ) if not load_balancing_model_config: raise ValueError(f"Load balancing config {config_id} does not exist.") @@ -469,16 +477,19 @@ def validate_load_balancing_credentials(self, tenant_id: str, model_type=model_type, model=model, credentials=credentials, - load_balancing_model_config=load_balancing_model_config + load_balancing_model_config=load_balancing_model_config, ) - def _custom_credentials_validate(self, tenant_id: str, - provider_configuration: ProviderConfiguration, - model_type: ModelType, - model: str, - credentials: dict, - load_balancing_model_config: Optional[LoadBalancingModelConfig] = None, - validate: bool = True) -> dict: + def _custom_credentials_validate( + self, + tenant_id: str, + provider_configuration: ProviderConfiguration, + model_type: ModelType, + model: str, + credentials: dict, + load_balancing_model_config: Optional[LoadBalancingModelConfig] = None, + validate: bool = True, + ) -> dict: """ Validate custom credentials. :param tenant_id: workspace id @@ -521,12 +532,11 @@ def _custom_credentials_validate(self, tenant_id: str, provider=provider_configuration.provider.provider, model_type=model_type, model=model, - credentials=credentials + credentials=credentials, ) else: credentials = model_provider_factory.provider_credentials_validate( - provider=provider_configuration.provider.provider, - credentials=credentials + provider=provider_configuration.provider.provider, credentials=credentials ) for key, value in credentials.items(): @@ -535,8 +545,9 @@ def _custom_credentials_validate(self, tenant_id: str, return credentials - def _get_credential_schema(self, provider_configuration: ProviderConfiguration) \ - -> ModelCredentialSchema | ProviderCredentialSchema: + def _get_credential_schema( + self, provider_configuration: ProviderConfiguration + ) -> ModelCredentialSchema | ProviderCredentialSchema: """ Get form schemas. :param provider_configuration: provider configuration @@ -558,9 +569,7 @@ def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None: :return: """ provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=tenant_id, - identity_id=config_id, - cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL + tenant_id=tenant_id, identity_id=config_id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL ) provider_model_credentials_cache.delete() diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 385af685f98f44..384a072b371fdd 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,6 +1,7 @@ import logging import mimetypes import os +from pathlib import Path from typing import Optional, cast import requests @@ -30,6 +31,7 @@ class ModelProviderService: """ Model Provider Service """ + def __init__(self) -> None: self.provider_manager = ProviderManager() @@ -72,8 +74,8 @@ def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> system_configuration=SystemConfigurationResponse( enabled=provider_configuration.system_configuration.enabled, current_quota_type=provider_configuration.system_configuration.current_quota_type, - quota_configurations=provider_configuration.system_configuration.quota_configurations - ) + quota_configurations=provider_configuration.system_configuration.quota_configurations, + ), ) provider_responses.append(provider_response) @@ -94,9 +96,9 @@ def get_models_by_provider(self, tenant_id: str, provider: str) -> list[ModelWit provider_configurations = self.provider_manager.get_configurations(tenant_id) # Get provider available models - return [ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models( - provider=provider - )] + return [ + ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider) + ] def get_provider_credentials(self, tenant_id: str, provider: str) -> dict: """ @@ -194,13 +196,12 @@ def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, # Get model custom credentials from ProviderModel if exists return provider_configuration.get_custom_model_credentials( - model_type=ModelType.value_of(model_type), - model=model, - obfuscated=True + model_type=ModelType.value_of(model_type), model=model, obfuscated=True ) - def model_credentials_validate(self, tenant_id: str, provider: str, model_type: str, model: str, - credentials: dict) -> None: + def model_credentials_validate( + self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict + ) -> None: """ validate model credentials. @@ -221,13 +222,12 @@ def model_credentials_validate(self, tenant_id: str, provider: str, model_type: # Validate model credentials provider_configuration.custom_model_credentials_validate( - model_type=ModelType.value_of(model_type), - model=model, - credentials=credentials + model_type=ModelType.value_of(model_type), model=model, credentials=credentials ) - def save_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, - credentials: dict) -> None: + def save_model_credentials( + self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict + ) -> None: """ save model credentials. @@ -248,9 +248,7 @@ def save_model_credentials(self, tenant_id: str, provider: str, model_type: str, # Add or update custom model credentials provider_configuration.add_or_update_custom_model_credentials( - model_type=ModelType.value_of(model_type), - model=model, - credentials=credentials + model_type=ModelType.value_of(model_type), model=model, credentials=credentials ) def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None: @@ -272,10 +270,7 @@ def remove_model_credentials(self, tenant_id: str, provider: str, model_type: st raise ValueError(f"Provider {provider} does not exist.") # Remove custom model credentials - provider_configuration.delete_custom_model_credentials( - model_type=ModelType.value_of(model_type), - model=model - ) + provider_configuration.delete_custom_model_credentials(model_type=ModelType.value_of(model_type), model=model) def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]: """ @@ -289,9 +284,7 @@ def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[Prov provider_configurations = self.provider_manager.get_configurations(tenant_id) # Get provider available models - models = provider_configurations.get_models( - model_type=ModelType.value_of(model_type) - ) + models = provider_configurations.get_models(model_type=ModelType.value_of(model_type)) # Group models by provider provider_models = {} @@ -322,16 +315,19 @@ def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[Prov icon_small=first_model.provider.icon_small, icon_large=first_model.provider.icon_large, status=CustomConfigurationStatus.ACTIVE, - models=[ProviderModelWithStatusEntity( - model=model.model, - label=model.label, - model_type=model.model_type, - features=model.features, - fetch_from=model.fetch_from, - model_properties=model.model_properties, - status=model.status, - load_balancing_enabled=model.load_balancing_enabled - ) for model in models] + models=[ + ProviderModelWithStatusEntity( + model=model.model, + label=model.label, + model_type=model.model_type, + features=model.features, + fetch_from=model.fetch_from, + model_properties=model.model_properties, + status=model.status, + load_balancing_enabled=model.load_balancing_enabled, + ) + for model in models + ], ) ) @@ -360,19 +356,13 @@ def get_model_parameter_rules(self, tenant_id: str, provider: str, model: str) - model_type_instance = cast(LargeLanguageModel, model_type_instance) # fetch credentials - credentials = provider_configuration.get_current_credentials( - model_type=ModelType.LLM, - model=model - ) + credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model) if not credentials: return [] # Call get_parameter_rules method of model instance to get model parameter rules - return model_type_instance.get_parameter_rules( - model=model, - credentials=credentials - ) + return model_type_instance.get_parameter_rules(model=model, credentials=credentials) def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]: """ @@ -383,22 +373,26 @@ def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Op :return: """ model_type_enum = ModelType.value_of(model_type) - result = self.provider_manager.get_default_model( - tenant_id=tenant_id, - model_type=model_type_enum - ) - - return DefaultModelResponse( - model=result.model, - model_type=result.model_type, - provider=SimpleProviderEntityResponse( - provider=result.provider.provider, - label=result.provider.label, - icon_small=result.provider.icon_small, - icon_large=result.provider.icon_large, - supported_model_types=result.provider.supported_model_types + result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum) + try: + return ( + DefaultModelResponse( + model=result.model, + model_type=result.model_type, + provider=SimpleProviderEntityResponse( + provider=result.provider.provider, + label=result.provider.label, + icon_small=result.provider.icon_small, + icon_large=result.provider.icon_large, + supported_model_types=result.provider.supported_model_types, + ), + ) + if result + else None ) - ) if result else None + except Exception as e: + logger.info(f"get_default_model_of_model_type error: {e}") + return None def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None: """ @@ -412,13 +406,12 @@ def update_default_model_of_model_type(self, tenant_id: str, model_type: str, pr """ model_type_enum = ModelType.value_of(model_type) self.provider_manager.update_default_model_record( - tenant_id=tenant_id, - model_type=model_type_enum, - provider=provider, - model=model + tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model ) - def get_model_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[Optional[bytes], Optional[str]]: + def get_model_provider_icon( + self, provider: str, icon_type: str, lang: str + ) -> tuple[Optional[bytes], Optional[str]]: """ get model provider icon. @@ -430,11 +423,11 @@ def get_model_provider_icon(self, provider: str, icon_type: str, lang: str) -> t provider_instance = model_provider_factory.get_provider_instance(provider) provider_schema = provider_instance.get_provider_schema() - if icon_type.lower() == 'icon_small': + if icon_type.lower() == "icon_small": if not provider_schema.icon_small: raise ValueError(f"Provider {provider} does not have small icon.") - if lang.lower() == 'zh_hans': + if lang.lower() == "zh_hans": file_name = provider_schema.icon_small.zh_Hans else: file_name = provider_schema.icon_small.en_US @@ -442,13 +435,15 @@ def get_model_provider_icon(self, provider: str, icon_type: str, lang: str) -> t if not provider_schema.icon_large: raise ValueError(f"Provider {provider} does not have large icon.") - if lang.lower() == 'zh_hans': + if lang.lower() == "zh_hans": file_name = provider_schema.icon_large.zh_Hans else: file_name = provider_schema.icon_large.en_US root_path = current_app.root_path - provider_instance_path = os.path.dirname(os.path.join(root_path, provider_instance.__class__.__module__.replace('.', '/'))) + provider_instance_path = os.path.dirname( + os.path.join(root_path, provider_instance.__class__.__module__.replace(".", "/")) + ) file_path = os.path.join(provider_instance_path, "_assets") file_path = os.path.join(file_path, file_name) @@ -456,12 +451,11 @@ def get_model_provider_icon(self, provider: str, icon_type: str, lang: str) -> t return None, None mimetype, _ = mimetypes.guess_type(file_path) - mimetype = mimetype or 'application/octet-stream' + mimetype = mimetype or "application/octet-stream" # read binary from file - with open(file_path, 'rb') as f: - byte_data = f.read() - return byte_data, mimetype + byte_data = Path(file_path).read_bytes() + return byte_data, mimetype def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None: """ @@ -505,10 +499,7 @@ def enable_model(self, tenant_id: str, provider: str, model: str, model_type: st raise ValueError(f"Provider {provider} does not exist.") # Enable model - provider_configuration.enable_model( - model=model, - model_type=ModelType.value_of(model_type) - ) + provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type)) def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: """ @@ -529,78 +520,49 @@ def disable_model(self, tenant_id: str, provider: str, model: str, model_type: s raise ValueError(f"Provider {provider} does not exist.") # Enable model - provider_configuration.disable_model( - model=model, - model_type=ModelType.value_of(model_type) - ) + provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type)) def free_quota_submit(self, tenant_id: str, provider: str): api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") - api_url = api_base_url + '/api/v1/providers/apply' + api_url = api_base_url + "/api/v1/providers/apply" - headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {api_key}" - } - response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider}) + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + response = requests.post(api_url, headers=headers, json={"workspace_id": tenant_id, "provider_name": provider}) if not response.ok: logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ") raise ValueError(f"Error: {response.status_code} ") - if response.json()["code"] != 'success': - raise ValueError( - f"error: {response.json()['message']}" - ) + if response.json()["code"] != "success": + raise ValueError(f"error: {response.json()['message']}") rst = response.json() - if rst['type'] == 'redirect': - return { - 'type': rst['type'], - 'redirect_url': rst['redirect_url'] - } + if rst["type"] == "redirect": + return {"type": rst["type"], "redirect_url": rst["redirect_url"]} else: - return { - 'type': rst['type'], - 'result': 'success' - } + return {"type": rst["type"], "result": "success"} def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]): api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") - api_url = api_base_url + '/api/v1/providers/qualification-verify' + api_url = api_base_url + "/api/v1/providers/qualification-verify" - headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {api_key}" - } - json_data = {'workspace_id': tenant_id, 'provider_name': provider} + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + json_data = {"workspace_id": tenant_id, "provider_name": provider} if token: - json_data['token'] = token - response = requests.post(api_url, headers=headers, - json=json_data) + json_data["token"] = token + response = requests.post(api_url, headers=headers, json=json_data) if not response.ok: logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ") raise ValueError(f"Error: {response.status_code} ") rst = response.json() - if rst["code"] != 'success': - raise ValueError( - f"error: {rst['message']}" - ) + if rst["code"] != "success": + raise ValueError(f"error: {rst['message']}") - data = rst['data'] - if data['qualified'] is True: - return { - 'result': 'success', - 'provider_name': provider, - 'flag': True - } + data = rst["data"] + if data["qualified"] is True: + return {"result": "success", "provider_name": provider, "flag": True} else: - return { - 'result': 'success', - 'provider_name': provider, - 'flag': False, - 'reason': data['reason'] - } + return {"result": "success", "provider_name": provider, "flag": False, "reason": data["reason"]} diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py index d472f8cfbca435..dfb21e767fc9b9 100644 --- a/api/services/moderation_service.py +++ b/api/services/moderation_service.py @@ -4,17 +4,18 @@ class ModerationService: - def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult: app_model_config: AppModelConfig = None - app_model_config = db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() + app_model_config = ( + db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() + ) if not app_model_config: raise ValueError("app model config not found") - name = app_model_config.sensitive_word_avoidance_dict['type'] - config = app_model_config.sensitive_word_avoidance_dict['config'] + name = app_model_config.sensitive_word_avoidance_dict["type"] + config = app_model_config.sensitive_word_avoidance_dict["config"] moderation = ModerationFactory(name, app_id, app_model.tenant_id, config) return moderation.moderation_for_outputs(text) diff --git a/api/services/operation_service.py b/api/services/operation_service.py index 39f249dc24eb09..8c8b64bcd5d344 100644 --- a/api/services/operation_service.py +++ b/api/services/operation_service.py @@ -4,15 +4,12 @@ class OperationService: - base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL') - secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY') + base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") + secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY") @classmethod def _send_request(cls, method, endpoint, json=None, params=None): - headers = { - "Content-Type": "application/json", - "Billing-Api-Secret-Key": cls.secret_key - } + headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" response = requests.request(method, url, json=json, params=params, headers=headers) @@ -22,11 +19,11 @@ def _send_request(cls, method, endpoint, json=None, params=None): @classmethod def record_utm(cls, tenant_id: str, utm_info: dict): params = { - 'tenant_id': tenant_id, - 'utm_source': utm_info.get('utm_source', ''), - 'utm_medium': utm_info.get('utm_medium', ''), - 'utm_campaign': utm_info.get('utm_campaign', ''), - 'utm_content': utm_info.get('utm_content', ''), - 'utm_term': utm_info.get('utm_term', '') + "tenant_id": tenant_id, + "utm_source": utm_info.get("utm_source", ""), + "utm_medium": utm_info.get("utm_medium", ""), + "utm_campaign": utm_info.get("utm_campaign", ""), + "utm_content": utm_info.get("utm_content", ""), + "utm_term": utm_info.get("utm_term", ""), } - return cls._send_request('POST', '/tenant_utms', params=params) + return cls._send_request("POST", "/tenant_utms", params=params) diff --git a/api/services/ops_service.py b/api/services/ops_service.py index ffc12a9acdb42c..1160a1f2751d74 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -12,20 +12,49 @@ def get_tracing_app_config(cls, app_id: str, tracing_provider: str): :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + trace_config_data: TraceAppConfig = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if not trace_config_data: return None # decrypt_token and obfuscated_token tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id - decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(tenant_id, tracing_provider, trace_config_data.tracing_config) - decrypt_tracing_config = OpsTraceManager.obfuscated_decrypt_token(tracing_provider, decrypt_tracing_config) - - trace_config_data.tracing_config = decrypt_tracing_config - + decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config( + tenant_id, tracing_provider, trace_config_data.tracing_config + ) + new_decrypt_tracing_config = OpsTraceManager.obfuscated_decrypt_token(tracing_provider, decrypt_tracing_config) + + if tracing_provider == "langfuse" and ( + "project_key" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_key") + ): + try: + project_key = OpsTraceManager.get_trace_config_project_key(decrypt_tracing_config, tracing_provider) + new_decrypt_tracing_config.update( + { + "project_url": "{host}/project/{key}".format( + host=decrypt_tracing_config.get("host"), key=project_key + ) + } + ) + except Exception: + new_decrypt_tracing_config.update( + {"project_url": "{host}/".format(host=decrypt_tracing_config.get("host"))} + ) + + if tracing_provider == "langsmith" and ( + "project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url") + ): + try: + project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider) + new_decrypt_tracing_config.update({"project_url": project_url}) + except Exception: + new_decrypt_tracing_config.update({"project_url": "https://smith.langchain.com/"}) + + trace_config_data.tracing_config = new_decrypt_tracing_config return trace_config_data.to_dict() @classmethod @@ -37,11 +66,13 @@ def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c :param tracing_config: tracing config :return: """ - if tracing_provider not in provider_config_map.keys() and tracing_provider != None: + if tracing_provider not in provider_config_map and tracing_provider: return {"error": f"Invalid tracing provider: {tracing_provider}"} - config_class, other_keys = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['other_keys'] + config_class, other_keys = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["other_keys"], + ) default_config_instance = config_class(**tracing_config) for key in other_keys: if key in tracing_config and tracing_config[key] == "": @@ -51,10 +82,21 @@ def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c if not OpsTraceManager.check_trace_config_is_effective(tracing_config, tracing_provider): return {"error": "Invalid Credentials"} + # get project url + if tracing_provider == "langfuse": + project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider) + project_url = "{host}/project/{key}".format(host=tracing_config.get("host"), key=project_key) + elif tracing_provider == "langsmith": + project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider) + else: + project_url = None + # check if trace config already exists - trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + trace_config_data: TraceAppConfig = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if trace_config_data: return None @@ -62,6 +104,8 @@ def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c # get tenant id tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config) + if project_url: + tracing_config["project_url"] = project_url trace_config_data = TraceAppConfig( app_id=app_id, tracing_provider=tracing_provider, @@ -81,13 +125,15 @@ def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_c :param tracing_config: tracing config :return: """ - if tracing_provider not in provider_config_map.keys(): + if tracing_provider not in provider_config_map: raise ValueError(f"Invalid tracing provider: {tracing_provider}") # check if trace config already exists - current_trace_config = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + current_trace_config = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if not current_trace_config: return None @@ -117,9 +163,11 @@ def delete_tracing_app_config(cls, app_id: str, tracing_provider: str): :param tracing_provider: tracing provider :return: """ - trace_config = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + trace_config = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if not trace_config: return None diff --git a/api/services/recommend_app/__init__.py b/api/services/recommend_app/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/services/recommend_app/buildin/__init__.py b/api/services/recommend_app/buildin/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/services/recommend_app/buildin/buildin_retrieval.py b/api/services/recommend_app/buildin/buildin_retrieval.py new file mode 100644 index 00000000000000..4704d533a950ed --- /dev/null +++ b/api/services/recommend_app/buildin/buildin_retrieval.py @@ -0,0 +1,64 @@ +import json +from os import path +from pathlib import Path +from typing import Optional + +from flask import current_app + +from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase +from services.recommend_app.recommend_app_type import RecommendAppType + + +class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase): + """ + Retrieval recommended app from buildin, the location is constants/recommended_apps.json + """ + + builtin_data: Optional[dict] = None + + def get_type(self) -> str: + return RecommendAppType.BUILDIN + + def get_recommended_apps_and_categories(self, language: str) -> dict: + result = self.fetch_recommended_apps_from_builtin(language) + return result + + def get_recommend_app_detail(self, app_id: str): + result = self.fetch_recommended_app_detail_from_builtin(app_id) + return result + + @classmethod + def _get_builtin_data(cls) -> dict: + """ + Get builtin data. + :return: + """ + if cls.builtin_data: + return cls.builtin_data + + root_path = current_app.root_path + cls.builtin_data = json.loads( + Path(path.join(root_path, "constants", "recommended_apps.json")).read_text(encoding="utf-8") + ) + + return cls.builtin_data + + @classmethod + def fetch_recommended_apps_from_builtin(cls, language: str) -> dict: + """ + Fetch recommended apps from builtin. + :param language: language + :return: + """ + builtin_data = cls._get_builtin_data() + return builtin_data.get("recommended_apps", {}).get(language) + + @classmethod + def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict]: + """ + Fetch recommended app detail from builtin. + :param app_id: App ID + :return: + """ + builtin_data = cls._get_builtin_data() + return builtin_data.get("app_details", {}).get(app_id) diff --git a/api/services/recommend_app/database/__init__.py b/api/services/recommend_app/database/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py new file mode 100644 index 00000000000000..995d3755bb5b10 --- /dev/null +++ b/api/services/recommend_app/database/database_retrieval.py @@ -0,0 +1,111 @@ +from typing import Optional + +from constants.languages import languages +from extensions.ext_database import db +from models.model import App, RecommendedApp +from services.app_dsl_service import AppDslService +from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase +from services.recommend_app.recommend_app_type import RecommendAppType + + +class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): + """ + Retrieval recommended app from database + """ + + def get_recommended_apps_and_categories(self, language: str) -> dict: + result = self.fetch_recommended_apps_from_db(language) + return result + + def get_recommend_app_detail(self, app_id: str): + result = self.fetch_recommended_app_detail_from_db(app_id) + return result + + def get_type(self) -> str: + return RecommendAppType.DATABASE + + @classmethod + def fetch_recommended_apps_from_db(cls, language: str) -> dict: + """ + Fetch recommended apps from db. + :param language: language + :return: + """ + recommended_apps = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.language == language) + .all() + ) + + if len(recommended_apps) == 0: + recommended_apps = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) + .all() + ) + + categories = set() + recommended_apps_result = [] + for recommended_app in recommended_apps: + app = recommended_app.app + if not app or not app.is_public: + continue + + site = app.site + if not site: + continue + + recommended_app_result = { + "id": recommended_app.id, + "app": { + "id": app.id, + "name": app.name, + "mode": app.mode, + "icon": app.icon, + "icon_background": app.icon_background, + }, + "app_id": recommended_app.app_id, + "description": site.description, + "copyright": site.copyright, + "privacy_policy": site.privacy_policy, + "custom_disclaimer": site.custom_disclaimer, + "category": recommended_app.category, + "position": recommended_app.position, + "is_listed": recommended_app.is_listed, + } + recommended_apps_result.append(recommended_app_result) + + categories.add(recommended_app.category) + + return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)} + + @classmethod + def fetch_recommended_app_detail_from_db(cls, app_id: str) -> Optional[dict]: + """ + Fetch recommended app detail from db. + :param app_id: App ID + :return: + """ + # is in public recommended list + recommended_app = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id) + .first() + ) + + if not recommended_app: + return None + + # get app detail + app_model = db.session.query(App).filter(App.id == app_id).first() + if not app_model or not app_model.is_public: + return None + + return { + "id": app_model.id, + "name": app_model.name, + "icon": app_model.icon, + "icon_background": app_model.icon_background, + "mode": app_model.mode, + "export_data": AppDslService.export_dsl(app_model=app_model), + } diff --git a/api/services/recommend_app/recommend_app_base.py b/api/services/recommend_app/recommend_app_base.py new file mode 100644 index 00000000000000..00c037710e869c --- /dev/null +++ b/api/services/recommend_app/recommend_app_base.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod + + +class RecommendAppRetrievalBase(ABC): + """Interface for recommend app retrieval.""" + + @abstractmethod + def get_recommended_apps_and_categories(self, language: str) -> dict: + raise NotImplementedError + + @abstractmethod + def get_recommend_app_detail(self, app_id: str): + raise NotImplementedError + + @abstractmethod + def get_type(self) -> str: + raise NotImplementedError diff --git a/api/services/recommend_app/recommend_app_factory.py b/api/services/recommend_app/recommend_app_factory.py new file mode 100644 index 00000000000000..e53667c0b06dd6 --- /dev/null +++ b/api/services/recommend_app/recommend_app_factory.py @@ -0,0 +1,23 @@ +from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval +from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval +from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase +from services.recommend_app.recommend_app_type import RecommendAppType +from services.recommend_app.remote.remote_retrieval import RemoteRecommendAppRetrieval + + +class RecommendAppRetrievalFactory: + @staticmethod + def get_recommend_app_factory(mode: str) -> type[RecommendAppRetrievalBase]: + match mode: + case RecommendAppType.REMOTE: + return RemoteRecommendAppRetrieval + case RecommendAppType.DATABASE: + return DatabaseRecommendAppRetrieval + case RecommendAppType.BUILDIN: + return BuildInRecommendAppRetrieval + case _: + raise ValueError(f"invalid fetch recommended apps mode: {mode}") + + @staticmethod + def get_buildin_recommend_app_retrieval(): + return BuildInRecommendAppRetrieval diff --git a/api/services/recommend_app/recommend_app_type.py b/api/services/recommend_app/recommend_app_type.py new file mode 100644 index 00000000000000..7ea93b3f64b1d4 --- /dev/null +++ b/api/services/recommend_app/recommend_app_type.py @@ -0,0 +1,7 @@ +from enum import Enum + + +class RecommendAppType(str, Enum): + REMOTE = "remote" + BUILDIN = "builtin" + DATABASE = "db" diff --git a/api/services/recommend_app/remote/__init__.py b/api/services/recommend_app/remote/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/services/recommend_app/remote/remote_retrieval.py b/api/services/recommend_app/remote/remote_retrieval.py new file mode 100644 index 00000000000000..b0607a21323acb --- /dev/null +++ b/api/services/recommend_app/remote/remote_retrieval.py @@ -0,0 +1,71 @@ +import logging +from typing import Optional + +import requests + +from configs import dify_config +from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval +from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase +from services.recommend_app.recommend_app_type import RecommendAppType + +logger = logging.getLogger(__name__) + + +class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase): + """ + Retrieval recommended app from dify official + """ + + def get_recommend_app_detail(self, app_id: str): + try: + result = self.fetch_recommended_app_detail_from_dify_official(app_id) + except Exception as e: + logger.warning(f"fetch recommended app detail from dify official failed: {e}, switch to built-in.") + result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin(app_id) + return result + + def get_recommended_apps_and_categories(self, language: str) -> dict: + try: + result = self.fetch_recommended_apps_from_dify_official(language) + except Exception as e: + logger.warning(f"fetch recommended apps from dify official failed: {e}, switch to built-in.") + result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin(language) + return result + + def get_type(self) -> str: + return RecommendAppType.REMOTE + + @classmethod + def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Optional[dict]: + """ + Fetch recommended app detail from dify official. + :param app_id: App ID + :return: + """ + domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN + url = f"{domain}/apps/{app_id}" + response = requests.get(url, timeout=(3, 10)) + if response.status_code != 200: + return None + + return response.json() + + @classmethod + def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: + """ + Fetch recommended apps from dify official. + :param language: language + :return: + """ + domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN + url = f"{domain}/apps?language={language}" + response = requests.get(url, timeout=(3, 10)) + if response.status_code != 200: + raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}") + + result = response.json() + + if "categories" in result: + result["categories"] = sorted(result["categories"]) + + return result diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 1c1c5be17c64b3..4660316fcfcf71 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -1,24 +1,10 @@ -import json -import logging -from os import path from typing import Optional -import requests -from flask import current_app - from configs import dify_config -from constants.languages import languages -from extensions.ext_database import db -from models.model import App, RecommendedApp -from services.app_dsl_service import AppDslService - -logger = logging.getLogger(__name__) +from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory class RecommendedAppService: - - builtin_data: Optional[dict] = None - @classmethod def get_recommended_apps_and_categories(cls, language: str) -> dict: """ @@ -27,107 +13,17 @@ def get_recommended_apps_and_categories(cls, language: str) -> dict: :return: """ mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE - if mode == 'remote': - try: - result = cls._fetch_recommended_apps_from_dify_official(language) - except Exception as e: - logger.warning(f'fetch recommended apps from dify official failed: {e}, switch to built-in.') - result = cls._fetch_recommended_apps_from_builtin(language) - elif mode == 'db': - result = cls._fetch_recommended_apps_from_db(language) - elif mode == 'builtin': - result = cls._fetch_recommended_apps_from_builtin(language) - else: - raise ValueError(f'invalid fetch recommended apps mode: {mode}') - - if not result.get('recommended_apps') and language != 'en-US': - result = cls._fetch_recommended_apps_from_builtin('en-US') + retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)() + result = retrieval_instance.get_recommended_apps_and_categories(language) + if not result.get("recommended_apps") and language != "en-US": + result = ( + RecommendAppRetrievalFactory.get_buildin_recommend_app_retrieval().fetch_recommended_apps_from_builtin( + "en-US" + ) + ) return result - @classmethod - def _fetch_recommended_apps_from_db(cls, language: str) -> dict: - """ - Fetch recommended apps from db. - :param language: language - :return: - """ - recommended_apps = db.session.query(RecommendedApp).filter( - RecommendedApp.is_listed == True, - RecommendedApp.language == language - ).all() - - if len(recommended_apps) == 0: - recommended_apps = db.session.query(RecommendedApp).filter( - RecommendedApp.is_listed == True, - RecommendedApp.language == languages[0] - ).all() - - categories = set() - recommended_apps_result = [] - for recommended_app in recommended_apps: - app = recommended_app.app - if not app or not app.is_public: - continue - - site = app.site - if not site: - continue - - recommended_app_result = { - 'id': recommended_app.id, - 'app': { - 'id': app.id, - 'name': app.name, - 'mode': app.mode, - 'icon': app.icon, - 'icon_background': app.icon_background - }, - 'app_id': recommended_app.app_id, - 'description': site.description, - 'copyright': site.copyright, - 'privacy_policy': site.privacy_policy, - 'custom_disclaimer': site.custom_disclaimer, - 'category': recommended_app.category, - 'position': recommended_app.position, - 'is_listed': recommended_app.is_listed - } - recommended_apps_result.append(recommended_app_result) - - categories.add(recommended_app.category) # add category to categories - - return {'recommended_apps': recommended_apps_result, 'categories': sorted(categories)} - - @classmethod - def _fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: - """ - Fetch recommended apps from dify official. - :param language: language - :return: - """ - domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN - url = f'{domain}/apps?language={language}' - response = requests.get(url, timeout=(3, 10)) - if response.status_code != 200: - raise ValueError(f'fetch recommended apps failed, status code: {response.status_code}') - - result = response.json() - - if "categories" in result: - result["categories"] = sorted(result["categories"]) - - return result - - @classmethod - def _fetch_recommended_apps_from_builtin(cls, language: str) -> dict: - """ - Fetch recommended apps from builtin. - :param language: language - :return: - """ - builtin_data = cls._get_builtin_data() - return builtin_data.get('recommended_apps', {}).get(language) - @classmethod def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]: """ @@ -136,120 +32,6 @@ def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]: :return: """ mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE - if mode == 'remote': - try: - result = cls._fetch_recommended_app_detail_from_dify_official(app_id) - except Exception as e: - logger.warning(f'fetch recommended app detail from dify official failed: {e}, switch to built-in.') - result = cls._fetch_recommended_app_detail_from_builtin(app_id) - elif mode == 'db': - result = cls._fetch_recommended_app_detail_from_db(app_id) - elif mode == 'builtin': - result = cls._fetch_recommended_app_detail_from_builtin(app_id) - else: - raise ValueError(f'invalid fetch recommended app detail mode: {mode}') - + retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)() + result = retrieval_instance.get_recommend_app_detail(app_id) return result - - @classmethod - def _fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Optional[dict]: - """ - Fetch recommended app detail from dify official. - :param app_id: App ID - :return: - """ - domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN - url = f'{domain}/apps/{app_id}' - response = requests.get(url, timeout=(3, 10)) - if response.status_code != 200: - return None - - return response.json() - - @classmethod - def _fetch_recommended_app_detail_from_db(cls, app_id: str) -> Optional[dict]: - """ - Fetch recommended app detail from db. - :param app_id: App ID - :return: - """ - # is in public recommended list - recommended_app = db.session.query(RecommendedApp).filter( - RecommendedApp.is_listed == True, - RecommendedApp.app_id == app_id - ).first() - - if not recommended_app: - return None - - # get app detail - app_model = db.session.query(App).filter(App.id == app_id).first() - if not app_model or not app_model.is_public: - return None - - return { - 'id': app_model.id, - 'name': app_model.name, - 'icon': app_model.icon, - 'icon_background': app_model.icon_background, - 'mode': app_model.mode, - 'export_data': AppDslService.export_dsl(app_model=app_model) - } - - @classmethod - def _fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict]: - """ - Fetch recommended app detail from builtin. - :param app_id: App ID - :return: - """ - builtin_data = cls._get_builtin_data() - return builtin_data.get('app_details', {}).get(app_id) - - @classmethod - def _get_builtin_data(cls) -> dict: - """ - Get builtin data. - :return: - """ - if cls.builtin_data: - return cls.builtin_data - - root_path = current_app.root_path - with open(path.join(root_path, 'constants', 'recommended_apps.json'), encoding='utf-8') as f: - json_data = f.read() - data = json.loads(json_data) - cls.builtin_data = data - - return cls.builtin_data - - @classmethod - def fetch_all_recommended_apps_and_export_datas(cls): - """ - Fetch all recommended apps and export datas - :return: - """ - templates = { - "recommended_apps": {}, - "app_details": {} - } - for language in languages: - try: - result = cls._fetch_recommended_apps_from_dify_official(language) - except Exception as e: - logger.warning(f'fetch recommended apps from dify official failed: {e}, skip.') - continue - - templates['recommended_apps'][language] = result - - for recommended_app in result.get('recommended_apps'): - app_id = recommended_app.get('app_id') - - # get app detail - app_detail = cls._fetch_recommended_app_detail_from_dify_official(app_id) - if not app_detail: - continue - - templates['app_details'][app_id] = app_detail - - return templates diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index f1113c1505e197..9fe3cecce7546d 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -10,46 +10,48 @@ class SavedMessageService: @classmethod - def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - last_id: Optional[str], limit: int) -> InfiniteScrollPagination: - saved_messages = db.session.query(SavedMessage).filter( - SavedMessage.app_id == app_model.id, - SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - SavedMessage.created_by == user.id - ).order_by(SavedMessage.created_at.desc()).all() + def pagination_by_last_id( + cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int + ) -> InfiniteScrollPagination: + saved_messages = ( + db.session.query(SavedMessage) + .filter( + SavedMessage.app_id == app_model.id, + SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + SavedMessage.created_by == user.id, + ) + .order_by(SavedMessage.created_at.desc()) + .all() + ) message_ids = [sm.message_id for sm in saved_messages] return MessageService.pagination_by_last_id( - app_model=app_model, - user=user, - last_id=last_id, - limit=limit, - include_ids=message_ids + app_model=app_model, user=user, last_id=last_id, limit=limit, include_ids=message_ids ) @classmethod def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): - saved_message = db.session.query(SavedMessage).filter( - SavedMessage.app_id == app_model.id, - SavedMessage.message_id == message_id, - SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - SavedMessage.created_by == user.id - ).first() + saved_message = ( + db.session.query(SavedMessage) + .filter( + SavedMessage.app_id == app_model.id, + SavedMessage.message_id == message_id, + SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + SavedMessage.created_by == user.id, + ) + .first() + ) if saved_message: return - message = MessageService.get_message( - app_model=app_model, - user=user, - message_id=message_id - ) + message = MessageService.get_message(app_model=app_model, user=user, message_id=message_id) saved_message = SavedMessage( app_id=app_model.id, message_id=message.id, - created_by_role='account' if isinstance(user, Account) else 'end_user', - created_by=user.id + created_by_role="account" if isinstance(user, Account) else "end_user", + created_by=user.id, ) db.session.add(saved_message) @@ -57,12 +59,16 @@ def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_i @classmethod def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): - saved_message = db.session.query(SavedMessage).filter( - SavedMessage.app_id == app_model.id, - SavedMessage.message_id == message_id, - SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - SavedMessage.created_by == user.id - ).first() + saved_message = ( + db.session.query(SavedMessage) + .filter( + SavedMessage.app_id == app_model.id, + SavedMessage.message_id == message_id, + SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + SavedMessage.created_by == user.id, + ) + .first() + ) if not saved_message: return diff --git a/api/services/tag_service.py b/api/services/tag_service.py index d6eba38fbdef3e..a374bdcf002bef 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -1,4 +1,5 @@ import uuid +from typing import Optional from flask_login import current_user from sqlalchemy import func @@ -11,39 +12,33 @@ class TagService: @staticmethod - def get_tags(tag_type: str, current_tenant_id: str, keyword: str = None) -> list: - query = db.session.query( - Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label('binding_count') - ).outerjoin( - TagBinding, Tag.id == TagBinding.tag_id - ).filter( - Tag.type == tag_type, - Tag.tenant_id == current_tenant_id + def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None) -> list: + query = ( + db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count")) + .outerjoin(TagBinding, Tag.id == TagBinding.tag_id) + .filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id) ) if keyword: - query = query.filter(db.and_(Tag.name.ilike(f'%{keyword}%'))) - query = query.group_by( - Tag.id - ) + query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%"))) + query = query.group_by(Tag.id) results = query.order_by(Tag.created_at.desc()).all() return results @staticmethod def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list: - tags = db.session.query(Tag).filter( - Tag.id.in_(tag_ids), - Tag.tenant_id == current_tenant_id, - Tag.type == tag_type - ).all() + tags = ( + db.session.query(Tag) + .filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) + .all() + ) if not tags: return [] tag_ids = [tag.id for tag in tags] - tag_bindings = db.session.query( - TagBinding.target_id - ).filter( - TagBinding.tag_id.in_(tag_ids), - TagBinding.tenant_id == current_tenant_id - ).all() + tag_bindings = ( + db.session.query(TagBinding.target_id) + .filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id) + .all() + ) if not tag_bindings: return [] results = [tag_binding.target_id for tag_binding in tag_bindings] @@ -51,27 +46,28 @@ def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: li @staticmethod def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list: - tags = db.session.query(Tag).join( - TagBinding, - Tag.id == TagBinding.tag_id - ).filter( - TagBinding.target_id == target_id, - TagBinding.tenant_id == current_tenant_id, - Tag.tenant_id == current_tenant_id, - Tag.type == tag_type - ).all() - - return tags if tags else [] + tags = ( + db.session.query(Tag) + .join(TagBinding, Tag.id == TagBinding.tag_id) + .filter( + TagBinding.target_id == target_id, + TagBinding.tenant_id == current_tenant_id, + Tag.tenant_id == current_tenant_id, + Tag.type == tag_type, + ) + .all() + ) + return tags or [] @staticmethod def save_tags(args: dict) -> Tag: tag = Tag( id=str(uuid.uuid4()), - name=args['name'], - type=args['type'], + name=args["name"], + type=args["type"], created_by=current_user.id, - tenant_id=current_user.current_tenant_id + tenant_id=current_user.current_tenant_id, ) db.session.add(tag) db.session.commit() @@ -82,7 +78,7 @@ def update_tags(args: dict, tag_id: str) -> Tag: tag = db.session.query(Tag).filter(Tag.id == tag_id).first() if not tag: raise NotFound("Tag not found") - tag.name = args['name'] + tag.name = args["name"] db.session.commit() return tag @@ -107,20 +103,21 @@ def delete_tag(tag_id: str): @staticmethod def save_tag_binding(args): # check if target exists - TagService.check_target_exists(args['type'], args['target_id']) + TagService.check_target_exists(args["type"], args["target_id"]) # save tag binding - for tag_id in args['tag_ids']: - tag_binding = db.session.query(TagBinding).filter( - TagBinding.tag_id == tag_id, - TagBinding.target_id == args['target_id'] - ).first() + for tag_id in args["tag_ids"]: + tag_binding = ( + db.session.query(TagBinding) + .filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"]) + .first() + ) if tag_binding: continue new_tag_binding = TagBinding( tag_id=tag_id, - target_id=args['target_id'], + target_id=args["target_id"], tenant_id=current_user.current_tenant_id, - created_by=current_user.id + created_by=current_user.id, ) db.session.add(new_tag_binding) db.session.commit() @@ -128,34 +125,34 @@ def save_tag_binding(args): @staticmethod def delete_tag_binding(args): # check if target exists - TagService.check_target_exists(args['type'], args['target_id']) + TagService.check_target_exists(args["type"], args["target_id"]) # delete tag binding - tag_bindings = db.session.query(TagBinding).filter( - TagBinding.target_id == args['target_id'], - TagBinding.tag_id == (args['tag_id']) - ).first() + tag_bindings = ( + db.session.query(TagBinding) + .filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"])) + .first() + ) if tag_bindings: db.session.delete(tag_bindings) db.session.commit() - - @staticmethod def check_target_exists(type: str, target_id: str): - if type == 'knowledge': - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == current_user.current_tenant_id, - Dataset.id == target_id - ).first() + if type == "knowledge": + dataset = ( + db.session.query(Dataset) + .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id) + .first() + ) if not dataset: raise NotFound("Dataset not found") - elif type == 'app': - app = db.session.query(App).filter( - App.tenant_id == current_user.current_tenant_id, - App.id == target_id - ).first() + elif type == "app": + app = ( + db.session.query(App) + .filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id) + .first() + ) if not app: raise NotFound("App not found") else: raise NotFound("Invalid binding type") - diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index ecc065d521c703..4a938918550ab8 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -1,5 +1,6 @@ import json import logging +from typing import Optional from httpx import get @@ -29,111 +30,109 @@ class ApiToolManageService: @staticmethod def parser_api_schema(schema: str) -> list[ApiToolBundle]: """ - parse api schema to tool bundle + parse api schema to tool bundle """ try: warnings = {} try: tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings) except Exception as e: - raise ValueError(f'invalid schema: {str(e)}') - + raise ValueError(f"invalid schema: {str(e)}") + credentials_schema = [ ToolProviderCredentials( - name='auth_type', + name="auth_type", type=ToolProviderCredentials.CredentialsType.SELECT, required=True, - default='none', + default="none", options=[ - ToolCredentialsOption(value='none', label=I18nObject( - en_US='None', - zh_Hans='无' - )), - ToolCredentialsOption(value='api_key', label=I18nObject( - en_US='Api Key', - zh_Hans='Api Key' - )), + ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")), + ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")), ], - placeholder=I18nObject( - en_US='Select auth type', - zh_Hans='选择认证方式' - ) + placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"), ), ToolProviderCredentials( - name='api_key_header', + name="api_key_header", type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, required=False, - placeholder=I18nObject( - en_US='Enter api key header', - zh_Hans='输入 api key header,如:X-API-KEY' - ), - default='api_key', - help=I18nObject( - en_US='HTTP header name for api key', - zh_Hans='HTTP 头部字段名,用于传递 api key' - ) + placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"), + default="api_key", + help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"), ), ToolProviderCredentials( - name='api_key_value', + name="api_key_value", type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, required=False, - placeholder=I18nObject( - en_US='Enter api key', - zh_Hans='输入 api key' - ), - default='' + placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"), + default="", ), ] - return jsonable_encoder({ - 'schema_type': schema_type, - 'parameters_schema': tool_bundles, - 'credentials_schema': credentials_schema, - 'warning': warnings - }) + return jsonable_encoder( + { + "schema_type": schema_type, + "parameters_schema": tool_bundles, + "credentials_schema": credentials_schema, + "warning": warnings, + } + ) except Exception as e: - raise ValueError(f'invalid schema: {str(e)}') + raise ValueError(f"invalid schema: {str(e)}") @staticmethod - def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]: + def convert_schema_to_tool_bundles( + schema: str, extra_info: Optional[dict] = None + ) -> tuple[list[ApiToolBundle], str]: """ - convert schema to tool bundles + convert schema to tool bundles - :return: the list of tool bundles, description + :return: the list of tool bundles, description """ try: tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info) return tool_bundles except Exception as e: - raise ValueError(f'invalid schema: {str(e)}') + raise ValueError(f"invalid schema: {str(e)}") @staticmethod def create_api_tool_provider( - user_id: str, tenant_id: str, provider_name: str, icon: dict, credentials: dict, - schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str] + user_id: str, + tenant_id: str, + provider_name: str, + icon: dict, + credentials: dict, + schema_type: str, + schema: str, + privacy_policy: str, + custom_disclaimer: str, + labels: list[str], ): """ - create api tool provider + create api tool provider """ if schema_type not in [member.value for member in ApiProviderSchemaType]: - raise ValueError(f'invalid schema type {schema}') - + raise ValueError(f"invalid schema type {schema}") + # check if the provider exists - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .first() + ) if provider is not None: - raise ValueError(f'provider {provider_name} already exists') + raise ValueError(f"provider {provider_name} already exists") # parse openapi to tool bundle extra_info = {} # extra info like description will be set here tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - + if len(tool_bundles) > 100: - raise ValueError('the number of apis should be less than 100') + raise ValueError("the number of apis should be less than 100") # create db provider db_provider = ApiToolProvider( @@ -142,19 +141,19 @@ def create_api_tool_provider( name=provider_name, icon=json.dumps(icon), schema=schema, - description=extra_info.get('description', ''), + description=extra_info.get("description", ""), schema_type_str=schema_type, tools_str=json.dumps(jsonable_encoder(tool_bundles)), credentials_str={}, privacy_policy=privacy_policy, - custom_disclaimer=custom_disclaimer + custom_disclaimer=custom_disclaimer, ) - if 'auth_type' not in credentials: - raise ValueError('auth_type is required') + if "auth_type" not in credentials: + raise ValueError("auth_type is required") # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) # create provider entity provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) @@ -172,101 +171,114 @@ def create_api_tool_provider( # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) - return { 'result': 'success' } - + return {"result": "success"} + @staticmethod - def get_api_tool_provider_remote_schema( - user_id: str, tenant_id: str, url: str - ): + def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str): """ - get api tool provider remote schema + get api tool provider remote schema """ headers = { - "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0", + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko)" + " Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0", "Accept": "*/*", } try: response = get(url, headers=headers, timeout=10) if response.status_code != 200: - raise ValueError(f'Got status code {response.status_code}') + raise ValueError(f"Got status code {response.status_code}") schema = response.text # try to parse schema, avoid SSRF attack ApiToolManageService.parser_api_schema(schema) except Exception as e: - logger.error(f"parse api schema error: {str(e)}") - raise ValueError('invalid schema, please check the url you provided') - - return { - 'schema': schema - } + logger.exception(f"parse api schema error: {str(e)}") + raise ValueError("invalid schema, please check the url you provided") + + return {"schema": schema} @staticmethod - def list_api_tool_provider_tools( - user_id: str, tenant_id: str, provider: str - ) -> list[UserTool]: + def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]: """ - list api tool provider tools + list api tool provider tools """ - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider, + ) + .first() + ) if provider is None: - raise ValueError(f'you have not added provider {provider}') - + raise ValueError(f"you have not added provider {provider}") + controller = ToolTransformService.api_provider_to_controller(db_provider=provider) labels = ToolLabelManager.get_tool_labels(controller) - + return [ ToolTransformService.tool_to_user_tool( tool_bundle, labels=labels, - ) for tool_bundle in provider.tools + ) + for tool_bundle in provider.tools ] @staticmethod def update_api_tool_provider( - user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict, - schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str] + user_id: str, + tenant_id: str, + provider_name: str, + original_provider: str, + icon: dict, + credentials: dict, + schema_type: str, + schema: str, + privacy_policy: str, + custom_disclaimer: str, + labels: list[str], ): """ - update api tool provider + update api tool provider """ if schema_type not in [member.value for member in ApiProviderSchemaType]: - raise ValueError(f'invalid schema type {schema}') - + raise ValueError(f"invalid schema type {schema}") + # check if the provider exists - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == original_provider, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == original_provider, + ) + .first() + ) if provider is None: - raise ValueError(f'api provider {provider_name} does not exists') + raise ValueError(f"api provider {provider_name} does not exists") # parse openapi to tool bundle extra_info = {} # extra info like description will be set here tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - + # update db provider provider.name = provider_name provider.icon = json.dumps(icon) provider.schema = schema - provider.description = extra_info.get('description', '') + provider.description = extra_info.get("description", "") provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) provider.privacy_policy = privacy_policy provider.custom_disclaimer = custom_disclaimer - if 'auth_type' not in credentials: - raise ValueError('auth_type is required') + if "auth_type" not in credentials: + raise ValueError("auth_type is required") # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) # create provider entity provider_controller = ApiToolProviderController.from_db(provider, auth_type) @@ -295,84 +307,91 @@ def update_api_tool_provider( # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) - return { 'result': 'success' } - + return {"result": "success"} + @staticmethod - def delete_api_tool_provider( - user_id: str, tenant_id: str, provider_name: str - ): + def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str): """ - delete tool provider + delete tool provider """ - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .first() + ) if provider is None: - raise ValueError(f'you have not added provider {provider_name}') - + raise ValueError(f"you have not added provider {provider_name}") + db.session.delete(provider) db.session.commit() - return { 'result': 'success' } - + return {"result": "success"} + @staticmethod - def get_api_tool_provider( - user_id: str, tenant_id: str, provider: str - ): + def get_api_tool_provider(user_id: str, tenant_id: str, provider: str): """ - get api tool provider + get api tool provider """ return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id) - + @staticmethod def test_api_tool_preview( - tenant_id: str, + tenant_id: str, provider_name: str, - tool_name: str, - credentials: dict, - parameters: dict, - schema_type: str, - schema: str + tool_name: str, + credentials: dict, + parameters: dict, + schema_type: str, + schema: str, ): """ - test api tool before adding api tool provider + test api tool before adding api tool provider """ if schema_type not in [member.value for member in ApiProviderSchemaType]: - raise ValueError(f'invalid schema type {schema_type}') - + raise ValueError(f"invalid schema type {schema_type}") + try: tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema) except Exception as e: - raise ValueError('invalid schema') - + raise ValueError("invalid schema") + # get tool bundle tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None) if tool_bundle is None: - raise ValueError(f'invalid tool name {tool_name}') - - db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, - ).first() + raise ValueError(f"invalid tool name {tool_name}") + + db_provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .first() + ) if not db_provider: # create a fake db provider db_provider = ApiToolProvider( - tenant_id='', user_id='', name='', icon='', + tenant_id="", + user_id="", + name="", + icon="", schema=schema, - description='', + description="", schema_type_str=ApiProviderSchemaType.OPENAPI.value, tools_str=json.dumps(jsonable_encoder(tool_bundles)), credentials_str=json.dumps(credentials), ) - if 'auth_type' not in credentials: - raise ValueError('auth_type is required') + if "auth_type" not in credentials: + raise ValueError("auth_type is required") # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) # create provider entity provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) @@ -381,10 +400,7 @@ def test_api_tool_preview( # decrypt credentials if db_provider.id: - tool_configuration = ToolConfigurationManager( - tenant_id=tenant_id, - provider_controller=provider_controller - ) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) # check if the credential has changed, save the original credential masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) @@ -396,27 +412,27 @@ def test_api_tool_preview( provider_controller.validate_credentials_format(credentials) # get tool tool = provider_controller.get_tool(tool_name) - tool = tool.fork_tool_runtime(runtime={ - 'credentials': credentials, - 'tenant_id': tenant_id, - }) + tool = tool.fork_tool_runtime( + runtime={ + "credentials": credentials, + "tenant_id": tenant_id, + } + ) result = tool.validate_credentials(credentials, parameters) except Exception as e: - return { 'error': str(e) } - - return { 'result': result or 'empty response' } - + return {"error": str(e)} + + return {"result": result or "empty response"} + @staticmethod - def list_api_tools( - user_id: str, tenant_id: str - ) -> list[UserToolProvider]: + def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: """ - list api tools + list api tools """ # get all api providers - db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id - ).all() or [] + db_providers: list[ApiToolProvider] = ( + db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or [] + ) result: list[UserToolProvider] = [] @@ -425,26 +441,21 @@ def list_api_tools( provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider) labels = ToolLabelManager.get_tool_labels(provider_controller) user_provider = ToolTransformService.api_provider_to_user_provider( - provider_controller, - db_provider=provider, - decrypt_credentials=True + provider_controller, db_provider=provider, decrypt_credentials=True ) user_provider.labels = labels # add icon ToolTransformService.repack_provider(user_provider) - tools = provider_controller.get_tools( - user_id=user_id, tenant_id=tenant_id - ) + tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id) for tool in tools: - user_provider.tools.append(ToolTransformService.tool_to_user_tool( - tenant_id=tenant_id, - tool=tool, - credentials=user_provider.original_credentials, - labels=labels - )) + user_provider.tools.append( + ToolTransformService.tool_to_user_tool( + tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels + ) + ) result.append(user_provider) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index ea6ecf0c6985f5..e2e49d017ef167 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -1,6 +1,9 @@ import json import logging +from pathlib import Path +from configs import dify_config +from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.api_entities import UserTool, UserToolProvider from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError @@ -18,21 +21,25 @@ class BuiltinToolManageService: @staticmethod - def list_builtin_tool_provider_tools( - user_id: str, tenant_id: str, provider: str - ) -> list[UserTool]: + def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]: """ - list builtin tool provider tools + list builtin tool provider tools """ provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider) tools = provider_controller.get_tools() - tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + tool_provider_configurations = ToolConfigurationManager( + tenant_id=tenant_id, provider_controller=provider_controller + ) # check if user has added the provider - builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider, - ).first() + builtin_provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ) + .first() + ) credentials = {} if builtin_provider is not None: @@ -42,47 +49,47 @@ def list_builtin_tool_provider_tools( result = [] for tool in tools: - result.append(ToolTransformService.tool_to_user_tool( - tool=tool, - credentials=credentials, - tenant_id=tenant_id, - labels=ToolLabelManager.get_tool_labels(provider_controller) - )) + result.append( + ToolTransformService.tool_to_user_tool( + tool=tool, + credentials=credentials, + tenant_id=tenant_id, + labels=ToolLabelManager.get_tool_labels(provider_controller), + ) + ) return result - + @staticmethod - def list_builtin_provider_credentials_schema( - provider_name - ): + def list_builtin_provider_credentials_schema(provider_name): """ - list builtin provider credentials schema + list builtin provider credentials schema - :return: the list of tool providers + :return: the list of tool providers """ provider = ToolManager.get_builtin_provider(provider_name) - return jsonable_encoder([ - v for _, v in (provider.credentials_schema or {}).items() - ]) + return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()]) @staticmethod - def update_builtin_tool_provider( - user_id: str, tenant_id: str, provider_name: str, credentials: dict - ): + def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict): """ - update builtin tool provider + update builtin tool provider """ # get if the provider exists - provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_name, - ).first() + provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_name, + ) + .first() + ) - try: + try: # get provider provider_controller = ToolManager.get_builtin_provider(provider_name) if not provider_controller.need_credentials: - raise ValueError(f'provider {provider_name} does not need credentials') + raise ValueError(f"provider {provider_name} does not need credentials") tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) # get original credentials if exists if provider is not None: @@ -119,23 +126,25 @@ def update_builtin_tool_provider( # delete cache tool_configuration.delete_tool_credentials_cache() - return { 'result': 'success' } - + return {"result": "success"} + @staticmethod - def get_builtin_tool_provider_credentials( - user_id: str, tenant_id: str, provider: str - ): + def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str): """ - get builtin tool provider credentials + get builtin tool provider credentials """ - provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider, - ).first() + provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ) + .first() + ) if provider is None: return {} - + provider_controller = ToolManager.get_builtin_provider(provider.provider) tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) @@ -143,20 +152,22 @@ def get_builtin_tool_provider_credentials( return credentials @staticmethod - def delete_builtin_tool_provider( - user_id: str, tenant_id: str, provider_name: str - ): + def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str): """ - delete tool provider + delete tool provider """ - provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_name, - ).first() + provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_name, + ) + .first() + ) if provider is None: - raise ValueError(f'you have not added provider {provider_name}') - + raise ValueError(f"you have not added provider {provider_name}") + db.session.delete(provider) db.session.commit() @@ -165,48 +176,54 @@ def delete_builtin_tool_provider( tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) tool_configuration.delete_tool_credentials_cache() - return { 'result': 'success' } - + return {"result": "success"} + @staticmethod - def get_builtin_tool_provider_icon( - provider: str - ): + def get_builtin_tool_provider_icon(provider: str): """ - get tool provider icon and it's mimetype + get tool provider icon and it's mimetype """ icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider) - with open(icon_path, 'rb') as f: - icon_bytes = f.read() + icon_bytes = Path(icon_path).read_bytes() return icon_bytes, mime_type - + @staticmethod - def list_builtin_tools( - user_id: str, tenant_id: str - ) -> list[UserToolProvider]: + def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: """ - list builtin tools + list builtin tools """ # get all builtin providers provider_controllers = ToolManager.list_builtin_providers() # get all user added providers - db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id - ).all() or [] + db_providers: list[BuiltinToolProvider] = ( + db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or [] + ) # find provider - find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) + find_provider = lambda provider: next( + filter(lambda db_provider: db_provider.provider == provider, db_providers), None + ) result: list[UserToolProvider] = [] for provider_controller in provider_controllers: try: + # handle include, exclude + if is_filtered( + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, + data=provider_controller, + name_func=lambda x: x.identity.name, + ): + continue + # convert provider controller to user provider user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider_controller, db_provider=find_provider(provider_controller.identity.name), - decrypt_credentials=True + decrypt_credentials=True, ) # add icon @@ -214,16 +231,17 @@ def list_builtin_tools( tools = provider_controller.get_tools() for tool in tools: - user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool( - tenant_id=tenant_id, - tool=tool, - credentials=user_builtin_provider.original_credentials, - labels=ToolLabelManager.get_tool_labels(provider_controller) - )) + user_builtin_provider.tools.append( + ToolTransformService.tool_to_user_tool( + tenant_id=tenant_id, + tool=tool, + credentials=user_builtin_provider.original_credentials, + labels=ToolLabelManager.get_tool_labels(provider_controller), + ) + ) result.append(user_builtin_provider) except Exception as e: raise e return BuiltinToolProviderSort.sort(result) - \ No newline at end of file diff --git a/api/services/tools/tool_labels_service.py b/api/services/tools/tool_labels_service.py index 8a6aa025f27c6b..35e58b5adec58f 100644 --- a/api/services/tools/tool_labels_service.py +++ b/api/services/tools/tool_labels_service.py @@ -5,4 +5,4 @@ class ToolLabelsService: @classmethod def list_tool_labels(cls) -> list[ToolLabel]: - return default_tool_labels \ No newline at end of file + return default_tool_labels diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py index 76d2f53ae86535..1c67f7648ca99f 100644 --- a/api/services/tools/tools_manage_service.py +++ b/api/services/tools/tools_manage_service.py @@ -11,13 +11,11 @@ class ToolCommonService: @staticmethod def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral = None): """ - list tool providers + list tool providers - :return: the list of tool providers + :return: the list of tool providers """ - providers = ToolManager.user_list_providers( - user_id, tenant_id, typ - ) + providers = ToolManager.user_list_providers(user_id, tenant_id, typ) # add icon for provider in providers: @@ -26,4 +24,3 @@ def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeL result = [provider.to_dict() for provider in providers] return result - \ No newline at end of file diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index cfce3fbd019a2c..e535ddb575814c 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -22,46 +22,39 @@ logger = logging.getLogger(__name__) + class ToolTransformService: @staticmethod def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str) -> Union[str, dict]: """ - get tool provider icon url + get tool provider icon url """ - url_prefix = (dify_config.CONSOLE_API_URL - + "/console/api/workspaces/current/tool-provider/") - + url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/" + if provider_type == ToolProviderType.BUILT_IN.value: - return url_prefix + 'builtin/' + provider_name + '/icon' - elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]: + return url_prefix + "builtin/" + provider_name + "/icon" + elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}: try: return json.loads(icon) except: - return { - "background": "#252525", - "content": "\ud83d\ude01" - } - - return '' - + return {"background": "#252525", "content": "\ud83d\ude01"} + + return "" + @staticmethod def repack_provider(provider: Union[dict, UserToolProvider]): """ - repack provider + repack provider - :param provider: the provider dict + :param provider: the provider dict """ - if isinstance(provider, dict) and 'icon' in provider: - provider['icon'] = ToolTransformService.get_tool_provider_icon_url( - provider_type=provider['type'], - provider_name=provider['name'], - icon=provider['icon'] + if isinstance(provider, dict) and "icon" in provider: + provider["icon"] = ToolTransformService.get_tool_provider_icon_url( + provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"] ) elif isinstance(provider, UserToolProvider): provider.icon = ToolTransformService.get_tool_provider_icon_url( - provider_type=provider.type.value, - provider_name=provider.name, - icon=provider.icon + provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon ) @staticmethod @@ -81,25 +74,26 @@ def builtin_provider_to_user_provider( en_US=provider_controller.identity.description.en_US, zh_Hans=provider_controller.identity.description.zh_Hans, pt_BR=provider_controller.identity.description.pt_BR, + ja_JP=provider_controller.identity.description.ja_JP, ), icon=provider_controller.identity.icon, label=I18nObject( en_US=provider_controller.identity.label.en_US, zh_Hans=provider_controller.identity.label.zh_Hans, pt_BR=provider_controller.identity.label.pt_BR, + ja_JP=provider_controller.identity.label.ja_JP, ), type=ToolProviderType.BUILT_IN, masked_credentials={}, is_team_authorization=False, tools=[], - labels=provider_controller.tool_labels + labels=provider_controller.tool_labels, ) # get credentials schema schema = provider_controller.get_credentials_schema() for name, value in schema.items(): - result.masked_credentials[name] = \ - ToolProviderCredentials.CredentialsType.default(value.type) + result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type) # check if the provider need credentials if not provider_controller.need_credentials: @@ -113,8 +107,7 @@ def builtin_provider_to_user_provider( # init tool configuration tool_configuration = ToolConfigurationManager( - tenant_id=db_provider.tenant_id, - provider_controller=provider_controller + tenant_id=db_provider.tenant_id, provider_controller=provider_controller ) # decrypt the credentials and mask the credentials decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) @@ -124,7 +117,7 @@ def builtin_provider_to_user_provider( result.original_credentials = decrypted_credentials return result - + @staticmethod def api_provider_to_controller( db_provider: ApiToolProvider, @@ -135,25 +128,23 @@ def api_provider_to_controller( # package tool provider controller controller = ApiToolProviderController.from_db( db_provider=db_provider, - auth_type=ApiProviderAuthType.API_KEY if db_provider.credentials['auth_type'] == 'api_key' else - ApiProviderAuthType.NONE + auth_type=ApiProviderAuthType.API_KEY + if db_provider.credentials["auth_type"] == "api_key" + else ApiProviderAuthType.NONE, ) return controller - + @staticmethod - def workflow_provider_to_controller( - db_provider: WorkflowToolProvider - ) -> WorkflowToolProviderController: + def workflow_provider_to_controller(db_provider: WorkflowToolProvider) -> WorkflowToolProviderController: """ convert provider controller to provider """ return WorkflowToolProviderController.from_db(db_provider) - + @staticmethod def workflow_provider_to_user_provider( - provider_controller: WorkflowToolProviderController, - labels: list[str] = None + provider_controller: WorkflowToolProviderController, labels: Optional[list[str]] = None ): """ convert provider controller to user provider @@ -175,7 +166,7 @@ def workflow_provider_to_user_provider( masked_credentials={}, is_team_authorization=True, tools=[], - labels=labels or [] + labels=labels or [], ) @staticmethod @@ -183,16 +174,16 @@ def api_provider_to_user_provider( provider_controller: ApiToolProviderController, db_provider: ApiToolProvider, decrypt_credentials: bool = True, - labels: list[str] = None + labels: Optional[list[str]] = None, ) -> UserToolProvider: """ convert provider controller to user provider """ - username = 'Anonymous' + username = "Anonymous" try: username = db_provider.user.name except Exception as e: - logger.error(f'failed to get user name for api provider {db_provider.id}: {str(e)}') + logger.exception(f"failed to get user name for api provider {db_provider.id}: {str(e)}") # add provider into providers credentials = db_provider.credentials result = UserToolProvider( @@ -212,14 +203,13 @@ def api_provider_to_user_provider( masked_credentials={}, is_team_authorization=True, tools=[], - labels=labels or [] + labels=labels or [], ) if decrypt_credentials: # init tool configuration tool_configuration = ToolConfigurationManager( - tenant_id=db_provider.tenant_id, - provider_controller=provider_controller + tenant_id=db_provider.tenant_id, provider_controller=provider_controller ) # decrypt the credentials and mask the credentials @@ -229,23 +219,25 @@ def api_provider_to_user_provider( result.masked_credentials = masked_credentials return result - + @staticmethod def tool_to_user_tool( - tool: Union[ApiToolBundle, WorkflowTool, Tool], - credentials: dict = None, - tenant_id: str = None, - labels: list[str] = None + tool: Union[ApiToolBundle, WorkflowTool, Tool], + credentials: Optional[dict] = None, + tenant_id: Optional[str] = None, + labels: Optional[list[str]] = None, ) -> UserTool: """ convert tool to user tool """ if isinstance(tool, Tool): # fork tool runtime - tool = tool.fork_tool_runtime(runtime={ - 'credentials': credentials, - 'tenant_id': tenant_id, - }) + tool = tool.fork_tool_runtime( + runtime={ + "credentials": credentials, + "tenant_id": tenant_id, + } + ) # get tool parameters parameters = tool.parameters or [] @@ -270,20 +262,14 @@ def tool_to_user_tool( label=tool.identity.label, description=tool.description.human, parameters=current_parameters, - labels=labels + labels=labels, ) if isinstance(tool, ApiToolBundle): return UserTool( author=tool.author, name=tool.operation_id, - label=I18nObject( - en_US=tool.operation_id, - zh_Hans=tool.operation_id - ), - description=I18nObject( - en_US=tool.summary or '', - zh_Hans=tool.summary or '' - ), + label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id), + description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""), parameters=tool.parameters, - labels=labels - ) \ No newline at end of file + labels=labels, + ) diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 185483a71cf82e..833881b668b383 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -1,5 +1,7 @@ import json +from collections.abc import Mapping from datetime import datetime +from typing import Any, Optional from sqlalchemy import or_ @@ -19,46 +21,45 @@ class WorkflowToolManageService: """ Service class for managing workflow tools. """ - @classmethod - def create_workflow_tool(cls, user_id: str, tenant_id: str, workflow_app_id: str, name: str, - label: str, icon: dict, description: str, - parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict: - """ - Create a workflow tool. - :param user_id: the user id - :param tenant_id: the tenant id - :param name: the name - :param icon: the icon - :param description: the description - :param parameters: the parameters - :param privacy_policy: the privacy policy - :param labels: labels - :return: the created tool - """ + + @staticmethod + def create_workflow_tool( + *, + user_id: str, + tenant_id: str, + workflow_app_id: str, + name: str, + label: str, + icon: dict, + description: str, + parameters: Mapping[str, Any], + privacy_policy: str = "", + labels: Optional[list[str]] = None, + ) -> dict: WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) # check if the name is unique - existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - # name or app_id - or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id) - ).first() + existing_workflow_tool_provider = ( + db.session.query(WorkflowToolProvider) + .filter( + WorkflowToolProvider.tenant_id == tenant_id, + # name or app_id + or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id), + ) + .first() + ) if existing_workflow_tool_provider is not None: - raise ValueError(f'Tool with name {name} or app_id {workflow_app_id} already exists') - - app: App = db.session.query(App).filter( - App.id == workflow_app_id, - App.tenant_id == tenant_id - ).first() + raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists") + app = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first() if app is None: - raise ValueError(f'App {workflow_app_id} not found') - - workflow: Workflow = app.workflow + raise ValueError(f"App {workflow_app_id} not found") + + workflow = app.workflow if workflow is None: - raise ValueError(f'Workflow not found for app {workflow_app_id}') - + raise ValueError(f"Workflow not found for app {workflow_app_id}") + workflow_tool_provider = WorkflowToolProvider( tenant_id=tenant_id, user_id=user_id, @@ -76,19 +77,26 @@ def create_workflow_tool(cls, user_id: str, tenant_id: str, workflow_app_id: str WorkflowToolProviderController.from_db(workflow_tool_provider) except Exception as e: raise ValueError(str(e)) - + db.session.add(workflow_tool_provider) db.session.commit() - return { - 'result': 'success' - } - + return {"result": "success"} @classmethod - def update_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str, - name: str, label: str, icon: dict, description: str, - parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict: + def update_workflow_tool( + cls, + user_id: str, + tenant_id: str, + workflow_tool_id: str, + name: str, + label: str, + icon: dict, + description: str, + parameters: list[dict], + privacy_policy: str = "", + labels: Optional[list[str]] = None, + ) -> dict: """ Update a workflow tool. :param user_id: the user id @@ -106,35 +114,39 @@ def update_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: st WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) # check if the name is unique - existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.name == name, - WorkflowToolProvider.id != workflow_tool_id - ).first() + existing_workflow_tool_provider = ( + db.session.query(WorkflowToolProvider) + .filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.name == name, + WorkflowToolProvider.id != workflow_tool_id, + ) + .first() + ) if existing_workflow_tool_provider is not None: - raise ValueError(f'Tool with name {name} already exists') - - workflow_tool_provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == workflow_tool_id - ).first() + raise ValueError(f"Tool with name {name} already exists") + + workflow_tool_provider: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .first() + ) if workflow_tool_provider is None: - raise ValueError(f'Tool {workflow_tool_id} not found') - - app: App = db.session.query(App).filter( - App.id == workflow_tool_provider.app_id, - App.tenant_id == tenant_id - ).first() + raise ValueError(f"Tool {workflow_tool_id} not found") + + app: App = ( + db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first() + ) if app is None: - raise ValueError(f'App {workflow_tool_provider.app_id} not found') - + raise ValueError(f"App {workflow_tool_provider.app_id} not found") + workflow: Workflow = app.workflow if workflow is None: - raise ValueError(f'Workflow not found for app {workflow_tool_provider.app_id}') - + raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}") + workflow_tool_provider.name = name workflow_tool_provider.label = label workflow_tool_provider.icon = json.dumps(icon) @@ -154,13 +166,10 @@ def update_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: st if labels is not None: ToolLabelManager.update_tool_labels( - ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), - labels + ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels ) - return { - 'result': 'success' - } + return {"result": "success"} @classmethod def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]: @@ -170,9 +179,7 @@ def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserTo :param tenant_id: the tenant id :return: the list of tools """ - db_tools = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id - ).all() + db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() tools = [] for provider in db_tools: @@ -188,14 +195,12 @@ def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserTo for tool in tools: user_tool_provider = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=tool, - labels=labels.get(tool.provider_id, []) + provider_controller=tool, labels=labels.get(tool.provider_id, []) ) ToolTransformService.repack_provider(user_tool_provider) user_tool_provider.tools = [ ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], - labels=labels.get(tool.provider_id, []) + tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, []) ) ] result.append(user_tool_provider) @@ -211,15 +216,12 @@ def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: st :param workflow_app_id: the workflow app id """ db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == workflow_tool_id + WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id ).delete() db.session.commit() - return { - 'result': 'success' - } + return {"result": "success"} @classmethod def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict: @@ -230,40 +232,37 @@ def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_too :param workflow_app_id: the workflow app id :return: the tool """ - db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == workflow_tool_id - ).first() + db_tool: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .first() + ) if db_tool is None: - raise ValueError(f'Tool {workflow_tool_id} not found') - - workflow_app: App = db.session.query(App).filter( - App.id == db_tool.app_id, - App.tenant_id == tenant_id - ).first() + raise ValueError(f"Tool {workflow_tool_id} not found") + + workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() if workflow_app is None: - raise ValueError(f'App {db_tool.app_id} not found') + raise ValueError(f"App {db_tool.app_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) return { - 'name': db_tool.name, - 'label': db_tool.label, - 'workflow_tool_id': db_tool.id, - 'workflow_app_id': db_tool.app_id, - 'icon': json.loads(db_tool.icon), - 'description': db_tool.description, - 'parameters': jsonable_encoder(db_tool.parameter_configurations), - 'tool': ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], - labels=ToolLabelManager.get_tool_labels(tool) + "name": db_tool.name, + "label": db_tool.label, + "workflow_tool_id": db_tool.id, + "workflow_app_id": db_tool.app_id, + "icon": json.loads(db_tool.icon), + "description": db_tool.description, + "parameters": jsonable_encoder(db_tool.parameter_configurations), + "tool": ToolTransformService.tool_to_user_tool( + tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) ), - 'synced': workflow_app.workflow.version == db_tool.version, - 'privacy_policy': db_tool.privacy_policy, + "synced": workflow_app.workflow.version == db_tool.version, + "privacy_policy": db_tool.privacy_policy, } - + @classmethod def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict: """ @@ -273,40 +272,37 @@ def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_ :param workflow_app_id: the workflow app id :return: the tool """ - db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.app_id == workflow_app_id - ).first() + db_tool: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) + .first() + ) if db_tool is None: - raise ValueError(f'Tool {workflow_app_id} not found') - - workflow_app: App = db.session.query(App).filter( - App.id == db_tool.app_id, - App.tenant_id == tenant_id - ).first() + raise ValueError(f"Tool {workflow_app_id} not found") + + workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() if workflow_app is None: - raise ValueError(f'App {db_tool.app_id} not found') + raise ValueError(f"App {db_tool.app_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) return { - 'name': db_tool.name, - 'label': db_tool.label, - 'workflow_tool_id': db_tool.id, - 'workflow_app_id': db_tool.app_id, - 'icon': json.loads(db_tool.icon), - 'description': db_tool.description, - 'parameters': jsonable_encoder(db_tool.parameter_configurations), - 'tool': ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], - labels=ToolLabelManager.get_tool_labels(tool) + "name": db_tool.name, + "label": db_tool.label, + "workflow_tool_id": db_tool.id, + "workflow_app_id": db_tool.app_id, + "icon": json.loads(db_tool.icon), + "description": db_tool.description, + "parameters": jsonable_encoder(db_tool.parameter_configurations), + "tool": ToolTransformService.tool_to_user_tool( + tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) ), - 'synced': workflow_app.workflow.version == db_tool.version, - 'privacy_policy': db_tool.privacy_policy + "synced": workflow_app.workflow.version == db_tool.version, + "privacy_policy": db_tool.privacy_policy, } - + @classmethod def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]: """ @@ -316,19 +312,19 @@ def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_ :param workflow_app_id: the workflow app id :return: the list of tools """ - db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == workflow_tool_id - ).first() + db_tool: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .first() + ) if db_tool is None: - raise ValueError(f'Tool {workflow_tool_id} not found') + raise ValueError(f"Tool {workflow_tool_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) return [ ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], - labels=ToolLabelManager.get_tool_labels(tool) + tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) ) - ] \ No newline at end of file + ] diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 232d2943256cad..3c67351335359d 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -7,10 +7,10 @@ class VectorService: - @classmethod - def create_segments_vector(cls, keywords_list: Optional[list[list[str]]], - segments: list[DocumentSegment], dataset: Dataset): + def create_segments_vector( + cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset + ): documents = [] for segment in segments: document = Document( @@ -20,14 +20,12 @@ def create_segments_vector(cls, keywords_list: Optional[list[list[str]]], "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": # save vector index - vector = Vector( - dataset=dataset - ) + vector = Vector(dataset=dataset) vector.add_texts(documents, duplicate_check=True) # save keyword index @@ -50,13 +48,11 @@ def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentS "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": # update vector index - vector = Vector( - dataset=dataset - ) + vector = Vector(dataset=dataset) vector.delete_by_ids([segment.index_node_id]) vector.add_texts([document], duplicate_check=True) diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index cba048ccdbbbf1..d7ccc964cb70f8 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -11,17 +11,29 @@ class WebConversationService: @classmethod - def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - last_id: Optional[str], limit: int, invoke_from: InvokeFrom, - pinned: Optional[bool] = None) -> InfiniteScrollPagination: + def pagination_by_last_id( + cls, + app_model: App, + user: Optional[Union[Account, EndUser]], + last_id: Optional[str], + limit: int, + invoke_from: InvokeFrom, + pinned: Optional[bool] = None, + sort_by="-updated_at", + ) -> InfiniteScrollPagination: include_ids = None exclude_ids = None if pinned is not None: - pinned_conversations = db.session.query(PinnedConversation).filter( - PinnedConversation.app_id == app_model.id, - PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - PinnedConversation.created_by == user.id - ).order_by(PinnedConversation.created_at.desc()).all() + pinned_conversations = ( + db.session.query(PinnedConversation) + .filter( + PinnedConversation.app_id == app_model.id, + PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + PinnedConversation.created_by == user.id, + ) + .order_by(PinnedConversation.created_at.desc()) + .all() + ) pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations] if pinned: include_ids = pinned_conversation_ids @@ -36,31 +48,34 @@ def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, End invoke_from=invoke_from, include_ids=include_ids, exclude_ids=exclude_ids, + sort_by=sort_by, ) @classmethod def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): - pinned_conversation = db.session.query(PinnedConversation).filter( - PinnedConversation.app_id == app_model.id, - PinnedConversation.conversation_id == conversation_id, - PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - PinnedConversation.created_by == user.id - ).first() + pinned_conversation = ( + db.session.query(PinnedConversation) + .filter( + PinnedConversation.app_id == app_model.id, + PinnedConversation.conversation_id == conversation_id, + PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + PinnedConversation.created_by == user.id, + ) + .first() + ) if pinned_conversation: return conversation = ConversationService.get_conversation( - app_model=app_model, - conversation_id=conversation_id, - user=user + app_model=app_model, conversation_id=conversation_id, user=user ) pinned_conversation = PinnedConversation( app_id=app_model.id, conversation_id=conversation.id, - created_by_role='account' if isinstance(user, Account) else 'end_user', - created_by=user.id + created_by_role="account" if isinstance(user, Account) else "end_user", + created_by=user.id, ) db.session.add(pinned_conversation) @@ -68,12 +83,16 @@ def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, @classmethod def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): - pinned_conversation = db.session.query(PinnedConversation).filter( - PinnedConversation.app_id == app_model.id, - PinnedConversation.conversation_id == conversation_id, - PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - PinnedConversation.created_by == user.id - ).first() + pinned_conversation = ( + db.session.query(PinnedConversation) + .filter( + PinnedConversation.app_id == app_model.id, + PinnedConversation.conversation_id == conversation_id, + PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + PinnedConversation.created_by == user.id, + ) + .first() + ) if not pinned_conversation: return diff --git a/api/services/website_service.py b/api/services/website_service.py index c166b01237b6c4..13cc9c679adb90 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -1,6 +1,7 @@ import datetime import json +import requests from flask_login import current_user from core.helper import encrypter @@ -11,161 +12,225 @@ class WebsiteService: - @classmethod def document_create_args_validate(cls, args: dict): - if 'url' not in args or not args['url']: - raise ValueError('url is required') - if 'options' not in args or not args['options']: - raise ValueError('options is required') - if 'limit' not in args['options'] or not args['options']['limit']: - raise ValueError('limit is required') + if "url" not in args or not args["url"]: + raise ValueError("url is required") + if "options" not in args or not args["options"]: + raise ValueError("options is required") + if "limit" not in args["options"] or not args["options"]["limit"]: + raise ValueError("limit is required") @classmethod def crawl_url(cls, args: dict) -> dict: - provider = args.get('provider') - url = args.get('url') - options = args.get('options') - credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, - 'website', - provider) - if provider == 'firecrawl': + provider = args.get("provider") + url = args.get("url") + options = args.get("options") + credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) + if provider == "firecrawl": # decrypt api_key api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, - token=credentials.get('config').get('api_key') + tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") ) - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=credentials.get('config').get('base_url', None)) - crawl_sub_pages = options.get('crawl_sub_pages', False) - only_main_content = options.get('only_main_content', False) + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) + crawl_sub_pages = options.get("crawl_sub_pages", False) + only_main_content = options.get("only_main_content", False) if not crawl_sub_pages: params = { - 'crawlerOptions': { + "crawlerOptions": { "includes": [], "excludes": [], "generateImgAltText": True, "limit": 1, - 'returnOnlyUrls': False, - 'pageOptions': { - 'onlyMainContent': only_main_content, - "includeHtml": False - } + "returnOnlyUrls": False, + "pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False}, } } else: - includes = options.get('includes').split(',') if options.get('includes') else [] - excludes = options.get('excludes').split(',') if options.get('excludes') else [] + includes = options.get("includes").split(",") if options.get("includes") else [] + excludes = options.get("excludes").split(",") if options.get("excludes") else [] params = { - 'crawlerOptions': { - "includes": includes if includes else [], - "excludes": excludes if excludes else [], + "crawlerOptions": { + "includes": includes or [], + "excludes": excludes or [], "generateImgAltText": True, - "limit": options.get('limit', 1), - 'returnOnlyUrls': False, - 'pageOptions': { - 'onlyMainContent': only_main_content, - "includeHtml": False - } + "limit": options.get("limit", 1), + "returnOnlyUrls": False, + "pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False}, } } - if options.get('max_depth'): - params['crawlerOptions']['maxDepth'] = options.get('max_depth') + if options.get("max_depth"): + params["crawlerOptions"]["maxDepth"] = options.get("max_depth") job_id = firecrawl_app.crawl_url(url, params) - website_crawl_time_cache_key = f'website_crawl_{job_id}' + website_crawl_time_cache_key = f"website_crawl_{job_id}" time = str(datetime.datetime.now().timestamp()) redis_client.setex(website_crawl_time_cache_key, 3600, time) - return { - 'status': 'active', - 'job_id': job_id - } + return {"status": "active", "job_id": job_id} + elif provider == "jinareader": + api_key = encrypter.decrypt_token( + tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") + ) + crawl_sub_pages = options.get("crawl_sub_pages", False) + if not crawl_sub_pages: + response = requests.get( + f"https://r.jina.ai/{url}", + headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, + ) + if response.json().get("code") != 200: + raise ValueError("Failed to crawl") + return {"status": "active", "data": response.json().get("data")} + else: + response = requests.post( + "https://adaptivecrawl-kir3wx7b3a-uc.a.run.app", + json={ + "url": url, + "maxPages": options.get("limit", 1), + "useSitemap": options.get("use_sitemap", True), + }, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + }, + ) + if response.json().get("code") != 200: + raise ValueError("Failed to crawl") + return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")} else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") @classmethod def get_crawl_status(cls, job_id: str, provider: str) -> dict: - credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, - 'website', - provider) - if provider == 'firecrawl': + credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) + if provider == "firecrawl": # decrypt api_key api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, - token=credentials.get('config').get('api_key') + tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") ) - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=credentials.get('config').get('base_url', None)) + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) result = firecrawl_app.check_crawl_status(job_id) crawl_status_data = { - 'status': result.get('status', 'active'), - 'job_id': job_id, - 'total': result.get('total', 0), - 'current': result.get('current', 0), - 'data': result.get('data', []) + "status": result.get("status", "active"), + "job_id": job_id, + "total": result.get("total", 0), + "current": result.get("current", 0), + "data": result.get("data", []), } - if crawl_status_data['status'] == 'completed': - website_crawl_time_cache_key = f'website_crawl_{job_id}' + if crawl_status_data["status"] == "completed": + website_crawl_time_cache_key = f"website_crawl_{job_id}" start_time = redis_client.get(website_crawl_time_cache_key) if start_time: end_time = datetime.datetime.now().timestamp() time_consuming = abs(end_time - float(start_time)) - crawl_status_data['time_consuming'] = f"{time_consuming:.2f}" + crawl_status_data["time_consuming"] = f"{time_consuming:.2f}" redis_client.delete(website_crawl_time_cache_key) + elif provider == "jinareader": + api_key = encrypter.decrypt_token( + tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") + ) + response = requests.post( + "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", + headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, + json={"taskId": job_id}, + ) + data = response.json().get("data", {}) + crawl_status_data = { + "status": data.get("status", "active"), + "job_id": job_id, + "total": len(data.get("urls", [])), + "current": len(data.get("processed", [])) + len(data.get("failed", [])), + "data": [], + "time_consuming": data.get("duration", 0) / 1000, + } + + if crawl_status_data["status"] == "completed": + response = requests.post( + "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", + headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, + json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())}, + ) + data = response.json().get("data", {}) + formatted_data = [ + { + "title": item.get("data", {}).get("title"), + "source_url": item.get("data", {}).get("url"), + "description": item.get("data", {}).get("description"), + "markdown": item.get("data", {}).get("content"), + } + for item in data.get("processed", {}).values() + ] + crawl_status_data["data"] = formatted_data else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") return crawl_status_data @classmethod def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None: - credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, - 'website', - provider) - if provider == 'firecrawl': - file_key = 'website_files/' + job_id + '.txt' + credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) + # decrypt api_key + api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) + if provider == "firecrawl": + file_key = "website_files/" + job_id + ".txt" if storage.exists(file_key): data = storage.load_once(file_key) if data: - data = json.loads(data.decode('utf-8')) + data = json.loads(data.decode("utf-8")) else: - # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=tenant_id, - token=credentials.get('config').get('api_key') - ) - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=credentials.get('config').get('base_url', None)) + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) result = firecrawl_app.check_crawl_status(job_id) - if result.get('status') != 'completed': - raise ValueError('Crawl job is not completed') - data = result.get('data') + if result.get("status") != "completed": + raise ValueError("Crawl job is not completed") + data = result.get("data") if data: for item in data: - if item.get('source_url') == url: + if item.get("source_url") == url: return item return None + elif provider == "jinareader": + file_key = "website_files/" + job_id + ".txt" + if storage.exists(file_key): + data = storage.load_once(file_key) + if data: + data = json.loads(data.decode("utf-8")) + elif not job_id: + response = requests.get( + f"https://r.jina.ai/{url}", + headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, + ) + if response.json().get("code") != 200: + raise ValueError("Failed to crawl") + return response.json().get("data") + else: + api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) + response = requests.post( + "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", + headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, + json={"taskId": job_id}, + ) + data = response.json().get("data", {}) + if data.get("status") != "completed": + raise ValueError("Crawl job is not completed") + + response = requests.post( + "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", + headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, + json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())}, + ) + data = response.json().get("data", {}) + for item in data.get("processed", {}).values(): + if item.get("data", {}).get("url") == url: + return item.get("data", {}) else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") @classmethod def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None: - credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, - 'website', - provider) - if provider == 'firecrawl': + credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) + if provider == "firecrawl": # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=tenant_id, - token=credentials.get('config').get('api_key') - ) - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=credentials.get('config').get('base_url', None)) - params = { - 'pageOptions': { - 'onlyMainContent': only_main_content, - "includeHtml": False - } - } + api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) + params = {"pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False}} result = firecrawl_app.scrape_url(url, params) return result else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 9ca09c7f0d3918..90b5cc48362f3b 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -13,12 +13,12 @@ from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.completion.app_config_manager import CompletionAppConfigManager -from core.file.file_obj import FileExtraConfig +from core.file.models import FileUploadConfig from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.simple_prompt_transform import SimplePromptTransform -from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes import NodeType from events.app_event import app_was_created from extensions.ext_database import db from models.account import Account @@ -32,11 +32,9 @@ class WorkflowConverter: App Convert to Workflow Mode """ - def convert_to_workflow(self, app_model: App, - account: Account, - name: str, - icon: str, - icon_background: str) -> App: + def convert_to_workflow( + self, app_model: App, account: Account, name: str, icon_type: str, icon: str, icon_background: str + ): """ Convert app to workflow @@ -50,30 +48,34 @@ def convert_to_workflow(self, app_model: App, :param account: Account :param name: new app name :param icon: new app icon + :param icon_type: new app icon type :param icon_background: new app icon background :return: new App instance """ # convert app model config + if not app_model.app_model_config: + raise ValueError("App model config is required") + workflow = self.convert_app_model_config_to_workflow( - app_model=app_model, - app_model_config=app_model.app_model_config, - account_id=account.id + app_model=app_model, app_model_config=app_model.app_model_config, account_id=account.id ) # create new app new_app = App() new_app.tenant_id = app_model.tenant_id - new_app.name = name if name else app_model.name + '(workflow)' - new_app.mode = AppMode.ADVANCED_CHAT.value \ - if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value - new_app.icon = icon if icon else app_model.icon - new_app.icon_background = icon_background if icon_background else app_model.icon_background + new_app.name = name or app_model.name + "(workflow)" + new_app.mode = AppMode.ADVANCED_CHAT.value if app_model.mode == AppMode.CHAT.value else AppMode.WORKFLOW.value + new_app.icon_type = icon_type or app_model.icon_type + new_app.icon = icon or app_model.icon + new_app.icon_background = icon_background or app_model.icon_background new_app.enable_site = app_model.enable_site new_app.enable_api = app_model.enable_api new_app.api_rpm = app_model.api_rpm new_app.api_rph = app_model.api_rph new_app.is_demo = False new_app.is_public = app_model.is_public + new_app.created_by = account.id + new_app.updated_by = account.id db.session.add(new_app) db.session.flush() db.session.commit() @@ -85,30 +87,21 @@ def convert_to_workflow(self, app_model: App, return new_app - def convert_app_model_config_to_workflow(self, app_model: App, - app_model_config: AppModelConfig, - account_id: str) -> Workflow: + def convert_app_model_config_to_workflow(self, app_model: App, app_model_config: AppModelConfig, account_id: str): """ Convert app model config to workflow mode :param app_model: App instance :param app_model_config: AppModelConfig instance :param account_id: Account ID - :return: """ # get new app mode new_app_mode = self._get_new_app_mode(app_model) # convert app model config - app_config = self._convert_to_app_config( - app_model=app_model, - app_model_config=app_model_config - ) + app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config) # init workflow graph - graph = { - "nodes": [], - "edges": [] - } + graph = {"nodes": [], "edges": []} # Convert list: # - variables -> start @@ -120,11 +113,9 @@ def convert_app_model_config_to_workflow(self, app_model: App, # - show_retrieve_source -> knowledge-retrieval # convert to start node - start_node = self._convert_to_start_node( - variables=app_config.variables - ) + start_node = self._convert_to_start_node(variables=app_config.variables) - graph['nodes'].append(start_node) + graph["nodes"].append(start_node) # convert to http request node external_data_variable_node_mapping = {} @@ -132,7 +123,7 @@ def convert_app_model_config_to_workflow(self, app_model: App, http_request_nodes, external_data_variable_node_mapping = self._convert_to_http_request_node( app_model=app_model, variables=app_config.variables, - external_data_variables=app_config.external_data_variables + external_data_variables=app_config.external_data_variables, ) for http_request_node in http_request_nodes: @@ -141,9 +132,7 @@ def convert_app_model_config_to_workflow(self, app_model: App, # convert to knowledge retrieval node if app_config.dataset: knowledge_retrieval_node = self._convert_to_knowledge_retrieval_node( - new_app_mode=new_app_mode, - dataset_config=app_config.dataset, - model_config=app_config.model + new_app_mode=new_app_mode, dataset_config=app_config.dataset, model_config=app_config.model ) if knowledge_retrieval_node: @@ -157,7 +146,7 @@ def convert_app_model_config_to_workflow(self, app_model: App, model_config=app_config.model, prompt_template=app_config.prompt_template, file_upload=app_config.additional_features.file_upload, - external_data_variable_node_mapping=external_data_variable_node_mapping + external_data_variable_node_mapping=external_data_variable_node_mapping, ) graph = self._append_node(graph, llm_node) @@ -196,7 +185,7 @@ def convert_app_model_config_to_workflow(self, app_model: App, tenant_id=app_model.tenant_id, app_id=app_model.id, type=WorkflowType.from_app_mode(new_app_mode).value, - version='draft', + version="draft", graph=json.dumps(graph), features=json.dumps(features), created_by=account_id, @@ -209,24 +198,18 @@ def convert_app_model_config_to_workflow(self, app_model: App, return workflow - def _convert_to_app_config(self, app_model: App, - app_model_config: AppModelConfig) -> EasyUIBasedAppConfig: + def _convert_to_app_config(self, app_model: App, app_model_config: AppModelConfig) -> EasyUIBasedAppConfig: app_mode = AppMode.value_of(app_model.mode) if app_mode == AppMode.AGENT_CHAT or app_model.is_agent: app_model.mode = AppMode.AGENT_CHAT.value app_config = AgentChatAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config + app_model=app_model, app_model_config=app_model_config ) elif app_mode == AppMode.CHAT: - app_config = ChatAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config - ) + app_config = ChatAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config) elif app_mode == AppMode.COMPLETION: app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config + app_model=app_model, app_model_config=app_model_config ) else: raise ValueError("Invalid app mode") @@ -245,14 +228,13 @@ def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict: "data": { "title": "START", "type": NodeType.START.value, - "variables": [jsonable_encoder(v) for v in variables] - } + "variables": [jsonable_encoder(v) for v in variables], + }, } - def _convert_to_http_request_node(self, app_model: App, - variables: list[VariableEntity], - external_data_variables: list[ExternalDataVariableEntity]) \ - -> tuple[list[dict], dict[str, str]]: + def _convert_to_http_request_node( + self, app_model: App, variables: list[VariableEntity], external_data_variables: list[ExternalDataVariableEntity] + ) -> tuple[list[dict], dict[str, str]]: """ Convert API Based Extension to HTTP Request Node :param app_model: App instance @@ -274,40 +256,33 @@ def _convert_to_http_request_node(self, app_model: App, # get params from config api_based_extension_id = tool_config.get("api_based_extension_id") + if not api_based_extension_id: + continue # get api_based_extension api_based_extension = self._get_api_based_extension( - tenant_id=tenant_id, - api_based_extension_id=api_based_extension_id + tenant_id=tenant_id, api_based_extension_id=api_based_extension_id ) - if not api_based_extension: - raise ValueError("[External data tool] API query failed, variable: {}, " - "error: api_based_extension_id is invalid" - .format(tool_variable)) - # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=tenant_id, - token=api_based_extension.api_key - ) + api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=api_based_extension.api_key) inputs = {} for v in variables: - inputs[v.variable] = '{{#start.' + v.variable + '#}}' + inputs[v.variable] = "{{#start." + v.variable + "#}}" request_body = { - 'point': APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value, - 'params': { - 'app_id': app_model.id, - 'tool_variable': tool_variable, - 'inputs': inputs, - 'query': '{{#sys.query#}}' if app_model.mode == AppMode.CHAT.value else '' - } + "point": APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value, + "params": { + "app_id": app_model.id, + "tool_variable": tool_variable, + "inputs": inputs, + "query": "{{#sys.query#}}" if app_model.mode == AppMode.CHAT.value else "", + }, } request_body_json = json.dumps(request_body) - request_body_json = request_body_json.replace(r'\{\{', '{{').replace(r'\}\}', '}}') + request_body_json = request_body_json.replace(r"\{\{", "{{").replace(r"\}\}", "}}") http_request_node = { "id": f"http_request_{index}", @@ -317,20 +292,11 @@ def _convert_to_http_request_node(self, app_model: App, "type": NodeType.HTTP_REQUEST.value, "method": "post", "url": api_based_extension.api_endpoint, - "authorization": { - "type": "api-key", - "config": { - "type": "bearer", - "api_key": api_key - } - }, + "authorization": {"type": "api-key", "config": {"type": "bearer", "api_key": api_key}}, "headers": "", "params": "", - "body": { - "type": "json", - "data": request_body_json - } - } + "body": {"type": "json", "data": request_body_json}, + }, } nodes.append(http_request_node) @@ -342,32 +308,24 @@ def _convert_to_http_request_node(self, app_model: App, "data": { "title": f"Parse {api_based_extension.name} Response", "type": NodeType.CODE.value, - "variables": [{ - "variable": "response_json", - "value_selector": [http_request_node['id'], "body"] - }], + "variables": [{"variable": "response_json", "value_selector": [http_request_node["id"], "body"]}], "code_language": "python3", "code": "import json\n\ndef main(response_json: str) -> str:\n response_body = json.loads(" - "response_json)\n return {\n \"result\": response_body[\"result\"]\n }", - "outputs": { - "result": { - "type": "string" - } - } - } + 'response_json)\n return {\n "result": response_body["result"]\n }', + "outputs": {"result": {"type": "string"}}, + }, } nodes.append(code_node) - external_data_variable_node_mapping[external_data_variable.variable] = code_node['id'] + external_data_variable_node_mapping[external_data_variable.variable] = code_node["id"] index += 1 return nodes, external_data_variable_node_mapping - def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, - dataset_config: DatasetEntity, - model_config: ModelConfigEntity) \ - -> Optional[dict]: + def _convert_to_knowledge_retrieval_node( + self, new_app_mode: AppMode, dataset_config: DatasetEntity, model_config: ModelConfigEntity + ) -> Optional[dict]: """ Convert datasets to Knowledge Retrieval Node :param new_app_mode: new app mode @@ -401,7 +359,7 @@ def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, "completion_params": { **model_config.parameters, "stop": model_config.stop, - } + }, } } if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE @@ -409,20 +367,23 @@ def _convert_to_knowledge_retrieval_node(self, new_app_mode: AppMode, "multiple_retrieval_config": { "top_k": retrieve_config.top_k, "score_threshold": retrieve_config.score_threshold, - "reranking_model": retrieve_config.reranking_model + "reranking_model": retrieve_config.reranking_model, } if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE else None, - } + }, } - def _convert_to_llm_node(self, original_app_mode: AppMode, - new_app_mode: AppMode, - graph: dict, - model_config: ModelConfigEntity, - prompt_template: PromptTemplateEntity, - file_upload: Optional[FileExtraConfig] = None, - external_data_variable_node_mapping: dict[str, str] = None) -> dict: + def _convert_to_llm_node( + self, + original_app_mode: AppMode, + new_app_mode: AppMode, + graph: dict, + model_config: ModelConfigEntity, + prompt_template: PromptTemplateEntity, + file_upload: Optional[FileUploadConfig] = None, + external_data_variable_node_mapping: dict[str, str] | None = None, + ) -> dict: """ Convert to LLM Node :param original_app_mode: original app mode @@ -434,17 +395,18 @@ def _convert_to_llm_node(self, original_app_mode: AppMode, :param external_data_variable_node_mapping: external data variable node mapping """ # fetch start and knowledge retrieval node - start_node = next(filter(lambda n: n['data']['type'] == NodeType.START.value, graph['nodes'])) - knowledge_retrieval_node = next(filter( - lambda n: n['data']['type'] == NodeType.KNOWLEDGE_RETRIEVAL.value, - graph['nodes'] - ), None) + start_node = next(filter(lambda n: n["data"]["type"] == NodeType.START.value, graph["nodes"])) + knowledge_retrieval_node = next( + filter(lambda n: n["data"]["type"] == NodeType.KNOWLEDGE_RETRIEVAL.value, graph["nodes"]), None + ) role_prefix = None # Chat Model if model_config.mode == LLMMode.CHAT.value: if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + if not prompt_template.simple_prompt_template: + raise ValueError("Simple prompt template is required") # get prompt template prompt_transform = SimplePromptTransform() prompt_template_config = prompt_transform.get_prompt_template( @@ -453,45 +415,35 @@ def _convert_to_llm_node(self, original_app_mode: AppMode, model=model_config.model, pre_prompt=prompt_template.simple_prompt_template, has_context=knowledge_retrieval_node is not None, - query_in_prompt=False + query_in_prompt=False, ) - template = prompt_template_config['prompt_template'].template + template = prompt_template_config["prompt_template"].template if not template: prompts = [] else: template = self._replace_template_variables( - template, - start_node['data']['variables'], - external_data_variable_node_mapping + template, start_node["data"]["variables"], external_data_variable_node_mapping ) - prompts = [ - { - "role": 'user', - "text": template - } - ] + prompts = [{"role": "user", "text": template}] else: advanced_chat_prompt_template = prompt_template.advanced_chat_prompt_template prompts = [] - for m in advanced_chat_prompt_template.messages: - if advanced_chat_prompt_template: + if advanced_chat_prompt_template: + for m in advanced_chat_prompt_template.messages: text = m.text text = self._replace_template_variables( - text, - start_node['data']['variables'], - external_data_variable_node_mapping + text, start_node["data"]["variables"], external_data_variable_node_mapping ) - prompts.append({ - "role": m.role.value, - "text": text - }) + prompts.append({"role": m.role.value, "text": text}) # Completion Model else: if prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: + if not prompt_template.simple_prompt_template: + raise ValueError("Simple prompt template is required") # get prompt template prompt_transform = SimplePromptTransform() prompt_template_config = prompt_transform.get_prompt_template( @@ -500,57 +452,50 @@ def _convert_to_llm_node(self, original_app_mode: AppMode, model=model_config.model, pre_prompt=prompt_template.simple_prompt_template, has_context=knowledge_retrieval_node is not None, - query_in_prompt=False + query_in_prompt=False, ) - template = prompt_template_config['prompt_template'].template + template = prompt_template_config["prompt_template"].template template = self._replace_template_variables( - template, - start_node['data']['variables'], - external_data_variable_node_mapping + template=template, + variables=start_node["data"]["variables"], + external_data_variable_node_mapping=external_data_variable_node_mapping, ) - prompts = { - "text": template - } + prompts = {"text": template} - prompt_rules = prompt_template_config['prompt_rules'] + prompt_rules = prompt_template_config["prompt_rules"] role_prefix = { - "user": prompt_rules.get('human_prefix', 'Human'), - "assistant": prompt_rules.get('assistant_prefix', 'Assistant') + "user": prompt_rules.get("human_prefix", "Human"), + "assistant": prompt_rules.get("assistant_prefix", "Assistant"), } else: advanced_completion_prompt_template = prompt_template.advanced_completion_prompt_template if advanced_completion_prompt_template: text = advanced_completion_prompt_template.prompt text = self._replace_template_variables( - text, - start_node['data']['variables'], - external_data_variable_node_mapping + template=text, + variables=start_node["data"]["variables"], + external_data_variable_node_mapping=external_data_variable_node_mapping, ) else: text = "" - text = text.replace('{{#query#}}', '{{#sys.query#}}') + text = text.replace("{{#query#}}", "{{#sys.query#}}") prompts = { "text": text, } - if advanced_completion_prompt_template.role_prefix: + if advanced_completion_prompt_template and advanced_completion_prompt_template.role_prefix: role_prefix = { "user": advanced_completion_prompt_template.role_prefix.user, - "assistant": advanced_completion_prompt_template.role_prefix.assistant + "assistant": advanced_completion_prompt_template.role_prefix.assistant, } memory = None if new_app_mode == AppMode.ADVANCED_CHAT: - memory = { - "role_prefix": role_prefix, - "window": { - "enabled": False - } - } + memory = {"role_prefix": role_prefix, "window": {"enabled": False}} completion_params = model_config.parameters completion_params.update({"stop": model_config.stop}) @@ -564,28 +509,29 @@ def _convert_to_llm_node(self, original_app_mode: AppMode, "provider": model_config.provider, "name": model_config.model, "mode": model_config.mode, - "completion_params": completion_params + "completion_params": completion_params, }, "prompt_template": prompts, "memory": memory, "context": { "enabled": knowledge_retrieval_node is not None, "variable_selector": ["knowledge_retrieval", "result"] - if knowledge_retrieval_node is not None else None + if knowledge_retrieval_node is not None + else None, }, "vision": { "enabled": file_upload is not None, "variable_selector": ["sys", "files"] if file_upload is not None else None, - "configs": { - "detail": file_upload.image_config['detail'] - } if file_upload is not None else None - } - } + "configs": {"detail": file_upload.image_config.detail} + if file_upload is not None and file_upload.image_config is not None + else None, + }, + }, } - def _replace_template_variables(self, template: str, - variables: list[dict], - external_data_variable_node_mapping: dict[str, str] = None) -> str: + def _replace_template_variables( + self, template: str, variables: list[dict], external_data_variable_node_mapping: dict[str, str] | None = None + ) -> str: """ Replace Template Variables :param template: template @@ -594,12 +540,11 @@ def _replace_template_variables(self, template: str, :return: """ for v in variables: - template = template.replace('{{' + v['variable'] + '}}', '{{#start.' + v['variable'] + '#}}') + template = template.replace("{{" + v["variable"] + "}}", "{{#start." + v["variable"] + "#}}") if external_data_variable_node_mapping: for variable, code_node_id in external_data_variable_node_mapping.items(): - template = template.replace('{{' + variable + '}}', - '{{#' + code_node_id + '.result#}}') + template = template.replace("{{" + variable + "}}", "{{#" + code_node_id + ".result#}}") return template @@ -615,11 +560,8 @@ def _convert_to_end_node(self) -> dict: "data": { "title": "END", "type": NodeType.END.value, - "outputs": [{ - "variable": "result", - "value_selector": ["llm", "text"] - }] - } + "outputs": [{"variable": "result", "value_selector": ["llm", "text"]}], + }, } def _convert_to_answer_node(self) -> dict: @@ -631,11 +573,7 @@ def _convert_to_answer_node(self) -> dict: return { "id": "answer", "position": None, - "data": { - "title": "ANSWER", - "type": NodeType.ANSWER.value, - "answer": "{{#llm.text#}}" - } + "data": {"title": "ANSWER", "type": NodeType.ANSWER.value, "answer": "{{#llm.text#}}"}, } def _create_edge(self, source: str, target: str) -> dict: @@ -645,11 +583,7 @@ def _create_edge(self, source: str, target: str) -> dict: :param target: target node id :return: """ - return { - "id": f"{source}-{target}", - "source": source, - "target": target - } + return {"id": f"{source}-{target}", "source": source, "target": target} def _append_node(self, graph: dict, node: dict) -> dict: """ @@ -659,9 +593,9 @@ def _append_node(self, graph: dict, node: dict) -> dict: :param node: Node to append :return: """ - previous_node = graph['nodes'][-1] - graph['nodes'].append(node) - graph['edges'].append(self._create_edge(previous_node['id'], node['id'])) + previous_node = graph["nodes"][-1] + graph["nodes"].append(node) + graph["edges"].append(self._create_edge(previous_node["id"], node["id"])) return graph def _get_new_app_mode(self, app_model: App) -> AppMode: @@ -675,14 +609,20 @@ def _get_new_app_mode(self, app_model: App) -> AppMode: else: return AppMode.ADVANCED_CHAT - def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: + def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str): """ Get API Based Extension :param tenant_id: tenant id :param api_based_extension_id: api based extension id :return: """ - return db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + api_based_extension = ( + db.session.query(APIBasedExtension) + .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) + .first() + ) + + if not api_based_extension: + raise ValueError(f"API Based Extension not found, id: {api_based_extension_id}") + + return api_based_extension diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index c4d3d2763189a4..f89487415deef0 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -4,13 +4,12 @@ from sqlalchemy import and_, or_ from extensions.ext_database import db -from models import CreatedByRole -from models.model import App, EndUser -from models.workflow import WorkflowAppLog, WorkflowRun, WorkflowRunStatus +from models import App, EndUser, WorkflowAppLog, WorkflowRun +from models.enums import CreatedByRole +from models.workflow import WorkflowRunStatus class WorkflowAppService: - def get_paginate_workflow_app_logs(self, app_model: App, args: dict) -> Pagination: """ Get paginate workflow app logs @@ -18,20 +17,14 @@ def get_paginate_workflow_app_logs(self, app_model: App, args: dict) -> Paginati :param args: request args :return: """ - query = ( - db.select(WorkflowAppLog) - .where( - WorkflowAppLog.tenant_id == app_model.tenant_id, - WorkflowAppLog.app_id == app_model.id - ) + query = db.select(WorkflowAppLog).where( + WorkflowAppLog.tenant_id == app_model.tenant_id, WorkflowAppLog.app_id == app_model.id ) - status = WorkflowRunStatus.value_of(args.get('status')) if args.get('status') else None - keyword = args['keyword'] + status = WorkflowRunStatus.value_of(args.get("status", "")) if args.get("status") else None + keyword = args["keyword"] if keyword or status: - query = query.join( - WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id - ) + query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id) if keyword: keyword_like_val = f"%{args['keyword'][:30]}%" @@ -39,7 +32,7 @@ def get_paginate_workflow_app_logs(self, app_model: App, args: dict) -> Paginati WorkflowRun.inputs.ilike(keyword_like_val), WorkflowRun.outputs.ilike(keyword_like_val), # filter keyword by end user session id if created by end user role - and_(WorkflowRun.created_by_role == 'end_user', EndUser.session_id.ilike(keyword_like_val)) + and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)), ] # filter keyword by workflow run id @@ -49,23 +42,16 @@ def get_paginate_workflow_app_logs(self, app_model: App, args: dict) -> Paginati query = query.outerjoin( EndUser, - and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value) + and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER), ).filter(or_(*keyword_conditions)) if status: # join with workflow_run and filter by status - query = query.filter( - WorkflowRun.status == status.value - ) + query = query.filter(WorkflowRun.status == status.value) query = query.order_by(WorkflowAppLog.created_at.desc()) - pagination = db.paginate( - query, - page=args['page'], - per_page=args['limit'], - error_out=False - ) + pagination = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) return pagination diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index ccce38ada05fcd..d8ee323908a844 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -1,11 +1,11 @@ from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.enums import WorkflowRunTriggeredFrom from models.model import App from models.workflow import ( WorkflowNodeExecution, WorkflowNodeExecutionTriggeredFrom, WorkflowRun, - WorkflowRunTriggeredFrom, ) @@ -18,6 +18,7 @@ def get_paginate_advanced_chat_workflow_runs(self, app_model: App, args: dict) - :param app_model: app model :param args: request args """ + class WorkflowWithMessage: message_id: str conversation_id: str @@ -33,9 +34,7 @@ def __getattr__(self, item): with_message_workflow_runs = [] for workflow_run in pagination.data: message = workflow_run.message - with_message_workflow_run = WorkflowWithMessage( - workflow_run=workflow_run - ) + with_message_workflow_run = WorkflowWithMessage(workflow_run=workflow_run) if message: with_message_workflow_run.message_id = message.id with_message_workflow_run.conversation_id = message.conversation_id @@ -53,26 +52,30 @@ def get_paginate_workflow_runs(self, app_model: App, args: dict) -> InfiniteScro :param app_model: app model :param args: request args """ - limit = int(args.get('limit', 20)) + limit = int(args.get("limit", 20)) base_query = db.session.query(WorkflowRun).filter( WorkflowRun.tenant_id == app_model.tenant_id, WorkflowRun.app_id == app_model.id, - WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value + WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value, ) - if args.get('last_id'): + if args.get("last_id"): last_workflow_run = base_query.filter( - WorkflowRun.id == args.get('last_id'), + WorkflowRun.id == args.get("last_id"), ).first() if not last_workflow_run: - raise ValueError('Last workflow run not exists') - - workflow_runs = base_query.filter( - WorkflowRun.created_at < last_workflow_run.created_at, - WorkflowRun.id != last_workflow_run.id - ).order_by(WorkflowRun.created_at.desc()).limit(limit).all() + raise ValueError("Last workflow run not exists") + + workflow_runs = ( + base_query.filter( + WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id + ) + .order_by(WorkflowRun.created_at.desc()) + .limit(limit) + .all() + ) else: workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all() @@ -81,17 +84,13 @@ def get_paginate_workflow_runs(self, app_model: App, args: dict) -> InfiniteScro current_page_first_workflow_run = workflow_runs[-1] rest_count = base_query.filter( WorkflowRun.created_at < current_page_first_workflow_run.created_at, - WorkflowRun.id != current_page_first_workflow_run.id + WorkflowRun.id != current_page_first_workflow_run.id, ).count() if rest_count > 0: has_more = True - return InfiniteScrollPagination( - data=workflow_runs, - limit=limit, - has_more=has_more - ) + return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun: """ @@ -100,11 +99,15 @@ def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun: :param app_model: app model :param run_id: workflow run id """ - workflow_run = db.session.query(WorkflowRun).filter( - WorkflowRun.tenant_id == app_model.tenant_id, - WorkflowRun.app_id == app_model.id, - WorkflowRun.id == run_id, - ).first() + workflow_run = ( + db.session.query(WorkflowRun) + .filter( + WorkflowRun.tenant_id == app_model.tenant_id, + WorkflowRun.app_id == app_model.id, + WorkflowRun.id == run_id, + ) + .first() + ) return workflow_run @@ -117,12 +120,17 @@ def get_workflow_run_node_executions(self, app_model: App, run_id: str) -> list[ if not workflow_run: return [] - node_executions = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.tenant_id == app_model.tenant_id, - WorkflowNodeExecution.app_id == app_model.id, - WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - WorkflowNodeExecution.workflow_run_id == run_id, - ).order_by(WorkflowNodeExecution.index.desc()).all() + node_executions = ( + db.session.query(WorkflowNodeExecution) + .filter( + WorkflowNodeExecution.tenant_id == app_model.tenant_id, + WorkflowNodeExecution.app_id == app_model.id, + WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.workflow_run_id == run_id, + ) + .order_by(WorkflowNodeExecution.index.desc()) + .all() + ) return node_executions diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 2defb4cd6a7088..7187d405178a4e 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -6,17 +6,20 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager -from core.app.segments import Variable from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities.node_entities import NodeType +from core.variables import Variable +from core.workflow.entities.node_entities import NodeRunResult from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.workflow_engine_manager import WorkflowEngineManager +from core.workflow.nodes import NodeType +from core.workflow.nodes.event import RunCompletedEvent +from core.workflow.nodes.node_mapping import node_type_classes_mapping +from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db from models.account import Account +from models.enums import CreatedByRole from models.model import App, AppMode from models.workflow import ( - CreatedByRole, Workflow, WorkflowNodeExecution, WorkflowNodeExecutionStatus, @@ -37,11 +40,13 @@ def get_draft_workflow(self, app_model: App) -> Optional[Workflow]: Get draft workflow """ # fetch draft workflow by app_model - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.version == 'draft' - ).first() + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == "draft" + ) + .first() + ) # return draft workflow return workflow @@ -55,11 +60,15 @@ def get_published_workflow(self, app_model: App) -> Optional[Workflow]: return None # fetch published workflow by workflow_id - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.id == app_model.workflow_id - ).first() + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.id == app_model.workflow_id, + ) + .first() + ) return workflow @@ -85,10 +94,7 @@ def sync_draft_workflow( raise WorkflowHashNotEqualError() # validate features structure - self.validate_features_structure( - app_model=app_model, - features=features - ) + self.validate_features_structure(app_model=app_model, features=features) # create draft workflow if not found if not workflow: @@ -96,7 +102,7 @@ def sync_draft_workflow( tenant_id=app_model.tenant_id, app_id=app_model.id, type=WorkflowType.from_app_mode(app_model.mode).value, - version='draft', + version="draft", graph=json.dumps(graph), features=json.dumps(features), created_by=account.id, @@ -122,9 +128,7 @@ def sync_draft_workflow( # return draft workflow return workflow - def publish_workflow(self, app_model: App, - account: Account, - draft_workflow: Optional[Workflow] = None) -> Workflow: + def publish_workflow(self, app_model: App, account: Account, draft_workflow: Optional[Workflow] = None) -> Workflow: """ Publish workflow from draft @@ -137,7 +141,7 @@ def publish_workflow(self, app_model: App, draft_workflow = self.get_draft_workflow(app_model=app_model) if not draft_workflow: - raise ValueError('No valid workflow found.') + raise ValueError("No valid workflow found.") # create new workflow workflow = Workflow( @@ -171,8 +175,13 @@ def get_default_block_configs(self) -> list[dict]: Get default block configs """ # return default block config - workflow_engine_manager = WorkflowEngineManager() - return workflow_engine_manager.get_default_configs() + default_block_configs = [] + for node_type, node_class in node_type_classes_mapping.items(): + default_config = node_class.get_default_config() + if default_config: + default_block_configs.append(default_config) + + return default_block_configs def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]: """ @@ -181,100 +190,93 @@ def get_default_block_config(self, node_type: str, filters: Optional[dict] = Non :param filters: filter by node config parameters. :return: """ - node_type = NodeType.value_of(node_type) + node_type_enum: NodeType = NodeType(node_type) # return default block config - workflow_engine_manager = WorkflowEngineManager() - return workflow_engine_manager.get_default_config(node_type, filters) + node_class = node_type_classes_mapping.get(node_type_enum) + if not node_class: + return None + + default_config = node_class.get_default_config(filters=filters) + if not default_config: + return None - def run_draft_workflow_node(self, app_model: App, - node_id: str, - user_inputs: dict, - account: Account) -> WorkflowNodeExecution: + return default_config + + def run_draft_workflow_node( + self, app_model: App, node_id: str, user_inputs: dict, account: Account + ) -> WorkflowNodeExecution: """ Run draft workflow node """ # fetch draft workflow by app_model draft_workflow = self.get_draft_workflow(app_model=app_model) if not draft_workflow: - raise ValueError('Workflow not initialized') + raise ValueError("Workflow not initialized") # run draft workflow node - workflow_engine_manager = WorkflowEngineManager() start_at = time.perf_counter() try: - node_instance, node_run_result = workflow_engine_manager.single_step_run_workflow_node( + node_instance, generator = WorkflowEntry.single_step_run( workflow=draft_workflow, node_id=node_id, user_inputs=user_inputs, user_id=account.id, ) - except WorkflowNodeRunFailedError as e: - workflow_node_execution = WorkflowNodeExecution( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - workflow_id=draft_workflow.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, - index=1, - node_id=e.node_id, - node_type=e.node_type.value, - title=e.node_title, - status=WorkflowNodeExecutionStatus.FAILED.value, - error=e.error, - elapsed_time=time.perf_counter() - start_at, - created_by_role=CreatedByRole.ACCOUNT.value, - created_by=account.id, - created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None) - ) - db.session.add(workflow_node_execution) - db.session.commit() - return workflow_node_execution + node_run_result: NodeRunResult | None = None + for event in generator: + if isinstance(event, RunCompletedEvent): + node_run_result = event.run_result + + # sign output files + node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) + break - if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + if not node_run_result: + raise ValueError("Node run failed with no run result") + + run_succeeded = True if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED else False + error = node_run_result.error if not run_succeeded else None + except WorkflowNodeRunFailedError as e: + node_instance = e.node_instance + run_succeeded = False + node_run_result = None + error = e.error + + workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.tenant_id = app_model.tenant_id + workflow_node_execution.app_id = app_model.id + workflow_node_execution.workflow_id = draft_workflow.id + workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value + workflow_node_execution.index = 1 + workflow_node_execution.node_id = node_id + workflow_node_execution.node_type = node_instance.node_type + workflow_node_execution.title = node_instance.node_data.title + workflow_node_execution.elapsed_time = time.perf_counter() - start_at + workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value + workflow_node_execution.created_by = account.id + workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + + if run_succeeded and node_run_result: # create workflow node execution - workflow_node_execution = WorkflowNodeExecution( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - workflow_id=draft_workflow.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, - index=1, - node_id=node_id, - node_type=node_instance.node_type.value, - title=node_instance.node_data.title, - inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None, - process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None, - outputs=json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None, - execution_metadata=(json.dumps(jsonable_encoder(node_run_result.metadata)) - if node_run_result.metadata else None), - status=WorkflowNodeExecutionStatus.SUCCEEDED.value, - elapsed_time=time.perf_counter() - start_at, - created_by_role=CreatedByRole.ACCOUNT.value, - created_by=account.id, - created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.inputs = json.dumps(node_run_result.inputs) if node_run_result.inputs else None + workflow_node_execution.process_data = ( + json.dumps(node_run_result.process_data) if node_run_result.process_data else None + ) + workflow_node_execution.outputs = ( + json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None ) + workflow_node_execution.execution_metadata = ( + json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None + ) + workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value else: # create workflow node execution - workflow_node_execution = WorkflowNodeExecution( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - workflow_id=draft_workflow.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value, - index=1, - node_id=node_id, - node_type=node_instance.node_type.value, - title=node_instance.node_data.title, - status=node_run_result.status.value, - error=node_run_result.error, - elapsed_time=time.perf_counter() - start_at, - created_by_role=CreatedByRole.ACCOUNT.value, - created_by=account.id, - created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None) - ) + workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.error = error db.session.add(workflow_node_execution) db.session.commit() @@ -294,16 +296,17 @@ def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> A # chatbot convert to workflow mode workflow_converter = WorkflowConverter() - if app_model.mode not in [AppMode.CHAT.value, AppMode.COMPLETION.value]: - raise ValueError(f'Current App mode: {app_model.mode} is not supported convert to workflow.') + if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}: + raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") # convert to workflow new_app = workflow_converter.convert_to_workflow( app_model=app_model, account=account, - name=args.get('name'), - icon=args.get('icon'), - icon_background=args.get('icon_background'), + name=args.get("name"), + icon_type=args.get("icon_type"), + icon=args.get("icon"), + icon_background=args.get("icon_background"), ) return new_app @@ -311,37 +314,11 @@ def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> A def validate_features_structure(self, app_model: App, features: dict) -> dict: if app_model.mode == AppMode.ADVANCED_CHAT.value: return AdvancedChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=features, - only_structure_validate=True + tenant_id=app_model.tenant_id, config=features, only_structure_validate=True ) elif app_model.mode == AppMode.WORKFLOW.value: return WorkflowAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=features, - only_structure_validate=True + tenant_id=app_model.tenant_id, config=features, only_structure_validate=True ) else: raise ValueError(f"Invalid app mode: {app_model.mode}") - - @classmethod - def get_elapsed_time(cls, workflow_run_id: str) -> float: - """ - Get elapsed time - """ - elapsed_time = 0.0 - - # fetch workflow node execution by workflow_run_id - workflow_nodes = ( - db.session.query(WorkflowNodeExecution) - .filter(WorkflowNodeExecution.workflow_run_id == workflow_run_id) - .order_by(WorkflowNodeExecution.created_at.asc()) - .all() - ) - if not workflow_nodes: - return elapsed_time - - for node in workflow_nodes: - elapsed_time += node.elapsed_time - - return elapsed_time diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 2bcbe5c6f6318a..8fcb12b1cb9664 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -1,4 +1,3 @@ - from flask_login import current_user from configs import dify_config @@ -14,34 +13,40 @@ def get_tenant_info(cls, tenant: Tenant): if not tenant: return None tenant_info = { - 'id': tenant.id, - 'name': tenant.name, - 'plan': tenant.plan, - 'status': tenant.status, - 'created_at': tenant.created_at, - 'in_trail': True, - 'trial_end_reason': None, - 'role': 'normal', + "id": tenant.id, + "name": tenant.name, + "plan": tenant.plan, + "status": tenant.status, + "created_at": tenant.created_at, + "in_trail": True, + "trial_end_reason": None, + "role": "normal", } # Get role of user - tenant_account_join = db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.tenant_id == tenant.id, - TenantAccountJoin.account_id == current_user.id - ).first() - tenant_info['role'] = tenant_account_join.role - - can_replace_logo = FeatureService.get_features(tenant_info['id']).can_replace_logo - - if can_replace_logo and TenantService.has_roles(tenant, - [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]): + tenant_account_join = ( + db.session.query(TenantAccountJoin) + .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id) + .first() + ) + tenant_info["role"] = tenant_account_join.role + + can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo + + if can_replace_logo and TenantService.has_roles( + tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN] + ): base_url = dify_config.FILES_URL - replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None - remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False) - - tenant_info['custom_config'] = { - 'remove_webapp_brand': remove_webapp_brand, - 'replace_webapp_logo': replace_webapp_logo, + replace_webapp_logo = ( + f"{base_url}/files/workspaces/{tenant.id}/webapp-logo" + if tenant.custom_config_dict.get("replace_webapp_logo") + else None + ) + remove_webapp_brand = tenant.custom_config_dict.get("remove_webapp_brand", False) + + tenant_info["custom_config"] = { + "remove_webapp_brand": remove_webapp_brand, + "replace_webapp_logo": replace_webapp_logo, } return tenant_info diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index e0a1b219095b92..b50876cc794c55 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -14,7 +14,7 @@ from models.dataset import DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def add_document_to_index_task(dataset_document_id: str): """ Async Add document to index @@ -22,24 +22,25 @@ def add_document_to_index_task(dataset_document_id: str): Usage: add_document_to_index.delay(document_id) """ - logging.info(click.style('Start add document to index: {}'.format(dataset_document_id), fg='green')) + logging.info(click.style("Start add document to index: {}".format(dataset_document_id), fg="green")) start_at = time.perf_counter() dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document_id).first() if not dataset_document: - raise NotFound('Document not found') + raise NotFound("Document not found") - if dataset_document.indexing_status != 'completed': + if dataset_document.indexing_status != "completed": return - indexing_cache_key = 'document_{}_indexing'.format(dataset_document.id) + indexing_cache_key = "document_{}_indexing".format(dataset_document.id) try: - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.enabled == True - ) \ - .order_by(DocumentSegment.position.asc()).all() + segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .order_by(DocumentSegment.position.asc()) + .all() + ) documents = [] for segment in segments: @@ -50,7 +51,7 @@ def add_document_to_index_task(dataset_document_id: str): "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) @@ -58,7 +59,7 @@ def add_document_to_index_task(dataset_document_id: str): dataset = dataset_document.dataset if not dataset: - raise Exception('Document has no dataset') + raise Exception("Document has no dataset") index_type = dataset.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -66,12 +67,15 @@ def add_document_to_index_task(dataset_document_id: str): end_at = time.perf_counter() logging.info( - click.style('Document added to index: {} latency: {}'.format(dataset_document.id, end_at - start_at), fg='green')) + click.style( + "Document added to index: {} latency: {}".format(dataset_document.id, end_at - start_at), fg="green" + ) + ) except Exception as e: logging.exception("add document to index failed") dataset_document.enabled = False dataset_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - dataset_document.status = 'error' + dataset_document.status = "error" dataset_document.error = str(e) db.session.commit() finally: diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index b3aa8b596c8a3c..25c55bcfafe11c 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -10,9 +10,10 @@ from services.dataset_service import DatasetCollectionBindingService -@shared_task(queue='dataset') -def add_annotation_to_index_task(annotation_id: str, question: str, tenant_id: str, app_id: str, - collection_binding_id: str): +@shared_task(queue="dataset") +def add_annotation_to_index_task( + annotation_id: str, question: str, tenant_id: str, app_id: str, collection_binding_id: str +): """ Add annotation to index. :param annotation_id: annotation id @@ -23,38 +24,34 @@ def add_annotation_to_index_task(annotation_id: str, question: str, tenant_id: s Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) """ - logging.info(click.style('Start build index for annotation: {}'.format(annotation_id), fg='green')) + logging.info(click.style("Start build index for annotation: {}".format(annotation_id), fg="green")) start_at = time.perf_counter() try: dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - collection_binding_id, - 'annotation' + collection_binding_id, "annotation" ) dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) document = Document( - page_content=question, - metadata={ - "annotation_id": annotation_id, - "app_id": app_id, - "doc_id": annotation_id - } + page_content=question, metadata={"annotation_id": annotation_id, "app_id": app_id, "doc_id": annotation_id} ) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) vector.create([document], duplicate_check=True) end_at = time.perf_counter() logging.info( click.style( - 'Build index successful for annotation: {} latency: {}'.format(annotation_id, end_at - start_at), - fg='green')) + "Build index successful for annotation: {} latency: {}".format(annotation_id, end_at - start_at), + fg="green", + ) + ) except Exception: logging.exception("Build index for annotation failed") diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index 6e6b16045dad9b..fa7e5ac9190f3c 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -14,9 +14,8 @@ from services.dataset_service import DatasetCollectionBindingService -@shared_task(queue='dataset') -def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: str, tenant_id: str, - user_id: str): +@shared_task(queue="dataset") +def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: str, tenant_id: str, user_id: str): """ Add annotation to index. :param job_id: job_id @@ -26,72 +25,66 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: :param user_id: user_id """ - logging.info(click.style('Start batch import annotation: {}'.format(job_id), fg='green')) + logging.info(click.style("Start batch import annotation: {}".format(job_id), fg="green")) start_at = time.perf_counter() - indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id)) + indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == tenant_id, - App.status == 'normal' - ).first() + app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() if app: try: documents = [] for content in content_list: annotation = MessageAnnotation( - app_id=app.id, - content=content['answer'], - question=content['question'], - account_id=user_id + app_id=app.id, content=content["answer"], question=content["question"], account_id=user_id ) db.session.add(annotation) db.session.flush() document = Document( - page_content=content['question'], - metadata={ - "annotation_id": annotation.id, - "app_id": app_id, - "doc_id": annotation.id - } + page_content=content["question"], + metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id}, ) documents.append(document) # if annotation reply is enabled , batch add annotations' index - app_annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id - ).first() + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if app_annotation_setting: - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - app_annotation_setting.collection_binding_id, - 'annotation' + dataset_collection_binding = ( + DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + app_annotation_setting.collection_binding_id, "annotation" + ) ) if not dataset_collection_binding: raise NotFound("App annotation setting not found") dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) vector.create(documents, duplicate_check=True) db.session.commit() - redis_client.setex(indexing_cache_key, 600, 'completed') + redis_client.setex(indexing_cache_key, 600, "completed") end_at = time.perf_counter() logging.info( click.style( - 'Build index successful for batch import annotation: {} latency: {}'.format(job_id, end_at - start_at), - fg='green')) + "Build index successful for batch import annotation: {} latency: {}".format( + job_id, end_at - start_at + ), + fg="green", + ) + ) except Exception as e: db.session.rollback() - redis_client.setex(indexing_cache_key, 600, 'error') - indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id)) + redis_client.setex(indexing_cache_key, 600, "error") + indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id)) redis_client.setex(indexing_error_msg_key, 600, str(e)) logging.exception("Build index for batch import annotations failed") diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index 81155a35e42b0e..5758db53de820b 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -9,36 +9,33 @@ from services.dataset_service import DatasetCollectionBindingService -@shared_task(queue='dataset') -def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str, - collection_binding_id: str): +@shared_task(queue="dataset") +def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str, collection_binding_id: str): """ Async delete annotation index task """ - logging.info(click.style('Start delete app annotation index: {}'.format(app_id), fg='green')) + logging.info(click.style("Start delete app annotation index: {}".format(app_id), fg="green")) start_at = time.perf_counter() try: dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - collection_binding_id, - 'annotation' + collection_binding_id, "annotation" ) dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', - collection_binding_id=dataset_collection_binding.id + indexing_technique="high_quality", + collection_binding_id=dataset_collection_binding.id, ) try: - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) - vector.delete_by_metadata_field('annotation_id', annotation_id) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + vector.delete_by_metadata_field("annotation_id", annotation_id) except Exception: logging.exception("Delete annotation index failed when annotation deleted.") end_at = time.perf_counter() logging.info( - click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at), - fg='green')) + click.style("App annotations index deleted : {} latency: {}".format(app_id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("Annotation deleted index failed:{}".format(str(e))) - diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index e5e03c9b51758b..0f83dfdbd4a72f 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -12,49 +12,44 @@ from models.model import App, AppAnnotationSetting, MessageAnnotation -@shared_task(queue='dataset') +@shared_task(queue="dataset") def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): """ Async enable annotation reply task """ - logging.info(click.style('Start delete app annotations index: {}'.format(app_id), fg='green')) + logging.info(click.style("Start delete app annotations index: {}".format(app_id), fg="green")) start_at = time.perf_counter() # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == tenant_id, - App.status == 'normal' - ).first() + app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() annotations_count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).count() if not app: raise NotFound("App not found") - app_annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id - ).first() + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if not app_annotation_setting: raise NotFound("App annotation setting not found") - disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id)) - disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id)) + disable_app_annotation_key = "disable_app_annotation_{}".format(str(app_id)) + disable_app_annotation_job_key = "disable_app_annotation_job_{}".format(str(job_id)) try: - dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', - collection_binding_id=app_annotation_setting.collection_binding_id + indexing_technique="high_quality", + collection_binding_id=app_annotation_setting.collection_binding_id, ) try: if annotations_count > 0: - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) - vector.delete_by_metadata_field('app_id', app_id) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + vector.delete_by_metadata_field("app_id", app_id) except Exception: logging.exception("Delete annotation index failed when annotation deleted.") - redis_client.setex(disable_app_annotation_job_key, 600, 'completed') + redis_client.setex(disable_app_annotation_job_key, 600, "completed") # delete annotation setting db.session.delete(app_annotation_setting) @@ -62,12 +57,12 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): end_at = time.perf_counter() logging.info( - click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at), - fg='green')) + click.style("App annotations index deleted : {} latency: {}".format(app_id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("Annotation batch deleted index failed:{}".format(str(e))) - redis_client.setex(disable_app_annotation_job_key, 600, 'error') - disable_app_annotation_error_key = 'disable_app_annotation_error_{}'.format(str(job_id)) + redis_client.setex(disable_app_annotation_job_key, 600, "error") + disable_app_annotation_error_key = "disable_app_annotation_error_{}".format(str(job_id)) redis_client.setex(disable_app_annotation_error_key, 600, str(e)) finally: redis_client.delete(disable_app_annotation_key) diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index fda8b7a250f128..82b70f6b71eddf 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -15,37 +15,39 @@ from services.dataset_service import DatasetCollectionBindingService -@shared_task(queue='dataset') -def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_id: str, score_threshold: float, - embedding_provider_name: str, embedding_model_name: str): +@shared_task(queue="dataset") +def enable_annotation_reply_task( + job_id: str, + app_id: str, + user_id: str, + tenant_id: str, + score_threshold: float, + embedding_provider_name: str, + embedding_model_name: str, +): """ Async enable annotation reply task """ - logging.info(click.style('Start add app annotation to index: {}'.format(app_id), fg='green')) + logging.info(click.style("Start add app annotation to index: {}".format(app_id), fg="green")) start_at = time.perf_counter() # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == tenant_id, - App.status == 'normal' - ).first() + app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() if not app: raise NotFound("App not found") annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).all() - enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id)) - enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id)) + enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id)) + enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id)) try: documents = [] dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_provider_name, - embedding_model_name, - 'annotation' + embedding_provider_name, embedding_model_name, "annotation" + ) + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() ) - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id).first() if annotation_setting: annotation_setting.score_threshold = score_threshold annotation_setting.collection_binding_id = dataset_collection_binding.id @@ -58,48 +60,42 @@ def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_ score_threshold=score_threshold, collection_binding_id=dataset_collection_binding.id, created_user_id=user_id, - updated_user_id=user_id + updated_user_id=user_id, ) db.session.add(new_app_annotation_setting) dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=embedding_provider_name, embedding_model=embedding_model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) if annotations: for annotation in annotations: document = Document( page_content=annotation.question, - metadata={ - "annotation_id": annotation.id, - "app_id": app_id, - "doc_id": annotation.id - } + metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id}, ) documents.append(document) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) try: - vector.delete_by_metadata_field('app_id', app_id) + vector.delete_by_metadata_field("app_id", app_id) except Exception as e: - logging.info( - click.style('Delete annotation index error: {}'.format(str(e)), - fg='red')) + logging.info(click.style("Delete annotation index error: {}".format(str(e)), fg="red")) vector.create(documents) db.session.commit() - redis_client.setex(enable_app_annotation_job_key, 600, 'completed') + redis_client.setex(enable_app_annotation_job_key, 600, "completed") end_at = time.perf_counter() logging.info( - click.style('App annotations added to index: {} latency: {}'.format(app_id, end_at - start_at), - fg='green')) + click.style("App annotations added to index: {} latency: {}".format(app_id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("Annotation batch created index failed:{}".format(str(e))) - redis_client.setex(enable_app_annotation_job_key, 600, 'error') - enable_app_annotation_error_key = 'enable_app_annotation_error_{}'.format(str(job_id)) + redis_client.setex(enable_app_annotation_job_key, 600, "error") + enable_app_annotation_error_key = "enable_app_annotation_error_{}".format(str(job_id)) redis_client.setex(enable_app_annotation_error_key, 600, str(e)) db.session.rollback() finally: diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index 7219abd3cdf6fc..b685d84d07ad28 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -10,9 +10,10 @@ from services.dataset_service import DatasetCollectionBindingService -@shared_task(queue='dataset') -def update_annotation_to_index_task(annotation_id: str, question: str, tenant_id: str, app_id: str, - collection_binding_id: str): +@shared_task(queue="dataset") +def update_annotation_to_index_task( + annotation_id: str, question: str, tenant_id: str, app_id: str, collection_binding_id: str +): """ Update annotation to index. :param annotation_id: annotation id @@ -23,39 +24,35 @@ def update_annotation_to_index_task(annotation_id: str, question: str, tenant_id Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) """ - logging.info(click.style('Start update index for annotation: {}'.format(annotation_id), fg='green')) + logging.info(click.style("Start update index for annotation: {}".format(annotation_id), fg="green")) start_at = time.perf_counter() try: dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - collection_binding_id, - 'annotation' + collection_binding_id, "annotation" ) dataset = Dataset( id=app_id, tenant_id=tenant_id, - indexing_technique='high_quality', + indexing_technique="high_quality", embedding_model_provider=dataset_collection_binding.provider_name, embedding_model=dataset_collection_binding.model_name, - collection_binding_id=dataset_collection_binding.id + collection_binding_id=dataset_collection_binding.id, ) document = Document( - page_content=question, - metadata={ - "annotation_id": annotation_id, - "app_id": app_id, - "doc_id": annotation_id - } + page_content=question, metadata={"annotation_id": annotation_id, "app_id": app_id, "doc_id": annotation_id} ) - vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id']) - vector.delete_by_metadata_field('annotation_id', annotation_id) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + vector.delete_by_metadata_field("annotation_id", annotation_id) vector.add_texts([document]) end_at = time.perf_counter() logging.info( click.style( - 'Build index successful for annotation: {} latency: {}'.format(annotation_id, end_at - start_at), - fg='green')) + "Build index successful for annotation: {} latency: {}".format(annotation_id, end_at - start_at), + fg="green", + ) + ) except Exception: logging.exception("Build index for annotation failed") diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 67cc03bdebee54..d1b41f26751519 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -16,9 +16,10 @@ from models.dataset import Dataset, Document, DocumentSegment -@shared_task(queue='dataset') -def batch_create_segment_to_index_task(job_id: str, content: list, dataset_id: str, document_id: str, - tenant_id: str, user_id: str): +@shared_task(queue="dataset") +def batch_create_segment_to_index_task( + job_id: str, content: list, dataset_id: str, document_id: str, tenant_id: str, user_id: str +): """ Async batch create segment to index :param job_id: @@ -30,44 +31,44 @@ def batch_create_segment_to_index_task(job_id: str, content: list, dataset_id: s Usage: batch_create_segment_to_index_task.delay(segment_id) """ - logging.info(click.style('Start batch create segment jobId: {}'.format(job_id), fg='green')) + logging.info(click.style("Start batch create segment jobId: {}".format(job_id), fg="green")) start_at = time.perf_counter() - indexing_cache_key = 'segment_batch_import_{}'.format(job_id) + indexing_cache_key = "segment_batch_import_{}".format(job_id) try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset not exist.') + raise ValueError("Dataset not exist.") dataset_document = db.session.query(Document).filter(Document.id == document_id).first() if not dataset_document: - raise ValueError('Document not exist.') + raise ValueError("Document not exist.") - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - raise ValueError('Document is not available.') + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + raise ValueError("Document is not available.") document_segments = [] embedding_model = None - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) - + word_count_change = 0 for segment in content: - content = segment['content'] + content = segment["content"] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[content] - ) if embedding_model else 0 - max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document_id == dataset_document.id - ).scalar() + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) if embedding_model else 0 + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .filter(DocumentSegment.document_id == dataset_document.id) + .scalar() + ) segment_document = DocumentSegment( tenant_id=tenant_id, dataset_id=dataset_id, @@ -80,20 +81,27 @@ def batch_create_segment_to_index_task(job_id: str, content: list, dataset_id: s tokens=tokens, created_by=user_id, indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), - status='completed', - completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + status="completed", + completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), ) - if dataset_document.doc_form == 'qa_model': - segment_document.answer = segment['answer'] + if dataset_document.doc_form == "qa_model": + segment_document.answer = segment["answer"] + segment_document.word_count += len(segment["answer"]) + word_count_change += segment_document.word_count db.session.add(segment_document) document_segments.append(segment_document) + # update document word count + dataset_document.word_count += word_count_change + db.session.add(dataset_document) # add index to db indexing_runner = IndexingRunner() indexing_runner.batch_add_segments(document_segments, dataset) db.session.commit() - redis_client.setex(indexing_cache_key, 600, 'completed') + redis_client.setex(indexing_cache_key, 600, "completed") end_at = time.perf_counter() - logging.info(click.style('Segment batch created job: {} latency: {}'.format(job_id, end_at - start_at), fg='green')) + logging.info( + click.style("Segment batch created job: {} latency: {}".format(job_id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("Segments batch created index failed:{}".format(str(e))) - redis_client.setex(indexing_cache_key, 600, 'error') + redis_client.setex(indexing_cache_key, 600, "error") diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 1f26c966c46af5..36249038011747 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -19,9 +19,15 @@ # Add import statement for ValueError -@shared_task(queue='dataset') -def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, - index_struct: str, collection_binding_id: str, doc_form: str): +@shared_task(queue="dataset") +def clean_dataset_task( + dataset_id: str, + tenant_id: str, + indexing_technique: str, + index_struct: str, + collection_binding_id: str, + doc_form: str, +): """ Clean dataset when dataset deleted. :param dataset_id: dataset id @@ -33,7 +39,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) """ - logging.info(click.style('Start clean dataset when dataset deleted: {}'.format(dataset_id), fg='green')) + logging.info(click.style("Start clean dataset when dataset deleted: {}".format(dataset_id), fg="green")) start_at = time.perf_counter() try: @@ -48,9 +54,9 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() if documents is None or len(documents) == 0: - logging.info(click.style('No documents found for dataset: {}'.format(dataset_id), fg='green')) + logging.info(click.style("No documents found for dataset: {}".format(dataset_id), fg="green")) else: - logging.info(click.style('Cleaning documents for dataset: {}'.format(dataset_id), fg='green')) + logging.info(click.style("Cleaning documents for dataset: {}".format(dataset_id), fg="green")) # Specify the index type before initializing the index processor if doc_form is None: raise ValueError("Index type must be specified.") @@ -71,15 +77,16 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, if documents: for document in documents: try: - if document.data_source_type == 'upload_file': + if document.data_source_type == "upload_file": if document.data_source_info: data_source_info = document.data_source_info_dict - if data_source_info and 'upload_file_id' in data_source_info: - file_id = data_source_info['upload_file_id'] - file = db.session.query(UploadFile).filter( - UploadFile.tenant_id == document.tenant_id, - UploadFile.id == file_id - ).first() + if data_source_info and "upload_file_id" in data_source_info: + file_id = data_source_info["upload_file_id"] + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) + .first() + ) if not file: continue storage.delete(file.key) @@ -90,6 +97,9 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, db.session.commit() end_at = time.perf_counter() logging.info( - click.style('Cleaned dataset when dataset deleted: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + click.style( + "Cleaned dataset when dataset deleted: {} latency: {}".format(dataset_id, end_at - start_at), fg="green" + ) + ) except Exception: logging.exception("Cleaned dataset when dataset deleted failed") diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 0fd05615b65430..ae2855aa2ebc4d 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -12,7 +12,7 @@ from models.model import UploadFile -@shared_task(queue='dataset') +@shared_task(queue="dataset") def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_id: Optional[str]): """ Clean document when document deleted. @@ -23,14 +23,14 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i Usage: clean_document_task.delay(document_id, dataset_id) """ - logging.info(click.style('Start clean document when document deleted: {}'.format(document_id), fg='green')) + logging.info(click.style("Start clean document when document deleted: {}".format(document_id), fg="green")) start_at = time.perf_counter() try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - raise Exception('Document has no dataset') + raise Exception("Document has no dataset") segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() # check segment is exist @@ -44,9 +44,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i db.session.commit() if file_id: - file = db.session.query(UploadFile).filter( - UploadFile.id == file_id - ).first() + file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() if file: try: storage.delete(file.key) @@ -57,6 +55,10 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i end_at = time.perf_counter() logging.info( - click.style('Cleaned document when document deleted: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) + click.style( + "Cleaned document when document deleted: {} latency: {}".format(document_id, end_at - start_at), + fg="green", + ) + ) except Exception: logging.exception("Cleaned document when document deleted failed") diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index 9b697b63511a12..75d9e031306381 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -9,7 +9,7 @@ from models.dataset import Dataset, Document, DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def clean_notion_document_task(document_ids: list[str], dataset_id: str): """ Clean document when document deleted. @@ -18,20 +18,20 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): Usage: clean_notion_document_task.delay(document_ids, dataset_id) """ - logging.info(click.style('Start clean document when import form notion document deleted: {}'.format(dataset_id), fg='green')) + logging.info( + click.style("Start clean document when import form notion document deleted: {}".format(dataset_id), fg="green") + ) start_at = time.perf_counter() try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - raise Exception('Document has no dataset') + raise Exception("Document has no dataset") index_type = dataset.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() for document_id in document_ids: - document = db.session.query(Document).filter( - Document.id == document_id - ).first() + document = db.session.query(Document).filter(Document.id == document_id).first() db.session.delete(document) segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() @@ -44,8 +44,12 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): db.session.commit() end_at = time.perf_counter() logging.info( - click.style('Clean document when import form notion document deleted end :: {} latency: {}'.format( - dataset_id, end_at - start_at), - fg='green')) + click.style( + "Clean document when import form notion document deleted end :: {} latency: {}".format( + dataset_id, end_at - start_at + ), + fg="green", + ) + ) except Exception: logging.exception("Cleaned document when import form notion document deleted failed") diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index d31286e4cc4c42..26375743b68e25 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -14,7 +14,7 @@ from models.dataset import DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] = None): """ Async create segment to index @@ -22,23 +22,23 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] :param keywords: Usage: create_segment_to_index_task.delay(segment_id) """ - logging.info(click.style('Start create segment to index: {}'.format(segment_id), fg='green')) + logging.info(click.style("Start create segment to index: {}".format(segment_id), fg="green")) start_at = time.perf_counter() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() if not segment: - raise NotFound('Segment not found') + raise NotFound("Segment not found") - if segment.status != 'waiting': + if segment.status != "waiting": return - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_indexing".format(segment.id) try: # update segment status to indexing update_params = { DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), } DocumentSegment.query.filter_by(id=segment.id).update(update_params) db.session.commit() @@ -49,23 +49,23 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) dataset = segment.dataset if not dataset: - logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no dataset, pass.".format(segment.id), fg="cyan")) return dataset_document = segment.document if not dataset_document: - logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no document, pass.".format(segment.id), fg="cyan")) return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Segment {} document status is invalid, pass.".format(segment.id), fg="cyan")) return index_type = dataset.doc_form @@ -75,18 +75,20 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] # update segment to completed update_params = { DocumentSegment.status: "completed", - DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), } DocumentSegment.query.filter_by(id=segment.id).update(update_params) db.session.commit() end_at = time.perf_counter() - logging.info(click.style('Segment created to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) + logging.info( + click.style("Segment created to index: {} latency: {}".format(segment.id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("create segment to index failed") segment.enabled = False segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment.status = 'error' + segment.status = "error" segment.error = str(e) db.session.commit() finally: diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index ce93e111e54aa4..cfc54920e23caa 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -11,7 +11,7 @@ from models.dataset import Document as DatasetDocument -@shared_task(queue='dataset') +@shared_task(queue="dataset") def deal_dataset_vector_index_task(dataset_id: str, action: str): """ Async deal dataset from index @@ -19,41 +19,46 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): :param action: action Usage: deal_dataset_vector_index_task.delay(dataset_id, action) """ - logging.info(click.style('Start deal dataset vector index: {}'.format(dataset_id), fg='green')) + logging.info(click.style("Start deal dataset vector index: {}".format(dataset_id), fg="green")) start_at = time.perf_counter() try: - dataset = Dataset.query.filter_by( - id=dataset_id - ).first() + dataset = Dataset.query.filter_by(id=dataset_id).first() if not dataset: - raise Exception('Dataset not found') + raise Exception("Dataset not found") index_type = dataset.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "remove": index_processor.clean(dataset, None, with_keywords=False) elif action == "add": - dataset_documents = db.session.query(DatasetDocument).filter( - DatasetDocument.dataset_id == dataset_id, - DatasetDocument.indexing_status == 'completed', - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).all() + dataset_documents = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() + ) if dataset_documents: dataset_documents_ids = [doc.id for doc in dataset_documents] - db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)) \ - .update({"indexing_status": "indexing"}, synchronize_session=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) db.session.commit() for dataset_document in dataset_documents: try: # add from vector index - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.enabled == True - ) .order_by(DocumentSegment.position.asc()).all() + segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .order_by(DocumentSegment.position.asc()) + .all() + ) if segments: documents = [] for segment in segments: @@ -64,32 +69,39 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) # save vector index index_processor.load(dataset, documents, with_keywords=False) - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id) \ - .update({"indexing_status": "completed"}, synchronize_session=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) db.session.commit() except Exception as e: - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id) \ - .update({"indexing_status": "error", "error": str(e)}, synchronize_session=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) db.session.commit() - elif action == 'update': - dataset_documents = db.session.query(DatasetDocument).filter( - DatasetDocument.dataset_id == dataset_id, - DatasetDocument.indexing_status == 'completed', - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).all() + elif action == "update": + dataset_documents = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.dataset_id == dataset_id, + DatasetDocument.indexing_status == "completed", + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .all() + ) # add new index if dataset_documents: # update document status dataset_documents_ids = [doc.id for doc in dataset_documents] - db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)) \ - .update({"indexing_status": "indexing"}, synchronize_session=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + {"indexing_status": "indexing"}, synchronize_session=False + ) db.session.commit() # clean index @@ -98,10 +110,12 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): for dataset_document in dataset_documents: # update from vector index try: - segments = db.session.query(DocumentSegment).filter( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.enabled == True - ).order_by(DocumentSegment.position.asc()).all() + segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .order_by(DocumentSegment.position.asc()) + .all() + ) if segments: documents = [] for segment in segments: @@ -112,23 +126,25 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) # save vector index index_processor.load(dataset, documents, with_keywords=False) - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id) \ - .update({"indexing_status": "completed"}, synchronize_session=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "completed"}, synchronize_session=False + ) db.session.commit() except Exception as e: - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id) \ - .update({"indexing_status": "error", "error": str(e)}, synchronize_session=False) + db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + {"indexing_status": "error", "error": str(e)}, synchronize_session=False + ) db.session.commit() - end_at = time.perf_counter() logging.info( - click.style('Deal dataset vector index: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green") + ) except Exception: logging.exception("Deal dataset vector index failed") diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index d79286cf3d8a3c..c3e0ea5d9fbb77 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -10,7 +10,7 @@ from models.dataset import Dataset, Document -@shared_task(queue='dataset') +@shared_task(queue="dataset") def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_id: str, document_id: str): """ Async Remove segment from index @@ -21,22 +21,22 @@ def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_ Usage: delete_segment_from_index_task.delay(segment_id) """ - logging.info(click.style('Start delete segment from index: {}'.format(segment_id), fg='green')) + logging.info(click.style("Start delete segment from index: {}".format(segment_id), fg="green")) start_at = time.perf_counter() - indexing_cache_key = 'segment_{}_delete_indexing'.format(segment_id) + indexing_cache_key = "segment_{}_delete_indexing".format(segment_id) try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - logging.info(click.style('Segment {} has no dataset, pass.'.format(segment_id), fg='cyan')) + logging.info(click.style("Segment {} has no dataset, pass.".format(segment_id), fg="cyan")) return dataset_document = db.session.query(Document).filter(Document.id == document_id).first() if not dataset_document: - logging.info(click.style('Segment {} has no document, pass.'.format(segment_id), fg='cyan')) + logging.info(click.style("Segment {} has no document, pass.".format(segment_id), fg="cyan")) return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment_id), fg='cyan')) + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Segment {} document status is invalid, pass.".format(segment_id), fg="cyan")) return index_type = dataset_document.doc_form @@ -44,7 +44,9 @@ def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_ index_processor.clean(dataset, [index_node_id]) end_at = time.perf_counter() - logging.info(click.style('Segment deleted from index: {} latency: {}'.format(segment_id, end_at - start_at), fg='green')) + logging.info( + click.style("Segment deleted from index: {} latency: {}".format(segment_id, end_at - start_at), fg="green") + ) except Exception: logging.exception("delete segment from index failed") finally: diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index 4788bf4e4b5960..15e1e50076e8c9 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -11,7 +11,7 @@ from models.dataset import DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def disable_segment_from_index_task(segment_id: str): """ Async disable segment from index @@ -19,33 +19,33 @@ def disable_segment_from_index_task(segment_id: str): Usage: disable_segment_from_index_task.delay(segment_id) """ - logging.info(click.style('Start disable segment from index: {}'.format(segment_id), fg='green')) + logging.info(click.style("Start disable segment from index: {}".format(segment_id), fg="green")) start_at = time.perf_counter() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() if not segment: - raise NotFound('Segment not found') + raise NotFound("Segment not found") - if segment.status != 'completed': - raise NotFound('Segment is not completed , disable action is not allowed.') + if segment.status != "completed": + raise NotFound("Segment is not completed , disable action is not allowed.") - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_indexing".format(segment.id) try: dataset = segment.dataset if not dataset: - logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no dataset, pass.".format(segment.id), fg="cyan")) return dataset_document = segment.document if not dataset_document: - logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no document, pass.".format(segment.id), fg="cyan")) return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Segment {} document status is invalid, pass.".format(segment.id), fg="cyan")) return index_type = dataset_document.doc_form @@ -53,7 +53,9 @@ def disable_segment_from_index_task(segment_id: str): index_processor.clean(dataset, [segment.index_node_id]) end_at = time.perf_counter() - logging.info(click.style('Segment removed from index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) + logging.info( + click.style("Segment removed from index: {} latency: {}".format(segment.id, end_at - start_at), fg="green") + ) except Exception: logging.exception("remove segment from index failed") segment.enabled = True diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 4cced36ecdd856..6dd755ab032201 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -6,7 +6,7 @@ from celery import shared_task from werkzeug.exceptions import NotFound -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db @@ -14,7 +14,7 @@ from models.source import DataSourceOauthBinding -@shared_task(queue='dataset') +@shared_task(queue="dataset") def document_indexing_sync_task(dataset_id: str, document_id: str): """ Async update document @@ -23,50 +23,50 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): Usage: document_indexing_sync_task.delay(dataset_id, document_id) """ - logging.info(click.style('Start sync document: {}'.format(document_id), fg='green')) + logging.info(click.style("Start sync document: {}".format(document_id), fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: - raise NotFound('Document not found') + raise NotFound("Document not found") data_source_info = document.data_source_info_dict - if document.data_source_type == 'notion_import': - if not data_source_info or 'notion_page_id' not in data_source_info \ - or 'notion_workspace_id' not in data_source_info: + if document.data_source_type == "notion_import": + if ( + not data_source_info + or "notion_page_id" not in data_source_info + or "notion_workspace_id" not in data_source_info + ): raise ValueError("no notion page found") - workspace_id = data_source_info['notion_workspace_id'] - page_id = data_source_info['notion_page_id'] - page_type = data_source_info['type'] - page_edited_time = data_source_info['last_edited_time'] + workspace_id = data_source_info["notion_workspace_id"] + page_id = data_source_info["notion_page_id"] + page_type = data_source_info["type"] + page_edited_time = data_source_info["last_edited_time"] data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == document.tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', ) ).first() if not data_source_binding: - raise ValueError('Data source binding not found.') + raise ValueError("Data source binding not found.") loader = NotionExtractor( notion_workspace_id=workspace_id, notion_obj_id=page_id, notion_page_type=page_type, notion_access_token=data_source_binding.access_token, - tenant_id=document.tenant_id + tenant_id=document.tenant_id, ) last_edited_time = loader.get_notion_last_edited_time() # check the page is updated if last_edited_time != page_edited_time: - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() @@ -74,7 +74,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - raise Exception('Dataset not found') + raise Exception("Dataset not found") index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -89,7 +89,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): end_at = time.perf_counter() logging.info( - click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) + click.style( + "Cleaned document when document update data source or process rule: {} latency: {}".format( + document_id, end_at - start_at + ), + fg="green", + ) + ) except Exception: logging.exception("Cleaned document when document update data source or process rule failed") @@ -97,8 +103,10 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): indexing_runner = IndexingRunner() indexing_runner.run([document]) end_at = time.perf_counter() - logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) - except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info( + click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green") + ) + except DocumentIsPausedError as ex: + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index cc93a1341e180d..72c4674e0fccbb 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -6,13 +6,13 @@ from celery import shared_task from configs import dify_config -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from extensions.ext_database import db from models.dataset import Dataset, Document from services.feature_service import FeatureService -@shared_task(queue='dataset') +@shared_task(queue="dataset") def document_indexing_task(dataset_id: str, document_ids: list): """ Async process document @@ -36,16 +36,17 @@ def document_indexing_task(dataset_id: str, document_ids: list): if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") if 0 < vector_space.limit <= vector_space.size: - raise ValueError("Your total number of documents plus the number of uploads have over the limit of " - "your subscription.") + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) except Exception as e: for document_id in document_ids: - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(e) document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) @@ -53,15 +54,14 @@ def document_indexing_task(dataset_id: str, document_ids: list): return for document_id in document_ids: - logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) + logging.info(click.style("Start process document: {}".format(document_id), fg="green")) - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) documents.append(document) db.session.add(document) @@ -71,8 +71,8 @@ def document_indexing_task(dataset_id: str, document_ids: list): indexing_runner = IndexingRunner() indexing_runner.run(documents) end_at = time.perf_counter() - logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) - except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) + except DocumentIsPausedError as ex: + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index f129d93de8da5a..cb38bc668d059b 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -6,13 +6,13 @@ from celery import shared_task from werkzeug.exceptions import NotFound -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def document_indexing_update_task(dataset_id: str, document_id: str): """ Async update document @@ -21,18 +21,15 @@ def document_indexing_update_task(dataset_id: str, document_id: str): Usage: document_indexing_update_task.delay(dataset_id, document_id) """ - logging.info(click.style('Start update document: {}'.format(document_id), fg='green')) + logging.info(click.style("Start update document: {}".format(document_id), fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: - raise NotFound('Document not found') + raise NotFound("Document not found") - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() @@ -40,7 +37,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - raise Exception('Dataset not found') + raise Exception("Dataset not found") index_type = document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() @@ -57,7 +54,13 @@ def document_indexing_update_task(dataset_id: str, document_id: str): db.session.commit() end_at = time.perf_counter() logging.info( - click.style('Cleaned document when document update data source or process rule: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) + click.style( + "Cleaned document when document update data source or process rule: {} latency: {}".format( + document_id, end_at - start_at + ), + fg="green", + ) + ) except Exception: logging.exception("Cleaned document when document update data source or process rule failed") @@ -65,8 +68,8 @@ def document_indexing_update_task(dataset_id: str, document_id: str): indexing_runner = IndexingRunner() indexing_runner.run([document]) end_at = time.perf_counter() - logging.info(click.style('update document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) - except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green")) + except DocumentIsPausedError as ex: + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 884e222d1b3eef..f4c3dbd2e2860c 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -6,14 +6,14 @@ from celery import shared_task from configs import dify_config -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment from services.feature_service import FeatureService -@shared_task(queue='dataset') +@shared_task(queue="dataset") def duplicate_document_indexing_task(dataset_id: str, document_ids: list): """ Async process document @@ -37,16 +37,17 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") if 0 < vector_space.limit <= vector_space.size: - raise ValueError("Your total number of documents plus the number of uploads have over the limit of " - "your subscription.") + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) except Exception as e: for document_id in document_ids: - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(e) document.stopped_at = datetime.datetime.utcnow() db.session.add(document) @@ -54,12 +55,11 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): return for document_id in document_ids: - logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) + logging.info(click.style("Start process document: {}".format(document_id), fg="green")) - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: # clean old data @@ -77,7 +77,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): db.session.delete(segment) db.session.commit() - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.utcnow() documents.append(document) db.session.add(document) @@ -87,8 +87,8 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): indexing_runner = IndexingRunner() indexing_runner.run(documents) end_at = time.perf_counter() - logging.info(click.style('Processed dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) - except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) + except DocumentIsPausedError as ex: + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index e37c06855d47de..1412ad9ec74c06 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -13,7 +13,7 @@ from models.dataset import DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def enable_segment_to_index_task(segment_id: str): """ Async enable segment to index @@ -21,17 +21,17 @@ def enable_segment_to_index_task(segment_id: str): Usage: enable_segment_to_index_task.delay(segment_id) """ - logging.info(click.style('Start enable segment to index: {}'.format(segment_id), fg='green')) + logging.info(click.style("Start enable segment to index: {}".format(segment_id), fg="green")) start_at = time.perf_counter() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() if not segment: - raise NotFound('Segment not found') + raise NotFound("Segment not found") - if segment.status != 'completed': - raise NotFound('Segment is not completed, enable action is not allowed.') + if segment.status != "completed": + raise NotFound("Segment is not completed, enable action is not allowed.") - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_indexing".format(segment.id) try: document = Document( @@ -41,23 +41,23 @@ def enable_segment_to_index_task(segment_id: str): "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) dataset = segment.dataset if not dataset: - logging.info(click.style('Segment {} has no dataset, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no dataset, pass.".format(segment.id), fg="cyan")) return dataset_document = segment.document if not dataset_document: - logging.info(click.style('Segment {} has no document, pass.'.format(segment.id), fg='cyan')) + logging.info(click.style("Segment {} has no document, pass.".format(segment.id), fg="cyan")) return - if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed': - logging.info(click.style('Segment {} document status is invalid, pass.'.format(segment.id), fg='cyan')) + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Segment {} document status is invalid, pass.".format(segment.id), fg="cyan")) return index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() @@ -65,12 +65,14 @@ def enable_segment_to_index_task(segment_id: str): index_processor.load(dataset, [document]) end_at = time.perf_counter() - logging.info(click.style('Segment enabled to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) + logging.info( + click.style("Segment enabled to index: {} latency: {}".format(segment.id, end_at - start_at), fg="green") + ) except Exception as e: logging.exception("enable segment to index failed") segment.enabled = False segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment.status = 'error' + segment.status = "error" segment.error = str(e) db.session.commit() finally: diff --git a/api/tasks/external_document_indexing_task.py b/api/tasks/external_document_indexing_task.py new file mode 100644 index 00000000000000..6fc719ae8d085a --- /dev/null +++ b/api/tasks/external_document_indexing_task.py @@ -0,0 +1,93 @@ +import json +import logging +import time + +import click +from celery import shared_task + +from core.indexing_runner import DocumentIsPausedException +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.dataset import Dataset, ExternalKnowledgeApis +from models.model import UploadFile +from services.external_knowledge_service import ExternalDatasetService + + +@shared_task(queue="dataset") +def external_document_indexing_task( + dataset_id: str, external_knowledge_api_id: str, data_source: dict, process_parameter: dict +): + """ + Async process document + :param dataset_id: + :param external_knowledge_api_id: + :param data_source: + :param process_parameter: + Usage: external_document_indexing_task.delay(dataset_id, document_id) + """ + start_at = time.perf_counter() + + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + logging.info( + click.style("Processed external dataset: {} failed, dataset not exit.".format(dataset_id), fg="red") + ) + return + + # get external api template + external_knowledge_api = ( + db.session.query(ExternalKnowledgeApis) + .filter( + ExternalKnowledgeApis.id == external_knowledge_api_id, ExternalKnowledgeApis.tenant_id == dataset.tenant_id + ) + .first() + ) + + if not external_knowledge_api: + logging.info( + click.style( + "Processed external dataset: {} failed, api template: {} not exit.".format( + dataset_id, external_knowledge_api_id + ), + fg="red", + ) + ) + return + files = {} + if data_source["type"] == "upload_file": + upload_file_list = data_source["info_list"]["file_info_list"]["file_ids"] + for file_id in upload_file_list: + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .first() + ) + if file: + files[file.id] = (file.name, storage.load_once(file.key), file.mime_type) + try: + settings = ExternalDatasetService.get_external_knowledge_api_settings( + json.loads(external_knowledge_api.settings) + ) + # assemble headers + headers = ExternalDatasetService.assembling_headers(settings.authorization, settings.headers) + + # do http request + response = ExternalDatasetService.process_external_api(settings, headers, process_parameter, files) + job_id = response.json().get("job_id") + if job_id: + # save job_id to dataset + dataset.job_id = job_id + db.session.commit() + + end_at = time.perf_counter() + logging.info( + click.style( + "Processed external dataset: {} successful, latency: {}".format(dataset.id, end_at - start_at), + fg="green", + ) + ) + except DocumentIsPausedException as ex: + logging.info(click.style(str(ex), fg="yellow")) + + except Exception: + pass diff --git a/api/tasks/mail_email_code_login.py b/api/tasks/mail_email_code_login.py new file mode 100644 index 00000000000000..d78fc2b8915520 --- /dev/null +++ b/api/tasks/mail_email_code_login.py @@ -0,0 +1,41 @@ +import logging +import time + +import click +from celery import shared_task +from flask import render_template + +from extensions.ext_mail import mail + + +@shared_task(queue="mail") +def send_email_code_login_mail_task(language: str, to: str, code: str): + """ + Async Send email code login mail + :param language: Language in which the email should be sent (e.g., 'en', 'zh') + :param to: Recipient email address + :param code: Email code to be included in the email + """ + if not mail.is_inited(): + return + + logging.info(click.style("Start email code login mail to {}".format(to), fg="green")) + start_at = time.perf_counter() + + # send email code login mail using different languages + try: + if language == "zh-Hans": + html_content = render_template("email_code_login_mail_template_zh-CN.html", to=to, code=code) + mail.send(to=to, subject="邮箱验证码", html=html_content) + else: + html_content = render_template("email_code_login_mail_template_en-US.html", to=to, code=code) + mail.send(to=to, subject="Email Code", html=html_content) + + end_at = time.perf_counter() + logging.info( + click.style( + "Send email code login mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green" + ) + ) + except Exception: + logging.exception("Send email code login mail to {} failed".format(to)) diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py index a46eafa797907c..c7dfb9bf6063ff 100644 --- a/api/tasks/mail_invite_member_task.py +++ b/api/tasks/mail_invite_member_task.py @@ -9,7 +9,7 @@ from extensions.ext_mail import mail -@shared_task(queue='mail') +@shared_task(queue="mail") def send_invite_member_mail_task(language: str, to: str, token: str, inviter_name: str, workspace_name: str): """ Async Send invite member mail @@ -19,36 +19,43 @@ def send_invite_member_mail_task(language: str, to: str, token: str, inviter_nam :param inviter_name :param workspace_name - Usage: send_invite_member_mail_task.delay(langauge, to, token, inviter_name, workspace_name) + Usage: send_invite_member_mail_task.delay(language, to, token, inviter_name, workspace_name) """ if not mail.is_inited(): return - logging.info(click.style('Start send invite member mail to {} in workspace {}'.format(to, workspace_name), - fg='green')) + logging.info( + click.style("Start send invite member mail to {} in workspace {}".format(to, workspace_name), fg="green") + ) start_at = time.perf_counter() # send invite member mail using different languages try: - url = f'{dify_config.CONSOLE_WEB_URL}/activate?token={token}' - if language == 'zh-Hans': - html_content = render_template('invite_member_mail_template_zh-CN.html', - to=to, - inviter_name=inviter_name, - workspace_name=workspace_name, - url=url) + url = f"{dify_config.CONSOLE_WEB_URL}/activate?token={token}" + if language == "zh-Hans": + html_content = render_template( + "invite_member_mail_template_zh-CN.html", + to=to, + inviter_name=inviter_name, + workspace_name=workspace_name, + url=url, + ) mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content) else: - html_content = render_template('invite_member_mail_template_en-US.html', - to=to, - inviter_name=inviter_name, - workspace_name=workspace_name, - url=url) + html_content = render_template( + "invite_member_mail_template_en-US.html", + to=to, + inviter_name=inviter_name, + workspace_name=workspace_name, + url=url, + ) mail.send(to=to, subject="Join Dify Workspace Now", html=html_content) end_at = time.perf_counter() logging.info( - click.style('Send invite member mail to {} succeeded: latency: {}'.format(to, end_at - start_at), - fg='green')) + click.style( + "Send invite member mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green" + ) + ) except Exception: - logging.exception("Send invite member mail to {} failed".format(to)) \ No newline at end of file + logging.exception("Send invite member mail to {} failed".format(to)) diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py index 4e1b8a89135e7f..8596ca07cfcee3 100644 --- a/api/tasks/mail_reset_password_task.py +++ b/api/tasks/mail_reset_password_task.py @@ -5,41 +5,37 @@ from celery import shared_task from flask import render_template -from configs import dify_config from extensions.ext_mail import mail -@shared_task(queue='mail') -def send_reset_password_mail_task(language: str, to: str, token: str): +@shared_task(queue="mail") +def send_reset_password_mail_task(language: str, to: str, code: str): """ Async Send reset password mail :param language: Language in which the email should be sent (e.g., 'en', 'zh') :param to: Recipient email address - :param token: Reset password token to be included in the email + :param code: Reset password code """ if not mail.is_inited(): return - logging.info(click.style('Start password reset mail to {}'.format(to), fg='green')) + logging.info(click.style("Start password reset mail to {}".format(to), fg="green")) start_at = time.perf_counter() # send reset password mail using different languages try: - url = f'{dify_config.CONSOLE_WEB_URL}/forgot-password?token={token}' - if language == 'zh-Hans': - html_content = render_template('reset_password_mail_template_zh-CN.html', - to=to, - url=url) - mail.send(to=to, subject="重置您的 Dify 密码", html=html_content) + if language == "zh-Hans": + html_content = render_template("reset_password_mail_template_zh-CN.html", to=to, code=code) + mail.send(to=to, subject="设置您的 Dify 密码", html=html_content) else: - html_content = render_template('reset_password_mail_template_en-US.html', - to=to, - url=url) - mail.send(to=to, subject="Reset Your Dify Password", html=html_content) + html_content = render_template("reset_password_mail_template_en-US.html", to=to, code=code) + mail.send(to=to, subject="Set Your Dify Password", html=html_content) end_at = time.perf_counter() logging.info( - click.style('Send password reset mail to {} succeeded: latency: {}'.format(to, end_at - start_at), - fg='green')) + click.style( + "Send password reset mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green" + ) + ) except Exception: logging.exception("Send password reset mail to {} failed".format(to)) diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py index 6b4cab55b399ae..34c62dc9237fc0 100644 --- a/api/tasks/ops_trace_task.py +++ b/api/tasks/ops_trace_task.py @@ -1,17 +1,20 @@ +import json import logging -import time from celery import shared_task from flask import current_app +from core.ops.entities.config_entity import OPS_FILE_PATH, OPS_TRACE_FAILED_KEY from core.ops.entities.trace_entity import trace_info_info_map from core.rag.models.document import Document +from extensions.ext_redis import redis_client +from extensions.ext_storage import storage from models.model import Message from models.workflow import WorkflowRun -@shared_task(queue='ops_trace') -def process_trace_tasks(tasks_data): +@shared_task(queue="ops_trace") +def process_trace_tasks(file_info): """ Async process trace tasks :param tasks_data: List of dictionaries containing task data @@ -20,17 +23,20 @@ def process_trace_tasks(tasks_data): """ from core.ops.ops_trace_manager import OpsTraceManager - trace_info = tasks_data.get('trace_info') - app_id = tasks_data.get('app_id') - trace_info_type = tasks_data.get('trace_info_type') + app_id = file_info.get("app_id") + file_id = file_info.get("file_id") + file_path = f"{OPS_FILE_PATH}{app_id}/{file_id}.json" + file_data = json.loads(storage.load(file_path)) + trace_info = file_data.get("trace_info") + trace_info_type = file_data.get("trace_info_type") trace_instance = OpsTraceManager.get_ops_trace_instance(app_id) - if trace_info.get('message_data'): - trace_info['message_data'] = Message.from_dict(data=trace_info['message_data']) - if trace_info.get('workflow_data'): - trace_info['workflow_data'] = WorkflowRun.from_dict(data=trace_info['workflow_data']) - if trace_info.get('documents'): - trace_info['documents'] = [Document(**doc) for doc in trace_info['documents']] + if trace_info.get("message_data"): + trace_info["message_data"] = Message.from_dict(data=trace_info["message_data"]) + if trace_info.get("workflow_data"): + trace_info["workflow_data"] = WorkflowRun.from_dict(data=trace_info["workflow_data"]) + if trace_info.get("documents"): + trace_info["documents"] = [Document(**doc) for doc in trace_info["documents"]] try: if trace_instance: @@ -39,6 +45,10 @@ def process_trace_tasks(tasks_data): if trace_type: trace_info = trace_type(**trace_info) trace_instance.trace(trace_info) - end_at = time.perf_counter() + logging.info(f"Processing trace tasks success, app_id: {app_id}") except Exception: - logging.exception("Processing trace tasks failed") + failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}" + redis_client.incr(failed_key) + logging.info(f"Processing trace tasks failed, app_id: {app_id}") + finally: + storage.delete(file_path) diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index 02278f512b74b6..934eb7430c90c3 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -5,12 +5,12 @@ from celery import shared_task from werkzeug.exceptions import NotFound -from core.indexing_runner import DocumentIsPausedException, IndexingRunner +from core.indexing_runner import DocumentIsPausedError, IndexingRunner from extensions.ext_database import db from models.dataset import Document -@shared_task(queue='dataset') +@shared_task(queue="dataset") def recover_document_indexing_task(dataset_id: str, document_id: str): """ Async recover document @@ -19,28 +19,27 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): Usage: recover_document_indexing_task.delay(dataset_id, document_id) """ - logging.info(click.style('Recover document: {}'.format(document_id), fg='green')) + logging.info(click.style("Recover document: {}".format(document_id), fg="green")) start_at = time.perf_counter() - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: - raise NotFound('Document not found') + raise NotFound("Document not found") try: indexing_runner = IndexingRunner() - if document.indexing_status in ["waiting", "parsing", "cleaning"]: + if document.indexing_status in {"waiting", "parsing", "cleaning"}: indexing_runner.run([document]) elif document.indexing_status == "splitting": indexing_runner.run_in_splitting_status(document) elif document.indexing_status == "indexing": indexing_runner.run_in_indexing_status(document) end_at = time.perf_counter() - logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) - except DocumentIsPausedException as ex: - logging.info(click.style(str(ex), fg='yellow')) + logging.info( + click.style("Processed document: {} latency: {}".format(document.id, end_at - start_at), fg="green") + ) + except DocumentIsPausedError as ex: + logging.info(click.style(str(ex), fg="yellow")) except Exception: pass diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 4efe7ee38c0b32..66f78636ecca60 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -33,9 +33,9 @@ from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun -@shared_task(queue='app_deletion', bind=True, max_retries=3) +@shared_task(queue="app_deletion", bind=True, max_retries=3) def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): - logging.info(click.style(f'Start deleting app and related data: {tenant_id}:{app_id}', fg='green')) + logging.info(click.style(f"Start deleting app and related data: {tenant_id}:{app_id}", fg="green")) start_at = time.perf_counter() try: # Delete related data @@ -59,13 +59,14 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str): _delete_conversation_variables(app_id=app_id) end_at = time.perf_counter() - logging.info(click.style(f'App and related data deleted: {app_id} latency: {end_at - start_at}', fg='green')) + logging.info(click.style(f"App and related data deleted: {app_id} latency: {end_at - start_at}", fg="green")) except SQLAlchemyError as e: logging.exception( - click.style(f"Database error occurred while deleting app {app_id} and related data", fg='red')) + click.style(f"Database error occurred while deleting app {app_id} and related data", fg="red") + ) raise self.retry(exc=e, countdown=60) # Retry after 60 seconds except Exception as e: - logging.exception(click.style(f"Error occurred while deleting app {app_id} and related data", fg='red')) + logging.exception(click.style(f"Error occurred while deleting app {app_id} and related data", fg="red")) raise self.retry(exc=e, countdown=60) # Retry after 60 seconds @@ -77,7 +78,7 @@ def del_model_config(model_config_id: str): """select id from app_model_configs where app_id=:app_id limit 1000""", {"app_id": app_id}, del_model_config, - "app model config" + "app model config", ) @@ -85,12 +86,7 @@ def _delete_app_site(tenant_id: str, app_id: str): def del_site(site_id: str): db.session.query(Site).filter(Site.id == site_id).delete(synchronize_session=False) - _delete_records( - """select id from sites where app_id=:app_id limit 1000""", - {"app_id": app_id}, - del_site, - "site" - ) + _delete_records("""select id from sites where app_id=:app_id limit 1000""", {"app_id": app_id}, del_site, "site") def _delete_app_api_tokens(tenant_id: str, app_id: str): @@ -98,10 +94,7 @@ def del_api_token(api_token_id: str): db.session.query(ApiToken).filter(ApiToken.id == api_token_id).delete(synchronize_session=False) _delete_records( - """select id from api_tokens where app_id=:app_id limit 1000""", - {"app_id": app_id}, - del_api_token, - "api token" + """select id from api_tokens where app_id=:app_id limit 1000""", {"app_id": app_id}, del_api_token, "api token" ) @@ -113,44 +106,47 @@ def del_installed_app(installed_app_id: str): """select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_installed_app, - "installed app" + "installed app", ) def _delete_recommended_apps(tenant_id: str, app_id: str): def del_recommended_app(recommended_app_id: str): db.session.query(RecommendedApp).filter(RecommendedApp.id == recommended_app_id).delete( - synchronize_session=False) + synchronize_session=False + ) _delete_records( """select id from recommended_apps where app_id=:app_id limit 1000""", {"app_id": app_id}, del_recommended_app, - "recommended app" + "recommended app", ) def _delete_app_annotation_data(tenant_id: str, app_id: str): def del_annotation_hit_history(annotation_hit_history_id: str): db.session.query(AppAnnotationHitHistory).filter( - AppAnnotationHitHistory.id == annotation_hit_history_id).delete(synchronize_session=False) + AppAnnotationHitHistory.id == annotation_hit_history_id + ).delete(synchronize_session=False) _delete_records( """select id from app_annotation_hit_histories where app_id=:app_id limit 1000""", {"app_id": app_id}, del_annotation_hit_history, - "annotation hit history" + "annotation hit history", ) def del_annotation_setting(annotation_setting_id: str): db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.id == annotation_setting_id).delete( - synchronize_session=False) + synchronize_session=False + ) _delete_records( """select id from app_annotation_settings where app_id=:app_id limit 1000""", {"app_id": app_id}, del_annotation_setting, - "annotation setting" + "annotation setting", ) @@ -162,7 +158,7 @@ def del_dataset_join(dataset_join_id: str): """select id from app_dataset_joins where app_id=:app_id limit 1000""", {"app_id": app_id}, del_dataset_join, - "dataset join" + "dataset join", ) @@ -174,7 +170,7 @@ def del_workflow(workflow_id: str): """select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_workflow, - "workflow" + "workflow", ) @@ -186,89 +182,93 @@ def del_workflow_run(workflow_run_id: str): """select id from workflow_runs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_workflow_run, - "workflow run" + "workflow run", ) def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): def del_workflow_node_execution(workflow_node_execution_id: str): - db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id == workflow_node_execution_id).delete(synchronize_session=False) + db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).delete( + synchronize_session=False + ) _delete_records( """select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_workflow_node_execution, - "workflow node execution" + "workflow node execution", ) def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def del_workflow_app_log(workflow_app_log_id: str): - db.session.query(WorkflowAppLog).filter(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False) + db.session.query(WorkflowAppLog).filter(WorkflowAppLog.id == workflow_app_log_id).delete( + synchronize_session=False + ) _delete_records( """select id from workflow_app_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_workflow_app_log, - "workflow app log" + "workflow app log", ) def _delete_app_conversations(tenant_id: str, app_id: str): def del_conversation(conversation_id: str): db.session.query(PinnedConversation).filter(PinnedConversation.conversation_id == conversation_id).delete( - synchronize_session=False) + synchronize_session=False + ) db.session.query(Conversation).filter(Conversation.id == conversation_id).delete(synchronize_session=False) _delete_records( """select id from conversations where app_id=:app_id limit 1000""", {"app_id": app_id}, del_conversation, - "conversation" + "conversation", ) + def _delete_conversation_variables(*, app_id: str): stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id) with db.engine.connect() as conn: conn.execute(stmt) conn.commit() - logging.info(click.style(f"Deleted conversation variables for app {app_id}", fg='green')) + logging.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green")) def _delete_app_messages(tenant_id: str, app_id: str): def del_message(message_id: str): db.session.query(MessageFeedback).filter(MessageFeedback.message_id == message_id).delete( - synchronize_session=False) + synchronize_session=False + ) db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == message_id).delete( - synchronize_session=False) - db.session.query(MessageChain).filter(MessageChain.message_id == message_id).delete( - synchronize_session=False) + synchronize_session=False + ) + db.session.query(MessageChain).filter(MessageChain.message_id == message_id).delete(synchronize_session=False) db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message_id).delete( - synchronize_session=False) + synchronize_session=False + ) db.session.query(MessageFile).filter(MessageFile.message_id == message_id).delete(synchronize_session=False) - db.session.query(SavedMessage).filter(SavedMessage.message_id == message_id).delete( - synchronize_session=False) + db.session.query(SavedMessage).filter(SavedMessage.message_id == message_id).delete(synchronize_session=False) db.session.query(Message).filter(Message.id == message_id).delete() _delete_records( - """select id from messages where app_id=:app_id limit 1000""", - {"app_id": app_id}, - del_message, - "message" + """select id from messages where app_id=:app_id limit 1000""", {"app_id": app_id}, del_message, "message" ) def _delete_workflow_tool_providers(tenant_id: str, app_id: str): def del_tool_provider(tool_provider_id: str): db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.id == tool_provider_id).delete( - synchronize_session=False) + synchronize_session=False + ) _delete_records( """select id from tool_workflow_providers where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_tool_provider, - "tool workflow provider" + "tool workflow provider", ) @@ -280,7 +280,7 @@ def del_tag_binding(tag_binding_id: str): """select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_tag_binding, - "tag binding" + "tag binding", ) @@ -292,20 +292,21 @@ def del_end_user(end_user_id: str): """select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""", {"tenant_id": tenant_id, "app_id": app_id}, del_end_user, - "end user" + "end user", ) def _delete_trace_app_configs(tenant_id: str, app_id: str): def del_trace_app_config(trace_app_config_id: str): db.session.query(TraceAppConfig).filter(TraceAppConfig.id == trace_app_config_id).delete( - synchronize_session=False) + synchronize_session=False + ) _delete_records( """select id from trace_app_config where app_id=:app_id limit 1000""", {"app_id": app_id}, del_trace_app_config, - "trace app config" + "trace app config", ) @@ -321,7 +322,7 @@ def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: s try: delete_func(record_id) db.session.commit() - logging.info(click.style(f"Deleted {name} {record_id}", fg='green')) + logging.info(click.style(f"Deleted {name} {record_id}", fg="green")) except Exception: logging.exception(f"Error occurred while deleting {name} {record_id}") continue diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index cff8dddc53c630..1909eaf3418517 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -11,7 +11,7 @@ from models.dataset import Document, DocumentSegment -@shared_task(queue='dataset') +@shared_task(queue="dataset") def remove_document_from_index_task(document_id: str): """ Async Remove document from index @@ -19,23 +19,23 @@ def remove_document_from_index_task(document_id: str): Usage: remove_document_from_index.delay(document_id) """ - logging.info(click.style('Start remove document segments from index: {}'.format(document_id), fg='green')) + logging.info(click.style("Start remove document segments from index: {}".format(document_id), fg="green")) start_at = time.perf_counter() document = db.session.query(Document).filter(Document.id == document_id).first() if not document: - raise NotFound('Document not found') + raise NotFound("Document not found") - if document.indexing_status != 'completed': + if document.indexing_status != "completed": return - indexing_cache_key = 'document_{}_indexing'.format(document.id) + indexing_cache_key = "document_{}_indexing".format(document.id) try: dataset = document.dataset if not dataset: - raise Exception('Document has no dataset') + raise Exception("Document has no dataset") index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() @@ -49,7 +49,10 @@ def remove_document_from_index_task(document_id: str): end_at = time.perf_counter() logging.info( - click.style('Document removed from index: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) + click.style( + "Document removed from index: {} latency: {}".format(document.id, end_at - start_at), fg="green" + ) + ) except Exception: logging.exception("remove document from index failed") if not document.archived: diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index 1114809b3019a1..73471fd6e77c9b 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -13,7 +13,7 @@ from services.feature_service import FeatureService -@shared_task(queue='dataset') +@shared_task(queue="dataset") def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): """ Async process document @@ -27,22 +27,23 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() for document_id in document_ids: - retry_indexing_cache_key = 'document_{}_is_retried'.format(document_id) + retry_indexing_cache_key = "document_{}_is_retried".format(document_id) # check document limit features = FeatureService.get_features(dataset.tenant_id) try: if features.billing.enabled: vector_space = features.vector_space if 0 < vector_space.limit <= vector_space.size: - raise ValueError("Your total number of documents plus the number of uploads have over the limit of " - "your subscription.") + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) except Exception as e: - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(e) document.stopped_at = datetime.datetime.utcnow() db.session.add(document) @@ -50,11 +51,10 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): redis_client.delete(retry_indexing_cache_key) return - logging.info(click.style('Start retry document: {}'.format(document_id), fg='green')) - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + logging.info(click.style("Start retry document: {}".format(document_id), fg="green")) + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) try: if document: # clean old data @@ -70,7 +70,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): db.session.delete(segment) db.session.commit() - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.utcnow() db.session.add(document) db.session.commit() @@ -79,13 +79,13 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): indexing_runner.run([document]) redis_client.delete(retry_indexing_cache_key) except Exception as ex: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(ex) document.stopped_at = datetime.datetime.utcnow() db.session.add(document) db.session.commit() - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) redis_client.delete(retry_indexing_cache_key) pass end_at = time.perf_counter() - logging.info(click.style('Retry dataset: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + logging.info(click.style("Retry dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index 320da8718a12cb..1d2a338c831764 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -13,35 +13,36 @@ from services.feature_service import FeatureService -@shared_task(queue='dataset') +@shared_task(queue="dataset") def sync_website_document_indexing_task(dataset_id: str, document_id: str): """ Async process document :param dataset_id: :param document_id: - Usage: sunc_website_document_indexing_task.delay(dataset_id, document_id) + Usage: sync_website_document_indexing_task.delay(dataset_id, document_id) """ start_at = time.perf_counter() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() - sync_indexing_cache_key = 'document_{}_is_sync'.format(document_id) + sync_indexing_cache_key = "document_{}_is_sync".format(document_id) # check document limit features = FeatureService.get_features(dataset.tenant_id) try: if features.billing.enabled: vector_space = features.vector_space if 0 < vector_space.limit <= vector_space.size: - raise ValueError("Your total number of documents plus the number of uploads have over the limit of " - "your subscription.") + raise ValueError( + "Your total number of documents plus the number of uploads have over the limit of " + "your subscription." + ) except Exception as e: - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) if document: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(e) document.stopped_at = datetime.datetime.utcnow() db.session.add(document) @@ -49,11 +50,8 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): redis_client.delete(sync_indexing_cache_key) return - logging.info(click.style('Start sync website document: {}'.format(document_id), fg='green')) - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + logging.info(click.style("Start sync website document: {}".format(document_id), fg="green")) + document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() try: if document: # clean old data @@ -69,7 +67,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): db.session.delete(segment) db.session.commit() - document.indexing_status = 'parsing' + document.indexing_status = "parsing" document.processing_started_at = datetime.datetime.utcnow() db.session.add(document) db.session.commit() @@ -78,13 +76,13 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): indexing_runner.run([document]) redis_client.delete(sync_indexing_cache_key) except Exception as ex: - document.indexing_status = 'error' + document.indexing_status = "error" document.error = str(ex) document.stopped_at = datetime.datetime.utcnow() db.session.add(document) db.session.commit() - logging.info(click.style(str(ex), fg='yellow')) + logging.info(click.style(str(ex), fg="yellow")) redis_client.delete(sync_indexing_cache_key) pass end_at = time.perf_counter() - logging.info(click.style('Sync document: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) + logging.info(click.style("Sync document: {} latency: {}".format(document_id, end_at - start_at), fg="green")) diff --git a/api/templates/email_code_login_mail_template_en-US.html b/api/templates/email_code_login_mail_template_en-US.html new file mode 100644 index 00000000000000..066818d10c5a11 --- /dev/null +++ b/api/templates/email_code_login_mail_template_en-US.html @@ -0,0 +1,74 @@ + + + + + + +
+
+ + Dify Logo +
+

Your login code for Dify

+

Copy and paste this code, this code will only be valid for the next 5 minutes.

+
+ {{code}} +
+

If you didn't request a login, don't worry. You can safely ignore this email.

+
+ + diff --git a/api/templates/email_code_login_mail_template_zh-CN.html b/api/templates/email_code_login_mail_template_zh-CN.html new file mode 100644 index 00000000000000..0c2b63a1f1a694 --- /dev/null +++ b/api/templates/email_code_login_mail_template_zh-CN.html @@ -0,0 +1,74 @@ + + + + + + +
+
+ + Dify Logo +
+

Dify 的登录验证码

+

复制并粘贴此验证码,注意验证码仅在接下来的 5 分钟内有效。

+
+ {{code}} +
+

如果您没有请求登录,请不要担心。您可以安全地忽略此电子邮件。

+
+ + diff --git a/api/templates/invite_member_mail_template_en-US.html b/api/templates/invite_member_mail_template_en-US.html index 80f7d42c202f18..e8bf7f5a52a689 100644 --- a/api/templates/invite_member_mail_template_en-US.html +++ b/api/templates/invite_member_mail_template_en-US.html @@ -59,7 +59,7 @@

Dear {{ to }},

{{ inviter_name }} is pleased to invite you to join our workspace on Dify, a platform specifically designed for LLM application development. On Dify, you can explore, create, and collaborate to build and operate AI applications.

-

You can now log in to Dify using the GitHub or Google account associated with this email.

+

Click the button below to log in to Dify and join the workspace.

Login Here

diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx index 120fe29dff0f01..6e5046ecf80e91 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/provider-panel.tsx @@ -6,6 +6,7 @@ import { TracingProvider } from './type' import cn from '@/utils/classnames' import { LangfuseIconBig, LangsmithIconBig } from '@/app/components/base/icons/src/public/tracing' import { Settings04 } from '@/app/components/base/icons/src/vender/line/general' +import { Eye as View } from '@/app/components/base/icons/src/vender/solid/general' const I18N_PREFIX = 'app.tracing' @@ -13,6 +14,7 @@ type Props = { type: TracingProvider readOnly: boolean isChosen: boolean + config: any onChoose: () => void hasConfigured: boolean onConfig: () => void @@ -29,6 +31,7 @@ const ProviderPanel: FC = ({ type, readOnly, isChosen, + config, onChoose, hasConfigured, onConfig, @@ -41,6 +44,15 @@ const ProviderPanel: FC = ({ onConfig() }, [onConfig]) + const viewBtnClick = useCallback((e: React.MouseEvent) => { + e.preventDefault() + e.stopPropagation() + + const url = config?.project_url + if (url) + window.open(url, '_blank', 'noopener,noreferrer') + }, [config?.project_url]) + const handleChosen = useCallback((e: React.MouseEvent) => { e.stopPropagation() if (isChosen || !hasConfigured || readOnly) @@ -58,12 +70,20 @@ const ProviderPanel: FC = ({ {isChosen &&
{t(`${I18N_PREFIX}.inUse`)}
} {!readOnly && ( -
- -
{t(`${I18N_PREFIX}.config`)}
+
+ {hasConfigured && ( +
+ +
{t(`${I18N_PREFIX}.view`)}
+
+ )} +
+ +
{t(`${I18N_PREFIX}.config`)}
+
)} diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/toggle-fold-btn.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/toggle-fold-btn.tsx index 9119deede879b2..934eb681b9c85e 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/toggle-fold-btn.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/toggle-fold-btn.tsx @@ -3,7 +3,7 @@ import { ChevronDoubleDownIcon } from '@heroicons/react/20/solid' import type { FC } from 'react' import { useTranslation } from 'react-i18next' import React, { useCallback } from 'react' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' const I18N_PREFIX = 'app.tracing' @@ -25,9 +25,8 @@ const ToggleFoldBtn: FC = ({ return ( // text-[0px] to hide spacing between tooltip elements
- {isFold && (
@@ -39,7 +38,7 @@ const ToggleFoldBtn: FC = ({
)} -
+
) } diff --git a/web/app/(commonLayout)/apps/AppCard.tsx b/web/app/(commonLayout)/apps/AppCard.tsx index 1387099a627e1c..1ffb132cf8c186 100644 --- a/web/app/(commonLayout)/apps/AppCard.tsx +++ b/web/app/(commonLayout)/apps/AppCard.tsx @@ -21,7 +21,7 @@ import Divider from '@/app/components/base/divider' import { getRedirection } from '@/utils/app-redirection' import { useProviderContext } from '@/context/provider-context' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' -import { AiText, ChatBot, CuteRobote } from '@/app/components/base/icons/src/vender/solid/communication' +import { AiText, ChatBot, CuteRobot } from '@/app/components/base/icons/src/vender/solid/communication' import { Route } from '@/app/components/base/icons/src/vender/solid/mapsAndTravel' import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' import EditAppModal from '@/app/components/explore/create-app-modal' @@ -75,17 +75,21 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({ name, + icon_type, icon, icon_background, description, + use_icon_as_answer_icon, }) => { try { await updateAppInfo({ appID: app.id, name, + icon_type, icon, icon_background, description, + use_icon_as_answer_icon, }) setShowEditModal(false) notify({ @@ -101,11 +105,12 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { } }, [app.id, mutateApps, notify, onRefresh, t]) - const onCopy: DuplicateAppModalProps['onConfirm'] = async ({ name, icon, icon_background }) => { + const onCopy: DuplicateAppModalProps['onConfirm'] = async ({ name, icon_type, icon, icon_background }) => { try { const newApp = await copyApp({ appID: app.id, name, + icon_type, icon, icon_background, mode: app.mode, @@ -252,21 +257,23 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { e.preventDefault() getRedirection(isCurrentWorkspaceEditor, app, push) }} - className='group flex col-span-1 bg-white border-2 border-solid border-transparent rounded-xl shadow-sm min-h-[160px] flex flex-col transition-all duration-200 ease-in-out cursor-pointer hover:shadow-lg' + className='relative group col-span-1 bg-white border-2 border-solid border-transparent rounded-xl shadow-sm flex flex-col transition-all duration-200 ease-in-out cursor-pointer hover:shadow-lg' >
{app.mode === 'advanced-chat' && ( )} {app.mode === 'agent-chat' && ( - + )} {app.mode === 'chat' && ( @@ -292,17 +299,16 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
-
- {app.description} +
+
+ {app.description} +
{isCurrentWorkspaceEditor && ( @@ -360,10 +366,14 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { {showEditModal && ( setShowEditModal(false)} @@ -372,8 +382,10 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { {showDuplicateModal && ( setShowDuplicateModal(false)} diff --git a/web/app/(commonLayout)/apps/Apps.tsx b/web/app/(commonLayout)/apps/Apps.tsx index c16512bd50db1f..9d6345aa6c3de1 100644 --- a/web/app/(commonLayout)/apps/Apps.tsx +++ b/web/app/(commonLayout)/apps/Apps.tsx @@ -21,7 +21,7 @@ import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { CheckModal } from '@/hooks/use-pay' import TabSliderNew from '@/app/components/base/tab-slider-new' import { useTabSearchParams } from '@/hooks/use-tab-searchparams' -import SearchInput from '@/app/components/base/search-input' +import Input from '@/app/components/base/input' import { useStore as useTagStore } from '@/app/components/base/tag-management/store' import TagManagementModal from '@/app/components/base/tag-management' import TagFilter from '@/app/components/base/tag-management/filter' @@ -87,15 +87,15 @@ const Apps = () => { localStorage.removeItem(NEED_REFRESH_APP_LIST_KEY) mutate() } - }, []) + }, [mutate, t]) useEffect(() => { if (isCurrentWorkspaceDatasetOperator) return router.replace('/datasets') - }, [isCurrentWorkspaceDatasetOperator]) + }, [router, isCurrentWorkspaceDatasetOperator]) - const hasMore = data?.at(-1)?.has_more ?? true useEffect(() => { + const hasMore = data?.at(-1)?.has_more ?? true let observer: IntersectionObserver | undefined if (anchorRef.current) { observer = new IntersectionObserver((entries) => { @@ -105,7 +105,7 @@ const Apps = () => { observer.observe(anchorRef.current) } return () => observer?.disconnect() - }, [isLoading, setSize, anchorRef, mutate, hasMore]) + }, [isLoading, setSize, anchorRef, mutate, data]) const { run: handleSearch } = useDebounceFn(() => { setSearchKeywords(keywords) @@ -133,13 +133,20 @@ const Apps = () => { />
- + handleKeywordsChange(e.target.value)} + onClear={() => handleKeywordsChange('')} + />
---- +
@@ -50,7 +50,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from 索引方式 - high_quality 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引 - - economy 经济:使用 Keyword Table Index 的倒排索引进行构建 + - economy 经济:使用 keyword table index 的倒排索引进行构建 处理规则 @@ -64,7 +64,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - enabled (bool) 是否选中该规则,不传入文档 ID 时代表默认值 - segmentation (object) 分段规则 - separator 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n - - max_tokens 最大长度 (token) 默认为 1000 + - max_tokens 最大长度(token)默认为 1000 @@ -72,11 +72,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create_by_text' \ + curl --location --request --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create-by-text' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ @@ -123,13 +123,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -142,20 +142,20 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - ### Request Bodys + ### Request Body - - original_document_id 源文档 ID (选填) + - original_document_id 源文档 ID(选填) - 用于重新上传文档或修改文档清洗、分段配置,缺失的信息从源文档复制 - 源文档不可为归档的文档 - 当传入 original_document_id 时,代表文档进行更新操作,process_rule 为可填项目,不填默认使用源文档的分段方式 - 未传入 original_document_id 时,代表文档进行新增操作,process_rule 为必填 - - indexing_technique 索引方式 + - indexing_technique 索引方式 - high_quality 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引 - - economy 经济:使用 Keyword Table Index 的倒排索引进行构建 + - economy 经济:使用 keyword table index 的倒排索引进行构建 - - process_rule 处理规则 + - process_rule 处理规则 - mode (string) 清洗、分段模式 ,automatic 自动 / custom 自定义 - rules (object) 自定义规则(自动模式下,该字段为空) - pre_processing_rules (array[object]) 预处理规则 @@ -166,7 +166,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - enabled (bool) 是否选中该规则,不传入文档 ID 时代表默认值 - segmentation (object) 分段规则 - separator 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n - - max_tokens 最大长度 (token) 默认为 1000 + - max_tokens 最大长度(token)默认为 1000 需要上传的文件。 @@ -177,11 +177,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create_by_file' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create-by-file' \ --header 'Authorization: Bearer {api_key}' \ --form 'data="{\"name\":\"Dify\",\"indexing_technique\":\"high_quality\",\"process_rule\":{\"rules\":{\"pre_processing_rules\":[{\"id\":\"remove_extra_spaces\",\"enabled\":true},{\"id\":\"remove_urls_emails\",\"enabled\":true}],\"segmentation\":{\"separator\":\"###\",\"max_tokens\":500}},\"mode\":\"custom\"}}";type=text/plain' \ --form 'file=@"/path/to/file"' @@ -221,7 +221,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
- 知识库名称 + 知识库名称(必填) + + + 知识库描述(选填) + + + 索引模式(选填,建议填写) + - high_quality 高质量 + - economy 经济 + + + 权限(选填,默认 only_me) + - only_me 仅自己 + - all_team_members 所有团队成员 + - partial_members 部分团队成员 + + + Provider(选填,默认 vendor) + - vendor 上传文件 + - external 外部知识库 + + + 外部知识库 API_ID(选填) + + + 外部知识库 ID(选填) - ```bash {{ title: 'cURL' }} curl --location --request POST '${props.apiBaseUrl}/datasets' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ - "name": "name" + "name": "name", + "permission": "only_me" }' ``` @@ -280,7 +306,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
---- +
---- +
@@ -405,7 +431,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### Request Body - 文档名称 (选填) + 文档名称(选填) 文档内容(选填) @@ -422,7 +448,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - enabled (bool) 是否选中该规则,不传入文档 ID 时代表默认值 - segmentation (object) 分段规则 - separator 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n - - max_tokens 最大长度 (token) 默认为 1000 + - max_tokens 最大长度(token)默认为 1000 @@ -430,11 +456,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/update_by_text' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/update-by-text' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ @@ -477,13 +503,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -502,7 +528,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### Request Body - 文档名称 (选填) + 文档名称(选填) 需要上传的文件 @@ -519,7 +545,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - enabled (bool) 是否选中该规则,不传入文档 ID 时代表默认值 - segmentation (object) 分段规则 - separator 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n - - max_tokens 最大长度 (token) 默认为 1000 + - max_tokens 最大长度(token)默认为 1000 @@ -527,11 +553,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/update_by_file' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/update-by-file' \ --header 'Authorization: Bearer {api_key}' \ --form 'data="{\"name\":\"Dify\",\"indexing_technique\":\"high_quality\",\"process_rule\":{\"rules\":{\"pre_processing_rules\":[{\"id\":\"remove_extra_spaces\",\"enabled\":true},{\"id\":\"remove_urls_emails\",\"enabled\":true}],\"segmentation\":{\"separator\":\"###\",\"max_tokens\":500}},\"mode\":\"custom\"}}";type=text/plain' \ --form 'file=@"/path/to/file"' @@ -571,7 +597,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
---- +
---- +
---- +
- content (text) 文本内容/问题内容,必填 - - answer (text) 答案内容,非必填,如果知识库的模式为qa模式则传值 + - answer (text) 答案内容,非必填,如果知识库的模式为 Q&A 模式则传值 - keywords (list) 关键字,非必填 @@ -829,7 +855,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
---- +
---- +
- + - content (text) 文本内容/问题内容,必填 - - answer (text) 答案内容,非必填,如果知识库的模式为qa模式则传值 + - answer (text) 答案内容,非必填,如果知识库的模式为 Q&A 模式则传值 - keywords (list) 关键字,非必填 - enabled (bool) false/true,非必填 @@ -1042,7 +1068,153 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
+ + + + + ### Path + + + 知识库 ID + + + + ### Request Body + + + 检索关键词 + + + 检索参数(选填,如不填,按照默认方式召回) + - search_method (text) 检索方法:以下三个关键字之一,必填 + - keyword_search 关键字检索 + - semantic_search 语义检索 + - full_text_search 全文检索 + - hybrid_search 混合检索 + - reranking_enable (bool) 是否启用 Reranking,非必填,如果检索模式为 semantic_search 模式或者 hybrid_search 则传值 + - reranking_mode (object) Rerank模型配置,非必填,如果启用了 reranking 则传值 + - reranking_provider_name (string) Rerank 模型提供商 + - reranking_model_name (string) Rerank 模型名称 + - weights (double) 混合检索模式下语意检索的权重设置 + - top_k (integer) 返回结果数量,非必填 + - score_threshold_enabled (bool) 是否开启 score 阈值 + - score_threshold (double) Score 阈值 + + + 未启用字段 + + + + + + ```bash {{ title: 'cURL' }} + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/retrieve' \ + --header 'Authorization: Bearer {api_key}' \ + --header 'Content-Type: application/json' \ + --data-raw '{ + "query": "test", + "retrieval_model": { + "search_method": "keyword_search", + "reranking_enable": false, + "reranking_mode": null, + "reranking_model": { + "reranking_provider_name": "", + "reranking_model_name": "" + }, + "weights": null, + "top_k": 2, + "score_threshold_enabled": false, + "score_threshold": null + } + }' + ``` + + + ```json {{ title: 'Response' }} + { + "query": { + "content": "test" + }, + "records": [ + { + "segment": { + "id": "7fa6f24f-8679-48b3-bc9d-bdf28d73f218", + "position": 1, + "document_id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2", + "content": "Operation guide", + "answer": null, + "word_count": 847, + "tokens": 280, + "keywords": [ + "install", + "java", + "base", + "scripts", + "jdk", + "manual", + "internal", + "opens", + "add", + "vmoptions" + ], + "index_node_id": "39dd8443-d960-45a8-bb46-7275ad7fbc8e", + "index_node_hash": "0189157697b3c6a418ccf8264a09699f25858975578f3467c76d6bfc94df1d73", + "hit_count": 0, + "enabled": true, + "disabled_at": null, + "disabled_by": null, + "status": "completed", + "created_by": "dbcb1ab5-90c8-41a7-8b78-73b235eb6f6f", + "created_at": 1728734540, + "indexing_at": 1728734552, + "completed_at": 1728734584, + "error": null, + "stopped_at": null, + "document": { + "id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2", + "data_source_type": "upload_file", + "name": "readme.txt", + "doc_type": null + } + }, + "score": 3.730463140527718e-05, + "tsne_position": null + } + ] + } + ``` + + + + + +
diff --git a/web/app/(shareLayout)/layout.tsx b/web/app/(shareLayout)/layout.tsx index 9c4632cd450dfa..259af2bc2dc845 100644 --- a/web/app/(shareLayout)/layout.tsx +++ b/web/app/(shareLayout)/layout.tsx @@ -1,7 +1,12 @@ import React from 'react' import type { FC } from 'react' +import type { Metadata } from 'next' import GA, { GaType } from '@/app/components/base/ga' +export const metadata: Metadata = { + icons: 'data:,', // prevent browser from using default favicon +} + const Layout: FC<{ children: React.ReactNode }> = ({ children }) => { diff --git a/web/app/account/account-page/index.module.css b/web/app/account/account-page/index.module.css new file mode 100644 index 00000000000000..949d1257e9820c --- /dev/null +++ b/web/app/account/account-page/index.module.css @@ -0,0 +1,9 @@ +.modal { + padding: 24px 32px !important; + width: 400px !important; +} + +.bg { + background: linear-gradient(180deg, rgba(217, 45, 32, 0.05) 0%, rgba(217, 45, 32, 0.00) 24.02%), #F9FAFB; +} + diff --git a/web/app/account/account-page/index.tsx b/web/app/account/account-page/index.tsx new file mode 100644 index 00000000000000..71540ce3b1265a --- /dev/null +++ b/web/app/account/account-page/index.tsx @@ -0,0 +1,335 @@ +'use client' +import { useState } from 'react' +import { useTranslation } from 'react-i18next' + +import { useContext } from 'use-context-selector' +import s from './index.module.css' +import Collapse from '@/app/components/header/account-setting/collapse' +import type { IItem } from '@/app/components/header/account-setting/collapse' +import Modal from '@/app/components/base/modal' +import Confirm from '@/app/components/base/confirm' +import Button from '@/app/components/base/button' +import { updateUserProfile } from '@/service/common' +import { useAppContext } from '@/context/app-context' +import { ToastContext } from '@/app/components/base/toast' +import AppIcon from '@/app/components/base/app-icon' +import Avatar from '@/app/components/base/avatar' +import { IS_CE_EDITION } from '@/config' +import Input from '@/app/components/base/input' + +const titleClassName = ` + text-sm font-medium text-gray-900 +` +const descriptionClassName = ` + mt-1 text-xs font-normal text-gray-500 +` + +const validPassword = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/ + +export default function AccountPage() { + const { t } = useTranslation() + const { systemFeatures } = useAppContext() + const { mutateUserProfile, userProfile, apps } = useAppContext() + const { notify } = useContext(ToastContext) + const [editNameModalVisible, setEditNameModalVisible] = useState(false) + const [editName, setEditName] = useState('') + const [editing, setEditing] = useState(false) + const [editPasswordModalVisible, setEditPasswordModalVisible] = useState(false) + const [currentPassword, setCurrentPassword] = useState('') + const [password, setPassword] = useState('') + const [confirmPassword, setConfirmPassword] = useState('') + const [showDeleteAccountModal, setShowDeleteAccountModal] = useState(false) + const [showCurrentPassword, setShowCurrentPassword] = useState(false) + const [showPassword, setShowPassword] = useState(false) + const [showConfirmPassword, setShowConfirmPassword] = useState(false) + + const handleEditName = () => { + setEditNameModalVisible(true) + setEditName(userProfile.name) + } + const handleSaveName = async () => { + try { + setEditing(true) + await updateUserProfile({ url: 'account/name', body: { name: editName } }) + notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) + mutateUserProfile() + setEditNameModalVisible(false) + setEditing(false) + } + catch (e) { + notify({ type: 'error', message: (e as Error).message }) + setEditNameModalVisible(false) + setEditing(false) + } + } + + const showErrorMessage = (message: string) => { + notify({ + type: 'error', + message, + }) + } + const valid = () => { + if (!password.trim()) { + showErrorMessage(t('login.error.passwordEmpty')) + return false + } + if (!validPassword.test(password)) { + showErrorMessage(t('login.error.passwordInvalid')) + return false + } + if (password !== confirmPassword) { + showErrorMessage(t('common.account.notEqual')) + return false + } + + return true + } + const resetPasswordForm = () => { + setCurrentPassword('') + setPassword('') + setConfirmPassword('') + } + const handleSavePassword = async () => { + if (!valid()) + return + try { + setEditing(true) + await updateUserProfile({ + url: 'account/password', + body: { + password: currentPassword, + new_password: password, + repeat_new_password: confirmPassword, + }, + }) + notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) + mutateUserProfile() + setEditPasswordModalVisible(false) + resetPasswordForm() + setEditing(false) + } + catch (e) { + notify({ type: 'error', message: (e as Error).message }) + setEditPasswordModalVisible(false) + setEditing(false) + } + } + + const renderAppItem = (item: IItem) => { + return ( +
+
+ +
+
{item.name}
+
+ ) + } + + return ( + <> +
+

{t('common.account.myAccount')}

+
+
+ +
+

{userProfile.name}

+

{userProfile.email}

+
+
+
+
{t('common.account.name')}
+
+
+ {userProfile.name} +
+
+ {t('common.operation.edit')} +
+
+
+
+
{t('common.account.email')}
+
+
+ {userProfile.email} +
+
+
+ { + systemFeatures.enable_email_password_login && ( +
+
+
{t('common.account.password')}
+
{t('common.account.passwordTip')}
+
+ +
+ ) + } +
+
+
{t('common.account.langGeniusAccount')}
+
{t('common.account.langGeniusAccountTip')}
+ {!!apps.length && ( + ({ key: app.id, name: app.name }))} + renderItem={renderAppItem} + wrapperClassName='mt-2' + /> + )} + {!IS_CE_EDITION && } +
+ { + editNameModalVisible && ( + setEditNameModalVisible(false)} + className={s.modal} + > +
{t('common.account.editName')}
+
{t('common.account.name')}
+ setEditName(e.target.value)} + /> +
+ + +
+
+ ) + } + { + editPasswordModalVisible && ( + { + setEditPasswordModalVisible(false) + resetPasswordForm() + }} + className={s.modal} + > +
{userProfile.is_password_set ? t('common.account.resetPassword') : t('common.account.setPassword')}
+ {userProfile.is_password_set && ( + <> +
{t('common.account.currentPassword')}
+
+ setCurrentPassword(e.target.value)} + /> + +
+ +
+
+ + )} +
+ {userProfile.is_password_set ? t('common.account.newPassword') : t('common.account.password')} +
+
+ setPassword(e.target.value)} + /> +
+ +
+
+
{t('common.account.confirmPassword')}
+
+ setConfirmPassword(e.target.value)} + /> +
+ +
+
+
+ + +
+
+ ) + } + { + showDeleteAccountModal && ( + setShowDeleteAccountModal(false)} + onConfirm={() => setShowDeleteAccountModal(false)} + showCancel={false} + type='warning' + title={t('common.account.delete')} + content={ + <> +
+ {t('common.account.deleteTip')} +
+ +
{`${t('common.account.delete')}: ${userProfile.email}`}
+ + } + confirmText={t('common.operation.ok') as string} + /> + ) + } + + ) +} diff --git a/web/app/account/avatar.tsx b/web/app/account/avatar.tsx new file mode 100644 index 00000000000000..544e43ab27f99f --- /dev/null +++ b/web/app/account/avatar.tsx @@ -0,0 +1,95 @@ +'use client' +import { useTranslation } from 'react-i18next' +import { Fragment } from 'react' +import { useRouter } from 'next/navigation' +import { Menu, Transition } from '@headlessui/react' +import Avatar from '@/app/components/base/avatar' +import { logout } from '@/service/common' +import { useAppContext } from '@/context/app-context' +import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general' + +export type IAppSelector = { + isMobile: boolean +} + +export default function AppSelector() { + const router = useRouter() + const { t } = useTranslation() + const { userProfile } = useAppContext() + + const handleLogout = async () => { + await logout({ + url: '/logout', + params: {}, + }) + + localStorage.removeItem('setup_status') + localStorage.removeItem('console_token') + localStorage.removeItem('refresh_token') + + router.push('/signin') + } + + return ( + + { + ({ open }) => ( + <> +
+ + + +
+ + + +
+
+
+
{userProfile.name}
+
{userProfile.email}
+
+ +
+
+
+ +
handleLogout()}> +
+ +
{t('common.userProfile.logout')}
+
+
+
+
+
+ + ) + } +
+ ) +} diff --git a/web/app/account/header.tsx b/web/app/account/header.tsx new file mode 100644 index 00000000000000..694533e5ab7cb6 --- /dev/null +++ b/web/app/account/header.tsx @@ -0,0 +1,37 @@ +'use client' +import { useTranslation } from 'react-i18next' +import { RiArrowRightUpLine, RiRobot2Line } from '@remixicon/react' +import { useRouter } from 'next/navigation' +import Button from '../components/base/button' +import Avatar from './avatar' +import LogoSite from '@/app/components/base/logo/logo-site' + +const Header = () => { + const { t } = useTranslation() + const router = useRouter() + + const back = () => { + router.back() + } + return ( +
+
+
+ +
+
+

{t('common.account.account')}

+
+
+ +
+ +
+
+ ) +} +export default Header diff --git a/web/app/account/layout.tsx b/web/app/account/layout.tsx new file mode 100644 index 00000000000000..5aa8b05cbfd07b --- /dev/null +++ b/web/app/account/layout.tsx @@ -0,0 +1,40 @@ +import React from 'react' +import type { ReactNode } from 'react' +import Header from './header' +import SwrInitor from '@/app/components/swr-initor' +import { AppContextProvider } from '@/context/app-context' +import GA, { GaType } from '@/app/components/base/ga' +import HeaderWrapper from '@/app/components/header/header-wrapper' +import { EventEmitterContextProvider } from '@/context/event-emitter' +import { ProviderContextProvider } from '@/context/provider-context' +import { ModalContextProvider } from '@/context/modal-context' + +const Layout = ({ children }: { children: ReactNode }) => { + return ( + <> + + + + + + + +
+ +
+ {children} +
+ + + + + + + ) +} + +export const metadata = { + title: 'Dify', +} + +export default Layout diff --git a/web/app/account/page.tsx b/web/app/account/page.tsx new file mode 100644 index 00000000000000..bb7e7f7feb1840 --- /dev/null +++ b/web/app/account/page.tsx @@ -0,0 +1,7 @@ +import AccountPage from './account-page' + +export default function Account() { + return
+ +
+} diff --git a/web/app/activate/activateForm.tsx b/web/app/activate/activateForm.tsx index 3b1eed6f09a5d2..9a32a76a73c20b 100644 --- a/web/app/activate/activateForm.tsx +++ b/web/app/activate/activateForm.tsx @@ -1,27 +1,16 @@ 'use client' -import { useCallback, useState } from 'react' -import { useContext } from 'use-context-selector' import { useTranslation } from 'react-i18next' import useSWR from 'swr' -import { useSearchParams } from 'next/navigation' -import Link from 'next/link' -import { CheckCircleIcon } from '@heroicons/react/24/solid' -import style from './style.module.css' +import { useRouter, useSearchParams } from 'next/navigation' import cn from '@/utils/classnames' import Button from '@/app/components/base/button' -import { SimpleSelect } from '@/app/components/base/select' -import { timezones } from '@/utils/timezone' -import { LanguagesSupported, languages } from '@/i18n/language' -import { activateMember, invitationCheck } from '@/service/common' -import Toast from '@/app/components/base/toast' +import { invitationCheck } from '@/service/common' import Loading from '@/app/components/base/loading' -import I18n from '@/context/i18n' -const validPassword = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/ const ActivateForm = () => { + const router = useRouter() const { t } = useTranslation() - const { locale, setLocaleOnClient } = useContext(I18n) const searchParams = useSearchParams() const workspaceID = searchParams.get('workspace_id') const email = searchParams.get('email') @@ -35,64 +24,20 @@ const ActivateForm = () => { token, }, } - const { data: checkRes, mutate: recheck } = useSWR(checkParams, invitationCheck, { + const { data: checkRes } = useSWR(checkParams, invitationCheck, { revalidateOnFocus: false, + onSuccess(data) { + if (data.is_valid) { + const params = new URLSearchParams(searchParams) + const { email, workspace_id } = data.data + params.set('email', encodeURIComponent(email)) + params.set('workspace_id', encodeURIComponent(workspace_id)) + params.set('invite_token', encodeURIComponent(token as string)) + router.replace(`/signin?${params.toString()}`) + } + }, }) - const [name, setName] = useState('') - const [password, setPassword] = useState('') - const [timezone, setTimezone] = useState(Intl.DateTimeFormat().resolvedOptions().timeZone) - const [language, setLanguage] = useState(locale) - const [showSuccess, setShowSuccess] = useState(false) - - const showErrorMessage = useCallback((message: string) => { - Toast.notify({ - type: 'error', - message, - }) - }, []) - - const valid = useCallback(() => { - if (!name.trim()) { - showErrorMessage(t('login.error.nameEmpty')) - return false - } - if (!password.trim()) { - showErrorMessage(t('login.error.passwordEmpty')) - return false - } - if (!validPassword.test(password)) { - showErrorMessage(t('login.error.passwordInvalid')) - return false - } - - return true - }, [name, password, showErrorMessage, t]) - - const handleActivate = useCallback(async () => { - if (!valid()) - return - try { - await activateMember({ - url: '/activate', - body: { - workspace_id: workspaceID, - email, - token, - name, - password, - interface_language: language, - timezone, - }, - }) - setLocaleOnClient(language, false) - setShowSuccess(true) - } - catch { - recheck() - } - }, [email, language, name, password, recheck, setLocaleOnClient, timezone, token, valid, workspaceID]) - return (
{
)} - {checkRes && checkRes.is_valid && !showSuccess && ( -
-
-
-
-

- {`${t('login.join')} ${checkRes.workspace_name}`} -

-

- {`${t('login.joinTipStart')} ${checkRes.workspace_name} ${t('login.joinTipEnd')}`} -

-
- -
-
- {/* username */} -
- -
- setName(e.target.value)} - placeholder={t('login.namePlaceholder') || ''} - className={'appearance-none block w-full rounded-lg pl-[14px] px-3 py-2 border border-gray-200 hover:border-gray-300 hover:shadow-sm focus:outline-none focus:ring-primary-500 focus:border-primary-500 placeholder-gray-400 caret-primary-600 sm:text-sm pr-10'} - /> -
-
- {/* password */} -
- -
- setPassword(e.target.value)} - placeholder={t('login.passwordPlaceholder') || ''} - className={'appearance-none block w-full rounded-lg pl-[14px] px-3 py-2 border border-gray-200 hover:border-gray-300 hover:shadow-sm focus:outline-none focus:ring-primary-500 focus:border-primary-500 placeholder-gray-400 caret-primary-600 sm:text-sm pr-10'} - /> -
-
{t('login.error.passwordInvalid')}
-
- {/* language */} -
- -
- item.supported)} - onSelect={(item) => { - setLanguage(item.value as string) - }} - /> -
-
- {/* timezone */} -
- -
- { - setTimezone(item.value as string) - }} - /> -
-
-
- -
-
- {t('login.license.tip')} -   - {t('login.license.link')} -
-
-
-
- )} - {checkRes && checkRes.is_valid && showSuccess && ( -
-
-
- -
-

- {`${t('login.activatedTipStart')} ${checkRes.workspace_name} ${t('login.activatedTipEnd')}`} -

-
- -
- )}
) } diff --git a/web/app/activate/page.tsx b/web/app/activate/page.tsx index 90874f50cefe0a..0f1854433552db 100644 --- a/web/app/activate/page.tsx +++ b/web/app/activate/page.tsx @@ -22,7 +22,7 @@ const Activate = () => {
- © {new Date().getFullYear()} Dify, Inc. All rights reserved. + © {new Date().getFullYear()} LangGenius, Inc. All rights reserved.
diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index 698846cae5144b..12fe5cba468df3 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -22,7 +22,7 @@ import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/ap import DuplicateAppModal from '@/app/components/app/duplicate-modal' import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-modal' import CreateAppModal from '@/app/components/explore/create-app-modal' -import { AiText, ChatBot, CuteRobote } from '@/app/components/base/icons/src/vender/solid/communication' +import { AiText, ChatBot, CuteRobot } from '@/app/components/base/icons/src/vender/solid/communication' import { Route } from '@/app/components/base/icons/src/vender/solid/mapsAndTravel' import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' @@ -59,9 +59,11 @@ const AppInfo = ({ expand }: IAppInfoProps) => { const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({ name, + icon_type, icon, icon_background, description, + use_icon_as_answer_icon, }) => { if (!appDetail) return @@ -69,9 +71,11 @@ const AppInfo = ({ expand }: IAppInfoProps) => { const app = await updateAppInfo({ appID: appDetail.id, name, + icon_type, icon, icon_background, description, + use_icon_as_answer_icon, }) setShowEditModal(false) notify({ @@ -86,13 +90,14 @@ const AppInfo = ({ expand }: IAppInfoProps) => { } }, [appDetail, mutateApps, notify, setAppDetail, t]) - const onCopy: DuplicateAppModalProps['onConfirm'] = async ({ name, icon, icon_background }) => { + const onCopy: DuplicateAppModalProps['onConfirm'] = async ({ name, icon_type, icon, icon_background }) => { if (!appDetail) return try { const newApp = await copyApp({ appID: appDetail.id, name, + icon_type, icon, icon_background, mode: appDetail.mode, @@ -194,7 +199,13 @@ const AppInfo = ({ expand }: IAppInfoProps) => { >
- + { )} {appDetail.mode === 'agent-chat' && ( - + )} {appDetail.mode === 'chat' && ( @@ -257,13 +268,19 @@ const AppInfo = ({ expand }: IAppInfoProps) => { {/* header */}
- + {appDetail.mode === 'advanced-chat' && ( )} {appDetail.mode === 'agent-chat' && ( - + )} {appDetail.mode === 'chat' && ( @@ -306,7 +323,7 @@ const AppInfo = ({ expand }: IAppInfoProps) => {
- {/* desscription */} + {/* description */} {appDetail.description && (
{appDetail.description}
)} @@ -402,10 +419,14 @@ const AppInfo = ({ expand }: IAppInfoProps) => { {showEditModal && ( setShowEditModal(false)} @@ -414,8 +435,10 @@ const AppInfo = ({ expand }: IAppInfoProps) => { {showDuplicateModal && ( setShowDuplicateModal(false)} diff --git a/web/app/components/app-sidebar/basic.tsx b/web/app/components/app-sidebar/basic.tsx index 09f978b04b4f01..51fc10721eb7a4 100644 --- a/web/app/components/app-sidebar/basic.tsx +++ b/web/app/components/app-sidebar/basic.tsx @@ -1,15 +1,13 @@ import React from 'react' -import { - InformationCircleIcon, -} from '@heroicons/react/24/outline' -import Tooltip from '../base/tooltip' +import { useTranslation } from 'react-i18next' import AppIcon from '../base/app-icon' -import { randomString } from '@/utils' +import Tooltip from '@/app/components/base/tooltip' export type IAppBasicProps = { iconType?: 'app' | 'api' | 'dataset' | 'webapp' | 'notion' icon?: string - icon_background?: string + icon_background?: string | null + isExternal?: boolean name: string type: string | React.ReactNode hoverTip?: string @@ -56,7 +54,9 @@ const ICON_MAP = { notion: , } -export default function AppBasic({ icon, icon_background, name, type, hoverTip, textStyle, mode = 'expand', iconType = 'app' }: IAppBasicProps) { +export default function AppBasic({ icon, icon_background, name, isExternal, type, hoverTip, textStyle, mode = 'expand', iconType = 'app' }: IAppBasicProps) { + const { t } = useTranslation() + return (
{icon && icon_background && iconType === 'app' && ( @@ -74,11 +74,20 @@ export default function AppBasic({ icon, icon_background, name, type, hoverTip,
{name} {hoverTip - && - - } + && + {hoverTip} +
+ } + popupClassName='ml-1' + triggerClassName='w-4 h-4 ml-1' + position='top' + /> + }
{type}
+
{isExternal ? t('dataset.externalTag') : ''}
} ) diff --git a/web/app/components/app-sidebar/index.tsx b/web/app/components/app-sidebar/index.tsx index 5d5d407dc0f900..5ee063ad646f3d 100644 --- a/web/app/components/app-sidebar/index.tsx +++ b/web/app/components/app-sidebar/index.tsx @@ -15,6 +15,7 @@ export type IAppDetailNavProps = { iconType?: 'app' | 'dataset' | 'notion' title: string desc: string + isExternal?: boolean icon: string icon_background: string navigation: Array<{ @@ -26,7 +27,7 @@ export type IAppDetailNavProps = { extraInfo?: (modeState: string) => React.ReactNode } -const AppDetailNav = ({ title, desc, icon, icon_background, navigation, extraInfo, iconType = 'app' }: IAppDetailNavProps) => { +const AppDetailNav = ({ title, desc, isExternal, icon, icon_background, navigation, extraInfo, iconType = 'app' }: IAppDetailNavProps) => { const { appSidebarExpand, setAppSiderbarExpand } = useAppStore(useShallow(state => ({ appSidebarExpand: state.appSidebarExpand, setAppSiderbarExpand: state.setAppSiderbarExpand, @@ -70,6 +71,7 @@ const AppDetailNav = ({ title, desc, icon, icon_background, navigation, extraInf icon_background={icon_background} name={title} type={desc} + isExternal={isExternal} /> )} diff --git a/web/app/components/app/annotation/empty-element.tsx b/web/app/components/app/annotation/empty-element.tsx index 2498e3853414de..9ba31ce11e9e56 100644 --- a/web/app/components/app/annotation/empty-element.tsx +++ b/web/app/components/app/annotation/empty-element.tsx @@ -14,9 +14,9 @@ const EmptyElement: FC = () => { return (
-
- {t('appAnnotation.noData.title')} -
+
+ {t('appAnnotation.noData.title')} +
{t('appAnnotation.noData.description')}
diff --git a/web/app/components/app/annotation/filter.tsx b/web/app/components/app/annotation/filter.tsx index 9a2a75058296de..d741f6de123980 100644 --- a/web/app/components/app/annotation/filter.tsx +++ b/web/app/components/app/annotation/filter.tsx @@ -2,10 +2,8 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { - MagnifyingGlassIcon, -} from '@heroicons/react/24/solid' import useSWR from 'swr' +import Input from '@/app/components/base/input' import { fetchAnnotationsCount } from '@/service/log' export type QueryParam = { @@ -31,22 +29,18 @@ const Filter: FC = ({ if (!data) return null return ( -
-
-
-
- { - setQueryParams({ ...queryParams, keyword: e.target.value }) - }} - /> -
+
+ { + setQueryParams({ ...queryParams, keyword: e.target.value }) + }} + onClear={() => setQueryParams({ ...queryParams, keyword: '' })} + /> {children}
) diff --git a/web/app/components/app/annotation/index.tsx b/web/app/components/app/annotation/index.tsx index 1e65d7a94f4f05..0783c3fa661b67 100644 --- a/web/app/components/app/annotation/index.tsx +++ b/web/app/components/app/annotation/index.tsx @@ -3,6 +3,7 @@ import type { FC } from 'react' import React, { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import { Pagination } from 'react-headless-pagination' +import { useDebounce } from 'ahooks' import { ArrowLeftIcon, ArrowRightIcon } from '@heroicons/react/24/outline' import Toast from '../../base/toast' import Filter from './filter' @@ -18,7 +19,7 @@ import Switch from '@/app/components/base/switch' import { addAnnotation, delAnnotation, fetchAnnotationConfig as doFetchAnnotationConfig, editAnnotation, fetchAnnotationList, queryAnnotationJobStatus, updateAnnotationScore, updateAnnotationStatus } from '@/service/annotation' import Loading from '@/app/components/base/loading' import { APP_PAGE_LIMIT } from '@/config' -import ConfigParamModal from '@/app/components/app/configuration/toolbox/annotation/config-param-modal' +import ConfigParamModal from '@/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal' import type { AnnotationReplyConfig } from '@/models/debug' import { sleep } from '@/utils' import { useProviderContext } from '@/context/provider-context' @@ -67,10 +68,11 @@ const Annotation: FC = ({ const [queryParams, setQueryParams] = useState({}) const [currPage, setCurrPage] = React.useState(0) + const debouncedQueryParams = useDebounce(queryParams, { wait: 500 }) const query = { page: currPage + 1, limit: APP_PAGE_LIMIT, - keyword: queryParams.keyword || '', + keyword: debouncedQueryParams.keyword || '', } const [controlUpdateList, setControlUpdateList] = useState(Date.now()) @@ -150,8 +152,8 @@ const Annotation: FC = ({ return (
-

{t('appLog.description')}

-
+

{t('appLog.description')}

+
{isChatApp && ( @@ -280,7 +282,7 @@ const Annotation: FC = ({ onSave={async (embeddingModel, score) => { if ( embeddingModel.embedding_model_name !== annotationConfig?.embedding_model?.embedding_model_name - && embeddingModel.embedding_provider_name !== annotationConfig?.embedding_model?.embedding_provider_name + || embeddingModel.embedding_provider_name !== annotationConfig?.embedding_model?.embedding_provider_name ) { const { job_id: jobId }: any = await updateAnnotationStatus(appDetail.id, AnnotationEnableStatus.enable, embeddingModel, score) await ensureJobCompleted(jobId, AnnotationEnableStatus.enable) diff --git a/web/app/components/app/annotation/list.tsx b/web/app/components/app/annotation/list.tsx index bc3a35158ff13a..49de5a2e7f9e93 100644 --- a/web/app/components/app/annotation/list.tsx +++ b/web/app/components/app/annotation/list.tsx @@ -4,7 +4,6 @@ import React from 'react' import { useTranslation } from 'react-i18next' import { RiDeleteBinLine } from '@remixicon/react' import { Edit02 } from '../../base/icons/src/vender/line/general' -import s from './style.module.css' import type { AnnotationItem } from './type' import RemoveAnnotationConfirmModal from './remove-annotation-confirm-modal' import cn from '@/utils/classnames' @@ -27,21 +26,21 @@ const List: FC = ({ const [showConfirmDelete, setShowConfirmDelete] = React.useState(false) return (
- - - - - - - - +
{t('appAnnotation.table.header.question')}{t('appAnnotation.table.header.answer')}{t('appAnnotation.table.header.createdAt')}{t('appAnnotation.table.header.hits')}{t('appAnnotation.table.header.actions')}
+ + + + + + + - + {list.map(item => ( { onView(item) @@ -49,16 +48,16 @@ const List: FC = ({ } > - - - + + {localDocs.map((doc) => { const isFile = doc.data_source_type === DataSourceType.FILE - const fileType = isFile ? doc.data_source_detail_dict?.upload_file.extension : '' + const fileType = isFile ? doc.data_source_detail_dict?.upload_file?.extension : '' return = ({ embeddingAvailable, documents = }
- +
{ @@ -436,7 +444,7 @@ const DocumentList: FC = ({ embeddingAvailable, documents = >
-
+
diff --git a/web/app/components/datasets/documents/rename-modal.tsx b/web/app/components/datasets/documents/rename-modal.tsx index 401115b7b9015e..883897b510b129 100644 --- a/web/app/components/datasets/documents/rename-modal.tsx +++ b/web/app/components/datasets/documents/rename-modal.tsx @@ -6,6 +6,7 @@ import { useBoolean } from 'ahooks' import Toast from '../../base/toast' import Modal from '@/app/components/base/modal' import Button from '@/app/components/base/button' +import Input from '@/app/components/base/input' import { renameDocumentName } from '@/service/datasets' type Props = { @@ -59,7 +60,8 @@ const RenameModal: FC = ({ onClose={onClose} >
{t('datasetDocuments.list.table.name')}
- setNewName(e.target.value)} /> diff --git a/web/app/components/datasets/external-api/declarations.ts b/web/app/components/datasets/external-api/declarations.ts new file mode 100644 index 00000000000000..ded736d1677dad --- /dev/null +++ b/web/app/components/datasets/external-api/declarations.ts @@ -0,0 +1,16 @@ +export type CreateExternalAPIReq = { + name: string + settings: { + endpoint: string + api_key: string + } +} + +export type FormSchema = { + variable: string + type: 'text' | 'secret' + label: { + [key: string]: string + } + required: boolean +} diff --git a/web/app/components/datasets/external-api/external-api-modal/Form.tsx b/web/app/components/datasets/external-api/external-api-modal/Form.tsx new file mode 100644 index 00000000000000..ada01493fe8e43 --- /dev/null +++ b/web/app/components/datasets/external-api/external-api-modal/Form.tsx @@ -0,0 +1,90 @@ +import React, { useState } from 'react' +import type { FC } from 'react' +import { useTranslation } from 'react-i18next' +import { RiBookOpenLine } from '@remixicon/react' +import type { CreateExternalAPIReq, FormSchema } from '../declarations' +import Input from '@/app/components/base/input' +import cn from '@/utils/classnames' + +type FormProps = { + className?: string + itemClassName?: string + fieldLabelClassName?: string + value: CreateExternalAPIReq + onChange: (val: CreateExternalAPIReq) => void + formSchemas: FormSchema[] + inputClassName?: string +} + +const Form: FC = React.memo(({ + className, + itemClassName, + fieldLabelClassName, + value, + onChange, + formSchemas, + inputClassName, +}) => { + const { t, i18n } = useTranslation() + const [changeKey, setChangeKey] = useState('') + + const handleFormChange = (key: string, val: string) => { + setChangeKey(key) + if (key === 'name') { + onChange({ ...value, [key]: val }) + } + else { + onChange({ + ...value, + settings: { + ...value.settings, + [key]: val, + }, + }) + } + } + + const renderField = (formSchema: FormSchema) => { + const { variable, type, label, required } = formSchema + const fieldValue = variable === 'name' ? value[variable] : (value.settings[variable as keyof typeof value.settings] || '') + + return ( +
+
+ + {variable === 'endpoint' && ( + + + {t('dataset.externalAPIPanelDocumentation')} + + )} +
+ handleFormChange(variable, val.target.value)} + required={required} + className={cn(inputClassName)} + /> +
+ ) + } + + return ( + + {formSchemas.map(formSchema => renderField(formSchema))} + + ) +}) + +export default Form diff --git a/web/app/components/datasets/external-api/external-api-modal/index.tsx b/web/app/components/datasets/external-api/external-api-modal/index.tsx new file mode 100644 index 00000000000000..340d147a505bd9 --- /dev/null +++ b/web/app/components/datasets/external-api/external-api-modal/index.tsx @@ -0,0 +1,218 @@ +import type { FC } from 'react' +import { + memo, + useEffect, + useState, +} from 'react' +import { useTranslation } from 'react-i18next' +import { + RiBook2Line, + RiCloseLine, + RiInformation2Line, + RiLock2Fill, +} from '@remixicon/react' +import type { CreateExternalAPIReq, FormSchema } from '../declarations' +import Form from './Form' +import ActionButton from '@/app/components/base/action-button' +import Confirm from '@/app/components/base/confirm' +import { + PortalToFollowElem, + PortalToFollowElemContent, +} from '@/app/components/base/portal-to-follow-elem' +import { createExternalAPI } from '@/service/datasets' +import { useToastContext } from '@/app/components/base/toast' +import Button from '@/app/components/base/button' +import Tooltip from '@/app/components/base/tooltip' + +type AddExternalAPIModalProps = { + data?: CreateExternalAPIReq + onSave: (formValue: CreateExternalAPIReq) => void + onCancel: () => void + onEdit?: (formValue: CreateExternalAPIReq) => Promise + datasetBindings?: { id: string; name: string }[] + isEditMode: boolean +} + +const formSchemas: FormSchema[] = [ + { + variable: 'name', + type: 'text', + label: { + en_US: 'Name', + }, + required: true, + }, + { + variable: 'endpoint', + type: 'text', + label: { + en_US: 'API Endpoint', + }, + required: true, + }, + { + variable: 'api_key', + type: 'secret', + label: { + en_US: 'API Key', + }, + required: true, + }, +] + +const AddExternalAPIModal: FC = ({ data, onSave, onCancel, datasetBindings, isEditMode, onEdit }) => { + const { t } = useTranslation() + const { notify } = useToastContext() + const [loading, setLoading] = useState(false) + const [showConfirm, setShowConfirm] = useState(false) + const [formData, setFormData] = useState({ name: '', settings: { endpoint: '', api_key: '' } }) + + useEffect(() => { + if (isEditMode && data) + setFormData(data) + }, [isEditMode, data]) + + const hasEmptyInputs = Object.values(formData).some(value => + typeof value === 'string' ? value.trim() === '' : Object.values(value).some(v => v.trim() === ''), + ) + const handleDataChange = (val: CreateExternalAPIReq) => { + setFormData(val) + } + + const handleSave = async () => { + if (formData && formData.settings.api_key && formData.settings.api_key?.length < 5) { + notify({ type: 'error', message: t('common.apiBasedExtension.modal.apiKey.lengthError') }) + setLoading(false) + return + } + try { + setLoading(true) + if (isEditMode && onEdit) { + await onEdit( + { + ...formData, + settings: { ...formData.settings, api_key: formData.settings.api_key ? '[__HIDDEN__]' : formData.settings.api_key }, + }, + ) + notify({ type: 'success', message: 'External API updated successfully' }) + } + else { + const res = await createExternalAPI({ body: formData }) + if (res && res.id) { + notify({ type: 'success', message: 'External API saved successfully' }) + onSave(res) + } + } + onCancel() + } + catch (error) { + console.error('Error saving/updating external API:', error) + notify({ type: 'error', message: 'Failed to save/update External API' }) + } + finally { + setLoading(false) + } + } + + return ( + + +
+
+
+
+ { + isEditMode ? t('dataset.editExternalAPIFormTitle') : t('dataset.createExternalAPI') + } +
+ {isEditMode && (datasetBindings?.length ?? 0) > 0 && ( +
+ {t('dataset.editExternalAPIFormWarning.front')} + +  {datasetBindings?.length} {t('dataset.editExternalAPIFormWarning.end')}  + +
+
{`${datasetBindings?.length} ${t('dataset.editExternalAPITooltipTitle')}`}
+
+ {datasetBindings?.map(binding => ( +
+ +
{binding.name}
+
+ ))} +
+ } + asChild={false} + position='bottom' + > + + + +
+ )} +
+ + + +
+
+ + +
+
+ + {t('dataset.externalAPIForm.encrypted.front')} + + PKCS1_OAEP + + {t('dataset.externalAPIForm.encrypted.end')} +
+
+ {showConfirm && (datasetBindings?.length ?? 0) > 0 && ( + setShowConfirm(false)} + onConfirm={handleSave} + /> + )} + +
+
+ ) +} + +export default memo(AddExternalAPIModal) diff --git a/web/app/components/datasets/external-api/external-api-panel/index.tsx b/web/app/components/datasets/external-api/external-api-panel/index.tsx new file mode 100644 index 00000000000000..044c008b12a609 --- /dev/null +++ b/web/app/components/datasets/external-api/external-api-panel/index.tsx @@ -0,0 +1,90 @@ +import React from 'react' +import { + RiAddLine, + RiBookOpenLine, + RiCloseLine, +} from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import ExternalKnowledgeAPICard from '../external-knowledge-api-card' +import cn from '@/utils/classnames' +import { useExternalKnowledgeApi } from '@/context/external-knowledge-api-context' +import ActionButton from '@/app/components/base/action-button' +import Button from '@/app/components/base/button' +import Loading from '@/app/components/base/loading' +import { useModalContext } from '@/context/modal-context' + +type ExternalAPIPanelProps = { + onClose: () => void +} + +const ExternalAPIPanel: React.FC = ({ onClose }) => { + const { t } = useTranslation() + const { setShowExternalKnowledgeAPIModal } = useModalContext() + const { externalKnowledgeApiList, mutateExternalKnowledgeApis, isLoading } = useExternalKnowledgeApi() + + const handleOpenExternalAPIModal = () => { + setShowExternalKnowledgeAPIModal({ + payload: { name: '', settings: { endpoint: '', api_key: '' } }, + datasetBindings: [], + onSaveCallback: () => { + mutateExternalKnowledgeApis() + }, + onCancelCallback: () => { + mutateExternalKnowledgeApis() + }, + isEditMode: false, + }) + } + + return ( +
+
+
+
+
{t('dataset.externalAPIPanelTitle')}
+
{t('dataset.externalAPIPanelDescription')}
+ + +
{t('dataset.externalAPIPanelDocumentation')}
+
+
+
+ onClose()}> + + +
+
+
+ +
+
+ {isLoading + ? ( + + ) + : ( + externalKnowledgeApiList.map(api => ( + + )) + )} +
+
+
+ ) +} + +export default ExternalAPIPanel diff --git a/web/app/components/datasets/external-api/external-knowledge-api-card/index.tsx b/web/app/components/datasets/external-api/external-knowledge-api-card/index.tsx new file mode 100644 index 00000000000000..603b4fe7cb2284 --- /dev/null +++ b/web/app/components/datasets/external-api/external-knowledge-api-card/index.tsx @@ -0,0 +1,151 @@ +import React, { useState } from 'react' +import { useTranslation } from 'react-i18next' +import { + RiDeleteBinLine, + RiEditLine, +} from '@remixicon/react' +import type { CreateExternalAPIReq } from '../declarations' +import type { ExternalAPIItem } from '@/models/datasets' +import { checkUsageExternalAPI, deleteExternalAPI, fetchExternalAPI, updateExternalAPI } from '@/service/datasets' +import { ApiConnectionMod } from '@/app/components/base/icons/src/vender/solid/development' +import { useExternalKnowledgeApi } from '@/context/external-knowledge-api-context' +import { useModalContext } from '@/context/modal-context' +import ActionButton from '@/app/components/base/action-button' +import Confirm from '@/app/components/base/confirm' + +type ExternalKnowledgeAPICardProps = { + api: ExternalAPIItem +} + +const ExternalKnowledgeAPICard: React.FC = ({ api }) => { + const { setShowExternalKnowledgeAPIModal } = useModalContext() + const [showConfirm, setShowConfirm] = useState(false) + const [isHovered, setIsHovered] = useState(false) + const [usageCount, setUsageCount] = useState(0) + const { mutateExternalKnowledgeApis } = useExternalKnowledgeApi() + + const { t } = useTranslation() + + const handleEditClick = async () => { + try { + const response = await fetchExternalAPI({ apiTemplateId: api.id }) + const formValue: CreateExternalAPIReq = { + name: response.name, + settings: { + endpoint: response.settings.endpoint, + api_key: response.settings.api_key, + }, + } + + setShowExternalKnowledgeAPIModal({ + payload: formValue, + onSaveCallback: () => { + mutateExternalKnowledgeApis() + }, + onCancelCallback: () => { + mutateExternalKnowledgeApis() + }, + isEditMode: true, + datasetBindings: response.dataset_bindings, + onEditCallback: async (updatedData: CreateExternalAPIReq) => { + try { + await updateExternalAPI({ + apiTemplateId: api.id, + body: { + ...response, + name: updatedData.name, + settings: { + ...response.settings, + endpoint: updatedData.settings.endpoint, + api_key: updatedData.settings.api_key, + }, + }, + }) + mutateExternalKnowledgeApis() + } + catch (error) { + console.error('Error updating external knowledge API:', error) + } + }, + }) + } + catch (error) { + console.error('Error fetching external knowledge API data:', error) + } + } + + const handleDeleteClick = async () => { + try { + const usage = await checkUsageExternalAPI({ apiTemplateId: api.id }) + if (usage.is_using) + setUsageCount(usage.count) + + setShowConfirm(true) + } + catch (error) { + console.error('Error checking external API usage:', error) + } + } + + const handleConfirmDelete = async () => { + try { + const response = await deleteExternalAPI({ apiTemplateId: api.id }) + if (response && response.result === 'success') { + setShowConfirm(false) + mutateExternalKnowledgeApis() + } + else { + console.error('Failed to delete external API') + } + } + catch (error) { + console.error('Error deleting external knowledge API:', error) + } + } + + return ( + <> +
+
+
+ +
{api.name}
+
+
{api.settings.endpoint}
+
+
+ + + + setIsHovered(true)} + onMouseLeave={() => setIsHovered(false)} + > + + +
+
+ {showConfirm && ( + 0 + ? `${t('dataset.deleteExternalAPIConfirmWarningContent.content.front')} ${usageCount} ${t('dataset.deleteExternalAPIConfirmWarningContent.content.end')}` + : t('dataset.deleteExternalAPIConfirmWarningContent.noConnectionContent') + } + type='warning' + onConfirm={handleConfirmDelete} + onCancel={() => setShowConfirm(false)} + /> + )} + + ) +} + +export default ExternalKnowledgeAPICard diff --git a/web/app/components/datasets/external-knowledge-base/connector/index.tsx b/web/app/components/datasets/external-knowledge-base/connector/index.tsx new file mode 100644 index 00000000000000..33f57d0b47dbfe --- /dev/null +++ b/web/app/components/datasets/external-knowledge-base/connector/index.tsx @@ -0,0 +1,36 @@ +'use client' + +import React, { useState } from 'react' +import { useRouter } from 'next/navigation' +import { useToastContext } from '@/app/components/base/toast' +import ExternalKnowledgeBaseCreate from '@/app/components/datasets/external-knowledge-base/create' +import type { CreateKnowledgeBaseReq } from '@/app/components/datasets/external-knowledge-base/create/declarations' +import { createExternalKnowledgeBase } from '@/service/datasets' + +const ExternalKnowledgeBaseConnector = () => { + const { notify } = useToastContext() + const [loading, setLoading] = useState(false) + const router = useRouter() + + const handleConnect = async (formValue: CreateKnowledgeBaseReq) => { + try { + setLoading(true) + const result = await createExternalKnowledgeBase({ body: formValue }) + if (result && result.id) { + notify({ type: 'success', message: 'External Knowledge Base Connected Successfully' }) + router.back() + } + else { throw new Error('Failed to create external knowledge base') } + } + catch (error) { + console.error('Error creating external knowledge base:', error) + notify({ type: 'error', message: 'Failed to connect External Knowledge Base' }) + } + setLoading(false) + } + return ( + + ) +} + +export default ExternalKnowledgeBaseConnector diff --git a/web/app/components/datasets/external-knowledge-base/create/ExternalApiSelect.tsx b/web/app/components/datasets/external-knowledge-base/create/ExternalApiSelect.tsx new file mode 100644 index 00000000000000..a6a46479a4eb43 --- /dev/null +++ b/web/app/components/datasets/external-knowledge-base/create/ExternalApiSelect.tsx @@ -0,0 +1,110 @@ +import React, { useEffect, useState } from 'react' +import { + RiAddLine, + RiArrowDownSLine, +} from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { useRouter } from 'next/navigation' +import { ApiConnectionMod } from '@/app/components/base/icons/src/vender/solid/development' +import { useModalContext } from '@/context/modal-context' +import { useExternalKnowledgeApi } from '@/context/external-knowledge-api-context' + +type ApiItem = { + value: string + name: string + url: string +} + +type ExternalApiSelectProps = { + items: ApiItem[] + value?: string + onSelect: (item: ApiItem) => void +} + +const ExternalApiSelect: React.FC = ({ items, value, onSelect }) => { + const { t } = useTranslation() + const [isOpen, setIsOpen] = useState(false) + const [selectedItem, setSelectedItem] = useState( + items.find(item => item.value === value) || null, + ) + const { setShowExternalKnowledgeAPIModal } = useModalContext() + const { mutateExternalKnowledgeApis } = useExternalKnowledgeApi() + const router = useRouter() + + useEffect(() => { + const newSelectedItem = items.find(item => item.value === value) || null + setSelectedItem(newSelectedItem) + }, [value, items]) + + const handleAddNewAPI = () => { + setShowExternalKnowledgeAPIModal({ + payload: { name: '', settings: { endpoint: '', api_key: '' } }, + onSaveCallback: async () => { + mutateExternalKnowledgeApis() + router.refresh() + }, + onCancelCallback: () => { + mutateExternalKnowledgeApis() + }, + isEditMode: false, + }) + } + + const handleSelect = (item: ApiItem) => { + setSelectedItem(item) + onSelect(item) + setIsOpen(false) + } + + return ( +
+
setIsOpen(!isOpen)} + > + {selectedItem + ? ( +
+ +
+ {selectedItem.name} +
+
+ ) + : ( + {t('dataset.selectExternalKnowledgeAPI.placeholder')} + )} + +
+ {isOpen && ( +
+ {items.map(item => ( +
handleSelect(item)} + > +
+ + {item.name} + {item.url} +
+
+ ))} +
+
+ + {t('dataset.createNewExternalAPI')} +
+
+
+ )} +
+ ) +} + +export default ExternalApiSelect diff --git a/web/app/components/datasets/external-knowledge-base/create/ExternalApiSelection.tsx b/web/app/components/datasets/external-knowledge-base/create/ExternalApiSelection.tsx new file mode 100644 index 00000000000000..c910d9b2a7e7eb --- /dev/null +++ b/web/app/components/datasets/external-knowledge-base/create/ExternalApiSelection.tsx @@ -0,0 +1,96 @@ +'use client' + +import React, { useEffect, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { RiAddLine } from '@remixicon/react' +import { useRouter } from 'next/navigation' +import ExternalApiSelect from './ExternalApiSelect' +import Input from '@/app/components/base/input' +import Button from '@/app/components/base/button' +import { useModalContext } from '@/context/modal-context' +import { useExternalKnowledgeApi } from '@/context/external-knowledge-api-context' + +type ExternalApiSelectionProps = { + external_knowledge_api_id: string + external_knowledge_id: string + onChange: (data: { external_knowledge_api_id?: string; external_knowledge_id?: string }) => void +} + +const ExternalApiSelection: React.FC = ({ external_knowledge_api_id, external_knowledge_id, onChange }) => { + const { t } = useTranslation() + const router = useRouter() + const { externalKnowledgeApiList } = useExternalKnowledgeApi() + const [selectedApiId, setSelectedApiId] = useState(external_knowledge_api_id) + const { setShowExternalKnowledgeAPIModal } = useModalContext() + const { mutateExternalKnowledgeApis } = useExternalKnowledgeApi() + + const apiItems = externalKnowledgeApiList.map(api => ({ + value: api.id, + name: api.name, + url: api.settings.endpoint, + })) + + useEffect(() => { + if (apiItems.length > 0) { + const newSelectedId = external_knowledge_api_id || apiItems[0].value + setSelectedApiId(newSelectedId) + if (newSelectedId !== external_knowledge_api_id) + onChange({ external_knowledge_api_id: newSelectedId, external_knowledge_id }) + } + }, [apiItems, external_knowledge_api_id, external_knowledge_id, onChange]) + + const handleAddNewAPI = () => { + setShowExternalKnowledgeAPIModal({ + payload: { name: '', settings: { endpoint: '', api_key: '' } }, + onSaveCallback: async () => { + mutateExternalKnowledgeApis() + router.refresh() + }, + onCancelCallback: () => { + mutateExternalKnowledgeApis() + }, + isEditMode: false, + }) + } + + useEffect(() => { + if (!external_knowledge_api_id && apiItems.length > 0) + onChange({ external_knowledge_api_id: apiItems[0].value, external_knowledge_id }) + }, []) + + return ( + +
+
+ +
+ {apiItems.length > 0 + ? { + setSelectedApiId(e.value) + onChange({ external_knowledge_api_id: e.value, external_knowledge_id }) + }} + /> + : + } +
+
+
+ +
+ onChange({ external_knowledge_id: e.target.value, external_knowledge_api_id })} + placeholder={t('dataset.externalKnowledgeIdPlaceholder') ?? ''} + /> +
+ + ) +} + +export default ExternalApiSelection diff --git a/web/app/components/datasets/external-knowledge-base/create/InfoPanel.tsx b/web/app/components/datasets/external-knowledge-base/create/InfoPanel.tsx new file mode 100644 index 00000000000000..bd32683c8579c1 --- /dev/null +++ b/web/app/components/datasets/external-knowledge-base/create/InfoPanel.tsx @@ -0,0 +1,33 @@ +import { RiBookOpenLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' + +const InfoPanel = () => { + const { t } = useTranslation() + + return ( +
+
+
+ +
+

+ + {t('dataset.connectDatasetIntro.title')} + + + {t('dataset.connectDatasetIntro.content.front')} + + {t('dataset.connectDatasetIntro.content.link')} + + {t('dataset.connectDatasetIntro.content.end')} + + + {t('dataset.connectDatasetIntro.learnMore')} + +

+
+
+ ) +} + +export default InfoPanel diff --git a/web/app/components/datasets/external-knowledge-base/create/KnowledgeBaseInfo.tsx b/web/app/components/datasets/external-knowledge-base/create/KnowledgeBaseInfo.tsx new file mode 100644 index 00000000000000..fec526b8811f78 --- /dev/null +++ b/web/app/components/datasets/external-knowledge-base/create/KnowledgeBaseInfo.tsx @@ -0,0 +1,53 @@ +import React from 'react' +import { useTranslation } from 'react-i18next' +import Input from '@/app/components/base/input' + +type KnowledgeBaseInfoProps = { + name: string + description?: string + onChange: (data: { name?: string; description?: string }) => void +} + +const KnowledgeBaseInfo: React.FC = ({ name, description, onChange }) => { + const { t } = useTranslation() + + const handleNameChange = (e: React.ChangeEvent) => { + onChange({ name: e.target.value }) + } + + const handleDescriptionChange = (e: React.ChangeEvent) => { + onChange({ description: e.target.value }) + } + + return ( +
+
+
+
+ +
+ +
+
+
+ +
+
+ + />
{/* Available Tools */} @@ -291,7 +293,7 @@ const EditCustomCollectionModal: FC = ({ {/* Privacy Policy */}
{t('tools.createTool.privacyPolicy')}
- { const newCollection = produce(customCollection, (draft) => { @@ -299,12 +301,12 @@ const EditCustomCollectionModal: FC = ({ }) setCustomCollection(newCollection) }} - className='w-full h-10 px-3 text-sm font-normal bg-gray-100 rounded-lg grow' placeholder={t('tools.createTool.privacyPolicyPlaceholder') || ''} /> + className='h-10 grow' placeholder={t('tools.createTool.privacyPolicyPlaceholder') || ''} />
{t('tools.createTool.customDisclaimer')}
- { const newCollection = produce(customCollection, (draft) => { @@ -312,7 +314,7 @@ const EditCustomCollectionModal: FC = ({ }) setCustomCollection(newCollection) }} - className='w-full h-10 px-3 text-sm font-normal bg-gray-100 rounded-lg grow' placeholder={t('tools.createTool.customDisclaimerPlaceholder') || ''} /> + className='h-10 grow' placeholder={t('tools.createTool.customDisclaimerPlaceholder') || ''} />
diff --git a/web/app/components/tools/labels/filter.tsx b/web/app/components/tools/labels/filter.tsx index 1223f918460a72..20db687e79b872 100644 --- a/web/app/components/tools/labels/filter.tsx +++ b/web/app/components/tools/labels/filter.tsx @@ -11,7 +11,7 @@ import { PortalToFollowElemContent, PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' -import SearchInput from '@/app/components/base/search-input' +import Input from '@/app/components/base/input' import { Tag01, Tag03 } from '@/app/components/base/icons/src/vender/line/financeAndECommerce' import { Check } from '@/app/components/base/icons/src/vender/line/general' import { XCircle } from '@/app/components/base/icons/src/vender/solid/general' @@ -113,7 +113,13 @@ const LabelFilter: FC = ({
- + handleKeywordsChange(e.target.value)} + onClear={() => handleKeywordsChange('')} + />
{filteredLabelList.map(label => ( diff --git a/web/app/components/tools/labels/selector.tsx b/web/app/components/tools/labels/selector.tsx index 2cc430d9569d06..3f33e45b9112bf 100644 --- a/web/app/components/tools/labels/selector.tsx +++ b/web/app/components/tools/labels/selector.tsx @@ -11,7 +11,7 @@ import { PortalToFollowElemContent, PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' -import SearchInput from '@/app/components/base/search-input' +import Input from '@/app/components/base/input' import { Tag03 } from '@/app/components/base/icons/src/vender/line/financeAndECommerce' import Checkbox from '@/app/components/base/checkbox' import type { Label } from '@/app/components/tools/labels/constant' @@ -79,7 +79,7 @@ const LabelSelector: FC = ({ className='block' >
0 ? selectedLabels : ''} className={cn('grow text-[13px] leading-[18px] text-gray-700 truncate', !value.length && '!text-gray-400')}> @@ -94,7 +94,13 @@ const LabelSelector: FC = ({
- + handleKeywordsChange(e.target.value)} + onClear={() => handleKeywordsChange('')} + />
{filteredLabelList.map(label => ( diff --git a/web/app/components/tools/provider-list.tsx b/web/app/components/tools/provider-list.tsx index f429a6ec8daf07..6f17835589be0e 100644 --- a/web/app/components/tools/provider-list.tsx +++ b/web/app/components/tools/provider-list.tsx @@ -7,7 +7,7 @@ import cn from '@/utils/classnames' import { useTabSearchParams } from '@/hooks/use-tab-searchparams' import TabSliderNew from '@/app/components/base/tab-slider-new' import LabelFilter from '@/app/components/tools/labels/filter' -import SearchInput from '@/app/components/base/search-input' +import Input from '@/app/components/base/input' import { DotsGrid } from '@/app/components/base/icons/src/vender/line/general' import { Colors } from '@/app/components/base/icons/src/vender/line/others' import { Route } from '@/app/components/base/icons/src/vender/line/mapsAndTravel' @@ -84,7 +84,14 @@ const ProviderList = () => { />
- + handleKeywordsChange(e.target.value)} + onClear={() => handleKeywordsChange('')} + />
{ const linkUrl = useMemo(() => { if (language.startsWith('zh_')) - return 'https://docs.dify.ai/v/zh-hans/guides/gong-ju/quick-tool-integration' - return 'https://docs.dify.ai/tutorials/quick-tool-integration' + return 'https://docs.dify.ai/zh-hans/guides/tools#ru-he-chuang-jian-zi-ding-yi-gong-ju' + return 'https://docs.dify.ai/guides/tools#how-to-create-custom-tools' }, [language]) const [isShowEditCollectionToolModal, setIsShowEditCustomCollectionModal] = useState(false) diff --git a/web/app/components/tools/utils/index.ts b/web/app/components/tools/utils/index.ts index 0c462aa6fc98ff..ced9ca1879367f 100644 --- a/web/app/components/tools/utils/index.ts +++ b/web/app/components/tools/utils/index.ts @@ -1,4 +1,5 @@ import type { ThoughtItem } from '@/app/components/base/chat/chat/type' +import type { FileEntity } from '@/app/components/base/file-uploader/types' import type { VisionFile } from '@/types/app' export const sortAgentSorts = (list: ThoughtItem[]) => { @@ -11,14 +12,14 @@ export const sortAgentSorts = (list: ThoughtItem[]) => { return temp } -export const addFileInfos = (list: ThoughtItem[], messageFiles: VisionFile[]) => { +export const addFileInfos = (list: ThoughtItem[], messageFiles: (FileEntity | VisionFile)[]) => { if (!list || !messageFiles) return list return list.map((item) => { if (item.files && item.files?.length > 0) { return { ...item, - message_files: item.files.map(fileId => messageFiles.find(file => file.id === fileId)) as VisionFile[], + message_files: item.files.map(fileId => messageFiles.find(file => file.id === fileId)) as FileEntity[], } } return item diff --git a/web/app/components/tools/workflow-tool/configure-button.tsx b/web/app/components/tools/workflow-tool/configure-button.tsx index d2c5142f53c10e..6521410daea017 100644 --- a/web/app/components/tools/workflow-tool/configure-button.tsx +++ b/web/app/components/tools/workflow-tool/configure-button.tsx @@ -65,7 +65,7 @@ const WorkflowToolConfigureButton = ({ else { if (item.type === 'paragraph' && param.type !== 'string') return true - if (param.type !== item.type && !(param.type === 'string' && item.type === 'paragraph')) + if (item.type === 'text-input' && param.type !== 'string') return true } } diff --git a/web/app/components/tools/workflow-tool/index.tsx b/web/app/components/tools/workflow-tool/index.tsx index 0f9fe4c4c1a491..c4d7424538eff1 100644 --- a/web/app/components/tools/workflow-tool/index.tsx +++ b/web/app/components/tools/workflow-tool/index.tsx @@ -2,13 +2,12 @@ import type { FC } from 'react' import React, { useState } from 'react' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' import produce from 'immer' import type { Emoji, WorkflowToolProviderParameter, WorkflowToolProviderRequest } from '../types' import cn from '@/utils/classnames' import Drawer from '@/app/components/base/drawer-plus' +import Input from '@/app/components/base/input' +import Textarea from '@/app/components/base/textarea' import Button from '@/app/components/base/button' import Toast from '@/app/components/base/toast' import EmojiPicker from '@/app/components/base/emoji-picker' @@ -133,10 +132,9 @@ const WorkflowToolAsModal: FC = ({
{t('tools.createTool.name')} *
- { setShowEmojiPicker(true) }} className='cursor-pointer' icon={emoji.content} background={emoji.background} /> - { setShowEmojiPicker(true) }} className='cursor-pointer' iconType='emoji' icon={emoji.content} background={emoji.background} /> + setLabel(e.target.value)} @@ -148,19 +146,15 @@ const WorkflowToolAsModal: FC = ({
{t('tools.createTool.nameForToolCall')} * {t('tools.createTool.nameForToolCallPlaceHolder')}
} - selector='workflow-tool-modal-tooltip' - > - - + />
- setName(e.target.value)} @@ -172,8 +166,7 @@ const WorkflowToolAsModal: FC = ({ {/* description */}
{t('tools.createTool.description')}
-
{t('appAnnotation.table.header.question')}{t('appAnnotation.table.header.answer')}{t('appAnnotation.table.header.createdAt')}{t('appAnnotation.table.header.hits')}{t('appAnnotation.table.header.actions')}
{item.question} {item.answer}{formatTime(item.created_at, t('appLog.dateTimeFormat') as string)}{item.hit_count} e.stopPropagation()}> + {formatTime(item.created_at, t('appLog.dateTimeFormat') as string)}{item.hit_count} e.stopPropagation()}> {/* Actions */}
= ({ middlePagesSiblingCount={1} setCurrentPage={setCurrPage} totalPages={Math.ceil(total / APP_PAGE_LIMIT)} - truncableClassName="w-8 px-0.5 text-center" - truncableText="..." + truncatableClassName="w-8 px-0.5 text-center" + truncatableText="..." > & { + onPublish?: (modelAndParameter?: ModelAndParameter, features?: any) => Promise | any + publishedConfig?: any + resetAppConfig?: () => void +} + +const FeaturesWrappedAppPublisher = (props: Props) => { + const { t } = useTranslation() + const features = useFeatures(s => s.features) + const featuresStore = useFeaturesStore() + const [restoreConfirmOpen, setRestoreConfirmOpen] = useState(false) + const handleConfirm = useCallback(() => { + props.resetAppConfig?.() + const { + features, + setFeatures, + } = featuresStore!.getState() + const newFeatures = produce(features, (draft) => { + draft.moreLikeThis = props.publishedConfig.modelConfig.more_like_this || { enabled: false } + draft.opening = { + enabled: !!props.publishedConfig.modelConfig.opening_statement, + opening_statement: props.publishedConfig.modelConfig.opening_statement || '', + suggested_questions: props.publishedConfig.modelConfig.suggested_questions || [], + } + draft.moderation = props.publishedConfig.modelConfig.sensitive_word_avoidance || { enabled: false } + draft.speech2text = props.publishedConfig.modelConfig.speech_to_text || { enabled: false } + draft.text2speech = props.publishedConfig.modelConfig.text_to_speech || { enabled: false } + draft.suggested = props.publishedConfig.modelConfig.suggested_questions_after_answer || { enabled: false } + draft.citation = props.publishedConfig.modelConfig.retriever_resource || { enabled: false } + draft.annotationReply = props.publishedConfig.modelConfig.annotation_reply || { enabled: false } + draft.file = { + image: { + detail: props.publishedConfig.modelConfig.file_upload?.image?.detail || Resolution.high, + enabled: !!props.publishedConfig.modelConfig.file_upload?.image?.enabled, + number_limits: props.publishedConfig.modelConfig.file_upload?.image?.number_limits || 3, + transfer_methods: props.publishedConfig.modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], + }, + enabled: !!(props.publishedConfig.modelConfig.file_upload?.enabled || props.publishedConfig.modelConfig.file_upload?.image?.enabled), + allowed_file_types: props.publishedConfig.modelConfig.file_upload?.allowed_file_types || [SupportUploadFileTypes.image], + allowed_file_extensions: props.publishedConfig.modelConfig.file_upload?.allowed_file_extensions || FILE_EXTS[SupportUploadFileTypes.image].map(ext => `.${ext}`), + allowed_file_upload_methods: props.publishedConfig.modelConfig.file_upload?.allowed_file_upload_methods || props.publishedConfig.modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], + number_limits: props.publishedConfig.modelConfig.file_upload?.number_limits || props.publishedConfig.modelConfig.file_upload?.image?.number_limits || 3, + } as FileUpload + }) + setFeatures(newFeatures) + setRestoreConfirmOpen(false) + }, [featuresStore, props]) + + const handlePublish = useCallback((modelAndParameter?: ModelAndParameter) => { + return props.onPublish?.(modelAndParameter, features) + }, [features, props]) + + return ( + <> + setRestoreConfirmOpen(true), + }}/> + {restoreConfirmOpen && ( + setRestoreConfirmOpen(false)} + /> + )} + + ) +} + +export default FeaturesWrappedAppPublisher diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index e971274a71b3bc..0558e299560396 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -24,6 +24,7 @@ import { LeftIndent02 } from '@/app/components/base/icons/src/vender/line/editor import { FileText } from '@/app/components/base/icons/src/vender/line/files' import WorkflowToolConfigureButton from '@/app/components/tools/workflow-tool/configure-button' import type { InputVar } from '@/app/components/workflow/types' +import { appDefaultIconBackground } from '@/config' export type AppPublisherProps = { disabled?: boolean @@ -212,8 +213,8 @@ const AppPublisher = ({ detailNeedUpdate={!!toolPublished && published} workflowAppId={appDetail?.id} icon={{ - content: appDetail?.icon, - background: appDetail?.icon_background, + content: (appDetail.icon_type === 'image' ? '🤖' : appDetail?.icon) || '🤖', + background: (appDetail.icon_type === 'image' ? appDefaultIconBackground : appDetail?.icon_background) || appDefaultIconBackground, }} name={appDetail?.name} description={appDetail?.description} diff --git a/web/app/components/app/configuration/base/feature-panel/index.tsx b/web/app/components/app/configuration/base/feature-panel/index.tsx index 1f6db9dee6d64b..9c4adbdd2de2eb 100644 --- a/web/app/components/app/configuration/base/feature-panel/index.tsx +++ b/web/app/components/app/configuration/base/feature-panel/index.tsx @@ -2,7 +2,6 @@ import type { FC, ReactNode } from 'react' import React from 'react' import cn from '@/utils/classnames' -import ParamsConfig from '@/app/components/app/configuration/config-voice/param-config' export type IFeaturePanelProps = { className?: string @@ -10,10 +9,8 @@ export type IFeaturePanelProps = { title: ReactNode headerRight?: ReactNode hasHeaderBottomBorder?: boolean - isFocus?: boolean noBodySpacing?: boolean children?: ReactNode - isShowTextToSpeech?: boolean } const FeaturePanel: FC = ({ @@ -22,32 +19,20 @@ const FeaturePanel: FC = ({ title, headerRight, hasHeaderBottomBorder, - isFocus, noBodySpacing, children, - isShowTextToSpeech, }) => { return ( -
+
{/* Header */} -
+
{headerIcon &&
{headerIcon}
} -
{title}
+
{title}
{headerRight &&
{headerRight}
} - {isShowTextToSpeech &&
- -
}
diff --git a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx index fb3ceadc0d8bcf..afa2bf8e277d0e 100644 --- a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx @@ -9,7 +9,6 @@ import produce from 'immer' import { RiDeleteBinLine, RiErrorWarningFill, - RiQuestionLine, } from '@remixicon/react' import s from './style.module.css' import MessageTypeSelector from './message-type-selector' @@ -174,12 +173,12 @@ const AdvancedPromptInput: FC = ({
{t('appDebug.pageTitle.line1')}
- {t('appDebug.promptTip')} -
} - selector='config-prompt-tooltip'> - - + popupContent={ +
+ {t('appDebug.promptTip')} +
+ } + />
)}
{canDelete && ( @@ -263,7 +262,7 @@ const AdvancedPromptInput: FC = ({ {isShowConfirmAddVar && ( v.name)} - onConfrim={handleAutoAdd(true)} + onConfirm={handleAutoAdd(true)} onCancel={handleAutoAdd(false)} onHide={hideConfirmAddVar} /> diff --git a/web/app/components/app/configuration/config-prompt/confirm-add-var/index.tsx b/web/app/components/app/configuration/config-prompt/confirm-add-var/index.tsx index f08f2ffc6968e7..922f8bb36afad2 100644 --- a/web/app/components/app/configuration/config-prompt/confirm-add-var/index.tsx +++ b/web/app/components/app/configuration/config-prompt/confirm-add-var/index.tsx @@ -7,7 +7,7 @@ import Button from '@/app/components/base/button' export type IConfirmAddVarProps = { varNameArr: string[] - onConfrim: () => void + onConfirm: () => void onCancel: () => void onHide: () => void } @@ -22,7 +22,7 @@ const VarIcon = ( const ConfirmAddVar: FC = ({ varNameArr, - onConfrim, + onConfirm, onCancel, // onHide, }) => { @@ -63,7 +63,7 @@ const ConfirmAddVar: FC = ({
- +
diff --git a/web/app/components/app/configuration/config-prompt/conversation-histroy/edit-modal.tsx b/web/app/components/app/configuration/config-prompt/conversation-history/edit-modal.tsx similarity index 100% rename from web/app/components/app/configuration/config-prompt/conversation-histroy/edit-modal.tsx rename to web/app/components/app/configuration/config-prompt/conversation-history/edit-modal.tsx diff --git a/web/app/components/app/configuration/config-prompt/conversation-histroy/history-panel.tsx b/web/app/components/app/configuration/config-prompt/conversation-history/history-panel.tsx similarity index 98% rename from web/app/components/app/configuration/config-prompt/conversation-histroy/history-panel.tsx rename to web/app/components/app/configuration/config-prompt/conversation-history/history-panel.tsx index f40bd4b733a2d5..199f9598a4484b 100644 --- a/web/app/components/app/configuration/config-prompt/conversation-histroy/history-panel.tsx +++ b/web/app/components/app/configuration/config-prompt/conversation-history/history-panel.tsx @@ -23,7 +23,7 @@ const HistoryPanel: FC = ({ return (
{t('appDebug.feature.conversationHistory.title')}
diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index adcfcdd1261f9f..d7bfe8534e6b05 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -3,9 +3,6 @@ import type { FC } from 'react' import React, { useState } from 'react' import { useTranslation } from 'react-i18next' import { useBoolean } from 'ahooks' -import { - RiQuestionLine, -} from '@remixicon/react' import produce from 'immer' import { useContext } from 'use-context-selector' import ConfirmAddVar from './confirm-add-var' @@ -36,7 +33,7 @@ export type ISimplePromptInput = { promptTemplate: string promptVariables: PromptVariable[] readonly?: boolean - onChange?: (promp: string, promptVariables: PromptVariable[]) => void + onChange?: (prompt: string, promptVariables: PromptVariable[]) => void noTitle?: boolean gradientBorder?: boolean editorHeight?: number @@ -156,12 +153,12 @@ const Prompt: FC = ({
{mode !== AppType.completion ? t('appDebug.chatSubTitle') : t('appDebug.completionSubTitle')}
{!readonly && ( - {t('appDebug.promptTip')} -
} - selector='config-prompt-tooltip'> - - + popupContent={ +
+ {t('appDebug.promptTip')} +
+ } + /> )}
@@ -242,7 +239,7 @@ const Prompt: FC = ({ {isShowConfirmAddVar && ( v.name)} - onConfrim={handleAutoAdd(true)} + onConfirm={handleAutoAdd(true)} onCancel={handleAutoAdd(false)} onHide={hideConfirmAddVar} /> diff --git a/web/app/components/app/configuration/config-var/config-modal/field.tsx b/web/app/components/app/configuration/config-var/config-modal/field.tsx index 82531784863e6c..5052f988d75df2 100644 --- a/web/app/components/app/configuration/config-var/config-modal/field.tsx +++ b/web/app/components/app/configuration/config-var/config-modal/field.tsx @@ -1,19 +1,22 @@ 'use client' import type { FC } from 'react' import React from 'react' +import cn from '@/utils/classnames' type Props = { + className?: string title: string children: JSX.Element } const Field: FC = ({ + className, title, children, }) => { return ( -
-
{title}
+
+
{title}
{children}
) diff --git a/web/app/components/app/configuration/config-var/config-modal/index.tsx b/web/app/components/app/configuration/config-var/config-modal/index.tsx index 20fcf49de18df5..85e241a203b99c 100644 --- a/web/app/components/app/configuration/config-var/config-modal/index.tsx +++ b/web/app/components/app/configuration/config-var/config-modal/index.tsx @@ -1,20 +1,25 @@ 'use client' import type { FC } from 'react' -import React, { useCallback, useState } from 'react' +import React, { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' +import produce from 'immer' import ModalFoot from '../modal-foot' import ConfigSelect from '../config-select' import ConfigString from '../config-string' import SelectTypeItem from '../select-type-item' import Field from './field' +import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' import { checkKeys, getNewVarInWorkflow } from '@/utils/var' import ConfigContext from '@/context/debug-configuration' -import type { InputVar, MoreInfo } from '@/app/components/workflow/types' +import type { InputVar, MoreInfo, UploadFileSetting } from '@/app/components/workflow/types' import Modal from '@/app/components/base/modal' -import Switch from '@/app/components/base/switch' -import { ChangeType, InputVarType } from '@/app/components/workflow/types' +import { ChangeType, InputVarType, SupportUploadFileTypes } from '@/app/components/workflow/types' +import FileUploadSetting from '@/app/components/workflow/nodes/_base/components/file-upload-setting' +import Checkbox from '@/app/components/base/checkbox' +import { DEFAULT_FILE_UPLOAD_SETTING } from '@/app/components/workflow/constants' +import { DEFAULT_VALUE_MAX_LEN } from '@/config' const TEXT_MAX_LENGTH = 256 @@ -25,35 +30,42 @@ export type IConfigModalProps = { varKeys?: string[] onClose: () => void onConfirm: (newValue: InputVar, moreInfo?: MoreInfo) => void + supportFile?: boolean } -const inputClassName = 'w-full px-3 text-sm leading-9 text-gray-900 border-0 rounded-lg grow h-9 bg-gray-100 focus:outline-none focus:ring-1 focus:ring-inset focus:ring-gray-200' - const ConfigModal: FC = ({ isCreate, payload, isShow, onClose, onConfirm, + supportFile, }) => { const { modelConfig } = useContext(ConfigContext) const { t } = useTranslation() const [tempPayload, setTempPayload] = useState(payload || getNewVarInWorkflow('') as any) const { type, label, variable, options, max_length } = tempPayload + const modalRef = useRef(null) + useEffect(() => { + // To fix the first input element auto focus, then directly close modal will raise error + if (isShow) + modalRef.current?.focus() + }, [isShow]) const isStringInput = type === InputVarType.textInput || type === InputVarType.paragraph + const checkVariableName = useCallback((value: string, canBeEmpty?: boolean) => { + const { isValid, errorMessageKey } = checkKeys([value], canBeEmpty) + if (!isValid) { + Toast.notify({ + type: 'error', + message: t(`appDebug.varKeyError.${errorMessageKey}`, { key: t('appDebug.variableConfig.varName') }), + }) + return false + } + return true + }, [t]) const handlePayloadChange = useCallback((key: string) => { return (value: any) => { - if (key === 'variable') { - const { isValid, errorKey, errorMessageKey } = checkKeys([value], true) - if (!isValid) { - Toast.notify({ - type: 'error', - message: t(`appDebug.varKeyError.${errorMessageKey}`, { key: errorKey }), - }) - return - } - } setTempPayload((prev) => { const newPayload = { ...prev, @@ -63,19 +75,39 @@ const ConfigModal: FC = ({ return newPayload }) } - }, [t]) + }, []) + + const handleTypeChange = useCallback((type: InputVarType) => { + return () => { + const newPayload = produce(tempPayload, (draft) => { + draft.type = type + if ([InputVarType.singleFile, InputVarType.multiFiles].includes(type)) { + (Object.keys(DEFAULT_FILE_UPLOAD_SETTING)).forEach((key) => { + if (key !== 'max_length') + (draft as any)[key] = (DEFAULT_FILE_UPLOAD_SETTING as any)[key] + }) + if (type === InputVarType.multiFiles) + draft.max_length = DEFAULT_FILE_UPLOAD_SETTING.max_length + } + if (type === InputVarType.paragraph) + draft.max_length = DEFAULT_VALUE_MAX_LEN + }) + setTempPayload(newPayload) + } + }, [tempPayload]) const handleVarKeyBlur = useCallback((e: any) => { - if (tempPayload.label) + const varName = e.target.value + if (!checkVariableName(varName, true) || tempPayload.label) return setTempPayload((prev) => { return { ...prev, - label: e.target.value, + label: varName, } }) - }, [tempPayload]) + }, [checkVariableName, tempPayload.label]) const handleConfirm = () => { const moreInfo = tempPayload.variable === payload?.variable @@ -84,10 +116,11 @@ const ConfigModal: FC = ({ type: ChangeType.changeVarName, payload: { beforeKey: payload?.variable || '', afterKey: tempPayload.variable }, } - if (!tempPayload.variable) { - Toast.notify({ type: 'error', message: t('appDebug.variableConig.errorMsg.varNameRequired') }) + + const isVariableNameValid = checkVariableName(tempPayload.variable) + if (!isVariableNameValid) return - } + // TODO: check if key already exists. should the consider the edit case // if (varKeys.map(key => key?.trim()).includes(tempPayload.variable.trim())) { // Toast.notify({ @@ -98,15 +131,15 @@ const ConfigModal: FC = ({ // } if (!tempPayload.label) { - Toast.notify({ type: 'error', message: t('appDebug.variableConig.errorMsg.labelNameRequired') }) + Toast.notify({ type: 'error', message: t('appDebug.variableConfig.errorMsg.labelNameRequired') }) return } if (isStringInput || type === InputVarType.number) { onConfirm(tempPayload, moreInfo) } - else { + else if (type === InputVarType.select) { if (options?.length === 0) { - Toast.notify({ type: 'error', message: t('appDebug.variableConig.errorMsg.atLeastOneOption') }) + Toast.notify({ type: 'error', message: t('appDebug.variableConfig.errorMsg.atLeastOneOption') }) return } const obj: Record = {} @@ -119,66 +152,91 @@ const ConfigModal: FC = ({ obj[o] = true }) if (hasRepeatedItem) { - Toast.notify({ type: 'error', message: t('appDebug.variableConig.errorMsg.optionRepeat') }) + Toast.notify({ type: 'error', message: t('appDebug.variableConfig.errorMsg.optionRepeat') }) + return + } + onConfirm(tempPayload, moreInfo) + } + else if ([InputVarType.singleFile, InputVarType.multiFiles].includes(type)) { + if (tempPayload.allowed_file_types?.length === 0) { + const errorMessages = t('workflow.errorMsg.fieldRequired', { field: t('appDebug.variableConfig.file.supportFileTypes') }) + Toast.notify({ type: 'error', message: errorMessages }) + return + } + if (tempPayload.allowed_file_types?.includes(SupportUploadFileTypes.custom) && !tempPayload.allowed_file_extensions?.length) { + const errorMessages = t('workflow.errorMsg.fieldRequired', { field: t('appDebug.variableConfig.file.custom.name') }) + Toast.notify({ type: 'error', message: errorMessages }) return } onConfirm(tempPayload, moreInfo) } + else { + onConfirm(tempPayload, moreInfo) + } } return ( -
+
- -
- handlePayloadChange('type')(InputVarType.textInput)} /> - handlePayloadChange('type')(InputVarType.paragraph)} /> - handlePayloadChange('type')(InputVarType.select)} /> - handlePayloadChange('type')(InputVarType.number)} /> + +
+ + + + + {supportFile && <> + + + }
- - + handlePayloadChange('variable')(e.target.value)} onBlur={handleVarKeyBlur} - placeholder={t('appDebug.variableConig.inputPlaceholder')!} + placeholder={t('appDebug.variableConfig.inputPlaceholder')!} /> - - + handlePayloadChange('label')(e.target.value)} - placeholder={t('appDebug.variableConig.inputPlaceholder')!} + placeholder={t('appDebug.variableConfig.inputPlaceholder')!} /> {isStringInput && ( - + )} {type === InputVarType.select && ( - + )} - - - + {[InputVarType.singleFile, InputVarType.multiFiles].includes(type) && ( + setTempPayload(p as InputVar)} + isMultiple={type === InputVarType.multiFiles} + /> + )} + +
+ handlePayloadChange('required')(!tempPayload.required)} /> + {t('appDebug.variableConfig.required')} +
= ({ onClick={() => { onChange([...options, '']) }} className='flex items-center h-9 px-3 gap-2 rounded-lg cursor-pointer text-gray-400 bg-gray-100'> -
{t('appDebug.variableConig.addOption')}
+
{t('appDebug.variableConfig.addOption')}
) diff --git a/web/app/components/app/configuration/config-var/config-string/index.tsx b/web/app/components/app/configuration/config-var/config-string/index.tsx index 2c941cfa4760ab..719ad8ee1313ec 100644 --- a/web/app/components/app/configuration/config-var/config-string/index.tsx +++ b/web/app/components/app/configuration/config-var/config-string/index.tsx @@ -1,6 +1,7 @@ 'use client' import type { FC } from 'react' import React, { useEffect } from 'react' +import Input from '@/app/components/base/input' export type IConfigStringProps = { value: number | undefined @@ -21,7 +22,7 @@ const ConfigString: FC = ({ return (
- = ({ onChange(value) }} - className="w-full px-3 text-sm leading-9 text-gray-900 border-0 rounded-lg grow h-9 bg-gray-100 focus:outline-none focus:ring-1 focus:ring-inset focus:ring-gray-200" />
) diff --git a/web/app/components/app/configuration/config-var/index.tsx b/web/app/components/app/configuration/config-var/index.tsx index 82a220c6db9bbd..67bc37385e187f 100644 --- a/web/app/components/app/configuration/config-var/index.tsx +++ b/web/app/components/app/configuration/config-var/index.tsx @@ -8,7 +8,6 @@ import { useContext } from 'use-context-selector' import produce from 'immer' import { RiDeleteBinLine, - RiQuestionLine, } from '@remixicon/react' import Panel from '../base/feature-panel' import EditModal from './config-modal' @@ -89,7 +88,6 @@ const ConfigVar: FC = ({ promptVariables, readonly, onPromptVar } as InputVar })() const updatePromptVariableItem = (payload: InputVar) => { - console.log(payload) const newPromptVariables = produce(promptVariables, (draft) => { const { variable, label, type, ...rest } = payload draft[currIndex] = { @@ -274,7 +272,7 @@ const ConfigVar: FC = ({ promptVariables, readonly, onPromptVar } return ( } @@ -282,11 +280,13 @@ const ConfigVar: FC = ({ promptVariables, readonly, onPromptVar
{t('appDebug.variableTitle')}
{!readonly && ( - - {t('appDebug.variableTip')} -
} selector='config-var-tooltip'> - - + + {t('appDebug.variableTip')} +
+ } + /> )}
} diff --git a/web/app/components/app/configuration/config-var/select-type-item/index.tsx b/web/app/components/app/configuration/config-var/select-type-item/index.tsx index bb5e700d119fc1..b71486b4eb99f6 100644 --- a/web/app/components/app/configuration/config-var/select-type-item/index.tsx +++ b/web/app/components/app/configuration/config-var/select-type-item/index.tsx @@ -2,7 +2,6 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import s from './style.module.css' import cn from '@/utils/classnames' import type { InputVarType } from '@/app/components/workflow/types' import InputVarTypeIcon from '@/app/components/workflow/nodes/_base/components/input-var-type-icon' @@ -12,23 +11,30 @@ export type ISelectTypeItemProps = { onClick: () => void } +const i18nFileTypeMap: Record = { + 'file': 'single-file', + 'file-list': 'multi-files', +} + const SelectTypeItem: FC = ({ type, selected, onClick, }) => { const { t } = useTranslation() - const typeName = t(`appDebug.variableConig.${type}`) + const typeName = t(`appDebug.variableConfig.${i18nFileTypeMap[type] || type}`) return (
- {typeName} + {typeName}
) } diff --git a/web/app/components/app/configuration/config-var/select-type-item/style.module.css b/web/app/components/app/configuration/config-var/select-type-item/style.module.css deleted file mode 100644 index 8ff716d58b8be6..00000000000000 --- a/web/app/components/app/configuration/config-var/select-type-item/style.module.css +++ /dev/null @@ -1,40 +0,0 @@ -.item { - display: flex; - flex-direction: column; - justify-content: center; - align-items: center; - height: 58px; - width: 98px; - border-radius: 8px; - border: 1px solid #EAECF0; - box-shadow: 0px 1px 2px rgba(16, 24, 40, 0.05); - background-color: #fff; - cursor: pointer; -} - -.item:not(.selected):hover { - border-color: #B2CCFF; - background-color: #F5F8FF; - box-shadow: 0px 4px 8px -2px rgba(16, 24, 40, 0.1), 0px 2px 4px -2px rgba(16, 24, 40, 0.06); -} - -.item.selected { - color: #155EEF; - border-color: #528BFF; - background-color: #F5F8FF; - box-shadow: 0px 1px 3px rgba(16, 24, 40, 0.1), 0px 1px 2px rgba(16, 24, 40, 0.06); -} - -.text { - font-size: 13px; - color: #667085; - font-weight: 500; -} - -.item.selected .text { - color: #155EEF; -} - -.item:not(.selected):hover { - color: #344054; -} \ No newline at end of file diff --git a/web/app/components/app/configuration/config-var/select-var-type.tsx b/web/app/components/app/configuration/config-var/select-var-type.tsx index 137f62b2bbb9ca..14d4f926edb74e 100644 --- a/web/app/components/app/configuration/config-var/select-var-type.tsx +++ b/web/app/components/app/configuration/config-var/select-var-type.tsx @@ -62,14 +62,14 @@ const SelectVarType: FC = ({
- - - - + + + +
- +
diff --git a/web/app/components/app/configuration/config-vision/index.tsx b/web/app/components/app/configuration/config-vision/index.tsx index 9b12e059b57cfb..23f00d46d837d6 100644 --- a/web/app/components/app/configuration/config-vision/index.tsx +++ b/web/app/components/app/configuration/config-vision/index.tsx @@ -1,62 +1,103 @@ 'use client' import type { FC } from 'react' -import React from 'react' +import React, { useCallback } from 'react' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' +import produce from 'immer' import { useContext } from 'use-context-selector' -import Panel from '../base/feature-panel' import ParamConfig from './param-config' +import { Vision } from '@/app/components/base/icons/src/vender/features' import Tooltip from '@/app/components/base/tooltip' -import Switch from '@/app/components/base/switch' -import { Eye } from '@/app/components/base/icons/src/vender/solid/general' +// import OptionCard from '@/app/components/workflow/nodes/_base/components/option-card' import ConfigContext from '@/context/debug-configuration' +// import { Resolution } from '@/types/app' +import { useFeatures, useFeaturesStore } from '@/app/components/base/features/hooks' +import Switch from '@/app/components/base/switch' +import type { FileUpload } from '@/app/components/base/features/types' const ConfigVision: FC = () => { const { t } = useTranslation() - const { - isShowVisionConfig, - visionConfig, - setVisionConfig, - } = useContext(ConfigContext) + const { isShowVisionConfig } = useContext(ConfigContext) + const file = useFeatures(s => s.features.file) + const featuresStore = useFeaturesStore() + + const handleChange = useCallback((data: FileUpload) => { + const { + features, + setFeatures, + } = featuresStore!.getState() + + const newFeatures = produce(features, (draft) => { + draft.file = { + ...draft.file, + enabled: data.enabled, + image: { + enabled: data.enabled, + detail: data.image?.detail, + transfer_methods: data.image?.transfer_methods, + number_limits: data.image?.number_limits, + }, + } + }) + setFeatures(newFeatures) + }, [featuresStore]) if (!isShowVisionConfig) return null - return (<> - - } - title={ -
-
{t('appDebug.vision.name')}
- - {t('appDebug.vision.description')} -
} selector='config-vision-tooltip'> - - + return ( +
+
+
+
- } - headerRight={ -
- -
- setVisionConfig({ - ...visionConfig, - enabled: value, - })} - size='md' +
+
+
{t('appDebug.vision.name')}
+ + {t('appDebug.vision.description')} +
+ } + /> +
+
+ {/*
+
{t('appDebug.vision.visionSettings.resolution')}
+ + {t('appDebug.vision.visionSettings.resolutionTooltip').split('\n').map(item => ( +
{item}
+ ))} +
+ } /> -
- } - noBodySpacing - /> - +
*/} + {/*
+ handleChange(Resolution.high)} + /> + handleChange(Resolution.low)} + /> +
*/} + +
+ handleChange({ + ...(file || {}), + enabled: value, + })} + size='md' + /> + + ) } export default React.memo(ConfigVision) diff --git a/web/app/components/app/configuration/config-vision/param-config-content.tsx b/web/app/components/app/configuration/config-vision/param-config-content.tsx index 89fad411e70f47..fe6d1cd7679a2e 100644 --- a/web/app/components/app/configuration/config-vision/param-config-content.tsx +++ b/web/app/components/app/configuration/config-vision/param-config-content.tsx @@ -1,131 +1,139 @@ 'use client' import type { FC } from 'react' -import React from 'react' -import { useContext } from 'use-context-selector' +import React, { useCallback } from 'react' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' -import RadioGroup from './radio-group' -import ConfigContext from '@/context/debug-configuration' +import produce from 'immer' +import OptionCard from '@/app/components/workflow/nodes/_base/components/option-card' import { Resolution, TransferMethod } from '@/types/app' import ParamItem from '@/app/components/base/param-item' import Tooltip from '@/app/components/base/tooltip' +import { useFeatures, useFeaturesStore } from '@/app/components/base/features/hooks' +import type { FileUpload } from '@/app/components/base/features/types' const MIN = 1 const MAX = 6 const ParamConfigContent: FC = () => { const { t } = useTranslation() + const file = useFeatures(s => s.features.file) + const featuresStore = useFeaturesStore() - const { - visionConfig, - setVisionConfig, - } = useContext(ConfigContext) + const handleChange = useCallback((data: FileUpload) => { + const { + features, + setFeatures, + } = featuresStore!.getState() - const transferMethod = (() => { - if (!visionConfig.transfer_methods || visionConfig.transfer_methods.length === 2) - return TransferMethod.all - - return visionConfig.transfer_methods[0] - })() + const newFeatures = produce(features, (draft) => { + draft.file = { + ...draft.file, + allowed_file_upload_methods: data.allowed_file_upload_methods, + number_limits: data.number_limits, + image: { + enabled: data.enabled, + detail: data.image?.detail, + transfer_methods: data.allowed_file_upload_methods, + number_limits: data.number_limits, + }, + } + }) + setFeatures(newFeatures) + }, [featuresStore]) return (
-
-
{t('appDebug.vision.visionSettings.title')}
-
-
-
-
{t('appDebug.vision.visionSettings.resolution')}
- - {t('appDebug.vision.visionSettings.resolutionTooltip').split('\n').map(item => ( -
{item}
- ))} -
} selector='config-resolution-tooltip'> - - -
- { - setVisionConfig({ - ...visionConfig, - detail: value, - }) - }} +
{t('appDebug.vision.visionSettings.title')}
+
+
+
+
{t('appDebug.vision.visionSettings.resolution')}
+ + {t('appDebug.vision.visionSettings.resolutionTooltip').split('\n').map(item => ( +
{item}
+ ))} +
+ } />
-
-
{t('appDebug.vision.visionSettings.uploadMethod')}
- { - if (value === TransferMethod.all) { - setVisionConfig({ - ...visionConfig, - transfer_methods: [TransferMethod.remote_url, TransferMethod.local_file], - }) - return - } - setVisionConfig({ - ...visionConfig, - transfer_methods: [value], - }) - }} +
+ handleChange({ + ...file, + image: { detail: Resolution.high }, + })} + /> + handleChange({ + ...file, + image: { detail: Resolution.low }, + })} />
-
- { - if (!value) - return - - setVisionConfig({ - ...visionConfig, - number_limits: value, - }) - }} +
+
+
{t('appDebug.vision.visionSettings.uploadMethod')}
+
+ handleChange({ + ...file, + allowed_file_upload_methods: [TransferMethod.local_file, TransferMethod.remote_url], + })} + /> + handleChange({ + ...file, + allowed_file_upload_methods: [TransferMethod.local_file], + })} + /> + handleChange({ + ...file, + allowed_file_upload_methods: [TransferMethod.remote_url], + })} />
+
+ { + if (!value) + return + + handleChange({ + ...file, + number_limits: value, + }) + }} + /> +
) diff --git a/web/app/components/app/configuration/config-vision/param-config.tsx b/web/app/components/app/configuration/config-vision/param-config.tsx index f1e2475495c8ed..8c638793915d20 100644 --- a/web/app/components/app/configuration/config-vision/param-config.tsx +++ b/web/app/components/app/configuration/config-vision/param-config.tsx @@ -2,7 +2,7 @@ import type { FC } from 'react' import { memo, useState } from 'react' import { useTranslation } from 'react-i18next' -import VoiceParamConfig from './param-config-content' +import ParamConfigContent from './param-config-content' import cn from '@/utils/classnames' import { Settings01 } from '@/app/components/base/icons/src/vender/line/general' import { @@ -25,14 +25,14 @@ const ParamsConfig: FC = () => { }} > setOpen(v => !v)}> -
+
{t('appDebug.voice.settings')}
- +
diff --git a/web/app/components/app/configuration/config-vision/radio-group/index.tsx b/web/app/components/app/configuration/config-vision/radio-group/index.tsx deleted file mode 100644 index a1cfb06e6afdee..00000000000000 --- a/web/app/components/app/configuration/config-vision/radio-group/index.tsx +++ /dev/null @@ -1,40 +0,0 @@ -'use client' -import type { FC } from 'react' -import React from 'react' -import s from './style.module.css' -import cn from '@/utils/classnames' - -type OPTION = { - label: string - value: any -} - -type Props = { - className?: string - options: OPTION[] - value: any - onChange: (value: any) => void -} - -const RadioGroup: FC = ({ - className = '', - options, - value, - onChange, -}) => { - return ( -
- {options.map(item => ( -
onChange(item.value)} - > -
-
{item.label}
-
- ))} -
- ) -} -export default React.memo(RadioGroup) diff --git a/web/app/components/app/configuration/config-vision/radio-group/style.module.css b/web/app/components/app/configuration/config-vision/radio-group/style.module.css deleted file mode 100644 index 22c29c6a423ee7..00000000000000 --- a/web/app/components/app/configuration/config-vision/radio-group/style.module.css +++ /dev/null @@ -1,24 +0,0 @@ -.item { - @apply grow flex items-center h-8 px-2.5 rounded-lg bg-gray-25 border border-gray-100 cursor-pointer space-x-2; -} - -.item:hover { - background-color: #ffffff; - border-color: #B2CCFF; - box-shadow: 0px 12px 16px -4px rgba(16, 24, 40, 0.08), 0px 4px 6px -2px rgba(16, 24, 40, 0.03); -} - -.item.checked { - background-color: #ffffff; - border-color: #528BFF; - box-shadow: 0px 1px 2px 0px rgba(16, 24, 40, 0.06), 0px 1px 3px 0px rgba(16, 24, 40, 0.10); -} - -.radio { - @apply w-4 h-4 border-[2px] border-gray-200 rounded-full; -} - -.item.checked .radio { - border-width: 5px; - border-color: #155eef; -} \ No newline at end of file diff --git a/web/app/components/app/configuration/config-voice/param-config-content.tsx b/web/app/components/app/configuration/config-voice/param-config-content.tsx deleted file mode 100644 index 9b0d5bbb69e230..00000000000000 --- a/web/app/components/app/configuration/config-voice/param-config-content.tsx +++ /dev/null @@ -1,221 +0,0 @@ -'use client' -import useSWR from 'swr' -import type { FC } from 'react' -import { useContext } from 'use-context-selector' -import React, { Fragment } from 'react' -import { - RiQuestionLine, -} from '@remixicon/react' -import { usePathname } from 'next/navigation' -import { useTranslation } from 'react-i18next' -import { Listbox, Transition } from '@headlessui/react' -import { CheckIcon, ChevronDownIcon } from '@heroicons/react/20/solid' -import classNames from '@/utils/classnames' -import RadioGroup from '@/app/components/app/configuration/config-vision/radio-group' -import type { Item } from '@/app/components/base/select' -import ConfigContext from '@/context/debug-configuration' -import { fetchAppVoices } from '@/service/apps' -import Tooltip from '@/app/components/base/tooltip' -import { languages } from '@/i18n/language' -import { TtsAutoPlay } from '@/types/app' -const VoiceParamConfig: FC = () => { - const { t } = useTranslation() - const pathname = usePathname() - const matched = pathname.match(/\/app\/([^/]+)/) - const appId = (matched?.length && matched[1]) ? matched[1] : '' - - const { - textToSpeechConfig, - setTextToSpeechConfig, - } = useContext(ConfigContext) - - let languageItem = languages.find(item => item.value === textToSpeechConfig.language) - const localLanguagePlaceholder = languageItem?.name || t('common.placeholder.select') - if (languages && !languageItem && languages.length > 0) - languageItem = languages[0] - const language = languageItem?.value - const voiceItems = useSWR({ appId, language }, fetchAppVoices).data - let voiceItem = voiceItems?.find(item => item.value === textToSpeechConfig.voice) - if (voiceItems && !voiceItem && voiceItems.length > 0) - voiceItem = voiceItems[0] - - const localVoicePlaceholder = voiceItem?.name || t('common.placeholder.select') - - return ( -
-
-
{t('appDebug.voice.voiceSettings.title')}
-
-
-
-
{t('appDebug.voice.voiceSettings.language')}
- - {t('appDebug.voice.voiceSettings.resolutionTooltip').split('\n').map(item => ( -
{item}
- ))} -
} selector='config-resolution-tooltip'> - - -
- { - setTextToSpeechConfig({ - ...textToSpeechConfig, - language: String(value.value), - }) - }} - > -
- - - {languageItem?.name ? t(`common.voice.language.${languageItem?.value.replace('-', '')}`) : localLanguagePlaceholder} - - - - - - - - {languages.map((item: Item) => ( - - `relative cursor-pointer select-none py-2 pl-3 pr-9 rounded-lg hover:bg-gray-100 text-gray-700 ${active ? 'bg-gray-100' : '' - }` - } - value={item} - disabled={false} - > - {({ /* active, */ selected }) => ( - <> - {t(`common.voice.language.${(item.value).toString().replace('-', '')}`)} - {(selected || item.value === textToSpeechConfig.language) && ( - - - )} - - )} - - ))} - - -
-
-
-
-
{t('appDebug.voice.voiceSettings.voice')}
- { - if (!value.value) - return - setTextToSpeechConfig({ - ...textToSpeechConfig, - voice: String(value.value), - }) - }} - > -
- - {voiceItem?.name ?? localVoicePlaceholder} - - - - - - - {voiceItems?.map((item: Item) => ( - - `relative cursor-pointer select-none py-2 pl-3 pr-9 rounded-lg hover:bg-gray-100 text-gray-700 ${active ? 'bg-gray-100' : '' - }` - } - value={item} - disabled={false} - > - {({ /* active, */ selected }) => ( - <> - {item.name} - {(selected || item.value === textToSpeechConfig.voice) && ( - - - )} - - )} - - ))} - - -
-
-
-
-
{t('appDebug.voice.voiceSettings.autoPlay')}
- { - setTextToSpeechConfig({ - ...textToSpeechConfig, - autoPlay: value, - }) - }} - /> -
-
-
-
- ) -} - -export default React.memo(VoiceParamConfig) diff --git a/web/app/components/app/configuration/config-voice/param-config.tsx b/web/app/components/app/configuration/config-voice/param-config.tsx deleted file mode 100644 index f1e2475495c8ed..00000000000000 --- a/web/app/components/app/configuration/config-voice/param-config.tsx +++ /dev/null @@ -1,41 +0,0 @@ -'use client' -import type { FC } from 'react' -import { memo, useState } from 'react' -import { useTranslation } from 'react-i18next' -import VoiceParamConfig from './param-config-content' -import cn from '@/utils/classnames' -import { Settings01 } from '@/app/components/base/icons/src/vender/line/general' -import { - PortalToFollowElem, - PortalToFollowElemContent, - PortalToFollowElemTrigger, -} from '@/app/components/base/portal-to-follow-elem' - -const ParamsConfig: FC = () => { - const { t } = useTranslation() - const [open, setOpen] = useState(false) - - return ( - - setOpen(v => !v)}> -
- -
{t('appDebug.voice.settings')}
-
-
- -
- -
-
-
- ) -} -export default memo(ParamsConfig) diff --git a/web/app/components/app/configuration/config/agent/agent-setting/index.tsx b/web/app/components/app/configuration/config/agent/agent-setting/index.tsx index b295a4e709bac9..959336457fc2ae 100644 --- a/web/app/components/app/configuration/config/agent/agent-setting/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-setting/index.tsx @@ -5,7 +5,7 @@ import { useTranslation } from 'react-i18next' import { RiCloseLine } from '@remixicon/react' import ItemPanel from './item-panel' import Button from '@/app/components/base/button' -import { CuteRobote } from '@/app/components/base/icons/src/vender/solid/communication' +import { CuteRobot } from '@/app/components/base/icons/src/vender/solid/communication' import { Unblur } from '@/app/components/base/icons/src/vender/solid/education' import Slider from '@/app/components/base/slider' import type { AgentConfig } from '@/models/debug' @@ -65,7 +65,7 @@ const AgentSetting: FC = ({ + } name={t('appDebug.agent.agentMode')} description={t('appDebug.agent.agentModeDes')} diff --git a/web/app/components/app/configuration/config/agent/agent-setting/item-panel.tsx b/web/app/components/app/configuration/config/agent/agent-setting/item-panel.tsx index 299dcb151db7c3..99c2478b0647ed 100644 --- a/web/app/components/app/configuration/config/agent/agent-setting/item-panel.tsx +++ b/web/app/components/app/configuration/config/agent/agent-setting/item-panel.tsx @@ -1,7 +1,6 @@ 'use client' import type { FC } from 'react' import React from 'react' -import { RiQuestionLine } from '@remixicon/react' import cn from '@/utils/classnames' import Tooltip from '@/app/components/base/tooltip' type Props = { @@ -25,14 +24,12 @@ const ItemPanel: FC = ({ {icon}
{name}
{description}
} - selector={`agent-setting-tooltip-${name}`} > -
diff --git a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx index 16f2257c38d175..52e5d5d906d0bd 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx @@ -7,13 +7,11 @@ import produce from 'immer' import { RiDeleteBinLine, RiHammerFill, - RiQuestionLine, } from '@remixicon/react' import { useFormattingChangedDispatcher } from '../../../debug/hooks' import SettingBuiltInTool from './setting-built-in-tool' import cn from '@/utils/classnames' import Panel from '@/app/components/app/configuration/base/feature-panel' -import Tooltip from '@/app/components/base/tooltip' import { InfoCircle } from '@/app/components/base/icons/src/vender/line/general' import OperationBtn from '@/app/components/app/configuration/base/operation-btn' import AppIcon from '@/app/components/base/app-icon' @@ -23,7 +21,7 @@ import type { AgentTool } from '@/types/app' import { type Collection, CollectionType } from '@/app/components/tools/types' import { MAX_TOOLS_NUM } from '@/config' import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { DefaultToolIcon } from '@/app/components/base/icons/src/public/other' import AddToolModal from '@/app/components/tools/add-tool-modal' @@ -60,7 +58,7 @@ const AgentTools: FC = () => { return ( <> @@ -68,11 +66,13 @@ const AgentTools: FC = () => { title={
{t('appDebug.agent.tools.name')}
- - {t('appDebug.agent.tools.description')} -
} selector='config-tools-tooltip'> - - + + {t('appDebug.agent.tools.description')} +
+ } + />
} headerRight={ @@ -119,19 +119,20 @@ const AgentTools: FC = () => { className={cn((item.isDeleted || item.notAuthor) ? 'line-through opacity-50' : '', 'grow w-0 ml-2 leading-[18px] text-[13px] font-medium text-gray-800 truncate')} > {item.provider_type === CollectionType.builtIn ? item.provider_name : item.tool_label} - {item.tool_name} - +
{(item.isDeleted || item.notAuthor) ? (
-
{ if (item.notAuthor) @@ -139,7 +140,7 @@ const AgentTools: FC = () => { }}>
-
+
{ const newModelConfig = produce(modelConfig, (draft) => { @@ -155,16 +156,17 @@ const AgentTools: FC = () => { ) : (
- -
{ +
{ setCurrentTool(item) setIsShowSettingTool(true) }}>
- +
{ const newModelConfig = produce(modelConfig, (draft) => { diff --git a/web/app/components/app/configuration/config/assistant-type-picker/index.tsx b/web/app/components/app/configuration/config/assistant-type-picker/index.tsx index 6bdf678f85f1e7..336d736e3b411b 100644 --- a/web/app/components/app/configuration/config/assistant-type-picker/index.tsx +++ b/web/app/components/app/configuration/config/assistant-type-picker/index.tsx @@ -12,7 +12,7 @@ import { } from '@/app/components/base/portal-to-follow-elem' import { BubbleText } from '@/app/components/base/icons/src/vender/solid/education' import Radio from '@/app/components/base/radio/ui' -import { CuteRobote } from '@/app/components/base/icons/src/vender/solid/communication' +import { CuteRobot } from '@/app/components/base/icons/src/vender/solid/communication' import { Settings04 } from '@/app/components/base/icons/src/vender/line/general' import { ArrowUpRight } from '@/app/components/base/icons/src/vender/line/arrows' import type { AgentConfig } from '@/models/debug' @@ -117,7 +117,7 @@ const AssistantTypePicker: FC = ({ > setOpen(v => !v)}>
- {isAgent ? : } + {isAgent ? : }
{t(`appDebug.assistantType.${isAgent ? 'agentAssistant' : 'chatAssistant'}.name`)}
@@ -135,7 +135,7 @@ const AssistantTypePicker: FC = ({ onClick={handleChange} /> = ({ onFinished, }) => { const { t } = useTranslation() - + const { + currentProvider, + currentModel, + } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) const tryList = [ { icon: RiTerminalBoxLine, @@ -190,6 +198,19 @@ const GetAutomaticRes: FC = ({
{t('appDebug.generate.title')}
{t('appDebug.generate.description')}
+
+ + +
{t('appDebug.generate.tryIt')}
@@ -212,7 +233,11 @@ const GetAutomaticRes: FC = ({
{t('appDebug.generate.instruction')}
- -
- ) - : ( -
- )} - {renderQuestions()} - ) : ( -
{t('appDebug.openingStatement.noDataPlaceHolder')}
- )} - - {isShowConfirmAddVar && ( - - )} - -
- - ) -} -export default React.memo(OpeningStatement) diff --git a/web/app/components/base/features/feature-panel/score-slider/base-slider/index.tsx b/web/app/components/base/features/feature-panel/score-slider/base-slider/index.tsx deleted file mode 100644 index 2e08a991226097..00000000000000 --- a/web/app/components/base/features/feature-panel/score-slider/base-slider/index.tsx +++ /dev/null @@ -1,38 +0,0 @@ -import ReactSlider from 'react-slider' -import s from './style.module.css' -import cn from '@/utils/classnames' - -type ISliderProps = { - className?: string - value: number - max?: number - min?: number - step?: number - disabled?: boolean - onChange: (value: number) => void -} - -const Slider: React.FC = ({ className, max, min, step, value, disabled, onChange }) => { - return ( -
-
-
- {(state.valueNow / 100).toFixed(2)} -
-
-
- )} - /> -} - -export default Slider diff --git a/web/app/components/base/features/feature-panel/score-slider/base-slider/style.module.css b/web/app/components/base/features/feature-panel/score-slider/base-slider/style.module.css deleted file mode 100644 index 4e93b39563f40e..00000000000000 --- a/web/app/components/base/features/feature-panel/score-slider/base-slider/style.module.css +++ /dev/null @@ -1,20 +0,0 @@ -.slider { - position: relative; -} - -.slider.disabled { - opacity: 0.6; -} - -.slider-thumb:focus { - outline: none; -} - -.slider-track { - background-color: #528BFF; - height: 2px; -} - -.slider-track-1 { - background-color: #E5E7EB; -} \ No newline at end of file diff --git a/web/app/components/base/features/feature-panel/score-slider/index.tsx b/web/app/components/base/features/feature-panel/score-slider/index.tsx deleted file mode 100644 index 9826cbadcf5d6a..00000000000000 --- a/web/app/components/base/features/feature-panel/score-slider/index.tsx +++ /dev/null @@ -1,46 +0,0 @@ -'use client' -import type { FC } from 'react' -import React from 'react' -import { useTranslation } from 'react-i18next' -import Slider from '@/app/components/app/configuration/toolbox/score-slider/base-slider' - -type Props = { - className?: string - value: number - onChange: (value: number) => void -} - -const ScoreSlider: FC = ({ - className, - value, - onChange, -}) => { - const { t } = useTranslation() - - return ( -
-
- -
-
-
-
0.8
-
·
-
{t('appDebug.feature.annotation.scoreThreshold.easyMatch')}
-
-
-
1.0
-
·
-
{t('appDebug.feature.annotation.scoreThreshold.accurateMatch')}
-
-
-
- ) -} -export default React.memo(ScoreSlider) diff --git a/web/app/components/base/features/feature-panel/speech-to-text/index.tsx b/web/app/components/base/features/feature-panel/speech-to-text/index.tsx deleted file mode 100644 index 2e5e3de439b8a7..00000000000000 --- a/web/app/components/base/features/feature-panel/speech-to-text/index.tsx +++ /dev/null @@ -1,22 +0,0 @@ -'use client' -import React, { type FC } from 'react' -import { useTranslation } from 'react-i18next' -import { Microphone01 } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices' - -const SpeechToTextConfig: FC = () => { - const { t } = useTranslation() - - return ( -
-
- -
-
-
{t('appDebug.feature.speechToText.title')}
-
-
-
{t('appDebug.feature.speechToText.resDes')}
-
- ) -} -export default React.memo(SpeechToTextConfig) diff --git a/web/app/components/base/features/feature-panel/suggested-questions-after-answer/index.tsx b/web/app/components/base/features/feature-panel/suggested-questions-after-answer/index.tsx deleted file mode 100644 index e424c4ead52e2e..00000000000000 --- a/web/app/components/base/features/feature-panel/suggested-questions-after-answer/index.tsx +++ /dev/null @@ -1,30 +0,0 @@ -'use client' -import type { FC } from 'react' -import React from 'react' -import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' -import { MessageSmileSquare } from '@/app/components/base/icons/src/vender/solid/communication' -import TooltipPlus from '@/app/components/base/tooltip-plus' - -const SuggestedQuestionsAfterAnswer: FC = () => { - const { t } = useTranslation() - - return ( -
-
- -
-
-
{t('appDebug.feature.suggestedQuestionsAfterAnswer.title')}
- - - -
-
-
{t('appDebug.feature.suggestedQuestionsAfterAnswer.resDes')}
-
- ) -} -export default React.memo(SuggestedQuestionsAfterAnswer) diff --git a/web/app/components/base/features/feature-panel/text-to-speech/index.tsx b/web/app/components/base/features/feature-panel/text-to-speech/index.tsx deleted file mode 100644 index 2480a19077134a..00000000000000 --- a/web/app/components/base/features/feature-panel/text-to-speech/index.tsx +++ /dev/null @@ -1,62 +0,0 @@ -'use client' -import useSWR from 'swr' -import React from 'react' -import { useTranslation } from 'react-i18next' -import { usePathname } from 'next/navigation' -import { useFeatures } from '../../hooks' -import type { OnFeaturesChange } from '../../types' -import ParamsConfig from './params-config' -import { Speaker } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices' -import { languages } from '@/i18n/language' -import { fetchAppVoices } from '@/service/apps' -import AudioBtn from '@/app/components/base/audio-btn' - -type TextToSpeechProps = { - onChange?: OnFeaturesChange - disabled?: boolean -} -const TextToSpeech = ({ - onChange, - disabled, -}: TextToSpeechProps) => { - const { t } = useTranslation() - const textToSpeech = useFeatures(s => s.features.text2speech) - - const pathname = usePathname() - const matched = pathname.match(/\/app\/([^/]+)/) - const appId = (matched?.length && matched[1]) ? matched[1] : '' - const language = textToSpeech?.language - const languageInfo = languages.find(i => i.value === textToSpeech?.language) - - const voiceItems = useSWR({ appId, language }, fetchAppVoices).data - const voiceItem = voiceItems?.find(item => item.value === textToSpeech?.voice) - - return ( -
-
- -
-
- {t('appDebug.feature.textToSpeech.title')} -
-
-
-
- {languageInfo && (`${languageInfo?.name} - `)}{voiceItem?.name ?? t('appDebug.voice.defaultDisplay')} - { languageInfo?.example && ( - - )} -
-
- -
-
- ) -} -export default React.memo(TextToSpeech) diff --git a/web/app/components/base/features/feature-panel/text-to-speech/param-config-content.tsx b/web/app/components/base/features/feature-panel/text-to-speech/param-config-content.tsx deleted file mode 100644 index a5a2eb7bb70e34..00000000000000 --- a/web/app/components/base/features/feature-panel/text-to-speech/param-config-content.tsx +++ /dev/null @@ -1,241 +0,0 @@ -'use client' -import useSWR from 'swr' -import produce from 'immer' -import React, { Fragment } from 'react' -import { - RiQuestionLine, -} from '@remixicon/react' -import { usePathname } from 'next/navigation' -import { useTranslation } from 'react-i18next' -import { Listbox, Transition } from '@headlessui/react' -import { CheckIcon, ChevronDownIcon } from '@heroicons/react/20/solid' -import { - useFeatures, - useFeaturesStore, -} from '../../hooks' -import type { OnFeaturesChange } from '../../types' -import classNames from '@/utils/classnames' -import type { Item } from '@/app/components/base/select' -import { fetchAppVoices } from '@/service/apps' -import Tooltip from '@/app/components/base/tooltip' -import { languages } from '@/i18n/language' -import RadioGroup from '@/app/components/app/configuration/config-vision/radio-group' -import { TtsAutoPlay } from '@/types/app' - -type VoiceParamConfigProps = { - onChange?: OnFeaturesChange -} -const VoiceParamConfig = ({ - onChange, -}: VoiceParamConfigProps) => { - const { t } = useTranslation() - const pathname = usePathname() - const matched = pathname.match(/\/app\/([^/]+)/) - const appId = (matched?.length && matched[1]) ? matched[1] : '' - const text2speech = useFeatures(state => state.features.text2speech) - const featuresStore = useFeaturesStore() - - let languageItem = languages.find(item => item.value === text2speech?.language) - if (languages && !languageItem) - languageItem = languages[0] - const localLanguagePlaceholder = languageItem?.name || t('common.placeholder.select') - - const language = languageItem?.value - const voiceItems = useSWR({ appId, language }, fetchAppVoices).data - let voiceItem = voiceItems?.find(item => item.value === text2speech?.voice) - if (voiceItems && !voiceItem) - voiceItem = voiceItems[0] - const localVoicePlaceholder = voiceItem?.name || t('common.placeholder.select') - - const handleChange = (value: Record) => { - const { - features, - setFeatures, - } = featuresStore!.getState() - - const newFeatures = produce(features, (draft) => { - draft.text2speech = { - ...draft.text2speech, - ...value, - } - }) - - setFeatures(newFeatures) - if (onChange) - onChange(newFeatures) - } - - return ( -
-
-
{t('appDebug.voice.voiceSettings.title')}
-
-
-
-
{t('appDebug.voice.voiceSettings.language')}
- - {t('appDebug.voice.voiceSettings.resolutionTooltip').split('\n').map(item => ( -
{item}
- ))} -
} selector='config-resolution-tooltip'> - - -
- { - handleChange({ - language: String(value.value), - }) - }} - > -
- - - {languageItem?.name ? t(`common.voice.language.${languageItem?.value.replace('-', '')}`) : localLanguagePlaceholder} - - - - - - - - {languages.map((item: Item) => ( - - `relative cursor-pointer select-none py-2 pl-3 pr-9 rounded-lg hover:bg-gray-100 text-gray-700 ${active ? 'bg-gray-100' : '' - }` - } - value={item} - disabled={false} - > - {({ /* active, */ selected }) => ( - <> - {t(`common.voice.language.${(item.value).toString().replace('-', '')}`)} - {(selected || item.value === text2speech?.language) && ( - - - )} - - )} - - ))} - - -
-
-
- -
-
{t('appDebug.voice.voiceSettings.voice')}
- { - handleChange({ - voice: String(value.value), - }) - }} - > -
- - {voiceItem?.name ?? localVoicePlaceholder} - - - - - - - {voiceItems?.map((item: Item) => ( - - `relative cursor-pointer select-none py-2 pl-3 pr-9 rounded-lg hover:bg-gray-100 text-gray-700 ${active ? 'bg-gray-100' : '' - }` - } - value={item} - disabled={false} - > - {({ /* active, */ selected }) => ( - <> - {item.name} - {(selected || item.value === text2speech?.voice) && ( - - - )} - - )} - - ))} - - -
-
-
-
-
{t('appDebug.voice.voiceSettings.autoPlay')}
- { - handleChange({ - autoPlay: value, - }) - }} - /> -
-
-
-
- ) -} - -export default React.memo(VoiceParamConfig) diff --git a/web/app/components/base/features/feature-panel/text-to-speech/params-config.tsx b/web/app/components/base/features/feature-panel/text-to-speech/params-config.tsx deleted file mode 100644 index 095fd6cce86535..00000000000000 --- a/web/app/components/base/features/feature-panel/text-to-speech/params-config.tsx +++ /dev/null @@ -1,48 +0,0 @@ -'use client' -import { memo, useState } from 'react' -import { useTranslation } from 'react-i18next' -import type { OnFeaturesChange } from '../../types' -import ParamConfigContent from './param-config-content' -import cn from '@/utils/classnames' -import { Settings01 } from '@/app/components/base/icons/src/vender/line/general' -import { - PortalToFollowElem, - PortalToFollowElemContent, - PortalToFollowElemTrigger, -} from '@/app/components/base/portal-to-follow-elem' - -type ParamsConfigProps = { - onChange?: OnFeaturesChange - disabled?: boolean -} -const ParamsConfig = ({ - onChange, - disabled, -}: ParamsConfigProps) => { - const { t } = useTranslation() - const [open, setOpen] = useState(false) - - return ( - - !disabled && setOpen(v => !v)}> -
- -
{t('appDebug.voice.settings')}
-
-
- -
- -
-
-
- ) -} -export default memo(ParamsConfig) diff --git a/web/app/components/base/features/index.tsx b/web/app/components/base/features/index.tsx index 13bffb3669fabc..daea711c075872 100644 --- a/web/app/components/base/features/index.tsx +++ b/web/app/components/base/features/index.tsx @@ -1,3 +1 @@ -export { default as FeaturesPanel } from './feature-panel' -export { default as FeaturesChoose } from './feature-choose' export { FeaturesProvider } from './context' diff --git a/web/app/components/base/features/new-feature-panel/annotation-reply/annotation-ctrl-btn/index.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/annotation-ctrl-btn/index.tsx new file mode 100644 index 00000000000000..809b907d627adf --- /dev/null +++ b/web/app/components/base/features/new-feature-panel/annotation-reply/annotation-ctrl-btn/index.tsx @@ -0,0 +1,135 @@ +'use client' +import type { FC } from 'react' +import React, { useRef, useState } from 'react' +import { useHover } from 'ahooks' +import { useTranslation } from 'react-i18next' +import cn from '@/utils/classnames' +import { MessageCheckRemove, MessageFastPlus } from '@/app/components/base/icons/src/vender/line/communication' +import { MessageFast } from '@/app/components/base/icons/src/vender/solid/communication' +import { Edit04 } from '@/app/components/base/icons/src/vender/line/general' +import RemoveAnnotationConfirmModal from '@/app/components/app/annotation/remove-annotation-confirm-modal' +import Tooltip from '@/app/components/base/tooltip' +import { addAnnotation, delAnnotation } from '@/service/annotation' +import Toast from '@/app/components/base/toast' +import { useProviderContext } from '@/context/provider-context' +import { useModalContext } from '@/context/modal-context' + +type Props = { + appId: string + messageId?: string + annotationId?: string + className?: string + cached: boolean + query: string + answer: string + onAdded: (annotationId: string, authorName: string) => void + onEdit: () => void + onRemoved: () => void +} + +const CacheCtrlBtn: FC = ({ + className, + cached, + query, + answer, + appId, + messageId, + annotationId, + onAdded, + onEdit, + onRemoved, +}) => { + const { t } = useTranslation() + const { plan, enableBilling } = useProviderContext() + const isAnnotationFull = (enableBilling && plan.usage.annotatedResponse >= plan.total.annotatedResponse) + const { setShowAnnotationFullModal } = useModalContext() + const [showModal, setShowModal] = useState(false) + const cachedBtnRef = useRef(null) + const isCachedBtnHovering = useHover(cachedBtnRef) + const handleAdd = async () => { + if (isAnnotationFull) { + setShowAnnotationFullModal() + return + } + const res: any = await addAnnotation(appId, { + message_id: messageId, + question: query, + answer, + }) + Toast.notify({ + message: t('common.api.actionSuccess') as string, + type: 'success', + }) + onAdded(res.id, res.account?.name) + } + + const handleRemove = async () => { + await delAnnotation(appId, annotationId!) + Toast.notify({ + message: t('common.api.actionSuccess') as string, + type: 'success', + }) + onRemoved() + setShowModal(false) + } + return ( +
+
+ {cached + ? ( +
+
setShowModal(true)} + > + {!isCachedBtnHovering + ? ( + <> + +
{t('appDebug.feature.annotation.cached')}
+ + ) + : <> + +
{t('appDebug.feature.annotation.remove')}
+ } +
+
+ ) + : answer + ? ( + +
+ +
+
+ ) + : null + } + +
+ +
+
+ +
+ setShowModal(false)} + onRemove={handleRemove} + /> +
+ ) +} +export default React.memo(CacheCtrlBtn) diff --git a/web/app/components/app/configuration/toolbox/annotation/config-param-modal.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal.tsx similarity index 99% rename from web/app/components/app/configuration/toolbox/annotation/config-param-modal.tsx rename to web/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal.tsx index b660977d084156..801f1348ee240a 100644 --- a/web/app/components/app/configuration/toolbox/annotation/config-param-modal.tsx +++ b/web/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal.tsx @@ -2,7 +2,7 @@ import type { FC } from 'react' import React, { useState } from 'react' import { useTranslation } from 'react-i18next' -import ScoreSlider from '../score-slider' +import ScoreSlider from './score-slider' import { Item } from './config-param' import Modal from '@/app/components/base/modal' import Button from '@/app/components/base/button' diff --git a/web/app/components/base/features/new-feature-panel/annotation-reply/config-param.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/config-param.tsx new file mode 100644 index 00000000000000..8b3a0af2403810 --- /dev/null +++ b/web/app/components/base/features/new-feature-panel/annotation-reply/config-param.tsx @@ -0,0 +1,24 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import Tooltip from '@/app/components/base/tooltip' + +export const Item: FC<{ title: string; tooltip: string; children: JSX.Element }> = ({ + title, + tooltip, + children, +}) => { + return ( +
+
+
{title}
+ {tooltip}
+ } + /> +
+
{children}
+
+ ) +} diff --git a/web/app/components/base/features/new-feature-panel/annotation-reply/index.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/index.tsx new file mode 100644 index 00000000000000..f44aab5b9cb1c3 --- /dev/null +++ b/web/app/components/base/features/new-feature-panel/annotation-reply/index.tsx @@ -0,0 +1,152 @@ +import React, { useCallback, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { usePathname, useRouter } from 'next/navigation' +import produce from 'immer' +import { RiEqualizer2Line, RiExternalLinkLine } from '@remixicon/react' +import { MessageFast } from '@/app/components/base/icons/src/vender/features' +import FeatureCard from '@/app/components/base/features/new-feature-panel/feature-card' +import Button from '@/app/components/base/button' +import { useFeatures, useFeaturesStore } from '@/app/components/base/features/hooks' +import type { OnFeaturesChange } from '@/app/components/base/features/types' +import useAnnotationConfig from '@/app/components/base/features/new-feature-panel/annotation-reply/use-annotation-config' +import ConfigParamModal from '@/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal' +import AnnotationFullModal from '@/app/components/billing/annotation-full/modal' +import { ANNOTATION_DEFAULT } from '@/config' + +type Props = { + disabled?: boolean + onChange?: OnFeaturesChange +} + +const AnnotationReply = ({ + disabled, + onChange, +}: Props) => { + const { t } = useTranslation() + const router = useRouter() + const pathname = usePathname() + const matched = pathname.match(/\/app\/([^/]+)/) + const appId = (matched?.length && matched[1]) ? matched[1] : '' + const featuresStore = useFeaturesStore() + const annotationReply = useFeatures(s => s.features.annotationReply) + + const updateAnnotationReply = useCallback((newConfig: any) => { + const { + features, + setFeatures, + } = featuresStore!.getState() + const newFeatures = produce(features, (draft) => { + draft.annotationReply = newConfig + }) + setFeatures(newFeatures) + if (onChange) + onChange(newFeatures) + }, [featuresStore, onChange]) + + const { + handleEnableAnnotation, + handleDisableAnnotation, + isShowAnnotationConfigInit, + setIsShowAnnotationConfigInit, + isShowAnnotationFullModal, + setIsShowAnnotationFullModal, + } = useAnnotationConfig({ + appId, + annotationConfig: annotationReply as any || { + id: '', + enabled: false, + score_threshold: ANNOTATION_DEFAULT.score_threshold, + embedding_model: { + embedding_provider_name: '', + embedding_model_name: '', + }, + }, + setAnnotationConfig: updateAnnotationReply, + }) + + const handleSwitch = useCallback((enabled: boolean) => { + if (enabled) + setIsShowAnnotationConfigInit(true) + else + handleDisableAnnotation(annotationReply?.embedding_model as any) + }, [annotationReply?.embedding_model, handleDisableAnnotation, setIsShowAnnotationConfigInit]) + + const [isHovering, setIsHovering] = useState(false) + + return ( + <> + + +
+ } + title={t('appDebug.feature.annotation.title')} + value={!!annotationReply?.enabled} + onChange={state => handleSwitch(state)} + onMouseEnter={() => setIsHovering(true)} + onMouseLeave={() => setIsHovering(false)} + disabled={disabled} + > + <> + {!annotationReply?.enabled && ( +
{t('appDebug.feature.annotation.description')}
+ )} + {!!annotationReply?.enabled && ( + <> + {!isHovering && ( +
+
+
{t('appDebug.feature.annotation.scoreThreshold.title')}
+
{annotationReply.score_threshold || '-'}
+
+
+
+
{t('common.modelProvider.embeddingModel.key')}
+
{annotationReply.embedding_model?.embedding_model_name}
+
+
+ )} + {isHovering && ( +
+ + +
+ )} + + )} + + + { + setIsShowAnnotationConfigInit(false) + // showChooseFeatureTrue() + }} + onSave={async (embeddingModel, score) => { + await handleEnableAnnotation(embeddingModel, score) + setIsShowAnnotationConfigInit(false) + }} + annotationConfig={annotationReply as any} + /> + {isShowAnnotationFullModal && ( + setIsShowAnnotationFullModal(false)} + /> + )} + + ) +} + +export default AnnotationReply diff --git a/web/app/components/app/configuration/toolbox/score-slider/base-slider/index.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/index.tsx similarity index 100% rename from web/app/components/app/configuration/toolbox/score-slider/base-slider/index.tsx rename to web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/index.tsx diff --git a/web/app/components/app/configuration/toolbox/score-slider/base-slider/style.module.css b/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/style.module.css similarity index 100% rename from web/app/components/app/configuration/toolbox/score-slider/base-slider/style.module.css rename to web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider/style.module.css diff --git a/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/index.tsx b/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/index.tsx new file mode 100644 index 00000000000000..d68db9be736e88 --- /dev/null +++ b/web/app/components/base/features/new-feature-panel/annotation-reply/score-slider/index.tsx @@ -0,0 +1,46 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { useTranslation } from 'react-i18next' +import Slider from '@/app/components/base/features/new-feature-panel/annotation-reply/score-slider/base-slider' + +type Props = { + className?: string + value: number + onChange: (value: number) => void +} + +const ScoreSlider: FC = ({ + className, + value, + onChange, +}) => { + const { t } = useTranslation() + + return ( +
+
+ +
+
+
+
0.8
+
·
+
{t('appDebug.feature.annotation.scoreThreshold.easyMatch')}
+
+
+
1.0
+
·
+
{t('appDebug.feature.annotation.scoreThreshold.accurateMatch')}
+
+
+
+ ) +} +export default React.memo(ScoreSlider) diff --git a/web/app/components/app/configuration/toolbox/annotation/type.ts b/web/app/components/base/features/new-feature-panel/annotation-reply/type.ts similarity index 100% rename from web/app/components/app/configuration/toolbox/annotation/type.ts rename to web/app/components/base/features/new-feature-panel/annotation-reply/type.ts diff --git a/web/app/components/app/configuration/toolbox/annotation/use-annotation-config.ts b/web/app/components/base/features/new-feature-panel/annotation-reply/use-annotation-config.ts similarity index 100% rename from web/app/components/app/configuration/toolbox/annotation/use-annotation-config.ts rename to web/app/components/base/features/new-feature-panel/annotation-reply/use-annotation-config.ts diff --git a/web/app/components/base/features/new-feature-panel/citation.tsx b/web/app/components/base/features/new-feature-panel/citation.tsx new file mode 100644 index 00000000000000..a0b702e9f973f2 --- /dev/null +++ b/web/app/components/base/features/new-feature-panel/citation.tsx @@ -0,0 +1,56 @@ +import React, { useCallback } from 'react' +import { useTranslation } from 'react-i18next' +import produce from 'immer' +import { Citations } from '@/app/components/base/icons/src/vender/features' +import FeatureCard from '@/app/components/base/features/new-feature-panel/feature-card' +import { useFeatures, useFeaturesStore } from '@/app/components/base/features/hooks' +import type { OnFeaturesChange } from '@/app/components/base/features/types' +import { FeatureEnum } from '@/app/components/base/features/types' + +type Props = { + disabled?: boolean + onChange?: OnFeaturesChange +} + +const Citation = ({ + disabled, + onChange, +}: Props) => { + const { t } = useTranslation() + const features = useFeatures(s => s.features) + const featuresStore = useFeaturesStore() + + const handleChange = useCallback((type: FeatureEnum, enabled: boolean) => { + const { + features, + setFeatures, + } = featuresStore!.getState() + + const newFeatures = produce(features, (draft) => { + draft[type] = { + ...draft[type], + enabled, + } + }) + setFeatures(newFeatures) + if (onChange) + onChange(newFeatures) + }, [featuresStore, onChange]) + + return ( + + +
+ } + title={t('appDebug.feature.citation.title')} + value={!!features.citation?.enabled} + description={t('appDebug.feature.citation.description')!} + onChange={state => handleChange(FeatureEnum.citation, state)} + disabled={disabled} + /> + ) +} + +export default Citation diff --git a/web/app/components/base/features/new-feature-panel/conversation-opener/index.tsx b/web/app/components/base/features/new-feature-panel/conversation-opener/index.tsx new file mode 100644 index 00000000000000..ab6b3ec6dbeb83 --- /dev/null +++ b/web/app/components/base/features/new-feature-panel/conversation-opener/index.tsx @@ -0,0 +1,119 @@ +import React, { useCallback, useState } from 'react' +import { useTranslation } from 'react-i18next' +import produce from 'immer' +import { RiEditLine } from '@remixicon/react' +import { LoveMessage } from '@/app/components/base/icons/src/vender/features' +import FeatureCard from '@/app/components/base/features/new-feature-panel/feature-card' +import Button from '@/app/components/base/button' +import { useFeatures, useFeaturesStore } from '@/app/components/base/features/hooks' +import type { OnFeaturesChange } from '@/app/components/base/features/types' +import { FeatureEnum } from '@/app/components/base/features/types' +import { useModalContext } from '@/context/modal-context' +import type { PromptVariable } from '@/models/debug' +import type { InputVar } from '@/app/components/workflow/types' + +type Props = { + disabled?: boolean + onChange?: OnFeaturesChange + promptVariables?: PromptVariable[] + workflowVariables?: InputVar[] + onAutoAddPromptVariable?: (variable: PromptVariable[]) => void +} + +const ConversationOpener = ({ + disabled, + onChange, + promptVariables, + workflowVariables, + onAutoAddPromptVariable, +}: Props) => { + const { t } = useTranslation() + const { setShowOpeningModal } = useModalContext() + const opening = useFeatures(s => s.features.opening) + const featuresStore = useFeaturesStore() + const [isHovering, setIsHovering] = useState(false) + const handleOpenOpeningModal = useCallback(() => { + if (disabled) + return + const { + features, + setFeatures, + } = featuresStore!.getState() + setShowOpeningModal({ + payload: { + ...opening, + promptVariables, + workflowVariables, + onAutoAddPromptVariable, + }, + onSaveCallback: (newOpening) => { + const newFeatures = produce(features, (draft) => { + draft.opening = newOpening + }) + setFeatures(newFeatures) + if (onChange) + onChange() + }, + onCancelCallback: () => { + if (onChange) + onChange() + }, + }) + }, [disabled, featuresStore, onAutoAddPromptVariable, onChange, opening, promptVariables, setShowOpeningModal]) + + const handleChange = useCallback((type: FeatureEnum, enabled: boolean) => { + const { + features, + setFeatures, + } = featuresStore!.getState() + + const newFeatures = produce(features, (draft) => { + draft[type] = { + ...draft[type], + enabled, + } + }) + setFeatures(newFeatures) + if (onChange) + onChange() + }, [featuresStore, onChange]) + + return ( + + +
+ } + title={t('appDebug.feature.conversationOpener.title')} + value={!!opening?.enabled} + onChange={state => handleChange(FeatureEnum.opening, state)} + onMouseEnter={() => setIsHovering(true)} + onMouseLeave={() => setIsHovering(false)} + disabled={disabled} + > + <> + {!opening?.enabled && ( +
{t('appDebug.feature.conversationOpener.description')}
+ )} + {!!opening?.enabled && ( + <> + {!isHovering && ( +
+ {opening.opening_statement || t('appDebug.openingStatement.placeholder')} +
+ )} + {isHovering && ( + + )} + + )} + + + ) +} + +export default ConversationOpener diff --git a/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx b/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx new file mode 100644 index 00000000000000..9f25d0fa11becf --- /dev/null +++ b/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx @@ -0,0 +1,206 @@ +import React, { useCallback, useEffect, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { useBoolean } from 'ahooks' +import produce from 'immer' +import { ReactSortable } from 'react-sortablejs' +import { RiAddLine, RiAsterisk, RiCloseLine, RiDeleteBinLine, RiDraggable } from '@remixicon/react' +import Modal from '@/app/components/base/modal' +import Button from '@/app/components/base/button' +import ConfirmAddVar from '@/app/components/app/configuration/config-prompt/confirm-add-var' +import type { OpeningStatement } from '@/app/components/base/features/types' +import { getInputKeys } from '@/app/components/base/block-input' +import type { PromptVariable } from '@/models/debug' +import type { InputVar } from '@/app/components/workflow/types' +import { getNewVar } from '@/utils/var' + +type OpeningSettingModalProps = { + data: OpeningStatement + onSave: (newState: OpeningStatement) => void + onCancel: () => void + promptVariables?: PromptVariable[] + workflowVariables?: InputVar[] + onAutoAddPromptVariable?: (variable: PromptVariable[]) => void +} + +const MAX_QUESTION_NUM = 5 + +const OpeningSettingModal = ({ + data, + onSave, + onCancel, + promptVariables = [], + workflowVariables = [], + onAutoAddPromptVariable, +}: OpeningSettingModalProps) => { + const { t } = useTranslation() + const [tempValue, setTempValue] = useState(data?.opening_statement || '') + useEffect(() => { + setTempValue(data.opening_statement || '') + }, [data.opening_statement]) + const [tempSuggestedQuestions, setTempSuggestedQuestions] = useState(data.suggested_questions || []) + const [isShowConfirmAddVar, { setTrue: showConfirmAddVar, setFalse: hideConfirmAddVar }] = useBoolean(false) + const [notIncludeKeys, setNotIncludeKeys] = useState([]) + + const handleSave = useCallback((ignoreVariablesCheck?: boolean) => { + if (!ignoreVariablesCheck) { + const keys = getInputKeys(tempValue) + const promptKeys = promptVariables.map(item => item.key) + const workflowVariableKeys = workflowVariables.map(item => item.variable) + let notIncludeKeys: string[] = [] + + if (promptKeys.length === 0 && workflowVariables.length === 0) { + if (keys.length > 0) + notIncludeKeys = keys + } + else { + if (workflowVariables.length > 0) + notIncludeKeys = keys.filter(key => !workflowVariableKeys.includes(key)) + else notIncludeKeys = keys.filter(key => !promptKeys.includes(key)) + } + + if (notIncludeKeys.length > 0) { + setNotIncludeKeys(notIncludeKeys) + showConfirmAddVar() + return + } + } + const newOpening = produce(data, (draft) => { + if (draft) { + draft.opening_statement = tempValue + draft.suggested_questions = tempSuggestedQuestions + } + }) + onSave(newOpening) + }, [data, onSave, promptVariables, workflowVariables, showConfirmAddVar, tempSuggestedQuestions, tempValue]) + + const cancelAutoAddVar = useCallback(() => { + hideConfirmAddVar() + handleSave(true) + }, [handleSave, hideConfirmAddVar]) + + const autoAddVar = useCallback(() => { + onAutoAddPromptVariable?.([ + ...notIncludeKeys.map(key => getNewVar(key, 'string')), + ]) + hideConfirmAddVar() + handleSave(true) + }, [handleSave, hideConfirmAddVar, notIncludeKeys, onAutoAddPromptVariable]) + + const renderQuestions = () => { + return ( +
+
+
+
{t('appDebug.openingStatement.openingQuestion')}
+
·
+
{tempSuggestedQuestions.length}/{MAX_QUESTION_NUM}
+
+
+
+ { + return { + id: index, + name, + } + })} + setList={list => setTempSuggestedQuestions(list.map(item => item.name))} + handle='.handle' + ghostClass="opacity-50" + animation={150} + > + {tempSuggestedQuestions.map((question, index) => { + return ( +
+ + { + const value = e.target.value + setTempSuggestedQuestions(tempSuggestedQuestions.map((item, i) => { + if (index === i) + return value + + return item + })) + }} + className={'w-full overflow-x-auto pl-1.5 pr-8 text-sm leading-9 text-gray-900 border-0 grow h-9 bg-transparent focus:outline-none cursor-pointer rounded-lg'} + /> + +
{ + setTempSuggestedQuestions(tempSuggestedQuestions.filter((_, i) => index !== i)) + }} + > + +
+
+ ) + })}
+ {tempSuggestedQuestions.length < MAX_QUESTION_NUM && ( +
{ setTempSuggestedQuestions([...tempSuggestedQuestions, '']) }} + className='mt-1 flex items-center h-9 px-3 gap-2 rounded-lg cursor-pointer text-gray-400 bg-gray-100 hover:bg-gray-200'> + +
{t('appDebug.variableConfig.addOption')}
+
+ )} +
+ ) + } + + return ( + { }} + className='!p-6 !mt-14 !max-w-none !w-[640px] !bg-components-panel-bg-blur' + > +
+
{t('appDebug.feature.conversationOpener.title')}
+
+
+
+
+ +
+
+ + // + // + const { onSend } = useChatContext() + + const getFormValues = (children: any) => { + const formValues: { [key: string]: any } = {} + children.forEach((child: any) => { + if (child.tagName === SUPPORTED_TAGS.INPUT) + formValues[child.properties.name] = child.properties.value + if (child.tagName === SUPPORTED_TAGS.TEXTAREA) + formValues[child.properties.name] = child.properties.value + }) + return formValues + } + const onSubmit = (e: any) => { + e.preventDefault() + const format = node.properties.dataFormat || DATA_FORMAT.TEXT + const result = getFormValues(node.children) + if (format === DATA_FORMAT.JSON) { + onSend?.(JSON.stringify(result)) + } + else { + const textResult = Object.entries(result) + .map(([key, value]) => `${key}: ${value}`) + .join('\n') + onSend?.(textResult) + } + } + return ( +
{ + e.preventDefault() + e.stopPropagation() + }} + > + {node.children.filter((i: any) => i.type === 'element').map((child: any, index: number) => { + if (child.tagName === SUPPORTED_TAGS.LABEL) { + return ( + + ) + } + if (child.tagName === SUPPORTED_TAGS.INPUT) { + if (Object.values(SUPPORTED_TYPES).includes(child.properties.type)) { + return ( + { + e.preventDefault() + child.properties.value = e.target.value + }} + /> + ) + } + else { + return

Unsupported input type: {child.properties.type}

+ } + } + if (child.tagName === SUPPORTED_TAGS.TEXTAREA) { + return ( + + ) + }, +) +Textarea.displayName = 'Textarea' + +export default Textarea +export { Textarea, textareaVariants } diff --git a/web/app/components/base/tooltip-plus/index.tsx b/web/app/components/base/tooltip-plus/index.tsx deleted file mode 100644 index 1f8a091fa5c9d5..00000000000000 --- a/web/app/components/base/tooltip-plus/index.tsx +++ /dev/null @@ -1,109 +0,0 @@ -'use client' -import type { FC } from 'react' -import React, { useEffect, useRef, useState } from 'react' -import { useBoolean } from 'ahooks' -import type { OffsetOptions, Placement } from '@floating-ui/react' -import cn from '@/utils/classnames' -import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem' -export type TooltipProps = { - position?: Placement - triggerMethod?: 'hover' | 'click' - disabled?: boolean - popupContent: React.ReactNode - children: React.ReactNode - hideArrow?: boolean - popupClassName?: string - offset?: OffsetOptions - asChild?: boolean -} - -const arrow = ( - -) - -const Tooltip: FC = ({ - position = 'top', - triggerMethod = 'hover', - disabled = false, - popupContent, - children, - hideArrow, - popupClassName, - offset, - asChild, -}) => { - const [open, setOpen] = useState(false) - const [isHoverPopup, { - setTrue: setHoverPopup, - setFalse: setNotHoverPopup, - }] = useBoolean(false) - - const isHoverPopupRef = useRef(isHoverPopup) - useEffect(() => { - isHoverPopupRef.current = isHoverPopup - }, [isHoverPopup]) - - const [isHoverTrigger, { - setTrue: setHoverTrigger, - setFalse: setNotHoverTrigger, - }] = useBoolean(false) - - const isHoverTriggerRef = useRef(isHoverTrigger) - useEffect(() => { - isHoverTriggerRef.current = isHoverTrigger - }, [isHoverTrigger]) - - const handleLeave = (isTrigger: boolean) => { - if (isTrigger) - setNotHoverTrigger() - - else - setNotHoverPopup() - - // give time to move to the popup - setTimeout(() => { - if (!isHoverPopupRef.current && !isHoverTriggerRef.current) - setOpen(false) - }, 500) - } - - return ( - - triggerMethod === 'click' && setOpen(v => !v)} - onMouseEnter={() => { - if (triggerMethod === 'hover') { - setHoverTrigger() - setOpen(true) - } - }} - onMouseLeave={() => triggerMethod === 'hover' && handleLeave(true)} - asChild={asChild} - > - {children} - - -
triggerMethod === 'hover' && setHoverPopup()} - onMouseLeave={() => triggerMethod === 'hover' && handleLeave(false)} - > - {popupContent} - {!hideArrow && arrow} -
-
-
- ) -} - -export default React.memo(Tooltip) diff --git a/web/app/components/base/tooltip/index.tsx b/web/app/components/base/tooltip/index.tsx index e7795c653793e4..f3b4cff1326998 100644 --- a/web/app/components/base/tooltip/index.tsx +++ b/web/app/components/base/tooltip/index.tsx @@ -1,52 +1,112 @@ 'use client' import type { FC } from 'react' -import React from 'react' -import { Tooltip as ReactTooltip } from 'react-tooltip' // fixed version to 5.8.3 https://github.com/ReactTooltip/react-tooltip/issues/972 -import classNames from '@/utils/classnames' -import 'react-tooltip/dist/react-tooltip.css' - -type TooltipProps = { - selector: string - content?: string +import React, { useEffect, useRef, useState } from 'react' +import { useBoolean } from 'ahooks' +import type { OffsetOptions, Placement } from '@floating-ui/react' +import { RiQuestionLine } from '@remixicon/react' +import cn from '@/utils/classnames' +import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem' +export type TooltipProps = { + position?: Placement + triggerMethod?: 'hover' | 'click' + triggerClassName?: string disabled?: boolean - htmlContent?: React.ReactNode - className?: string // This should use !impornant to override the default styles eg: '!bg-white' - position?: 'top' | 'right' | 'bottom' | 'left' - clickable?: boolean - children: React.ReactNode - noArrow?: boolean + popupContent?: React.ReactNode + children?: React.ReactNode + popupClassName?: string + offset?: OffsetOptions + needsDelay?: boolean + asChild?: boolean } const Tooltip: FC = ({ - selector, - content, - disabled, position = 'top', + triggerMethod = 'hover', + triggerClassName, + disabled = false, + popupContent, children, - htmlContent, - className, - clickable, - noArrow, + popupClassName, + offset, + asChild = true, + needsDelay = false, }) => { + const [open, setOpen] = useState(false) + const [isHoverPopup, { + setTrue: setHoverPopup, + setFalse: setNotHoverPopup, + }] = useBoolean(false) + + const isHoverPopupRef = useRef(isHoverPopup) + useEffect(() => { + isHoverPopupRef.current = isHoverPopup + }, [isHoverPopup]) + + const [isHoverTrigger, { + setTrue: setHoverTrigger, + setFalse: setNotHoverTrigger, + }] = useBoolean(false) + + const isHoverTriggerRef = useRef(isHoverTrigger) + useEffect(() => { + isHoverTriggerRef.current = isHoverTrigger + }, [isHoverTrigger]) + + const handleLeave = (isTrigger: boolean) => { + if (isTrigger) + setNotHoverTrigger() + + else + setNotHoverPopup() + + // give time to move to the popup + if (needsDelay) { + setTimeout(() => { + if (!isHoverPopupRef.current && !isHoverTriggerRef.current) + setOpen(false) + }, 500) + } + else { + setOpen(false) + } + } + return ( -
- {React.cloneElement(children as React.ReactElement, { - 'data-tooltip-id': selector, - }) - } - + triggerMethod === 'click' && setOpen(v => !v)} + onMouseEnter={() => { + if (triggerMethod === 'hover') { + setHoverTrigger() + setOpen(true) + } + }} + onMouseLeave={() => triggerMethod === 'hover' && handleLeave(true)} + asChild={asChild} + > + {children ||
} +
+ - {htmlContent && htmlContent} -
-
+ {popupContent && (
triggerMethod === 'hover' && setHoverPopup()} + onMouseLeave={() => triggerMethod === 'hover' && handleLeave(false)} + > + {popupContent} +
)} + + ) } -export default Tooltip +export default React.memo(Tooltip) diff --git a/web/app/components/base/topbar/index.tsx b/web/app/components/base/topbar/index.tsx deleted file mode 100644 index cf67456bd3423a..00000000000000 --- a/web/app/components/base/topbar/index.tsx +++ /dev/null @@ -1,16 +0,0 @@ -'use client' - -import { AppProgressBar as ProgressBar } from 'next-nprogress-bar' - -const Topbar = () => { - return ( - <> - - ) -} - -export default Topbar diff --git a/web/app/components/base/video-gallery/VideoPlayer.module.css b/web/app/components/base/video-gallery/VideoPlayer.module.css new file mode 100644 index 00000000000000..04c4a367d62497 --- /dev/null +++ b/web/app/components/base/video-gallery/VideoPlayer.module.css @@ -0,0 +1,188 @@ +.videoPlayer { + position: relative; + width: 100%; + max-width: 800px; + margin: 0 auto; + border-radius: 8px; + overflow: hidden; +} + +.video { + width: 100%; + display: block; +} + +.controls { + position: absolute; + bottom: 0; + left: 0; + right: 0; + width: 100%; + height: 100%; + display: flex; + flex-direction: column; + justify-content: flex-end; + transition: opacity 0.3s ease; +} + +.controls.hidden { + opacity: 0; +} + +.controls.visible { + opacity: 1; +} + +.overlay { + background: linear-gradient(to top, rgba(0, 0, 0, 0.7) 0%, transparent 100%); + padding: 20px; + display: flex; + flex-direction: column; +} + +.progressBarContainer { + width: 100%; + margin-bottom: 10px; +} + +.controlsContent { + display: flex; + justify-content: space-between; + align-items: center; +} + +.leftControls, .rightControls { + display: flex; + align-items: center; +} + +.playPauseButton, .muteButton, .fullscreenButton { + background: none; + border: none; + color: white; + cursor: pointer; + padding: 4px; + margin-right: 10px; + display: flex; + align-items: center; + justify-content: center; +} + +.playPauseButton:hover, .muteButton:hover, .fullscreenButton:hover { + background-color: rgba(255, 255, 255, 0.1); + border-radius: 50%; +} + +.time { + color: white; + font-size: 14px; + margin-left: 8px; +} + +.volumeControl { + display: flex; + align-items: center; + margin-right: 16px; +} + +.volumeSlider { + width: 60px; + height: 4px; + background: rgba(255, 255, 255, 0.3); + border-radius: 2px; + cursor: pointer; + margin-left: 12px; + position: relative; +} + +.volumeLevel { + position: absolute; + top: 0; + left: 0; + height: 100%; + background: #ffffff; + border-radius: 2px; +} + +.progressBar { + position: relative; + width: 100%; + height: 4px; + background: rgba(255, 255, 255, 0.3); + cursor: pointer; + border-radius: 2px; + overflow: visible; + transition: height 0.2s ease; +} + +.progressBar:hover { + height: 6px; +} + +.progress { + height: 100%; + background: #ffffff; + transition: width 0.1s ease-in-out; +} + +.hoverTimeIndicator { + position: absolute; + bottom: 100%; + transform: translateX(-50%); + background-color: rgba(0, 0, 0, 0.7); + color: white; + padding: 4px 8px; + border-radius: 4px; + font-size: 12px; + pointer-events: none; + white-space: nowrap; + margin-bottom: 8px; +} + +.hoverTimeIndicator::after { + content: ''; + position: absolute; + top: 100%; + left: 50%; + margin-left: -4px; + border-width: 4px; + border-style: solid; + border-color: rgba(0, 0, 0, 0.7) transparent transparent transparent; +} + +.controls.smallSize .controlsContent { + justify-content: space-between; +} + +.controls.smallSize .leftControls, +.controls.smallSize .rightControls { + flex: 0 0 auto; + display: flex; + align-items: center; +} + +.controls.smallSize .rightControls { + justify-content: flex-end; +} + +.controls.smallSize .progressBarContainer { + margin-bottom: 4px; +} + +.controls.smallSize .playPauseButton, +.controls.smallSize .muteButton, +.controls.smallSize .fullscreenButton { + padding: 2px; + margin-right: 4px; +} + +.controls.smallSize .playPauseButton svg, +.controls.smallSize .muteButton svg, +.controls.smallSize .fullscreenButton svg { + width: 16px; + height: 16px; +} + +.controls.smallSize .muteButton { + order: -1; +} diff --git a/web/app/components/base/video-gallery/VideoPlayer.tsx b/web/app/components/base/video-gallery/VideoPlayer.tsx new file mode 100644 index 00000000000000..d7c86a1af9702c --- /dev/null +++ b/web/app/components/base/video-gallery/VideoPlayer.tsx @@ -0,0 +1,278 @@ +import React, { useCallback, useEffect, useRef, useState } from 'react' +import styles from './VideoPlayer.module.css' + +type VideoPlayerProps = { + src: string +} + +const PlayIcon = () => ( + + + +) + +const PauseIcon = () => ( + + + +) + +const MuteIcon = () => ( + + + +) + +const UnmuteIcon = () => ( + + + +) + +const FullscreenIcon = () => ( + + + +) + +const VideoPlayer: React.FC = ({ src }) => { + const [isPlaying, setIsPlaying] = useState(false) + const [currentTime, setCurrentTime] = useState(0) + const [duration, setDuration] = useState(0) + const [isMuted, setIsMuted] = useState(false) + const [volume, setVolume] = useState(1) + const [isDragging, setIsDragging] = useState(false) + const [isControlsVisible, setIsControlsVisible] = useState(true) + const [hoverTime, setHoverTime] = useState(null) + const videoRef = useRef(null) + const progressRef = useRef(null) + const volumeRef = useRef(null) + const controlsTimeoutRef = useRef(null) + const [isSmallSize, setIsSmallSize] = useState(false) + const containerRef = useRef(null) + + useEffect(() => { + const video = videoRef.current + if (!video) + return + + const setVideoData = () => { + setDuration(video.duration) + setVolume(video.volume) + } + + const setVideoTime = () => { + setCurrentTime(video.currentTime) + } + + const handleEnded = () => { + setIsPlaying(false) + } + + video.addEventListener('loadedmetadata', setVideoData) + video.addEventListener('timeupdate', setVideoTime) + video.addEventListener('ended', handleEnded) + + return () => { + video.removeEventListener('loadedmetadata', setVideoData) + video.removeEventListener('timeupdate', setVideoTime) + video.removeEventListener('ended', handleEnded) + } + }, [src]) + + useEffect(() => { + return () => { + if (controlsTimeoutRef.current) + clearTimeout(controlsTimeoutRef.current) + } + }, []) + + const showControls = useCallback(() => { + setIsControlsVisible(true) + if (controlsTimeoutRef.current) + clearTimeout(controlsTimeoutRef.current) + + controlsTimeoutRef.current = setTimeout(() => setIsControlsVisible(false), 3000) + }, []) + + const togglePlayPause = useCallback(() => { + const video = videoRef.current + if (video) { + if (isPlaying) + video.pause() + else video.play().catch(error => console.error('Error playing video:', error)) + setIsPlaying(!isPlaying) + } + }, [isPlaying]) + + const toggleMute = useCallback(() => { + const video = videoRef.current + if (video) { + const newMutedState = !video.muted + video.muted = newMutedState + setIsMuted(newMutedState) + setVolume(newMutedState ? 0 : (video.volume > 0 ? video.volume : 1)) + video.volume = newMutedState ? 0 : (video.volume > 0 ? video.volume : 1) + } + }, []) + + const toggleFullscreen = useCallback(() => { + const video = videoRef.current + if (video) { + if (document.fullscreenElement) + document.exitFullscreen() + else video.requestFullscreen() + } + }, []) + + const formatTime = (time: number) => { + const minutes = Math.floor(time / 60) + const seconds = Math.floor(time % 60) + return `${minutes.toString().padStart(2, '0')}:${seconds.toString().padStart(2, '0')}` + } + + const updateVideoProgress = useCallback((clientX: number) => { + const progressBar = progressRef.current + const video = videoRef.current + if (progressBar && video) { + const rect = progressBar.getBoundingClientRect() + const pos = (clientX - rect.left) / rect.width + const newTime = pos * video.duration + if (newTime >= 0 && newTime <= video.duration) { + setHoverTime(newTime) + if (isDragging) + video.currentTime = newTime + } + } + }, [isDragging]) + + const handleMouseMove = useCallback((e: React.MouseEvent) => { + updateVideoProgress(e.clientX) + }, [updateVideoProgress]) + + const handleMouseLeave = useCallback(() => { + if (!isDragging) + setHoverTime(null) + }, [isDragging]) + + const handleMouseDown = useCallback((e: React.MouseEvent) => { + e.preventDefault() + setIsDragging(true) + updateVideoProgress(e.clientX) + }, [updateVideoProgress]) + + useEffect(() => { + const handleGlobalMouseMove = (e: MouseEvent) => { + if (isDragging) + updateVideoProgress(e.clientX) + } + + const handleGlobalMouseUp = () => { + setIsDragging(false) + setHoverTime(null) + } + + if (isDragging) { + document.addEventListener('mousemove', handleGlobalMouseMove) + document.addEventListener('mouseup', handleGlobalMouseUp) + } + + return () => { + document.removeEventListener('mousemove', handleGlobalMouseMove) + document.removeEventListener('mouseup', handleGlobalMouseUp) + } + }, [isDragging, updateVideoProgress]) + + const checkSize = useCallback(() => { + if (containerRef.current) + setIsSmallSize(containerRef.current.offsetWidth < 400) + }, []) + + useEffect(() => { + checkSize() + window.addEventListener('resize', checkSize) + return () => window.removeEventListener('resize', checkSize) + }, [checkSize]) + + const handleVolumeChange = useCallback((e: React.MouseEvent) => { + const volumeBar = volumeRef.current + const video = videoRef.current + if (volumeBar && video) { + const rect = volumeBar.getBoundingClientRect() + const newVolume = (e.clientX - rect.left) / rect.width + const clampedVolume = Math.max(0, Math.min(1, newVolume)) + video.volume = clampedVolume + setVolume(clampedVolume) + setIsMuted(clampedVolume === 0) + } + }, []) + + return ( +
+
) diff --git a/web/app/components/billing/pricing/plan-item.tsx b/web/app/components/billing/pricing/plan-item.tsx index 87a20437c3d210..b6ac17472e7b40 100644 --- a/web/app/components/billing/pricing/plan-item.tsx +++ b/web/app/components/billing/pricing/plan-item.tsx @@ -2,14 +2,11 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' import { useContext } from 'use-context-selector' import { Plan } from '../type' import { ALL_PLANS, NUM_INFINITE, contactSalesUrl, contractSales, unAvailable } from '../config' import Toast from '../../base/toast' -import TooltipPlus from '../../base/tooltip-plus' +import Tooltip from '../../base/tooltip' import { PlanRange } from './select-plan-range' import cn from '@/utils/classnames' import { useAppContext } from '@/context/app-context' @@ -30,13 +27,11 @@ const KeyValue = ({ label, value, tooltip }: { label: string; value: string | nu
{label}
{tooltip && ( - {tooltip}
} - > - - + /> )}
{value}
@@ -136,25 +131,21 @@ const PlanItem: FC = ({
+
{t('billing.plansCommon.supportItems.llmLoadingBalancing')}
- {t('billing.plansCommon.supportItems.llmLoadingBalancingTooltip')}
} - > - - + />
+
 {t('billing.plansCommon.supportItems.ragAPIRequest')}
- {t('billing.plansCommon.ragAPIRequestTooltip')}
} - > - - + />
{comingSoon}
diff --git a/web/app/components/billing/priority-label/index.tsx b/web/app/components/billing/priority-label/index.tsx index d8ad27b6e017ce..36338cf4a8e767 100644 --- a/web/app/components/billing/priority-label/index.tsx +++ b/web/app/components/billing/priority-label/index.tsx @@ -9,7 +9,7 @@ import { ZapFast, ZapNarrow, } from '@/app/components/base/icons/src/vender/solid/general' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' const PriorityLabel = () => { const { t } = useTranslation() @@ -27,7 +27,7 @@ const PriorityLabel = () => { }, [plan]) return ( -
{`${t('billing.plansCommon.documentProcessingPriority')}: ${t(`billing.plansCommon.priority.${priority}`)}`}
{ @@ -53,7 +53,7 @@ const PriorityLabel = () => { } {t(`billing.plansCommon.priority.${priority}`)} -
+ ) } diff --git a/web/app/components/billing/usage-info/index.tsx b/web/app/components/billing/usage-info/index.tsx index e92924958442cb..ee41508ea66e1d 100644 --- a/web/app/components/billing/usage-info/index.tsx +++ b/web/app/components/billing/usage-info/index.tsx @@ -2,7 +2,6 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' -import { InfoCircle } from '../../base/icons/src/vender/line/general' import ProgressBar from '../progress-bar' import { NUM_INFINITE } from '../config' import Tooltip from '@/app/components/base/tooltip' @@ -48,11 +47,13 @@ const UsageInfo: FC = ({
{name}
{tooltip && ( - - {tooltip} - } selector='config-var-tooltip'> - - + + {tooltip} + + } + /> )}
diff --git a/web/app/components/browser-initor.tsx b/web/app/components/browser-initor.tsx index 711ff62a94b02e..939ddd567d348f 100644 --- a/web/app/components/browser-initor.tsx +++ b/web/app/components/browser-initor.tsx @@ -43,10 +43,10 @@ Object.defineProperty(globalThis, 'sessionStorage', { value: sessionStorage, }) -const BrowerInitor = ({ +const BrowserInitor = ({ children, }: { children: React.ReactElement }) => { return children } -export default BrowerInitor +export default BrowserInitor diff --git a/web/app/components/custom/custom-app-header-brand/index.tsx b/web/app/components/custom/custom-app-header-brand/index.tsx deleted file mode 100644 index 9564986c286e7a..00000000000000 --- a/web/app/components/custom/custom-app-header-brand/index.tsx +++ /dev/null @@ -1,62 +0,0 @@ -import { useTranslation } from 'react-i18next' -import s from './style.module.css' -import Button from '@/app/components/base/button' -import { Grid01 } from '@/app/components/base/icons/src/vender/solid/layout' -import { Container, Database01 } from '@/app/components/base/icons/src/vender/line/development' -import { ImagePlus } from '@/app/components/base/icons/src/vender/line/images' -import { useProviderContext } from '@/context/provider-context' -import { Plan } from '@/app/components/billing/type' - -const CustomAppHeaderBrand = () => { - const { t } = useTranslation() - const { plan } = useProviderContext() - - return ( -
-
{t('custom.app.title')}
-
-
-
-
-
-
YOUR LOGO
-
-
-
-
-
-
- -
-
-
- -
-
-
- -
-
-
-
-
-
- -
- -
-
{t('custom.app.changeLogoTip')}
-
- ) -} - -export default CustomAppHeaderBrand diff --git a/web/app/components/custom/custom-app-header-brand/style.module.css b/web/app/components/custom/custom-app-header-brand/style.module.css deleted file mode 100644 index 492733ff9f89ae..00000000000000 --- a/web/app/components/custom/custom-app-header-brand/style.module.css +++ /dev/null @@ -1,3 +0,0 @@ -.mask { - background: linear-gradient(95deg, rgba(255, 255, 255, 0.00) 43.9%, rgba(255, 255, 255, 0.80) 95.76%); ; -} \ No newline at end of file diff --git a/web/app/components/custom/custom-page/index.tsx b/web/app/components/custom/custom-page/index.tsx index c3b1e93da3970a..75d592389d283a 100644 --- a/web/app/components/custom/custom-page/index.tsx +++ b/web/app/components/custom/custom-page/index.tsx @@ -1,6 +1,5 @@ import { useTranslation } from 'react-i18next' import CustomWebAppBrand from '../custom-web-app-brand' -import CustomAppHeaderBrand from '../custom-app-header-brand' import s from '../style.module.css' import GridMask from '@/app/components/base/grid-mask' import UpgradeBtn from '@/app/components/billing/upgrade-btn' @@ -13,7 +12,6 @@ const CustomPage = () => { const { plan, enableBilling } = useProviderContext() const showBillingTip = enableBilling && plan.type === Plan.sandbox - const showCustomAppHeaderBrand = enableBilling && plan.type === Plan.sandbox const showContact = enableBilling && (plan.type === Plan.professional || plan.type === Plan.team) return ( @@ -32,14 +30,6 @@ const CustomPage = () => { ) } - { - showCustomAppHeaderBrand && ( - <> -
- - - ) - } { showContact && (
diff --git a/web/app/components/datasets/common/check-rerank-model.ts b/web/app/components/datasets/common/check-rerank-model.ts index 42810e4bf03936..581c2bb69ac8b9 100644 --- a/web/app/components/datasets/common/check-rerank-model.ts +++ b/web/app/components/datasets/common/check-rerank-model.ts @@ -7,13 +7,13 @@ import { RerankingModeEnum } from '@/models/datasets' export const isReRankModelSelected = ({ rerankDefaultModel, - isRerankDefaultModelVaild, + isRerankDefaultModelValid, retrievalConfig, rerankModelList, indexMethod, }: { rerankDefaultModel?: DefaultModelResponse - isRerankDefaultModelVaild: boolean + isRerankDefaultModelValid: boolean retrievalConfig: RetrievalConfig rerankModelList: Model[] indexMethod?: string @@ -25,7 +25,7 @@ export const isReRankModelSelected = ({ return provider?.models.find(({ model }) => model === retrievalConfig.reranking_model?.reranking_model_name) } - if (isRerankDefaultModelVaild) + if (isRerankDefaultModelValid) return !!rerankDefaultModel return false diff --git a/web/app/components/datasets/common/retrieval-method-config/index.tsx b/web/app/components/datasets/common/retrieval-method-config/index.tsx index 1e407b62e16569..20d93568addbb7 100644 --- a/web/app/components/datasets/common/retrieval-method-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-method-config/index.tsx @@ -11,6 +11,11 @@ import { FileSearch02 } from '@/app/components/base/icons/src/vender/solid/files import { useProviderContext } from '@/context/provider-context' import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { + DEFAULT_WEIGHTED_SCORE, + RerankingModeEnum, + WeightedScoreEnum, +} from '@/models/datasets' type Props = { value: RetrievalConfig @@ -32,6 +37,18 @@ const RetrievalMethodConfig: FC = ({ reranking_provider_name: rerankDefaultModel?.provider.provider || '', reranking_model_name: rerankDefaultModel?.model || '', }, + reranking_mode: passValue.reranking_mode || (rerankDefaultModel ? RerankingModeEnum.RerankingModel : RerankingModeEnum.WeightedScore), + weights: passValue.weights || { + weight_type: WeightedScoreEnum.Customized, + vector_setting: { + vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic, + embedding_provider_name: '', + embedding_model_name: '', + }, + keyword_setting: { + keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword, + }, + }, } } return passValue diff --git a/web/app/components/datasets/common/retrieval-param-config/index.tsx b/web/app/components/datasets/common/retrieval-param-config/index.tsx index 98676f2e83888b..9d48d56a8dc511 100644 --- a/web/app/components/datasets/common/retrieval-param-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-param-config/index.tsx @@ -1,19 +1,17 @@ 'use client' import type { FC } from 'react' -import React from 'react' +import React, { useCallback } from 'react' import { useTranslation } from 'react-i18next' -import { - RiQuestionLine, -} from '@remixicon/react' + import cn from '@/utils/classnames' import TopKItem from '@/app/components/base/param-item/top-k-item' import ScoreThresholdItem from '@/app/components/base/param-item/score-threshold-item' import { RETRIEVE_METHOD } from '@/types/app' import Switch from '@/app/components/base/switch' -import Tooltip from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import type { RetrievalConfig } from '@/types/app' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' -import { useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { useCurrentProviderAndModel, useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { DEFAULT_WEIGHTED_SCORE, @@ -21,6 +19,7 @@ import { WeightedScoreEnum, } from '@/models/datasets' import WeightedScore from '@/app/components/app/configuration/dataset-config/params-config/weighted-score' +import Toast from '@/app/components/base/toast' type Props = { type: RETRIEVE_METHOD @@ -40,6 +39,24 @@ const RetrievalParamConfig: FC = ({ defaultModel: rerankDefaultModel, modelList: rerankModelList, } = useModelListAndDefaultModel(ModelTypeEnum.rerank) + + const { + currentModel, + } = useCurrentProviderAndModel( + rerankModelList, + rerankDefaultModel + ? { + ...rerankDefaultModel, + provider: rerankDefaultModel.provider.provider, + } + : undefined, + ) + + const handleDisabledSwitchClick = useCallback(() => { + if (!currentModel) + Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) + }, [currentModel, rerankDefaultModel, t]) + const isHybridSearch = type === RETRIEVE_METHOD.hybrid const rerankModel = (() => { @@ -101,22 +118,30 @@ const RetrievalParamConfig: FC = ({
{canToggleRerankModalEnable && ( - { - onChange({ - ...value, - reranking_enable: v, - }) - }} - /> +
+ { + onChange({ + ...value, + reranking_enable: v, + }) + }} + disabled={!currentModel} + /> +
)}
{t('common.modelProvider.rerankModel.key')} - {t('common.modelProvider.rerankModel.tip')}
}> - - + {t('common.modelProvider.rerankModel.tip')}
+ } + />
= ({
{option.label}
{option.tips}
} - hideArrow - > - -
+ triggerClassName='ml-0.5 w-3.5 h-3.5' + /> )) } diff --git a/web/app/components/datasets/create/assets/jina.png b/web/app/components/datasets/create/assets/jina.png new file mode 100644 index 00000000000000..b4beeafdfb1271 Binary files /dev/null and b/web/app/components/datasets/create/assets/jina.png differ diff --git a/web/app/components/datasets/create/assets/unknow.svg b/web/app/components/datasets/create/assets/unknown.svg similarity index 100% rename from web/app/components/datasets/create/assets/unknow.svg rename to web/app/components/datasets/create/assets/unknown.svg diff --git a/web/app/components/datasets/create/embedding-process/index.module.css b/web/app/components/datasets/create/embedding-process/index.module.css index a15b1310b4c856..1ebb006b543ac2 100644 --- a/web/app/components/datasets/create/embedding-process/index.module.css +++ b/web/app/components/datasets/create/embedding-process/index.module.css @@ -83,7 +83,7 @@ .fileIcon { @apply w-4 h-4 mr-1 bg-center bg-no-repeat; - background-image: url(../assets/unknow.svg); + background-image: url(../assets/unknown.svg); background-size: 16px; } .fileIcon.csv { diff --git a/web/app/components/datasets/create/embedding-process/index.tsx b/web/app/components/datasets/create/embedding-process/index.tsx index 1e340d692f290a..7786582085c16d 100644 --- a/web/app/components/datasets/create/embedding-process/index.tsx +++ b/web/app/components/datasets/create/embedding-process/index.tsx @@ -13,8 +13,7 @@ import cn from '@/utils/classnames' import { FieldInfo } from '@/app/components/datasets/documents/detail/metadata' import Button from '@/app/components/base/button' import type { FullDocumentDetail, IndexingStatusResponse, ProcessRuleResponse } from '@/models/datasets' -import { formatNumber } from '@/utils/format' -import { fetchIndexingStatusBatch as doFetchIndexingStatus, fetchIndexingEstimateBatch, fetchProcessRule } from '@/service/datasets' +import { fetchIndexingStatusBatch as doFetchIndexingStatus, fetchProcessRule } from '@/service/datasets' import { DataSourceType } from '@/models/datasets' import NotionIcon from '@/app/components/base/notion-icon' import PriorityLabel from '@/app/components/billing/priority-label' @@ -22,7 +21,7 @@ import { Plan } from '@/app/components/billing/type' import { ZapFast } from '@/app/components/base/icons/src/vender/solid/general' import UpgradeBtn from '@/app/components/billing/upgrade-btn' import { useProviderContext } from '@/context/provider-context' -import TooltipPlus from '@/app/components/base/tooltip-plus' +import Tooltip from '@/app/components/base/tooltip' import { sleep } from '@/utils' type Props = { @@ -142,14 +141,6 @@ const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], index }, apiParams => fetchProcessRule(omit(apiParams, 'action')), { revalidateOnFocus: false, }) - // get cost - const { data: indexingEstimateDetail } = useSWR({ - action: 'fetchIndexingEstimateBatch', - datasetId, - batchId, - }, apiParams => fetchIndexingEstimateBatch(omit(apiParams, 'action')), { - revalidateOnFocus: false, - }) const router = useRouter() const navToDocumentList = () => { @@ -190,28 +181,11 @@ const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], index return ( <> -
+
{isEmbedding && t('datasetDocuments.embedding.processing')} {isEmbeddingCompleted && t('datasetDocuments.embedding.completed')}
-
- {indexingType === 'high_quality' && ( -
-
- {t('datasetDocuments.embedding.highQuality')} · {t('datasetDocuments.embedding.estimate')} - {formatNumber(indexingEstimateDetail?.tokens || 0)}tokens - (${formatNumber(indexingEstimateDetail?.total_price || 0)}) -
- )} - {indexingType === 'economy' && ( -
-
- {t('datasetDocuments.embedding.economy')} · {t('datasetDocuments.embedding.estimate')} - 0tokens -
- )} -
{ enableBilling && plan.type !== Plan.team && ( @@ -259,16 +233,18 @@ const EmbeddingProcess: FC = ({ datasetId, batchId, documents = [], index
{`${getSourcePercent(indexingStatusDetail)}%`}
)} {indexingStatusDetail.indexing_status === 'error' && indexingStatusDetail.error && ( - - {indexingStatusDetail.error} -
- )}> + + {indexingStatusDetail.error} +
+ )} + >
Error
- + )} {indexingStatusDetail.indexing_status === 'error' && !indexingStatusDetail.error && (
diff --git a/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx b/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx index e9247c49dff132..7702a70d3f54ee 100644 --- a/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx +++ b/web/app/components/datasets/create/empty-dataset-creation-modal/index.tsx @@ -32,7 +32,7 @@ const EmptyDatasetCreationModal = ({ return } if (inputValue.length > 40) { - notify({ type: 'error', message: t('datasetCreation.stepOne.modal.nameLengthInvaild') }) + notify({ type: 'error', message: t('datasetCreation.stepOne.modal.nameLengthInvalid') }) return } try { @@ -58,7 +58,7 @@ const EmptyDatasetCreationModal = ({
{t('datasetCreation.stepOne.modal.tip')}
{t('datasetCreation.stepOne.modal.input')}
- + setInputValue(e.target.value)} />
diff --git a/web/app/components/datasets/create/file-uploader/index.module.css b/web/app/components/datasets/create/file-uploader/index.module.css index d141815c5a5785..bf5b7dcaf5b9b7 100644 --- a/web/app/components/datasets/create/file-uploader/index.module.css +++ b/web/app/components/datasets/create/file-uploader/index.module.css @@ -104,7 +104,7 @@ .fileIcon { @apply shrink-0 w-6 h-6 mr-2 bg-center bg-no-repeat; - background-image: url(../assets/unknow.svg); + background-image: url(../assets/unknown.svg); background-size: 24px; } diff --git a/web/app/components/datasets/create/index.tsx b/web/app/components/datasets/create/index.tsx index 12c6284d882c5b..98098445c7695c 100644 --- a/web/app/components/datasets/create/index.tsx +++ b/web/app/components/datasets/create/index.tsx @@ -11,7 +11,7 @@ import { DataSourceType } from '@/models/datasets' import type { CrawlOptions, CrawlResultItem, DataSet, FileItem, createDocumentResponse } from '@/models/datasets' import { fetchDataSource } from '@/service/common' import { fetchDatasetDetail } from '@/service/datasets' -import type { NotionPage } from '@/models/common' +import { DataSourceProvider, type NotionPage } from '@/models/common' import { useModalContext } from '@/context/modal-context' import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' @@ -26,6 +26,7 @@ const DEFAULT_CRAWL_OPTIONS: CrawlOptions = { excludes: '', limit: 10, max_depth: '', + use_sitemap: true, } const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => { @@ -51,7 +52,8 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => { const updateFileList = (preparedFiles: FileItem[]) => { setFiles(preparedFiles) } - const [fireCrawlJobId, setFireCrawlJobId] = useState('') + const [websiteCrawlProvider, setWebsiteCrawlProvider] = useState(DataSourceProvider.fireCrawl) + const [websiteCrawlJobId, setWebsiteCrawlJobId] = useState('') const updateFile = (fileItem: FileItem, progress: number, list: FileItem[]) => { const targetIndex = list.findIndex(file => file.fileID === fileItem.fileID) @@ -137,7 +139,8 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => { onStepChange={nextStep} websitePages={websitePages} updateWebsitePages={setWebsitePages} - onFireCrawlJobIdChange={setFireCrawlJobId} + onWebsiteCrawlProviderChange={setWebsiteCrawlProvider} + onWebsiteCrawlJobIdChange={setWebsiteCrawlJobId} crawlOptions={crawlOptions} onCrawlOptionsChange={setCrawlOptions} /> @@ -151,7 +154,8 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => { files={fileList.map(file => file.file)} notionPages={notionPages} websitePages={websitePages} - fireCrawlJobId={fireCrawlJobId} + websiteCrawlProvider={websiteCrawlProvider} + websiteCrawlJobId={websiteCrawlJobId} onStepChange={changeStep} updateIndexingTypeCache={updateIndexingTypeCache} updateResultCache={updateResultCache} diff --git a/web/app/components/datasets/create/step-one/index.tsx b/web/app/components/datasets/create/step-one/index.tsx index c2d77f4cecdcc8..643932e9ae21d5 100644 --- a/web/app/components/datasets/create/step-one/index.tsx +++ b/web/app/components/datasets/create/step-one/index.tsx @@ -10,7 +10,7 @@ import WebsitePreview from '../website/preview' import s from './index.module.css' import cn from '@/utils/classnames' import type { CrawlOptions, CrawlResultItem, FileItem } from '@/models/datasets' -import type { NotionPage } from '@/models/common' +import type { DataSourceProvider, NotionPage } from '@/models/common' import { DataSourceType } from '@/models/datasets' import Button from '@/app/components/base/button' import { NotionPageSelector } from '@/app/components/base/notion-page-selector' @@ -33,7 +33,8 @@ type IStepOneProps = { changeType: (type: DataSourceType) => void websitePages?: CrawlResultItem[] updateWebsitePages: (value: CrawlResultItem[]) => void - onFireCrawlJobIdChange: (jobId: string) => void + onWebsiteCrawlProviderChange: (provider: DataSourceProvider) => void + onWebsiteCrawlJobIdChange: (jobId: string) => void crawlOptions: CrawlOptions onCrawlOptionsChange: (payload: CrawlOptions) => void } @@ -69,7 +70,8 @@ const StepOne = ({ updateNotionPages, websitePages = [], updateWebsitePages, - onFireCrawlJobIdChange, + onWebsiteCrawlProviderChange, + onWebsiteCrawlJobIdChange, crawlOptions, onCrawlOptionsChange, }: IStepOneProps) => { @@ -229,7 +231,8 @@ const StepOne = ({ onPreview={setCurrentWebsite} checkedCrawlResult={websitePages} onCheckedCrawlResultChange={updateWebsitePages} - onJobIdChange={onFireCrawlJobIdChange} + onCrawlProviderChange={onWebsiteCrawlProviderChange} + onJobIdChange={onWebsiteCrawlJobIdChange} crawlOptions={crawlOptions} onCrawlOptionsChange={onCrawlOptionsChange} /> diff --git a/web/app/components/datasets/create/step-two/escape.ts b/web/app/components/datasets/create/step-two/escape.ts new file mode 100644 index 00000000000000..2e1c3a9d736463 --- /dev/null +++ b/web/app/components/datasets/create/step-two/escape.ts @@ -0,0 +1,18 @@ +function escape(input: string): string { + if (!input || typeof input !== 'string') + return '' + + const res = input + // .replaceAll('\\', '\\\\') // This would add too many backslashes + .replaceAll('\0', '\\0') + .replaceAll('\b', '\\b') + .replaceAll('\f', '\\f') + .replaceAll('\n', '\\n') + .replaceAll('\r', '\\r') + .replaceAll('\t', '\\t') + .replaceAll('\v', '\\v') + .replaceAll('\'', '\\\'') + return res +} + +export default escape diff --git a/web/app/components/datasets/create/step-two/index.module.css b/web/app/components/datasets/create/step-two/index.module.css index 24a62c8e3c9054..f89d6d67ea7088 100644 --- a/web/app/components/datasets/create/step-two/index.module.css +++ b/web/app/components/datasets/create/step-two/index.module.css @@ -30,7 +30,7 @@ } .indexItem { - min-height: 146px; + min-height: 126px; } .indexItem .disableMask { @@ -121,10 +121,6 @@ @apply pb-1; } -.radioItem.indexItem .typeHeader .tip { - @apply pb-3; -} - .radioItem .typeIcon { position: absolute; top: 18px; @@ -264,7 +260,7 @@ } .input { - @apply inline-flex h-9 w-full py-1 px-2 rounded-lg text-xs leading-normal; + @apply inline-flex h-9 w-full py-1 px-2 pr-14 rounded-lg text-xs leading-normal; @apply bg-gray-100 caret-primary-600 hover:bg-gray-100 focus:ring-1 focus:ring-inset focus:ring-gray-200 focus-visible:outline-none focus:bg-white placeholder:text-gray-400; } diff --git a/web/app/components/datasets/create/step-two/index.tsx b/web/app/components/datasets/create/step-two/index.tsx index 3849f817d613f2..634f031134f78f 100644 --- a/web/app/components/datasets/create/step-two/index.tsx +++ b/web/app/components/datasets/create/step-two/index.tsx @@ -1,5 +1,5 @@ 'use client' -import React, { useEffect, useLayoutEffect, useRef, useState } from 'react' +import React, { useCallback, useEffect, useLayoutEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import { useBoolean } from 'ahooks' @@ -7,16 +7,16 @@ import { XMarkIcon } from '@heroicons/react/20/solid' import { RocketLaunchIcon } from '@heroicons/react/24/outline' import { RiCloseLine, - RiQuestionLine, } from '@remixicon/react' import Link from 'next/link' import { groupBy } from 'lodash-es' -import RetrievalMethodInfo from '../../common/retrieval-method-info' import PreviewItem, { PreviewType } from './preview-item' import LanguageSelect from './language-select' import s from './index.module.css' +import unescape from './unescape' +import escape from './escape' import cn from '@/utils/classnames' -import type { CrawlOptions, CrawlResultItem, CreateDocumentReq, CustomFile, FileIndexingEstimateResponse, FullDocumentDetail, IndexingEstimateParams, IndexingEstimateResponse, NotionInfo, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets' +import type { CrawlOptions, CrawlResultItem, CreateDocumentReq, CustomFile, FileIndexingEstimateResponse, FullDocumentDetail, IndexingEstimateParams, NotionInfo, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets' import { createDocument, createFirstDocument, @@ -24,6 +24,7 @@ import { fetchDefaultProcessRule, } from '@/service/datasets' import Button from '@/app/components/base/button' +import Input from '@/app/components/base/input' import Loading from '@/app/components/base/loading' import FloatRightContainer from '@/app/components/base/float-right-container' import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config' @@ -33,6 +34,7 @@ import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/componen import Toast from '@/app/components/base/toast' import { formatNumber } from '@/utils/format' import type { NotionPage } from '@/models/common' +import { DataSourceProvider } from '@/models/common' import { DataSourceType, DocForm } from '@/models/datasets' import NotionIcon from '@/app/components/base/notion-icon' import Switch from '@/app/components/base/switch' @@ -43,9 +45,10 @@ import { IS_CE_EDITION } from '@/config' import { RETRIEVE_METHOD } from '@/types/app' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import Tooltip from '@/app/components/base/tooltip' -import TooltipPlus from '@/app/components/base/tooltip-plus' -import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { useDefaultModel, useModelList, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { LanguagesSupported } from '@/i18n/language' +import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' +import type { DefaultModel } from '@/app/components/header/account-setting/model-provider-page/declarations' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { Globe01 } from '@/app/components/base/icons/src/vender/line/mapsAndTravel' @@ -62,7 +65,8 @@ type StepTwoProps = { notionPages?: NotionPage[] websitePages?: CrawlResultItem[] crawlOptions?: CrawlOptions - fireCrawlJobId?: string + websiteCrawlProvider?: DataSourceProvider + websiteCrawlJobId?: string onStepChange?: (delta: number) => void updateIndexingTypeCache?: (type: string) => void updateResultCache?: (res: createDocumentResponse) => void @@ -79,6 +83,8 @@ enum IndexingType { ECONOMICAL = 'economy', } +const DEFAULT_SEGMENT_IDENTIFIER = '\\n\\n' + const StepTwo = ({ isSetting, documentDetail, @@ -91,7 +97,8 @@ const StepTwo = ({ notionPages = [], websitePages = [], crawlOptions, - fireCrawlJobId = '', + websiteCrawlProvider = DataSourceProvider.fireCrawl, + websiteCrawlJobId = '', onStepChange, updateIndexingTypeCache, updateResultCache, @@ -111,8 +118,11 @@ const StepTwo = ({ const previewScrollRef = useRef(null) const [previewScrolled, setPreviewScrolled] = useState(false) const [segmentationType, setSegmentationType] = useState(SegmentType.AUTO) - const [segmentIdentifier, setSegmentIdentifier] = useState('\\n') - const [max, setMax] = useState(500) + const [segmentIdentifier, doSetSegmentIdentifier] = useState(DEFAULT_SEGMENT_IDENTIFIER) + const setSegmentIdentifier = useCallback((value: string) => { + doSetSegmentIdentifier(value ? escape(value) : DEFAULT_SEGMENT_IDENTIFIER) + }, []) + const [max, setMax] = useState(4000) // default chunk length const [overlap, setOverlap] = useState(50) const [rules, setRules] = useState([]) const [defaultConfig, setDefaultConfig] = useState() @@ -123,16 +133,18 @@ const StepTwo = ({ ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL, ) + const [isLanguageSelectDisabled, setIsLanguageSelectDisabled] = useState(false) const [docForm, setDocForm] = useState( (datasetId && documentDetail) ? documentDetail.doc_form : DocForm.TEXT, ) - const [docLanguage, setDocLanguage] = useState(locale !== LanguagesSupported[1] ? 'English' : 'Chinese') + const [docLanguage, setDocLanguage] = useState( + (datasetId && documentDetail) ? documentDetail.doc_language : (locale !== LanguagesSupported[1] ? 'English' : 'Chinese'), + ) const [QATipHide, setQATipHide] = useState(false) const [previewSwitched, setPreviewSwitched] = useState(false) const [showPreview, { setTrue: setShowPreview, setFalse: hidePreview }] = useBoolean() const [customFileIndexingEstimate, setCustomFileIndexingEstimate] = useState(null) const [automaticFileIndexingEstimate, setAutomaticFileIndexingEstimate] = useState(null) - const [estimateTokes, setEstimateTokes] = useState | null>(null) const fileIndexingEstimate = (() => { return segmentationType === SegmentType.AUTO ? automaticFileIndexingEstimate : customFileIndexingEstimate @@ -183,26 +195,27 @@ const StepTwo = ({ } const resetRules = () => { if (defaultConfig) { - setSegmentIdentifier((defaultConfig.segmentation.separator === '\n' ? '\\n' : defaultConfig.segmentation.separator) || '\\n') + setSegmentIdentifier(defaultConfig.segmentation.separator) setMax(defaultConfig.segmentation.max_tokens) setOverlap(defaultConfig.segmentation.chunk_overlap) setRules(defaultConfig.pre_processing_rules) } } - const fetchFileIndexingEstimate = async (docForm = DocForm.TEXT) => { + const fetchFileIndexingEstimate = async (docForm = DocForm.TEXT, language?: string) => { // eslint-disable-next-line @typescript-eslint/no-use-before-define - const res = await didFetchFileIndexingEstimate(getFileIndexingEstimateParams(docForm)!) - if (segmentationType === SegmentType.CUSTOM) { + const res = await didFetchFileIndexingEstimate(getFileIndexingEstimateParams(docForm, language)!) + if (segmentationType === SegmentType.CUSTOM) setCustomFileIndexingEstimate(res) - } - else { + else setAutomaticFileIndexingEstimate(res) - indexType === IndexingType.QUALIFIED && setEstimateTokes({ tokens: res.tokens, total_price: res.total_price }) - } } const confirmChangeCustomConfig = () => { + if (segmentationType === SegmentType.CUSTOM && max > 4000) { + Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck') }) + return + } setCustomFileIndexingEstimate(null) setShowPreview() fetchFileIndexingEstimate() @@ -220,7 +233,7 @@ const StepTwo = ({ const ruleObj = { pre_processing_rules: rules, segmentation: { - separator: segmentIdentifier === '\\n' ? '\n' : segmentIdentifier, + separator: unescape(segmentIdentifier), max_tokens: max, chunk_overlap: overlap, }, @@ -256,14 +269,14 @@ const StepTwo = ({ const getWebsiteInfo = () => { return { - provider: 'firecrawl', - job_id: fireCrawlJobId, + provider: websiteCrawlProvider, + job_id: websiteCrawlJobId, urls: websitePages.map(page => page.source_url), only_main_content: crawlOptions?.only_main_content, } } - const getFileIndexingEstimateParams = (docForm: DocForm): IndexingEstimateParams | undefined => { + const getFileIndexingEstimateParams = (docForm: DocForm, language?: string): IndexingEstimateParams | undefined => { if (dataSourceType === DataSourceType.FILE) { return { info_list: { @@ -275,7 +288,7 @@ const StepTwo = ({ indexing_technique: getIndexing_technique() as string, process_rule: getProcessRule(), doc_form: docForm, - doc_language: docLanguage, + doc_language: language || docLanguage, dataset_id: datasetId as string, } } @@ -288,7 +301,7 @@ const StepTwo = ({ indexing_technique: getIndexing_technique() as string, process_rule: getProcessRule(), doc_form: docForm, - doc_language: docLanguage, + doc_language: language || docLanguage, dataset_id: datasetId as string, } } @@ -301,7 +314,7 @@ const StepTwo = ({ indexing_technique: getIndexing_technique() as string, process_rule: getProcessRule(), doc_form: docForm, - doc_language: docLanguage, + doc_language: language || docLanguage, dataset_id: datasetId as string, } } @@ -309,14 +322,31 @@ const StepTwo = ({ const { modelList: rerankModelList, defaultModel: rerankDefaultModel, - currentModel: isRerankDefaultModelVaild, + currentModel: isRerankDefaultModelValid, } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) + const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding) + const { data: defaultEmbeddingModel } = useDefaultModel(ModelTypeEnum.textEmbedding) + const [embeddingModel, setEmbeddingModel] = useState( + currentDataset?.embedding_model + ? { + provider: currentDataset.embedding_model_provider, + model: currentDataset.embedding_model, + } + : { + provider: defaultEmbeddingModel?.provider.provider || '', + model: defaultEmbeddingModel?.model || '', + }, + ) const getCreationParams = () => { let params if (segmentationType === SegmentType.CUSTOM && overlap > max) { Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.overlapCheck') }) return } + if (segmentationType === SegmentType.CUSTOM && max > 4000) { + Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck') }) + return + } if (isSetting) { params = { original_document_id: documentDetail?.id, @@ -325,6 +355,8 @@ const StepTwo = ({ process_rule: getProcessRule(), // eslint-disable-next-line @typescript-eslint/no-use-before-define retrieval_model: retrievalConfig, // Readonly. If want to changed, just go to settings page. + embedding_model: embeddingModel.model, // Readonly + embedding_model_provider: embeddingModel.provider, // Readonly } as CreateDocumentReq } else { // create @@ -332,7 +364,7 @@ const StepTwo = ({ if ( !isReRankModelSelected({ rerankDefaultModel, - isRerankDefaultModelVaild: !!isRerankDefaultModelVaild, + isRerankDefaultModelValid: !!isRerankDefaultModelValid, rerankModelList, // eslint-disable-next-line @typescript-eslint/no-use-before-define retrievalConfig, @@ -361,6 +393,8 @@ const StepTwo = ({ doc_language: docLanguage, retrieval_model: postRetrievalConfig, + embedding_model: embeddingModel.model, + embedding_model_provider: embeddingModel.provider, } as CreateDocumentReq if (dataSourceType === DataSourceType.FILE) { params.data_source.info_list.file_info_list = { @@ -380,7 +414,7 @@ const StepTwo = ({ try { const res = await fetchDefaultProcessRule({ url: '/datasets/process-rule' }) const separator = res.rules.segmentation.separator - setSegmentIdentifier((separator === '\n' ? '\\n' : separator) || '\\n') + setSegmentIdentifier(separator) setMax(res.rules.segmentation.max_tokens) setOverlap(res.rules.segmentation.chunk_overlap) setRules(res.rules.pre_processing_rules) @@ -397,7 +431,7 @@ const StepTwo = ({ const separator = rules.segmentation.separator const max = rules.segmentation.max_tokens const overlap = rules.segmentation.chunk_overlap - setSegmentIdentifier((separator === '\n' ? '\\n' : separator) || '\\n') + setSegmentIdentifier(separator) setMax(max) setOverlap(overlap) setRules(rules.pre_processing_rules) @@ -459,8 +493,26 @@ const StepTwo = ({ setDocForm(DocForm.TEXT) } + const previewSwitch = async (language?: string) => { + setPreviewSwitched(true) + setIsLanguageSelectDisabled(true) + if (segmentationType === SegmentType.AUTO) + setAutomaticFileIndexingEstimate(null) + else + setCustomFileIndexingEstimate(null) + try { + await fetchFileIndexingEstimate(DocForm.QA, language) + } + finally { + setIsLanguageSelectDisabled(false) + } + } + const handleSelect = (language: string) => { setDocLanguage(language) + // Switch language, re-cutter + if (docForm === DocForm.QA && previewSwitched) + previewSwitch(language) } const changeToEconomicalType = () => { @@ -470,15 +522,6 @@ const StepTwo = ({ } } - const previewSwitch = async () => { - setPreviewSwitched(true) - if (segmentationType === SegmentType.AUTO) - setAutomaticFileIndexingEstimate(null) - else - setCustomFileIndexingEstimate(null) - await fetchFileIndexingEstimate(DocForm.QA) - } - useEffect(() => { // fetch rules if (!isSetting) { @@ -551,12 +594,12 @@ const StepTwo = ({
{t('datasetCreation.steps.two')} - {isMobile && ( + {(isMobile || !showPreview) && (
@@ -768,14 +809,34 @@ const StepTwo = ({ )}
)} + {/* Embedding model */} + {indexType === IndexingType.QUALIFIED && ( +
+
{t('datasetSettings.form.embeddingModel')}
+ { + setEmbeddingModel(model) + }} + /> + {!!datasetId && ( +
+ {t('datasetCreation.stepTwo.indexSettingTip')} + {t('datasetCreation.stepTwo.datasetSettingLink')} +
+ )} +
+ )} {/* Retrieval Method Config */}
{!datasetId ? (
- {t('datasetSettings.form.retrievalSetting.title')} +
{t('datasetSettings.form.retrievalSetting.title')}
@@ -787,34 +848,21 @@ const StepTwo = ({ )}
- {!datasetId - ? (<> - {getIndexing_technique() === IndexingType.QUALIFIED - ? ( - - ) - : ( - - )} - ) - : ( -
- -
- {t('datasetCreation.stepTwo.retrivalSettedTip')} - {t('datasetCreation.stepTwo.datasetSettingLink')} -
-
- )} - + ) + : ( + + ) + }
@@ -875,7 +923,7 @@ const StepTwo = ({
-
{t('datasetCreation.stepTwo.emstimateSegment')}
+
{t('datasetCreation.stepTwo.estimateSegment')}
{ fileIndexingEstimate @@ -913,7 +961,7 @@ const StepTwo = ({
{t('datasetCreation.stepTwo.previewTitle')}
{docForm === DocForm.QA && !previewSwitched && ( - + )}
diff --git a/web/app/components/datasets/create/step-two/language-select/index.tsx b/web/app/components/datasets/create/step-two/language-select/index.tsx index f8709c89f3a6bb..41f3e0abb55b6e 100644 --- a/web/app/components/datasets/create/step-two/language-select/index.tsx +++ b/web/app/components/datasets/create/step-two/language-select/index.tsx @@ -9,19 +9,22 @@ import { languages } from '@/i18n/language' export type ILanguageSelectProps = { currentLanguage: string onSelect: (language: string) => void + disabled?: boolean } const LanguageSelect: FC = ({ currentLanguage, onSelect, + disabled, }) => { return ( - {languages.filter(language => language.supported).map(({ prompt_name, name }) => ( + {languages.filter(language => language.supported).map(({ prompt_name }) => (
= ({ const charNums = type === PreviewType.TEXT ? (content || '').length : (qa?.answer || '').length + (qa?.question || '').length - const formatedIndex = (() => String(index).padStart(3, '0'))() + const formattedIndex = (() => String(index).padStart(3, '0'))() return (
{sharpIcon} - {formatedIndex} + {formattedIndex}
{textIcon} diff --git a/web/app/components/datasets/create/step-two/unescape.ts b/web/app/components/datasets/create/step-two/unescape.ts new file mode 100644 index 00000000000000..5c0f9e426a2332 --- /dev/null +++ b/web/app/components/datasets/create/step-two/unescape.ts @@ -0,0 +1,54 @@ +// https://github.com/iamakulov/unescape-js/blob/master/src/index.js + +/** + * \\ - matches the backslash which indicates the beginning of an escape sequence + * ( + * u\{([0-9A-Fa-f]+)\} - first alternative; matches the variable-length hexadecimal escape sequence (\u{ABCD0}) + * | + * u([0-9A-Fa-f]{4}) - second alternative; matches the 4-digit hexadecimal escape sequence (\uABCD) + * | + * x([0-9A-Fa-f]{2}) - third alternative; matches the 2-digit hexadecimal escape sequence (\xA5) + * | + * ([1-7][0-7]{0,2}|[0-7]{2,3}) - fourth alternative; matches the up-to-3-digit octal escape sequence (\5 or \512) + * | + * (['"tbrnfv0\\]) - fifth alternative; matches the special escape characters (\t, \n and so on) + * | + * \U([0-9A-Fa-f]+) - sixth alternative; matches the 8-digit hexadecimal escape sequence used by python (\U0001F3B5) + * ) + */ +const jsEscapeRegex = /\\(u\{([0-9A-Fa-f]+)\}|u([0-9A-Fa-f]{4})|x([0-9A-Fa-f]{2})|([1-7][0-7]{0,2}|[0-7]{2,3})|(['"tbrnfv0\\]))|\\U([0-9A-Fa-f]{8})/g + +const usualEscapeSequences: Record = { + '0': '\0', + 'b': '\b', + 'f': '\f', + 'n': '\n', + 'r': '\r', + 't': '\t', + 'v': '\v', + '\'': '\'', + '"': '"', + '\\': '\\', +} + +const fromHex = (str: string) => String.fromCodePoint(parseInt(str, 16)) +const fromOct = (str: string) => String.fromCodePoint(parseInt(str, 8)) + +const unescape = (str: string) => { + return str.replace(jsEscapeRegex, (_, __, varHex, longHex, shortHex, octal, specialCharacter, python) => { + if (varHex !== undefined) + return fromHex(varHex) + else if (longHex !== undefined) + return fromHex(longHex) + else if (shortHex !== undefined) + return fromHex(shortHex) + else if (octal !== undefined) + return fromOct(octal) + else if (python !== undefined) + return fromHex(python) + else + return usualEscapeSequences[specialCharacter] + }) +} + +export default unescape diff --git a/web/app/components/datasets/create/steps-nav-bar/index.tsx b/web/app/components/datasets/create/steps-nav-bar/index.tsx index 70724a308c0669..b676f3ace4bf78 100644 --- a/web/app/components/datasets/create/steps-nav-bar/index.tsx +++ b/web/app/components/datasets/create/steps-nav-bar/index.tsx @@ -49,7 +49,7 @@ const StepsNavBar = ({ key={item} className={cn(s.stepItem, s[`step${item}`], step === item && s.active, step > item && s.done, isMobile && 'px-0')} > -
{item}
+
{step > item ? '' : item}
{isMobile ? '' : t(STEP_T_MAP[item])}
))} diff --git a/web/app/components/datasets/create/website/base/checkbox-with-label.tsx b/web/app/components/datasets/create/website/base/checkbox-with-label.tsx new file mode 100644 index 00000000000000..25d40fe0763dab --- /dev/null +++ b/web/app/components/datasets/create/website/base/checkbox-with-label.tsx @@ -0,0 +1,40 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import cn from '@/utils/classnames' +import Checkbox from '@/app/components/base/checkbox' +import Tooltip from '@/app/components/base/tooltip' + +type Props = { + className?: string + isChecked: boolean + onChange: (isChecked: boolean) => void + label: string + labelClassName?: string + tooltip?: string +} + +const CheckboxWithLabel: FC = ({ + className = '', + isChecked, + onChange, + label, + labelClassName, + tooltip, +}) => { + return ( +
+ } + triggerClassName='ml-0.5 w-4 h-4' + /> + )} + + ) +} +export default React.memo(CheckboxWithLabel) diff --git a/web/app/components/datasets/create/website/firecrawl/crawled-result-item.tsx b/web/app/components/datasets/create/website/base/crawled-result-item.tsx similarity index 100% rename from web/app/components/datasets/create/website/firecrawl/crawled-result-item.tsx rename to web/app/components/datasets/create/website/base/crawled-result-item.tsx diff --git a/web/app/components/datasets/create/website/base/crawled-result.tsx b/web/app/components/datasets/create/website/base/crawled-result.tsx new file mode 100644 index 00000000000000..d5c8d1b80a5a19 --- /dev/null +++ b/web/app/components/datasets/create/website/base/crawled-result.tsx @@ -0,0 +1,87 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import { useTranslation } from 'react-i18next' +import CheckboxWithLabel from './checkbox-with-label' +import CrawledResultItem from './crawled-result-item' +import cn from '@/utils/classnames' +import type { CrawlResultItem } from '@/models/datasets' + +const I18N_PREFIX = 'datasetCreation.stepOne.website' + +type Props = { + className?: string + list: CrawlResultItem[] + checkedList: CrawlResultItem[] + onSelectedChange: (selected: CrawlResultItem[]) => void + onPreview: (payload: CrawlResultItem) => void + usedTime: number +} + +const CrawledResult: FC = ({ + className = '', + list, + checkedList, + onSelectedChange, + onPreview, + usedTime, +}) => { + const { t } = useTranslation() + + const isCheckAll = checkedList.length === list.length + + const handleCheckedAll = useCallback(() => { + if (!isCheckAll) + onSelectedChange(list) + + else + onSelectedChange([]) + }, [isCheckAll, list, onSelectedChange]) + + const handleItemCheckChange = useCallback((item: CrawlResultItem) => { + return (checked: boolean) => { + if (checked) + onSelectedChange([...checkedList, item]) + + else + onSelectedChange(checkedList.filter(checkedItem => checkedItem.source_url !== item.source_url)) + } + }, [checkedList, onSelectedChange]) + + const [previewIndex, setPreviewIndex] = React.useState(-1) + const handlePreview = useCallback((index: number) => { + return () => { + setPreviewIndex(index) + onPreview(list[index]) + } + }, [list, onPreview]) + + return ( +
+
+ +
{t(`${I18N_PREFIX}.scrapTimeInfo`, { + total: list.length, + time: usedTime.toFixed(1), + })}
+
+
+ {list.map((item, index) => ( + checkedItem.source_url === item.source_url)} + onCheckChange={handleItemCheckChange(item)} + /> + ))} +
+
+ ) +} +export default React.memo(CrawledResult) diff --git a/web/app/components/datasets/create/website/firecrawl/crawling.tsx b/web/app/components/datasets/create/website/base/crawling.tsx similarity index 100% rename from web/app/components/datasets/create/website/firecrawl/crawling.tsx rename to web/app/components/datasets/create/website/base/crawling.tsx diff --git a/web/app/components/datasets/create/website/firecrawl/base/error-message.tsx b/web/app/components/datasets/create/website/base/error-message.tsx similarity index 100% rename from web/app/components/datasets/create/website/firecrawl/base/error-message.tsx rename to web/app/components/datasets/create/website/base/error-message.tsx diff --git a/web/app/components/datasets/create/website/base/field.tsx b/web/app/components/datasets/create/website/base/field.tsx new file mode 100644 index 00000000000000..5b5ca90c5dd31d --- /dev/null +++ b/web/app/components/datasets/create/website/base/field.tsx @@ -0,0 +1,54 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import Input from './input' +import cn from '@/utils/classnames' +import Tooltip from '@/app/components/base/tooltip' + +type Props = { + className?: string + label: string + labelClassName?: string + value: string | number + onChange: (value: string | number) => void + isRequired?: boolean + placeholder?: string + isNumber?: boolean + tooltip?: string +} + +const Field: FC = ({ + className, + label, + labelClassName, + value, + onChange, + isRequired = false, + placeholder = '', + isNumber = false, + tooltip, +}) => { + return ( +
+
+
{label}
+ {isRequired && *} + {tooltip && ( + {tooltip}
+ } + triggerClassName='ml-0.5 w-4 h-4' + /> + )} +
+ +
+ ) +} +export default React.memo(Field) diff --git a/web/app/components/datasets/create/website/base/input.tsx b/web/app/components/datasets/create/website/base/input.tsx new file mode 100644 index 00000000000000..7d2d2b609f0594 --- /dev/null +++ b/web/app/components/datasets/create/website/base/input.tsx @@ -0,0 +1,58 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' + +type Props = { + value: string | number + onChange: (value: string | number) => void + placeholder?: string + isNumber?: boolean +} + +const MIN_VALUE = 0 + +const Input: FC = ({ + value, + onChange, + placeholder = '', + isNumber = false, +}) => { + const handleChange = useCallback((e: React.ChangeEvent) => { + const value = e.target.value + if (isNumber) { + let numberValue = parseInt(value, 10) // integer only + if (isNaN(numberValue)) { + onChange('') + return + } + if (numberValue < MIN_VALUE) + numberValue = MIN_VALUE + + onChange(numberValue) + return + } + onChange(value) + }, [isNumber, onChange]) + + const otherOption = (() => { + if (isNumber) { + return { + min: MIN_VALUE, + } + } + return { + + } + })() + return ( + + ) +} +export default React.memo(Input) diff --git a/web/app/components/datasets/create/website/firecrawl/mock-crawl-result.ts b/web/app/components/datasets/create/website/base/mock-crawl-result.ts similarity index 100% rename from web/app/components/datasets/create/website/firecrawl/mock-crawl-result.ts rename to web/app/components/datasets/create/website/base/mock-crawl-result.ts diff --git a/web/app/components/datasets/create/website/firecrawl/base/options-wrap.tsx b/web/app/components/datasets/create/website/base/options-wrap.tsx similarity index 100% rename from web/app/components/datasets/create/website/firecrawl/base/options-wrap.tsx rename to web/app/components/datasets/create/website/base/options-wrap.tsx diff --git a/web/app/components/datasets/create/website/firecrawl/base/url-input.tsx b/web/app/components/datasets/create/website/base/url-input.tsx similarity index 100% rename from web/app/components/datasets/create/website/firecrawl/base/url-input.tsx rename to web/app/components/datasets/create/website/base/url-input.tsx diff --git a/web/app/components/datasets/create/website/firecrawl/base/checkbox-with-label.tsx b/web/app/components/datasets/create/website/firecrawl/base/checkbox-with-label.tsx deleted file mode 100644 index 5c574ebe3e6195..00000000000000 --- a/web/app/components/datasets/create/website/firecrawl/base/checkbox-with-label.tsx +++ /dev/null @@ -1,29 +0,0 @@ -'use client' -import type { FC } from 'react' -import React from 'react' -import cn from '@/utils/classnames' -import Checkbox from '@/app/components/base/checkbox' - -type Props = { - className?: string - isChecked: boolean - onChange: (isChecked: boolean) => void - label: string - labelClassName?: string -} - -const CheckboxWithLabel: FC = ({ - className = '', - isChecked, - onChange, - label, - labelClassName, -}) => { - return ( - - ) -} -export default React.memo(CheckboxWithLabel) diff --git a/web/app/components/datasets/create/website/firecrawl/base/field.tsx b/web/app/components/datasets/create/website/firecrawl/base/field.tsx deleted file mode 100644 index b1b7858d78e4c1..00000000000000 --- a/web/app/components/datasets/create/website/firecrawl/base/field.tsx +++ /dev/null @@ -1,56 +0,0 @@ -'use client' -import type { FC } from 'react' -import React from 'react' -import { - RiQuestionLine, -} from '@remixicon/react' -import Input from './input' -import cn from '@/utils/classnames' -import TooltipPlus from '@/app/components/base/tooltip-plus' - -type Props = { - className?: string - label: string - labelClassName?: string - value: string | number - onChange: (value: string | number) => void - isRequired?: boolean - placeholder?: string - isNumber?: boolean - tooltip?: string -} - -const Field: FC = ({ - className, - label, - labelClassName, - value, - onChange, - isRequired = false, - placeholder = '', - isNumber = false, - tooltip, -}) => { - return ( -
-
-
{label}
- {isRequired && *} - {tooltip && ( - {tooltip}
- }> - - - )} -
- -
- ) -} -export default React.memo(Field) diff --git a/web/app/components/datasets/create/website/firecrawl/base/input.tsx b/web/app/components/datasets/create/website/firecrawl/base/input.tsx deleted file mode 100644 index 06249f57e7469a..00000000000000 --- a/web/app/components/datasets/create/website/firecrawl/base/input.tsx +++ /dev/null @@ -1,58 +0,0 @@ -'use client' -import type { FC } from 'react' -import React, { useCallback } from 'react' - -type Props = { - value: string | number - onChange: (value: string | number) => void - placeholder?: string - isNumber?: boolean -} - -const MIN_VALUE = 1 - -const Input: FC = ({ - value, - onChange, - placeholder = '', - isNumber = false, -}) => { - const handleChange = useCallback((e: React.ChangeEvent) => { - const value = e.target.value - if (isNumber) { - let numberValue = parseInt(value, 10) // integer only - if (isNaN(numberValue)) { - onChange('') - return - } - if (numberValue < MIN_VALUE) - numberValue = MIN_VALUE - - onChange(numberValue) - return - } - onChange(value) - }, [isNumber, onChange]) - - const otherOption = (() => { - if (isNumber) { - return { - min: MIN_VALUE, - } - } - return { - - } - })() - return ( - - ) -} -export default React.memo(Input) diff --git a/web/app/components/datasets/create/website/firecrawl/index.tsx b/web/app/components/datasets/create/website/firecrawl/index.tsx index de4f8bb1293447..aa4dffc174315f 100644 --- a/web/app/components/datasets/create/website/firecrawl/index.tsx +++ b/web/app/components/datasets/create/website/firecrawl/index.tsx @@ -2,13 +2,13 @@ import type { FC } from 'react' import React, { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' +import UrlInput from '../base/url-input' +import OptionsWrap from '../base/options-wrap' +import CrawledResult from '../base/crawled-result' +import Crawling from '../base/crawling' +import ErrorMessage from '../base/error-message' import Header from './header' -import UrlInput from './base/url-input' -import OptionsWrap from './base/options-wrap' import Options from './options' -import CrawledResult from './crawled-result' -import Crawling from './crawling' -import ErrorMessage from './base/error-message' import cn from '@/utils/classnames' import { useModalContext } from '@/context/modal-context' import type { CrawlOptions, CrawlResultItem } from '@/models/datasets' diff --git a/web/app/components/datasets/create/website/firecrawl/options.tsx b/web/app/components/datasets/create/website/firecrawl/options.tsx index 20cc4f073fe43b..8cc2c6757c9615 100644 --- a/web/app/components/datasets/create/website/firecrawl/options.tsx +++ b/web/app/components/datasets/create/website/firecrawl/options.tsx @@ -2,8 +2,8 @@ import type { FC } from 'react' import React, { useCallback } from 'react' import { useTranslation } from 'react-i18next' -import CheckboxWithLabel from './base/checkbox-with-label' -import Field from './base/field' +import CheckboxWithLabel from '../base/checkbox-with-label' +import Field from '../base/field' import cn from '@/utils/classnames' import type { CrawlOptions } from '@/models/datasets' diff --git a/web/app/components/datasets/create/website/index.module.css b/web/app/components/datasets/create/website/index.module.css new file mode 100644 index 00000000000000..abaab4bea4b7a1 --- /dev/null +++ b/web/app/components/datasets/create/website/index.module.css @@ -0,0 +1,6 @@ +.jinaLogo { + @apply w-4 h-4 bg-center bg-no-repeat inline-block; + background-color: #F5FAFF; + background-image: url(../assets/jina.png); + background-size: 16px; +} diff --git a/web/app/components/datasets/create/website/index.tsx b/web/app/components/datasets/create/website/index.tsx index e06fbb4a1210b6..58b7f5f2fd77bd 100644 --- a/web/app/components/datasets/create/website/index.tsx +++ b/web/app/components/datasets/create/website/index.tsx @@ -1,8 +1,12 @@ 'use client' import type { FC } from 'react' import React, { useCallback, useEffect, useState } from 'react' +import { useTranslation } from 'react-i18next' +import s from './index.module.css' import NoData from './no-data' import Firecrawl from './firecrawl' +import JinaReader from './jina-reader' +import cn from '@/utils/classnames' import { useModalContext } from '@/context/modal-context' import type { CrawlOptions, CrawlResultItem } from '@/models/datasets' import { fetchDataSources } from '@/service/datasets' @@ -12,6 +16,7 @@ type Props = { onPreview: (payload: CrawlResultItem) => void checkedCrawlResult: CrawlResultItem[] onCheckedCrawlResultChange: (payload: CrawlResultItem[]) => void + onCrawlProviderChange: (provider: DataSourceProvider) => void onJobIdChange: (jobId: string) => void crawlOptions: CrawlOptions onCrawlOptionsChange: (payload: CrawlOptions) => void @@ -21,17 +26,32 @@ const Website: FC = ({ onPreview, checkedCrawlResult, onCheckedCrawlResultChange, + onCrawlProviderChange, onJobIdChange, crawlOptions, onCrawlOptionsChange, }) => { + const { t } = useTranslation() const { setShowAccountSettingModal } = useModalContext() const [isLoaded, setIsLoaded] = useState(false) - const [isSetFirecrawlApiKey, setIsSetFirecrawlApiKey] = useState(false) + const [selectedProvider, setSelectedProvider] = useState(DataSourceProvider.jinaReader) + const [sources, setSources] = useState([]) + + useEffect(() => { + onCrawlProviderChange(selectedProvider) + }, [selectedProvider, onCrawlProviderChange]) + const checkSetApiKey = useCallback(async () => { const res = await fetchDataSources() as any - const isFirecrawlSet = res.sources.some((item: DataSourceItem) => item.provider === DataSourceProvider.fireCrawl) - setIsSetFirecrawlApiKey(isFirecrawlSet) + setSources(res.sources) + + // If users have configured one of the providers, select it. + const availableProviders = res.sources.filter((item: DataSourceItem) => + [DataSourceProvider.jinaReader, DataSourceProvider.fireCrawl].includes(item.provider), + ) + + if (availableProviders.length > 0) + setSelectedProvider(availableProviders[0].provider) }, []) useEffect(() => { @@ -52,20 +72,66 @@ const Website: FC = ({ return (
- {isSetFirecrawlApiKey - ? ( - - ) - : ( - - )} +
+
+ {t('datasetCreation.stepOne.website.chooseProvider')} +
+
+ + +
+
+ + { + selectedProvider === DataSourceProvider.fireCrawl + ? sources.find(source => source.provider === DataSourceProvider.fireCrawl) + ? ( + + ) + : ( + + ) + : sources.find(source => source.provider === DataSourceProvider.jinaReader) + ? ( + + ) + : ( + + ) + }
) } diff --git a/web/app/components/datasets/create/website/jina-reader/base/checkbox-with-label.tsx b/web/app/components/datasets/create/website/jina-reader/base/checkbox-with-label.tsx new file mode 100644 index 00000000000000..25d40fe0763dab --- /dev/null +++ b/web/app/components/datasets/create/website/jina-reader/base/checkbox-with-label.tsx @@ -0,0 +1,40 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import cn from '@/utils/classnames' +import Checkbox from '@/app/components/base/checkbox' +import Tooltip from '@/app/components/base/tooltip' + +type Props = { + className?: string + isChecked: boolean + onChange: (isChecked: boolean) => void + label: string + labelClassName?: string + tooltip?: string +} + +const CheckboxWithLabel: FC = ({ + className = '', + isChecked, + onChange, + label, + labelClassName, + tooltip, +}) => { + return ( +
+ } + triggerClassName='ml-0.5 w-4 h-4' + /> + )} + + ) +} +export default React.memo(CheckboxWithLabel) diff --git a/web/app/components/datasets/create/website/jina-reader/base/error-message.tsx b/web/app/components/datasets/create/website/jina-reader/base/error-message.tsx new file mode 100644 index 00000000000000..aa337ec4bf5323 --- /dev/null +++ b/web/app/components/datasets/create/website/jina-reader/base/error-message.tsx @@ -0,0 +1,30 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import cn from '@/utils/classnames' +import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' + +type Props = { + className?: string + title: string + errorMsg?: string +} + +const ErrorMessage: FC = ({ + className, + title, + errorMsg, +}) => { + return ( +
+
+ +
{title}
+
+ {errorMsg && ( +
{errorMsg}
+ )} +
+ ) +} +export default React.memo(ErrorMessage) diff --git a/web/app/components/datasets/create/website/jina-reader/base/field.tsx b/web/app/components/datasets/create/website/jina-reader/base/field.tsx new file mode 100644 index 00000000000000..5b5ca90c5dd31d --- /dev/null +++ b/web/app/components/datasets/create/website/jina-reader/base/field.tsx @@ -0,0 +1,54 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import Input from './input' +import cn from '@/utils/classnames' +import Tooltip from '@/app/components/base/tooltip' + +type Props = { + className?: string + label: string + labelClassName?: string + value: string | number + onChange: (value: string | number) => void + isRequired?: boolean + placeholder?: string + isNumber?: boolean + tooltip?: string +} + +const Field: FC = ({ + className, + label, + labelClassName, + value, + onChange, + isRequired = false, + placeholder = '', + isNumber = false, + tooltip, +}) => { + return ( +
+
+
{label}
+ {isRequired && *} + {tooltip && ( + {tooltip}
+ } + triggerClassName='ml-0.5 w-4 h-4' + /> + )} +
+ +
+ ) +} +export default React.memo(Field) diff --git a/web/app/components/datasets/create/website/jina-reader/base/input.tsx b/web/app/components/datasets/create/website/jina-reader/base/input.tsx new file mode 100644 index 00000000000000..7d2d2b609f0594 --- /dev/null +++ b/web/app/components/datasets/create/website/jina-reader/base/input.tsx @@ -0,0 +1,58 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' + +type Props = { + value: string | number + onChange: (value: string | number) => void + placeholder?: string + isNumber?: boolean +} + +const MIN_VALUE = 0 + +const Input: FC = ({ + value, + onChange, + placeholder = '', + isNumber = false, +}) => { + const handleChange = useCallback((e: React.ChangeEvent) => { + const value = e.target.value + if (isNumber) { + let numberValue = parseInt(value, 10) // integer only + if (isNaN(numberValue)) { + onChange('') + return + } + if (numberValue < MIN_VALUE) + numberValue = MIN_VALUE + + onChange(numberValue) + return + } + onChange(value) + }, [isNumber, onChange]) + + const otherOption = (() => { + if (isNumber) { + return { + min: MIN_VALUE, + } + } + return { + + } + })() + return ( + + ) +} +export default React.memo(Input) diff --git a/web/app/components/datasets/create/website/jina-reader/base/options-wrap.tsx b/web/app/components/datasets/create/website/jina-reader/base/options-wrap.tsx new file mode 100644 index 00000000000000..652401a20f866b --- /dev/null +++ b/web/app/components/datasets/create/website/jina-reader/base/options-wrap.tsx @@ -0,0 +1,55 @@ +'use client' +import { useBoolean } from 'ahooks' +import type { FC } from 'react' +import React, { useEffect } from 'react' +import { useTranslation } from 'react-i18next' +import cn from '@/utils/classnames' +import { Settings04 } from '@/app/components/base/icons/src/vender/line/general' +import { ChevronRight } from '@/app/components/base/icons/src/vender/line/arrows' +const I18N_PREFIX = 'datasetCreation.stepOne.website' + +type Props = { + className?: string + children: React.ReactNode + controlFoldOptions?: number +} + +const OptionsWrap: FC = ({ + className = '', + children, + controlFoldOptions, +}) => { + const { t } = useTranslation() + + const [fold, { + toggle: foldToggle, + setTrue: foldHide, + }] = useBoolean(false) + + useEffect(() => { + if (controlFoldOptions) + foldHide() + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [controlFoldOptions]) + return ( +
+
+
+ +
{t(`${I18N_PREFIX}.options`)}
+
+ +
+ {!fold && ( +
+ {children} +
+ )} + +
+ ) +} +export default React.memo(OptionsWrap) diff --git a/web/app/components/datasets/create/website/jina-reader/base/url-input.tsx b/web/app/components/datasets/create/website/jina-reader/base/url-input.tsx new file mode 100644 index 00000000000000..e6b04758746e1f --- /dev/null +++ b/web/app/components/datasets/create/website/jina-reader/base/url-input.tsx @@ -0,0 +1,48 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback, useState } from 'react' +import { useTranslation } from 'react-i18next' +import Input from './input' +import Button from '@/app/components/base/button' + +const I18N_PREFIX = 'datasetCreation.stepOne.website' + +type Props = { + isRunning: boolean + onRun: (url: string) => void +} + +const UrlInput: FC = ({ + isRunning, + onRun, +}) => { + const { t } = useTranslation() + const [url, setUrl] = useState('') + const handleUrlChange = useCallback((url: string | number) => { + setUrl(url as string) + }, []) + const handleOnRun = useCallback(() => { + if (isRunning) + return + onRun(url) + }, [isRunning, onRun, url]) + + return ( +
+ + +
+ ) +} +export default React.memo(UrlInput) diff --git a/web/app/components/datasets/create/website/jina-reader/crawled-result-item.tsx b/web/app/components/datasets/create/website/jina-reader/crawled-result-item.tsx new file mode 100644 index 00000000000000..5531d3e140dc0f --- /dev/null +++ b/web/app/components/datasets/create/website/jina-reader/crawled-result-item.tsx @@ -0,0 +1,40 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import { useTranslation } from 'react-i18next' +import cn from '@/utils/classnames' +import type { CrawlResultItem as CrawlResultItemType } from '@/models/datasets' +import Checkbox from '@/app/components/base/checkbox' + +type Props = { + payload: CrawlResultItemType + isChecked: boolean + isPreview: boolean + onCheckChange: (checked: boolean) => void + onPreview: () => void +} + +const CrawledResultItem: FC = ({ + isPreview, + payload, + isChecked, + onCheckChange, + onPreview, +}) => { + const { t } = useTranslation() + + const handleCheckChange = useCallback(() => { + onCheckChange(!isChecked) + }, [isChecked, onCheckChange]) + return ( +
+
+ +
{payload.title}
+
{t('datasetCreation.stepOne.website.preview')}
+
+
{payload.source_url}
+
+ ) +} +export default React.memo(CrawledResultItem) diff --git a/web/app/components/datasets/create/website/firecrawl/crawled-result.tsx b/web/app/components/datasets/create/website/jina-reader/crawled-result.tsx similarity index 100% rename from web/app/components/datasets/create/website/firecrawl/crawled-result.tsx rename to web/app/components/datasets/create/website/jina-reader/crawled-result.tsx diff --git a/web/app/components/datasets/create/website/jina-reader/crawling.tsx b/web/app/components/datasets/create/website/jina-reader/crawling.tsx new file mode 100644 index 00000000000000..ee26e7671a4b26 --- /dev/null +++ b/web/app/components/datasets/create/website/jina-reader/crawling.tsx @@ -0,0 +1,37 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { useTranslation } from 'react-i18next' +import cn from '@/utils/classnames' +import { RowStruct } from '@/app/components/base/icons/src/public/other' + +type Props = { + className?: string + crawledNum: number + totalNum: number +} + +const Crawling: FC = ({ + className = '', + crawledNum, + totalNum, +}) => { + const { t } = useTranslation() + + return ( +
+
+ {t('datasetCreation.stepOne.website.totalPageScraped')} {crawledNum}/{totalNum} +
+ +
+ {['', '', '', ''].map((item, index) => ( +
+ +
+ ))} +
+
+ ) +} +export default React.memo(Crawling) diff --git a/web/app/components/datasets/create/website/jina-reader/header.tsx b/web/app/components/datasets/create/website/jina-reader/header.tsx new file mode 100644 index 00000000000000..85014a30ee2b12 --- /dev/null +++ b/web/app/components/datasets/create/website/jina-reader/header.tsx @@ -0,0 +1,42 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { useTranslation } from 'react-i18next' +import { Settings01 } from '@/app/components/base/icons/src/vender/line/general' +import { BookOpen01 } from '@/app/components/base/icons/src/vender/line/education' + +const I18N_PREFIX = 'datasetCreation.stepOne.website' + +type Props = { + onSetting: () => void +} + +const Header: FC = ({ + onSetting, +}) => { + const { t } = useTranslation() + + return ( +
+
+
{t(`${I18N_PREFIX}.jinaReaderTitle`)}
+
+
+ +
+
+ + + {t(`${I18N_PREFIX}.jinaReaderDoc`)} + +
+ ) +} +export default React.memo(Header) diff --git a/web/app/components/datasets/create/website/jina-reader/index.tsx b/web/app/components/datasets/create/website/jina-reader/index.tsx new file mode 100644 index 00000000000000..51d77d712140b7 --- /dev/null +++ b/web/app/components/datasets/create/website/jina-reader/index.tsx @@ -0,0 +1,232 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback, useEffect, useState } from 'react' +import { useTranslation } from 'react-i18next' +import UrlInput from '../base/url-input' +import OptionsWrap from '../base/options-wrap' +import CrawledResult from '../base/crawled-result' +import Crawling from '../base/crawling' +import ErrorMessage from '../base/error-message' +import Header from './header' +import Options from './options' +import cn from '@/utils/classnames' +import { useModalContext } from '@/context/modal-context' +import Toast from '@/app/components/base/toast' +import { checkJinaReaderTaskStatus, createJinaReaderTask } from '@/service/datasets' +import { sleep } from '@/utils' +import type { CrawlOptions, CrawlResultItem } from '@/models/datasets' + +const ERROR_I18N_PREFIX = 'common.errorMsg' +const I18N_PREFIX = 'datasetCreation.stepOne.website' + +type Props = { + onPreview: (payload: CrawlResultItem) => void + checkedCrawlResult: CrawlResultItem[] + onCheckedCrawlResultChange: (payload: CrawlResultItem[]) => void + onJobIdChange: (jobId: string) => void + crawlOptions: CrawlOptions + onCrawlOptionsChange: (payload: CrawlOptions) => void +} + +enum Step { + init = 'init', + running = 'running', + finished = 'finished', +} + +const JinaReader: FC = ({ + onPreview, + checkedCrawlResult, + onCheckedCrawlResultChange, + onJobIdChange, + crawlOptions, + onCrawlOptionsChange, +}) => { + const { t } = useTranslation() + const [step, setStep] = useState(Step.init) + const [controlFoldOptions, setControlFoldOptions] = useState(0) + useEffect(() => { + if (step !== Step.init) + setControlFoldOptions(Date.now()) + }, [step]) + const { setShowAccountSettingModal } = useModalContext() + const handleSetting = useCallback(() => { + setShowAccountSettingModal({ + payload: 'data-source', + }) + }, [setShowAccountSettingModal]) + + const checkValid = useCallback((url: string) => { + let errorMsg = '' + if (!url) { + errorMsg = t(`${ERROR_I18N_PREFIX}.fieldRequired`, { + field: 'url', + }) + } + + if (!errorMsg && !((url.startsWith('http://') || url.startsWith('https://')))) + errorMsg = t(`${ERROR_I18N_PREFIX}.urlError`) + + if (!errorMsg && (crawlOptions.limit === null || crawlOptions.limit === undefined || crawlOptions.limit === '')) { + errorMsg = t(`${ERROR_I18N_PREFIX}.fieldRequired`, { + field: t(`${I18N_PREFIX}.limit`), + }) + } + + return { + isValid: !errorMsg, + errorMsg, + } + }, [crawlOptions, t]) + + const isInit = step === Step.init + const isCrawlFinished = step === Step.finished + const isRunning = step === Step.running + const [crawlResult, setCrawlResult] = useState<{ + current: number + total: number + data: CrawlResultItem[] + time_consuming: number | string + } | undefined>(undefined) + const [crawlErrorMessage, setCrawlErrorMessage] = useState('') + const showError = isCrawlFinished && crawlErrorMessage + + const waitForCrawlFinished = useCallback(async (jobId: string) => { + try { + const res = await checkJinaReaderTaskStatus(jobId) as any + console.log('res', res) + if (res.status === 'completed') { + return { + isError: false, + data: { + ...res, + total: Math.min(res.total, parseFloat(crawlOptions.limit as string)), + }, + } + } + if (res.status === 'failed' || !res.status) { + return { + isError: true, + errorMessage: res.message, + data: { + data: [], + }, + } + } + // update the progress + setCrawlResult({ + ...res, + total: Math.min(res.total, parseFloat(crawlOptions.limit as string)), + }) + onCheckedCrawlResultChange(res.data || []) // default select the crawl result + await sleep(2500) + return await waitForCrawlFinished(jobId) + } + catch (e: any) { + const errorBody = await e.json() + return { + isError: true, + errorMessage: errorBody.message, + data: { + data: [], + }, + } + } + }, [crawlOptions.limit]) + + const handleRun = useCallback(async (url: string) => { + const { isValid, errorMsg } = checkValid(url) + if (!isValid) { + Toast.notify({ + message: errorMsg!, + type: 'error', + }) + return + } + setStep(Step.running) + try { + const startTime = Date.now() + const res = await createJinaReaderTask({ + url, + options: crawlOptions, + }) as any + + if (res.data) { + const data = { + current: 1, + total: 1, + data: [{ + title: res.data.title, + markdown: res.data.content, + description: res.data.description, + source_url: res.data.url, + }], + time_consuming: (Date.now() - startTime) / 1000, + } + setCrawlResult(data) + onCheckedCrawlResultChange(data.data || []) + setCrawlErrorMessage('') + } + else if (res.job_id) { + const jobId = res.job_id + onJobIdChange(jobId) + const { isError, data, errorMessage } = await waitForCrawlFinished(jobId) + if (isError) { + setCrawlErrorMessage(errorMessage || t(`${I18N_PREFIX}.unknownError`)) + } + else { + setCrawlResult(data) + onCheckedCrawlResultChange(data.data || []) // default select the crawl result + setCrawlErrorMessage('') + } + } + } + catch (e) { + setCrawlErrorMessage(t(`${I18N_PREFIX}.unknownError`)!) + console.log(e) + } + finally { + setStep(Step.finished) + } + }, [checkValid, crawlOptions, onJobIdChange, t, waitForCrawlFinished]) + + return ( +
+
+
+ + + + + + {!isInit && ( +
+ {isRunning + && } + {showError && ( + + )} + {isCrawlFinished && !showError + && + } +
+ )} +
+
+ ) +} +export default React.memo(JinaReader) diff --git a/web/app/components/datasets/create/website/jina-reader/mock-crawl-result.ts b/web/app/components/datasets/create/website/jina-reader/mock-crawl-result.ts new file mode 100644 index 00000000000000..8fd5e6636f1617 --- /dev/null +++ b/web/app/components/datasets/create/website/jina-reader/mock-crawl-result.ts @@ -0,0 +1,24 @@ +import type { CrawlResultItem } from '@/models/datasets' + +const result: CrawlResultItem[] = [ + { + title: 'Start the frontend Docker container separately', + markdown: 'Markdown 1', + description: 'Description 1', + source_url: 'https://example.com/1', + }, + { + title: 'Advanced Tool Integration', + markdown: 'Markdown 2', + description: 'Description 2', + source_url: 'https://example.com/2', + }, + { + title: 'Local Source Code Start | English | Dify', + markdown: 'Markdown 3', + description: 'Description 3', + source_url: 'https://example.com/3', + }, +] + +export default result diff --git a/web/app/components/datasets/create/website/jina-reader/options.tsx b/web/app/components/datasets/create/website/jina-reader/options.tsx new file mode 100644 index 00000000000000..52cfaa8b3b40f3 --- /dev/null +++ b/web/app/components/datasets/create/website/jina-reader/options.tsx @@ -0,0 +1,59 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import { useTranslation } from 'react-i18next' +import CheckboxWithLabel from '../base/checkbox-with-label' +import Field from '../base/field' +import cn from '@/utils/classnames' +import type { CrawlOptions } from '@/models/datasets' + +const I18N_PREFIX = 'datasetCreation.stepOne.website' + +type Props = { + className?: string + payload: CrawlOptions + onChange: (payload: CrawlOptions) => void +} + +const Options: FC = ({ + className = '', + payload, + onChange, +}) => { + const { t } = useTranslation() + + const handleChange = useCallback((key: keyof CrawlOptions) => { + return (value: any) => { + onChange({ + ...payload, + [key]: value, + }) + } + }, [payload, onChange]) + return ( +
+ + +
+ +
+
+ ) +} +export default React.memo(Options) diff --git a/web/app/components/datasets/create/website/no-data.tsx b/web/app/components/datasets/create/website/no-data.tsx index 13e5ee7dfbd508..8a508a48c6bb8e 100644 --- a/web/app/components/datasets/create/website/no-data.tsx +++ b/web/app/components/datasets/create/website/no-data.tsx @@ -2,35 +2,56 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' +import s from './index.module.css' import { Icon3Dots } from '@/app/components/base/icons/src/vender/line/others' import Button from '@/app/components/base/button' +import { DataSourceProvider } from '@/models/common' const I18N_PREFIX = 'datasetCreation.stepOne.website' type Props = { onConfig: () => void + provider: DataSourceProvider } const NoData: FC = ({ onConfig, + provider, }) => { const { t } = useTranslation() + const providerConfig = { + [DataSourceProvider.jinaReader]: { + emoji: , + title: t(`${I18N_PREFIX}.jinaReaderNotConfigured`), + description: t(`${I18N_PREFIX}.jinaReaderNotConfiguredDescription`), + }, + [DataSourceProvider.fireCrawl]: { + emoji: '🔥', + title: t(`${I18N_PREFIX}.fireCrawlNotConfigured`), + description: t(`${I18N_PREFIX}.fireCrawlNotConfiguredDescription`), + }, + } + + const currentProvider = providerConfig[provider] + return ( -
-
- 🔥 -
-
- {t(`${I18N_PREFIX}.fireCrawlNotConfigured`)} -
- {t(`${I18N_PREFIX}.fireCrawlNotConfiguredDescription`)} + <> +
+
+ {currentProvider.emoji} +
+
+ {currentProvider.title} +
+ {currentProvider.description} +
+
- -
+ ) } export default React.memo(NoData) diff --git a/web/app/components/datasets/documents/detail/completed/SegmentCard.tsx b/web/app/components/datasets/documents/detail/completed/SegmentCard.tsx index c65b244f6d315b..5b76acc9360c69 100644 --- a/web/app/components/datasets/documents/detail/completed/SegmentCard.tsx +++ b/web/app/components/datasets/documents/detail/completed/SegmentCard.tsx @@ -36,6 +36,12 @@ export type UsageScene = 'doc' | 'hitTesting' type ISegmentCardProps = { loading: boolean detail?: SegmentDetailModel & { document: { name: string } } + contentExternal?: string + refSource?: { + title: string + uri: string + } + isExternal?: boolean score?: number onClick?: () => void onChangeSwitch?: (segId: string, enabled: boolean) => Promise @@ -48,6 +54,9 @@ type ISegmentCardProps = { const SegmentCard: FC = ({ detail = {}, + contentExternal, + isExternal, + refSource, score, onClick, onChangeSwitch, @@ -88,6 +97,9 @@ const SegmentCard: FC = ({ ) } + if (contentExternal) + return contentExternal + return content } @@ -199,16 +211,16 @@ const SegmentCard: FC = ({
-
+
- {t('datasetHitTesting.viewChart')} + {isExternal ? t('datasetHitTesting.viewDetail') : t('datasetHitTesting.viewChart')}
diff --git a/web/app/components/datasets/documents/detail/completed/index.tsx b/web/app/components/datasets/documents/detail/completed/index.tsx index f2addac2e21dfc..2c9e6ca2ea7136 100644 --- a/web/app/components/datasets/documents/detail/completed/index.tsx +++ b/web/app/components/datasets/documents/detail/completed/index.tsx @@ -1,10 +1,11 @@ 'use client' import type { FC } from 'react' import React, { memo, useEffect, useMemo, useState } from 'react' +import { useDebounceFn } from 'ahooks' import { HashtagIcon } from '@heroicons/react/24/solid' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' -import { debounce, isNil, omitBy } from 'lodash-es' +import { isNil, omitBy } from 'lodash-es' import { RiCloseLine, RiEditLine, @@ -24,7 +25,7 @@ import { ToastContext } from '@/app/components/base/toast' import type { Item } from '@/app/components/base/select' import { SimpleSelect } from '@/app/components/base/select' import { deleteSegment, disableSegment, enableSegment, fetchSegments, updateSegment } from '@/service/datasets' -import type { SegmentDetailModel, SegmentUpdator, SegmentsQuery, SegmentsResponse } from '@/models/datasets' +import type { SegmentDetailModel, SegmentUpdater, SegmentsQuery, SegmentsResponse } from '@/models/datasets' import { asyncRunSafe } from '@/utils' import type { CommonResponse } from '@/models/common' import AutoHeightTextarea from '@/app/components/base/auto-height-textarea/common' @@ -241,7 +242,8 @@ const Completed: FC = ({ // the current segment id and whether to show the modal const [currSegment, setCurrSegment] = useState<{ segInfo?: SegmentDetailModel; showModal: boolean }>({ showModal: false }) - const [searchValue, setSearchValue] = useState() // the search value + const [inputValue, setInputValue] = useState('') // the input value + const [searchValue, setSearchValue] = useState('') // the search value const [selectedStatus, setSelectedStatus] = useState('all') // the selected status, enabled/disabled/undefined const [lastSegmentsRes, setLastSegmentsRes] = useState(undefined) @@ -250,6 +252,15 @@ const Completed: FC = ({ const [total, setTotal] = useState() const { eventEmitter } = useEventEmitterContextContext() + const { run: handleSearch } = useDebounceFn(() => { + setSearchValue(inputValue) + }, { wait: 500 }) + + const handleInputChange = (value: string) => { + setInputValue(value) + handleSearch() + } + const onChangeStatus = ({ value }: Item) => { setSelectedStatus(value === 'all' ? 'all' : !!value) } @@ -322,7 +333,7 @@ const Completed: FC = ({ } const handleUpdateSegment = async (segmentId: string, question: string, answer: string, keywords: string[]) => { - const params: SegmentUpdator = { content: '' } + const params: SegmentUpdater = { content: '' } if (docForm === 'qa_model') { if (!question.trim()) return notify({ type: 'error', message: t('datasetDocuments.segment.questionEmpty') }) @@ -391,7 +402,14 @@ const Completed: FC = ({ defaultValue={'all'} className={s.select} wrapperClassName='h-fit w-[120px] mr-2' /> - + handleInputChange(e.target.value)} + onClear={() => handleInputChange('')} + />
= (
} -const EmbeddingDetail: FC = ({ detail, stopPosition = 'top', datasetId: dstId, documentId: docId, indexingType, detailUpdate }) => { +const EmbeddingDetail: FC = ({ detail, stopPosition = 'top', datasetId: dstId, documentId: docId, detailUpdate }) => { const onTop = stopPosition === 'top' const { t } = useTranslation() const { notify } = useContext(ToastContext) const { datasetId = '', documentId = '' } = useContext(DocumentContext) - const { indexingTechnique } = useContext(DatasetDetailContext) const localDatasetId = dstId ?? datasetId const localDocumentId = docId ?? documentId - const localIndexingTechnique = indexingType ?? indexingTechnique const [indexingStatusDetail, setIndexingStatusDetail] = useState(null) const fetchIndexingStatus = async () => { @@ -160,14 +156,6 @@ const EmbeddingDetail: FC = ({ detail, stopPosition = 'top', datasetId: d } }, [startQueryStatus, stopQueryStatus]) - const { data: indexingEstimateDetail, error: indexingEstimateErr } = useSWR({ - action: 'fetchIndexingEstimate', - datasetId: localDatasetId, - documentId: localDocumentId, - }, apiParams => fetchIndexingEstimate(omit(apiParams, 'action')), { - revalidateOnFocus: false, - }) - const { data: ruleDetail, error: ruleError } = useSWR({ action: 'fetchProcessRule', params: { documentId: localDocumentId }, @@ -250,21 +238,6 @@ const EmbeddingDetail: FC = ({ detail, stopPosition = 'top', datasetId: d
{t('datasetDocuments.embedding.segments')} {indexingStatusDetail?.completed_segments}/{indexingStatusDetail?.total_segments} · {percent}%
- {localIndexingTechnique === 'high_quaility' && ( -
-
- {t('datasetDocuments.embedding.highQuality')} · {t('datasetDocuments.embedding.estimate')} - {formatNumber(indexingEstimateDetail?.tokens || 0)}tokens - (${formatNumber(indexingEstimateDetail?.total_price || 0)}) -
- )} - {localIndexingTechnique === 'economy' && ( -
-
- {t('datasetDocuments.embedding.economy')} · {t('datasetDocuments.embedding.estimate')} - 0tokens -
- )}
{!onTop && ( diff --git a/web/app/components/datasets/documents/detail/embedding/style.module.css b/web/app/components/datasets/documents/detail/embedding/style.module.css index 6dc1a5e80b84f9..c24444ac121446 100644 --- a/web/app/components/datasets/documents/detail/embedding/style.module.css +++ b/web/app/components/datasets/documents/detail/embedding/style.module.css @@ -31,7 +31,7 @@ @apply rounded-r-md; } .progressData { - @apply w-full flex justify-between items-center text-xs text-gray-700; + @apply w-full flex items-center text-xs text-gray-700; } .previewTip { @apply pb-1 pt-12 text-gray-900 text-sm font-medium; diff --git a/web/app/components/datasets/documents/detail/metadata/index.tsx b/web/app/components/datasets/documents/detail/metadata/index.tsx index 92109260227a08..9990ff7404f2a2 100644 --- a/web/app/components/datasets/documents/detail/metadata/index.tsx +++ b/web/app/components/datasets/documents/detail/metadata/index.tsx @@ -78,8 +78,7 @@ export const FieldInfo: FC = ({ placeholder={`${t('datasetDocuments.metadata.placeholder.add')}${label}`} /> : onUpdate?.(e.target.value)} value={value} defaultValue={defaultValue} placeholder={`${t('datasetDocuments.metadata.placeholder.add')}${label}`} @@ -102,7 +101,9 @@ const IconButton: FC<{ const metadataMap = useMetadataMap() return ( - +